添加国外服务器中转ai请求数据,添加gemini

This commit is contained in:
junleea 2025-05-13 13:08:23 +08:00
parent 4a7c7efe19
commit 5188f9c3c6
3 changed files with 127 additions and 10 deletions

View File

@ -120,10 +120,10 @@ type UserFileMessage struct {
}
type SendMessageForeignAIRRequest struct {
ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数
IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数
ChatCompletionMessageParamUnion []openai.ChatCompletionMessageParamUnion `json:"chat_completion_message_param_union" form:"chat_completion_message_param_union"` //消息内容
PlatForm string `json:"platform" form:"platform"` //平台名称
ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数
IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数
MessageContent []MessageContent `json:"message_content" form:"message_content"` //消息内容
PlatForm string `json:"platform" form:"platform"` //平台名称
}
type ReceiveMessageForeignAIRRequest struct {
@ -132,3 +132,9 @@ type ReceiveMessageForeignAIRRequest struct {
PlatForm string `json:"platform" form:"platform"` //平台名称
ChatCompletionChunk openai.ChatCompletionChunk `json:"chat_completion_chunk" form:"chat_completion_chunk"` //消息内容
}
type MessageContent struct {
Role string `json:"role"`
Content string `json:"content"`
ContentType string `json:"content_type,omitempty"`
}

View File

@ -3,6 +3,7 @@ package generalModel
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
@ -188,7 +189,7 @@ func SingleTurnConversationText(prompt string, param proto.ModelParam) (string,
}
// 只进行与模型信息交互不进行保存,返回消息给国内服务器
func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string, messages []openai.ChatCompletionMessageParamUnion) {
func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string, originMessages []proto.MessageContent) {
apiKey := modelParam.APIKey
if apiKey == "" {
log.Printf("openai_%s apiKey is empty\n", platform)
@ -204,6 +205,20 @@ func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, p
// log.Println("GetOpenAIHistoryMessages error:", err)
// return
//}
var messages []openai.ChatCompletionMessageParamUnion
for _, msg := range originMessages {
switch msg.Role {
case "user":
messages = append(messages, openai.UserMessage(msg.Content))
case "assistant":
messages = append(messages, openai.AssistantMessage(msg.Content))
case "system":
messages = append(messages, openai.SystemMessage(msg.Content))
default:
messages = append(messages, openai.UserMessage(msg.Content))
}
}
if proto.Config.LOG_OUTPUT {
messagesBytes, _ := json.Marshal(messages)
log.Println("openai messages:", string(messagesBytes))
@ -315,13 +330,13 @@ func SendOpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContex
req.IMParamContext = *imCtx
req.ModelParam = modelParam
req.PlatForm = platform
var messages []openai.ChatCompletionMessageParamUnion
err1 := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
var messages []proto.MessageContent
err1 := GetGeneralSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
if err1 != nil {
log.Println("GetOpenAIHistoryMessages error:", err1)
return
}
req.ChatCompletionMessageParamUnion = messages
req.MessageContent = messages
reqBytes, err := json.Marshal(req)
if err != nil {
log.Println("send openai foreign ai request error:", err)
@ -354,3 +369,47 @@ func SendOpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContex
return
}
}
// 获取spark模型会话的历史消息
func GetGeneralSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]proto.MessageContent, userID int) error {
//获取会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.SparkContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
}
//添加系统消息
if systemPrompt != "" {
var message proto.MessageContent
message.Role = "system"
message.Content = systemPrompt
*messages = append(*messages, message)
}
for _, msg := range msgs {
var message proto.MessageContent
if msg.Type == proto.UserToModelMsgType {
message.Role = proto.SparkRoleUser
if msg.Status == proto.UserToModelFileMsgType {
err2 := service.DecodeFileMessageToGeneralMessage(msg.Msg, userID, messages)
if err2 != nil {
log.Println("spark decode file message error:", err2)
}
} else {
message.Content = msg.Msg
*messages = append(*messages, message)
}
} else if msg.Type == proto.ModelToUserMsgType {
message.Role = proto.SparkRoleAssistant
message.Content = msg.Msg
*messages = append(*messages, message)
} else {
continue
}
}
//添加本次请求消息(上面获取到的已包含本次请求消息)
//var message MessageContent
//message.Role = "user"
//message.Content = prompt
//*messages = append(*messages, message)
return nil
}

View File

@ -465,7 +465,7 @@ func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[
if msg.Type == proto.UserToModelMsgType {
message.Role = proto.SparkRoleUser
if msg.Status == proto.UserToModelFileMsgType {
err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
err2 := DecodeFileMessageToSparkMessage(msg.Msg, userID, messages)
if err2 != nil {
log.Println("spark decode file message error:", err2)
}
@ -489,7 +489,7 @@ func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[
return nil
}
func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error {
func DecodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error {
var fileMsgContent proto.UserFileMessage
err := json.Unmarshal([]byte(msg), &fileMsgContent)
if err != nil {
@ -540,3 +540,55 @@ func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageC
*message = append(*message, messageContent)
return nil
}
func DecodeFileMessageToGeneralMessage(msg string, userID int, message *[]proto.MessageContent) error {
var fileMsgContent proto.UserFileMessage
err := json.Unmarshal([]byte(msg), &fileMsgContent)
if err != nil {
log.Println("general unmarshal file message error:", err)
return err
}
var isRemoveSystem bool
log.Println("general,decode file message:", fileMsgContent)
for _, file := range fileMsgContent.FileContent {
//文件消息为文本类型,则获取文本内容,若为图片则添加图片链接
if file.FileType == proto.FileTypeText {
fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID)
if err2 != nil {
log.Println("doubao find file content error:", err2, "\t,file id:", file.FileContent.FileID, "\t,file name:", file.FileContent.UserFileName)
continue
}
fileContent := fileContents[0].FileContent
//文件文本内容
var messageContent proto.MessageContent
messageContent.Role = "user"
messageContent.Content = fileContent
messageContent.ContentType = "text"
*message = append(*message, messageContent)
} else if file.FileType == proto.FileTypeImage {
imgBase64, err2 := FindImageFileContent(file.FileContent.FileID, userID)
if err2 != nil {
continue
}
//图片文件内容
var messageContent proto.MessageContent
messageContent.Role = "user"
messageContent.Content = imgBase64
messageContent.ContentType = "image"
*message = append(*message, messageContent)
if !isRemoveSystem {
//移除*message数组首个元素
*message = (*message)[1:]
isRemoveSystem = true
}
} else {
log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID)
}
}
var messageContent proto.MessageContent
messageContent.Role = "user"
messageContent.Content = fileMsgContent.Text
messageContent.ContentType = "text"
*message = append(*message, messageContent)
return nil
}