添加会话上下文保存

This commit is contained in:
junleea 2025-03-23 13:17:03 +08:00
parent 08d134ac41
commit 142463a092
4 changed files with 28 additions and 14 deletions

View File

@ -7,8 +7,9 @@ import (
type Session struct {
gorm.Model
UserID int `gorm:"column:user_id"` //只能由用户创建
Name string `gorm:"column:name"`
UserID int `gorm:"column:user_id"` //只能由用户创建
Name string `gorm:"column:name"`
Context []int `gorm:"column:context"` //会话上下文
}
type Message struct {
@ -54,12 +55,12 @@ func FindSessionByUserID(userID int) []Session {
}
// 更新会话的名字
func UpdateSessionByID(id int, userId int, name string) error {
func UpdateSessionByID(id int, userId int, name string, context []int) error {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
} else {
res = DB.Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
}
return res.Error
}

View File

@ -176,6 +176,18 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
for {
select {
case <-done:
//结束会话将context保存
ctx2, ok := service.SessionContextMap[int(sessionID)]
if !ok {
log.Println("session context not found")
return
}
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
if err != nil {
log.Println("update session context error:", err)
}
//保存的会话
delete(service.SessionContextMap, int(sessionID))
//客户端断开连接
return
//case t := <-ticker.C:
@ -188,4 +200,5 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
//}
}
}
}

View File

@ -70,7 +70,7 @@ func UpdateSession(c *gin.Context) {
var req Session
if err := c.ShouldBind(&req); err == nil {
// 修改会话
err2 := service.UpdateSessionByID(req.ID, userID, req.Name)
err2 := service.UpdateSessionByID(req.ID, userID, req.Name, nil)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"})
} else {

View File

@ -43,8 +43,8 @@ func DeleteSessionByID(id, userID int) error {
}
// 更新会话名字
func UpdateSessionByID(id int, userId int, name string) error {
err := dao.UpdateSessionByID(id, userId, name)
func UpdateSessionByID(id int, userId int, name string, context []int) error {
err := dao.UpdateSessionByID(id, userId, name, context)
return err
}
@ -77,7 +77,7 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
}
var client *api.Client
var actxMap map[int][]int
var SessionContextMap map[int][]int
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
var resErr error
@ -92,7 +92,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
log.Println("get ollama client error:", err)
}
prompt := msg.Msg
actx, ok := actxMap[sessionID]
actx, ok := SessionContextMap[sessionID]
if ok {
log.Println("actx has get will use:", len(actx))
}
@ -111,11 +111,11 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
robotMsg += resp.Response
if resp.Done { //该消息完成
actx = resp.Context
if actxMap == nil {
actxMap = make(map[int][]int)
if SessionContextMap == nil {
SessionContextMap = make(map[int][]int)
}
actxMap[sessionID] = actx
log.Println("actx has save:", len(actxMap[sessionID]))
SessionContextMap[sessionID] = actx
log.Println("actx has save:", len(SessionContextMap[sessionID]))
//创建消息
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
if err2 != nil {