From 5188f9c3c6bd47d74babb360bddd835a6f77a429 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Tue, 13 May 2025 13:08:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=9B=BD=E5=A4=96=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=99=A8=E4=B8=AD=E8=BD=ACai=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E6=95=B0=E6=8D=AE,=E6=B7=BB=E5=8A=A0gemini?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- proto/im.go | 14 +++++-- service/generalModel/openai.go | 67 ++++++++++++++++++++++++++++++++-- service/spark.go | 56 +++++++++++++++++++++++++++- 3 files changed, 127 insertions(+), 10 deletions(-) diff --git a/proto/im.go b/proto/im.go index 664213a..b32f6a6 100644 --- a/proto/im.go +++ b/proto/im.go @@ -120,10 +120,10 @@ type UserFileMessage struct { } type SendMessageForeignAIRRequest struct { - ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数 - IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数 - ChatCompletionMessageParamUnion []openai.ChatCompletionMessageParamUnion `json:"chat_completion_message_param_union" form:"chat_completion_message_param_union"` //消息内容 - PlatForm string `json:"platform" form:"platform"` //平台名称 + ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数 + IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数 + MessageContent []MessageContent `json:"message_content" form:"message_content"` //消息内容 + PlatForm string `json:"platform" form:"platform"` //平台名称 } type ReceiveMessageForeignAIRRequest struct { @@ -132,3 +132,9 @@ type ReceiveMessageForeignAIRRequest struct { PlatForm string `json:"platform" form:"platform"` //平台名称 ChatCompletionChunk openai.ChatCompletionChunk `json:"chat_completion_chunk" form:"chat_completion_chunk"` //消息内容 } + +type MessageContent struct { + Role string `json:"role"` + Content string `json:"content"` + ContentType string `json:"content_type,omitempty"` +} diff --git a/service/generalModel/openai.go b/service/generalModel/openai.go index 947e580..f79d46b 100644 --- a/service/generalModel/openai.go +++ b/service/generalModel/openai.go @@ -3,6 +3,7 @@ package generalModel import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" + "StuAcaWorksAI/service" "StuAcaWorksAI/worker" "context" "encoding/json" @@ -188,7 +189,7 @@ func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, } // 只进行与模型信息交互不进行保存,返回消息给国内服务器 -func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string, messages []openai.ChatCompletionMessageParamUnion) { +func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string, originMessages []proto.MessageContent) { apiKey := modelParam.APIKey if apiKey == "" { log.Printf("openai_%s apiKey is empty\n", platform) @@ -204,6 +205,20 @@ func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, p // log.Println("GetOpenAIHistoryMessages error:", err) // return //} + + var messages []openai.ChatCompletionMessageParamUnion + for _, msg := range originMessages { + switch msg.Role { + case "user": + messages = append(messages, openai.UserMessage(msg.Content)) + case "assistant": + messages = append(messages, openai.AssistantMessage(msg.Content)) + case "system": + messages = append(messages, openai.SystemMessage(msg.Content)) + default: + messages = append(messages, openai.UserMessage(msg.Content)) + } + } if proto.Config.LOG_OUTPUT { messagesBytes, _ := json.Marshal(messages) log.Println("openai messages:", string(messagesBytes)) @@ -315,13 +330,13 @@ func SendOpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContex req.IMParamContext = *imCtx req.ModelParam = modelParam req.PlatForm = platform - var messages []openai.ChatCompletionMessageParamUnion - err1 := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) + var messages []proto.MessageContent + err1 := GetGeneralSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) if err1 != nil { log.Println("GetOpenAIHistoryMessages error:", err1) return } - req.ChatCompletionMessageParamUnion = messages + req.MessageContent = messages reqBytes, err := json.Marshal(req) if err != nil { log.Println("send openai foreign ai request error:", err) @@ -354,3 +369,47 @@ func SendOpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContex return } } + +// 获取spark模型会话的历史消息 +func GetGeneralSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]proto.MessageContent, userID int) error { + //获取会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + if len(msgs) > proto.SparkContextLength { + //只取最后的 SparkContextLength 条消息 + msgs = msgs[len(msgs)-proto.SparkContextLength-1:] + } + //添加系统消息 + if systemPrompt != "" { + var message proto.MessageContent + message.Role = "system" + message.Content = systemPrompt + *messages = append(*messages, message) + } + for _, msg := range msgs { + var message proto.MessageContent + if msg.Type == proto.UserToModelMsgType { + message.Role = proto.SparkRoleUser + if msg.Status == proto.UserToModelFileMsgType { + err2 := service.DecodeFileMessageToGeneralMessage(msg.Msg, userID, messages) + if err2 != nil { + log.Println("spark decode file message error:", err2) + } + } else { + message.Content = msg.Msg + *messages = append(*messages, message) + } + } else if msg.Type == proto.ModelToUserMsgType { + message.Role = proto.SparkRoleAssistant + message.Content = msg.Msg + *messages = append(*messages, message) + } else { + continue + } + } + //添加本次请求消息(上面获取到的已包含本次请求消息) + //var message MessageContent + //message.Role = "user" + //message.Content = prompt + //*messages = append(*messages, message) + return nil +} diff --git a/service/spark.go b/service/spark.go index 0397b53..0a9488d 100644 --- a/service/spark.go +++ b/service/spark.go @@ -465,7 +465,7 @@ func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[ if msg.Type == proto.UserToModelMsgType { message.Role = proto.SparkRoleUser if msg.Status == proto.UserToModelFileMsgType { - err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages) + err2 := DecodeFileMessageToSparkMessage(msg.Msg, userID, messages) if err2 != nil { log.Println("spark decode file message error:", err2) } @@ -489,7 +489,7 @@ func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[ return nil } -func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error { +func DecodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error { var fileMsgContent proto.UserFileMessage err := json.Unmarshal([]byte(msg), &fileMsgContent) if err != nil { @@ -540,3 +540,55 @@ func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageC *message = append(*message, messageContent) return nil } + +func DecodeFileMessageToGeneralMessage(msg string, userID int, message *[]proto.MessageContent) error { + var fileMsgContent proto.UserFileMessage + err := json.Unmarshal([]byte(msg), &fileMsgContent) + if err != nil { + log.Println("general unmarshal file message error:", err) + return err + } + var isRemoveSystem bool + log.Println("general,decode file message:", fileMsgContent) + for _, file := range fileMsgContent.FileContent { + //文件消息为文本类型,则获取文本内容,若为图片则添加图片链接 + if file.FileType == proto.FileTypeText { + fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID) + if err2 != nil { + log.Println("doubao find file content error:", err2, "\t,file id:", file.FileContent.FileID, "\t,file name:", file.FileContent.UserFileName) + continue + } + fileContent := fileContents[0].FileContent + //文件文本内容 + var messageContent proto.MessageContent + messageContent.Role = "user" + messageContent.Content = fileContent + messageContent.ContentType = "text" + *message = append(*message, messageContent) + } else if file.FileType == proto.FileTypeImage { + imgBase64, err2 := FindImageFileContent(file.FileContent.FileID, userID) + if err2 != nil { + continue + } + //图片文件内容 + var messageContent proto.MessageContent + messageContent.Role = "user" + messageContent.Content = imgBase64 + messageContent.ContentType = "image" + *message = append(*message, messageContent) + if !isRemoveSystem { + //移除*message数组首个元素 + *message = (*message)[1:] + isRemoveSystem = true + } + } else { + log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID) + } + } + var messageContent proto.MessageContent + messageContent.Role = "user" + messageContent.Content = fileMsgContent.Text + messageContent.ContentType = "text" + *message = append(*message, messageContent) + return nil +}