添加会话上下文保存
This commit is contained in:
parent
08d134ac41
commit
142463a092
11
dao/im.go
11
dao/im.go
|
|
@ -7,8 +7,9 @@ import (
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
UserID int `gorm:"column:user_id"` //只能由用户创建
|
UserID int `gorm:"column:user_id"` //只能由用户创建
|
||||||
Name string `gorm:"column:name"`
|
Name string `gorm:"column:name"`
|
||||||
|
Context []int `gorm:"column:context"` //会话上下文
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
|
|
@ -54,12 +55,12 @@ func FindSessionByUserID(userID int) []Session {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新会话的名字
|
// 更新会话的名字
|
||||||
func UpdateSessionByID(id int, userId int, name string) error {
|
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
||||||
var res *gorm.DB
|
var res *gorm.DB
|
||||||
if proto.Config.SERVER_SQL_LOG {
|
if proto.Config.SERVER_SQL_LOG {
|
||||||
res = DB.Debug().Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
|
res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
|
||||||
} else {
|
} else {
|
||||||
res = DB.Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
|
res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
|
||||||
}
|
}
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -176,6 +176,18 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
|
//结束会话将context保存
|
||||||
|
ctx2, ok := service.SessionContextMap[int(sessionID)]
|
||||||
|
if !ok {
|
||||||
|
log.Println("session context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("update session context error:", err)
|
||||||
|
}
|
||||||
|
//保存的会话
|
||||||
|
delete(service.SessionContextMap, int(sessionID))
|
||||||
//客户端断开连接
|
//客户端断开连接
|
||||||
return
|
return
|
||||||
//case t := <-ticker.C:
|
//case t := <-ticker.C:
|
||||||
|
|
@ -188,4 +200,5 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
||||||
//}
|
//}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ func UpdateSession(c *gin.Context) {
|
||||||
var req Session
|
var req Session
|
||||||
if err := c.ShouldBind(&req); err == nil {
|
if err := c.ShouldBind(&req); err == nil {
|
||||||
// 修改会话
|
// 修改会话
|
||||||
err2 := service.UpdateSessionByID(req.ID, userID, req.Name)
|
err2 := service.UpdateSessionByID(req.ID, userID, req.Name, nil)
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"})
|
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"})
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -43,8 +43,8 @@ func DeleteSessionByID(id, userID int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新会话名字
|
// 更新会话名字
|
||||||
func UpdateSessionByID(id int, userId int, name string) error {
|
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
||||||
err := dao.UpdateSessionByID(id, userId, name)
|
err := dao.UpdateSessionByID(id, userId, name, context)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -77,7 +77,7 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
|
||||||
}
|
}
|
||||||
|
|
||||||
var client *api.Client
|
var client *api.Client
|
||||||
var actxMap map[int][]int
|
var SessionContextMap map[int][]int
|
||||||
|
|
||||||
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
||||||
var resErr error
|
var resErr error
|
||||||
|
|
@ -92,7 +92,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
||||||
log.Println("get ollama client error:", err)
|
log.Println("get ollama client error:", err)
|
||||||
}
|
}
|
||||||
prompt := msg.Msg
|
prompt := msg.Msg
|
||||||
actx, ok := actxMap[sessionID]
|
actx, ok := SessionContextMap[sessionID]
|
||||||
if ok {
|
if ok {
|
||||||
log.Println("actx has get will use:", len(actx))
|
log.Println("actx has get will use:", len(actx))
|
||||||
}
|
}
|
||||||
|
|
@ -111,11 +111,11 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
||||||
robotMsg += resp.Response
|
robotMsg += resp.Response
|
||||||
if resp.Done { //该消息完成
|
if resp.Done { //该消息完成
|
||||||
actx = resp.Context
|
actx = resp.Context
|
||||||
if actxMap == nil {
|
if SessionContextMap == nil {
|
||||||
actxMap = make(map[int][]int)
|
SessionContextMap = make(map[int][]int)
|
||||||
}
|
}
|
||||||
actxMap[sessionID] = actx
|
SessionContextMap[sessionID] = actx
|
||||||
log.Println("actx has save:", len(actxMap[sessionID]))
|
log.Println("actx has save:", len(SessionContextMap[sessionID]))
|
||||||
//创建消息
|
//创建消息
|
||||||
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
|
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue