消息表添加会话ID索引、消息类型区分,豆包模型测试历史消息
This commit is contained in:
parent
c91dcaf57c
commit
370caca0ae
22
dao/im.go
22
dao/im.go
|
|
@ -16,12 +16,12 @@ type Session struct {
|
|||
|
||||
type Message struct {
|
||||
gorm.Model
|
||||
Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息
|
||||
SessionID int `gorm:"column:session_id"`
|
||||
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
|
||||
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断
|
||||
Msg string `gorm:"column:msg"` //消息内容
|
||||
Status int `gorm:"column:status"` //0为未读,1为已读
|
||||
Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息
|
||||
SessionID int `gorm:"column:session_id;index"` //添加索引
|
||||
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
|
||||
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断
|
||||
Msg string `gorm:"column:msg"` //消息内容
|
||||
Status int `gorm:"column:status"` //0为未读,1为已读
|
||||
}
|
||||
|
||||
func CreateSession(userID int, name string) (error, uint) {
|
||||
|
|
@ -106,6 +106,16 @@ func FindMessageBySessionID(sessionID int) []Message {
|
|||
return messages
|
||||
}
|
||||
|
||||
func FindMessageBySessionIDV2(sessionID int) []Message {
|
||||
var messages []Message
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
DB.Debug().Where("session_id = ?", sessionID).Find(&messages)
|
||||
} else {
|
||||
DB.Where("session_id = ?", sessionID).Find(&messages)
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
type SessionMessageCount struct {
|
||||
SessionID int `gorm:"column:session_id"`
|
||||
Count int `gorm:"column:count"`
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
|
||||
var userMsgID uint
|
||||
//将消息存入数据库
|
||||
err, userMsgID = service.CreateMessage(proto.UserAndModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead)
|
||||
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead)
|
||||
log.Println("create user message id:", userMsgID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -94,6 +94,11 @@ const (
|
|||
// 代码中使用常量定义
|
||||
UserAndModelMsgType = 2 // 用户与模型消息类型
|
||||
MsgHasRead = 1 // 消息已读
|
||||
|
||||
//用户发到模型
|
||||
UserToModelMsgType = 3
|
||||
//模型发到用户
|
||||
ModelToUserMsgType = 4
|
||||
)
|
||||
|
||||
// 豆包返回的数据停止原因
|
||||
|
|
|
|||
|
|
@ -23,6 +23,12 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
|
|||
arkruntime.WithBaseUrl(modelParam.Url),
|
||||
)
|
||||
ctx := context.Background()
|
||||
messages := make([]model.ChatCompletionMessage, 0)
|
||||
err := GetDouBaoSessionHistoryMsg(SessionID, "", &messages, question)
|
||||
if err != nil {
|
||||
log.Println("get doubao session history message error:", err)
|
||||
}
|
||||
log.Println("doubao messages:", messages)
|
||||
|
||||
req := model.CreateChatCompletionRequest{
|
||||
Model: modelParam.Model,
|
||||
|
|
@ -76,7 +82,7 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
|
|||
}
|
||||
}
|
||||
//将消息存入数据库
|
||||
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.UserAndModelMsgType, proto.MsgHasRead, answer)
|
||||
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer)
|
||||
if err != nil {
|
||||
log.Println("spark create message error:", err)
|
||||
}
|
||||
|
|
@ -108,3 +114,37 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo
|
|||
wsMsgStr, _ := json.Marshal(wsMsg)
|
||||
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
||||
}
|
||||
|
||||
func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]model.ChatCompletionMessage, prompt string) error {
|
||||
var err error
|
||||
//添加系统
|
||||
if systemPrompt != "" {
|
||||
var message model.ChatCompletionMessage
|
||||
message.Role = model.ChatMessageRoleSystem
|
||||
message.Content = &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(systemPrompt),
|
||||
}
|
||||
*messages = append(*messages, message)
|
||||
}
|
||||
//获取该会话的历史消息
|
||||
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
||||
for _, v := range msgs {
|
||||
var message model.ChatCompletionMessage
|
||||
if v.Type == proto.UserToModelMsgType {
|
||||
message.Role = model.ChatMessageRoleUser
|
||||
} else {
|
||||
message.Role = model.ChatMessageRoleAssistant
|
||||
}
|
||||
message.Content = &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(v.Msg),
|
||||
}
|
||||
*messages = append(*messages, message)
|
||||
}
|
||||
//添加本次请求消息
|
||||
var message model.ChatCompletionMessage
|
||||
message.Role = model.ChatMessageRoleUser
|
||||
message.Content = &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(prompt),
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
|
|||
} else {
|
||||
answer += data.Payload.Choices.Text[0].Content
|
||||
}
|
||||
go CreateAIStreamMsg(UserID, 4, SessionID, data, modelParam.Model)
|
||||
go CreateAIStreamMsg(UserID, modelID, SessionID, data, modelParam.Model)
|
||||
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue