添加会话上下文保存
This commit is contained in:
parent
08d134ac41
commit
142463a092
11
dao/im.go
11
dao/im.go
|
|
@ -7,8 +7,9 @@ import (
|
|||
|
||||
type Session struct {
|
||||
gorm.Model
|
||||
UserID int `gorm:"column:user_id"` //只能由用户创建
|
||||
Name string `gorm:"column:name"`
|
||||
UserID int `gorm:"column:user_id"` //只能由用户创建
|
||||
Name string `gorm:"column:name"`
|
||||
Context []int `gorm:"column:context"` //会话上下文
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
|
|
@ -54,12 +55,12 @@ func FindSessionByUserID(userID int) []Session {
|
|||
}
|
||||
|
||||
// 更新会话的名字
|
||||
func UpdateSessionByID(id int, userId int, name string) error {
|
||||
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
||||
var res *gorm.DB
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
res = DB.Debug().Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
|
||||
res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
|
||||
} else {
|
||||
res = DB.Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
|
||||
res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
|
||||
}
|
||||
return res.Error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -176,6 +176,18 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|||
for {
|
||||
select {
|
||||
case <-done:
|
||||
//结束会话将context保存
|
||||
ctx2, ok := service.SessionContextMap[int(sessionID)]
|
||||
if !ok {
|
||||
log.Println("session context not found")
|
||||
return
|
||||
}
|
||||
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
|
||||
if err != nil {
|
||||
log.Println("update session context error:", err)
|
||||
}
|
||||
//保存的会话
|
||||
delete(service.SessionContextMap, int(sessionID))
|
||||
//客户端断开连接
|
||||
return
|
||||
//case t := <-ticker.C:
|
||||
|
|
@ -188,4 +200,5 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|||
//}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ func UpdateSession(c *gin.Context) {
|
|||
var req Session
|
||||
if err := c.ShouldBind(&req); err == nil {
|
||||
// 修改会话
|
||||
err2 := service.UpdateSessionByID(req.ID, userID, req.Name)
|
||||
err2 := service.UpdateSessionByID(req.ID, userID, req.Name, nil)
|
||||
if err2 == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"})
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -43,8 +43,8 @@ func DeleteSessionByID(id, userID int) error {
|
|||
}
|
||||
|
||||
// 更新会话名字
|
||||
func UpdateSessionByID(id int, userId int, name string) error {
|
||||
err := dao.UpdateSessionByID(id, userId, name)
|
||||
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
||||
err := dao.UpdateSessionByID(id, userId, name, context)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
|
|||
}
|
||||
|
||||
var client *api.Client
|
||||
var actxMap map[int][]int
|
||||
var SessionContextMap map[int][]int
|
||||
|
||||
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
||||
var resErr error
|
||||
|
|
@ -92,7 +92,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
|||
log.Println("get ollama client error:", err)
|
||||
}
|
||||
prompt := msg.Msg
|
||||
actx, ok := actxMap[sessionID]
|
||||
actx, ok := SessionContextMap[sessionID]
|
||||
if ok {
|
||||
log.Println("actx has get will use:", len(actx))
|
||||
}
|
||||
|
|
@ -111,11 +111,11 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
|||
robotMsg += resp.Response
|
||||
if resp.Done { //该消息完成
|
||||
actx = resp.Context
|
||||
if actxMap == nil {
|
||||
actxMap = make(map[int][]int)
|
||||
if SessionContextMap == nil {
|
||||
SessionContextMap = make(map[int][]int)
|
||||
}
|
||||
actxMap[sessionID] = actx
|
||||
log.Println("actx has save:", len(actxMap[sessionID]))
|
||||
SessionContextMap[sessionID] = actx
|
||||
log.Println("actx has save:", len(SessionContextMap[sessionID]))
|
||||
//创建消息
|
||||
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
|
||||
if err2 != nil {
|
||||
|
|
|
|||
Loading…
Reference in New Issue