diff --git a/dao/im.go b/dao/im.go index 50bbb68..84d9b7e 100644 --- a/dao/im.go +++ b/dao/im.go @@ -16,12 +16,12 @@ type Session struct { type Message struct { gorm.Model - Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息 - SessionID int `gorm:"column:session_id"` - FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型 - ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断 - Msg string `gorm:"column:msg"` //消息内容 - Status int `gorm:"column:status"` //0为未读,1为已读 + Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息 + SessionID int `gorm:"column:session_id;index"` //添加索引 + FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型 + ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断 + Msg string `gorm:"column:msg"` //消息内容 + Status int `gorm:"column:status"` //0为未读,1为已读 } func CreateSession(userID int, name string) (error, uint) { @@ -106,6 +106,16 @@ func FindMessageBySessionID(sessionID int) []Message { return messages } +func FindMessageBySessionIDV2(sessionID int) []Message { + var messages []Message + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("session_id = ?", sessionID).Find(&messages) + } else { + DB.Where("session_id = ?", sessionID).Find(&messages) + } + return messages +} + type SessionMessageCount struct { SessionID int `gorm:"column:session_id"` Count int `gorm:"column:count"` diff --git a/handler/im.go b/handler/im.go index 10c2510..9d76a1c 100644 --- a/handler/im.go +++ b/handler/im.go @@ -211,7 +211,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR var userMsgID uint //将消息存入数据库 - err, userMsgID = service.CreateMessage(proto.UserAndModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead) + err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead) log.Println("create user message id:", userMsgID) if err != nil { return err diff --git a/proto/status.go b/proto/status.go index bc6341f..c87ef48 100644 --- a/proto/status.go +++ b/proto/status.go @@ -94,6 +94,11 @@ const ( // 代码中使用常量定义 UserAndModelMsgType = 2 // 用户与模型消息类型 MsgHasRead = 1 // 消息已读 + + //用户发到模型 + UserToModelMsgType = 3 + //模型发到用户 + ModelToUserMsgType = 4 ) // 豆包返回的数据停止原因 diff --git a/service/doubao.go b/service/doubao.go index 2b99ff0..5214aa5 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -23,6 +23,12 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us arkruntime.WithBaseUrl(modelParam.Url), ) ctx := context.Background() + messages := make([]model.ChatCompletionMessage, 0) + err := GetDouBaoSessionHistoryMsg(SessionID, "", &messages, question) + if err != nil { + log.Println("get doubao session history message error:", err) + } + log.Println("doubao messages:", messages) req := model.CreateChatCompletionRequest{ Model: modelParam.Model, @@ -76,7 +82,7 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us } } //将消息存入数据库 - err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.UserAndModelMsgType, proto.MsgHasRead, answer) + err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer) if err != nil { log.Println("spark create message error:", err) } @@ -108,3 +114,37 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo wsMsgStr, _ := json.Marshal(wsMsg) worker.Publish(channel, string(wsMsgStr), time.Second*60) } + +func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]model.ChatCompletionMessage, prompt string) error { + var err error + //添加系统 + if systemPrompt != "" { + var message model.ChatCompletionMessage + message.Role = model.ChatMessageRoleSystem + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(systemPrompt), + } + *messages = append(*messages, message) + } + //获取该会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + for _, v := range msgs { + var message model.ChatCompletionMessage + if v.Type == proto.UserToModelMsgType { + message.Role = model.ChatMessageRoleUser + } else { + message.Role = model.ChatMessageRoleAssistant + } + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(v.Msg), + } + *messages = append(*messages, message) + } + //添加本次请求消息 + var message model.ChatCompletionMessage + message.Role = model.ChatMessageRoleUser + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(prompt), + } + return err +} diff --git a/service/spark.go b/service/spark.go index 6b835da..cc4862b 100644 --- a/service/spark.go +++ b/service/spark.go @@ -164,7 +164,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use } else { answer += data.Payload.Choices.Text[0].Content } - go CreateAIStreamMsg(UserID, 4, SessionID, data, modelParam.Model) + go CreateAIStreamMsg(UserID, modelID, SessionID, data, modelParam.Model) SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID) }