消息表添加会话ID索引、消息类型区分,豆包模型测试历史消息
This commit is contained in:
parent
c91dcaf57c
commit
370caca0ae
22
dao/im.go
22
dao/im.go
|
|
@ -16,12 +16,12 @@ type Session struct {
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息
|
Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息
|
||||||
SessionID int `gorm:"column:session_id"`
|
SessionID int `gorm:"column:session_id;index"` //添加索引
|
||||||
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
|
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
|
||||||
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断
|
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断
|
||||||
Msg string `gorm:"column:msg"` //消息内容
|
Msg string `gorm:"column:msg"` //消息内容
|
||||||
Status int `gorm:"column:status"` //0为未读,1为已读
|
Status int `gorm:"column:status"` //0为未读,1为已读
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateSession(userID int, name string) (error, uint) {
|
func CreateSession(userID int, name string) (error, uint) {
|
||||||
|
|
@ -106,6 +106,16 @@ func FindMessageBySessionID(sessionID int) []Message {
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FindMessageBySessionIDV2(sessionID int) []Message {
|
||||||
|
var messages []Message
|
||||||
|
if proto.Config.SERVER_SQL_LOG {
|
||||||
|
DB.Debug().Where("session_id = ?", sessionID).Find(&messages)
|
||||||
|
} else {
|
||||||
|
DB.Where("session_id = ?", sessionID).Find(&messages)
|
||||||
|
}
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
type SessionMessageCount struct {
|
type SessionMessageCount struct {
|
||||||
SessionID int `gorm:"column:session_id"`
|
SessionID int `gorm:"column:session_id"`
|
||||||
Count int `gorm:"column:count"`
|
Count int `gorm:"column:count"`
|
||||||
|
|
|
||||||
|
|
@ -211,7 +211,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
||||||
|
|
||||||
var userMsgID uint
|
var userMsgID uint
|
||||||
//将消息存入数据库
|
//将消息存入数据库
|
||||||
err, userMsgID = service.CreateMessage(proto.UserAndModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead)
|
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead)
|
||||||
log.Println("create user message id:", userMsgID)
|
log.Println("create user message id:", userMsgID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -94,6 +94,11 @@ const (
|
||||||
// 代码中使用常量定义
|
// 代码中使用常量定义
|
||||||
UserAndModelMsgType = 2 // 用户与模型消息类型
|
UserAndModelMsgType = 2 // 用户与模型消息类型
|
||||||
MsgHasRead = 1 // 消息已读
|
MsgHasRead = 1 // 消息已读
|
||||||
|
|
||||||
|
//用户发到模型
|
||||||
|
UserToModelMsgType = 3
|
||||||
|
//模型发到用户
|
||||||
|
ModelToUserMsgType = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
// 豆包返回的数据停止原因
|
// 豆包返回的数据停止原因
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,12 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
|
||||||
arkruntime.WithBaseUrl(modelParam.Url),
|
arkruntime.WithBaseUrl(modelParam.Url),
|
||||||
)
|
)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
messages := make([]model.ChatCompletionMessage, 0)
|
||||||
|
err := GetDouBaoSessionHistoryMsg(SessionID, "", &messages, question)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("get doubao session history message error:", err)
|
||||||
|
}
|
||||||
|
log.Println("doubao messages:", messages)
|
||||||
|
|
||||||
req := model.CreateChatCompletionRequest{
|
req := model.CreateChatCompletionRequest{
|
||||||
Model: modelParam.Model,
|
Model: modelParam.Model,
|
||||||
|
|
@ -76,7 +82,7 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//将消息存入数据库
|
//将消息存入数据库
|
||||||
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.UserAndModelMsgType, proto.MsgHasRead, answer)
|
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("spark create message error:", err)
|
log.Println("spark create message error:", err)
|
||||||
}
|
}
|
||||||
|
|
@ -108,3 +114,37 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo
|
||||||
wsMsgStr, _ := json.Marshal(wsMsg)
|
wsMsgStr, _ := json.Marshal(wsMsg)
|
||||||
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]model.ChatCompletionMessage, prompt string) error {
|
||||||
|
var err error
|
||||||
|
//添加系统
|
||||||
|
if systemPrompt != "" {
|
||||||
|
var message model.ChatCompletionMessage
|
||||||
|
message.Role = model.ChatMessageRoleSystem
|
||||||
|
message.Content = &model.ChatCompletionMessageContent{
|
||||||
|
StringValue: volcengine.String(systemPrompt),
|
||||||
|
}
|
||||||
|
*messages = append(*messages, message)
|
||||||
|
}
|
||||||
|
//获取该会话的历史消息
|
||||||
|
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
||||||
|
for _, v := range msgs {
|
||||||
|
var message model.ChatCompletionMessage
|
||||||
|
if v.Type == proto.UserToModelMsgType {
|
||||||
|
message.Role = model.ChatMessageRoleUser
|
||||||
|
} else {
|
||||||
|
message.Role = model.ChatMessageRoleAssistant
|
||||||
|
}
|
||||||
|
message.Content = &model.ChatCompletionMessageContent{
|
||||||
|
StringValue: volcengine.String(v.Msg),
|
||||||
|
}
|
||||||
|
*messages = append(*messages, message)
|
||||||
|
}
|
||||||
|
//添加本次请求消息
|
||||||
|
var message model.ChatCompletionMessage
|
||||||
|
message.Role = model.ChatMessageRoleUser
|
||||||
|
message.Content = &model.ChatCompletionMessageContent{
|
||||||
|
StringValue: volcengine.String(prompt),
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
|
||||||
} else {
|
} else {
|
||||||
answer += data.Payload.Choices.Text[0].Content
|
answer += data.Payload.Choices.Text[0].Content
|
||||||
}
|
}
|
||||||
go CreateAIStreamMsg(UserID, 4, SessionID, data, modelParam.Model)
|
go CreateAIStreamMsg(UserID, modelID, SessionID, data, modelParam.Model)
|
||||||
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
|
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue