添加会话上下文保存

This commit is contained in:
junleea 2025-03-23 13:17:03 +08:00
parent 08d134ac41
commit 142463a092
4 changed files with 28 additions and 14 deletions

View File

@ -7,8 +7,9 @@ import (
type Session struct { type Session struct {
gorm.Model gorm.Model
UserID int `gorm:"column:user_id"` //只能由用户创建 UserID int `gorm:"column:user_id"` //只能由用户创建
Name string `gorm:"column:name"` Name string `gorm:"column:name"`
Context []int `gorm:"column:context"` //会话上下文
} }
type Message struct { type Message struct {
@ -54,12 +55,12 @@ func FindSessionByUserID(userID int) []Session {
} }
// 更新会话的名字 // 更新会话的名字
func UpdateSessionByID(id int, userId int, name string) error { func UpdateSessionByID(id int, userId int, name string, context []int) error {
var res *gorm.DB var res *gorm.DB
if proto.Config.SERVER_SQL_LOG { if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name) res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
} else { } else {
res = DB.Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name) res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
} }
return res.Error return res.Error
} }

View File

@ -176,6 +176,18 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
for { for {
select { select {
case <-done: case <-done:
//结束会话将context保存
ctx2, ok := service.SessionContextMap[int(sessionID)]
if !ok {
log.Println("session context not found")
return
}
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
if err != nil {
log.Println("update session context error:", err)
}
//保存的会话
delete(service.SessionContextMap, int(sessionID))
//客户端断开连接 //客户端断开连接
return return
//case t := <-ticker.C: //case t := <-ticker.C:
@ -188,4 +200,5 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
//} //}
} }
} }
} }

View File

@ -70,7 +70,7 @@ func UpdateSession(c *gin.Context) {
var req Session var req Session
if err := c.ShouldBind(&req); err == nil { if err := c.ShouldBind(&req); err == nil {
// 修改会话 // 修改会话
err2 := service.UpdateSessionByID(req.ID, userID, req.Name) err2 := service.UpdateSessionByID(req.ID, userID, req.Name, nil)
if err2 == nil { if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"}) c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"})
} else { } else {

View File

@ -43,8 +43,8 @@ func DeleteSessionByID(id, userID int) error {
} }
// 更新会话名字 // 更新会话名字
func UpdateSessionByID(id int, userId int, name string) error { func UpdateSessionByID(id int, userId int, name string, context []int) error {
err := dao.UpdateSessionByID(id, userId, name) err := dao.UpdateSessionByID(id, userId, name, context)
return err return err
} }
@ -77,7 +77,7 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
} }
var client *api.Client var client *api.Client
var actxMap map[int][]int var SessionContextMap map[int][]int
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
var resErr error var resErr error
@ -92,7 +92,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
log.Println("get ollama client error:", err) log.Println("get ollama client error:", err)
} }
prompt := msg.Msg prompt := msg.Msg
actx, ok := actxMap[sessionID] actx, ok := SessionContextMap[sessionID]
if ok { if ok {
log.Println("actx has get will use:", len(actx)) log.Println("actx has get will use:", len(actx))
} }
@ -111,11 +111,11 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
robotMsg += resp.Response robotMsg += resp.Response
if resp.Done { //该消息完成 if resp.Done { //该消息完成
actx = resp.Context actx = resp.Context
if actxMap == nil { if SessionContextMap == nil {
actxMap = make(map[int][]int) SessionContextMap = make(map[int][]int)
} }
actxMap[sessionID] = actx SessionContextMap[sessionID] = actx
log.Println("actx has save:", len(actxMap[sessionID])) log.Println("actx has save:", len(SessionContextMap[sessionID]))
//创建消息 //创建消息
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
if err2 != nil { if err2 != nil {