添加国外服务器中转ai请求数据,添加gemini
This commit is contained in:
parent
4a7c7efe19
commit
5188f9c3c6
|
|
@ -122,7 +122,7 @@ type UserFileMessage struct {
|
||||||
type SendMessageForeignAIRRequest struct {
|
type SendMessageForeignAIRRequest struct {
|
||||||
ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数
|
ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数
|
||||||
IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数
|
IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数
|
||||||
ChatCompletionMessageParamUnion []openai.ChatCompletionMessageParamUnion `json:"chat_completion_message_param_union" form:"chat_completion_message_param_union"` //消息内容
|
MessageContent []MessageContent `json:"message_content" form:"message_content"` //消息内容
|
||||||
PlatForm string `json:"platform" form:"platform"` //平台名称
|
PlatForm string `json:"platform" form:"platform"` //平台名称
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -132,3 +132,9 @@ type ReceiveMessageForeignAIRRequest struct {
|
||||||
PlatForm string `json:"platform" form:"platform"` //平台名称
|
PlatForm string `json:"platform" form:"platform"` //平台名称
|
||||||
ChatCompletionChunk openai.ChatCompletionChunk `json:"chat_completion_chunk" form:"chat_completion_chunk"` //消息内容
|
ChatCompletionChunk openai.ChatCompletionChunk `json:"chat_completion_chunk" form:"chat_completion_chunk"` //消息内容
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MessageContent struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ContentType string `json:"content_type,omitempty"`
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package generalModel
|
||||||
import (
|
import (
|
||||||
"StuAcaWorksAI/dao"
|
"StuAcaWorksAI/dao"
|
||||||
"StuAcaWorksAI/proto"
|
"StuAcaWorksAI/proto"
|
||||||
|
"StuAcaWorksAI/service"
|
||||||
"StuAcaWorksAI/worker"
|
"StuAcaWorksAI/worker"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
@ -188,7 +189,7 @@ func SingleTurnConversationText(prompt string, param proto.ModelParam) (string,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只进行与模型信息交互不进行保存,返回消息给国内服务器
|
// 只进行与模型信息交互不进行保存,返回消息给国内服务器
|
||||||
func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string, messages []openai.ChatCompletionMessageParamUnion) {
|
func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string, originMessages []proto.MessageContent) {
|
||||||
apiKey := modelParam.APIKey
|
apiKey := modelParam.APIKey
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
log.Printf("openai_%s apiKey is empty\n", platform)
|
log.Printf("openai_%s apiKey is empty\n", platform)
|
||||||
|
|
@ -204,6 +205,20 @@ func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, p
|
||||||
// log.Println("GetOpenAIHistoryMessages error:", err)
|
// log.Println("GetOpenAIHistoryMessages error:", err)
|
||||||
// return
|
// return
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
var messages []openai.ChatCompletionMessageParamUnion
|
||||||
|
for _, msg := range originMessages {
|
||||||
|
switch msg.Role {
|
||||||
|
case "user":
|
||||||
|
messages = append(messages, openai.UserMessage(msg.Content))
|
||||||
|
case "assistant":
|
||||||
|
messages = append(messages, openai.AssistantMessage(msg.Content))
|
||||||
|
case "system":
|
||||||
|
messages = append(messages, openai.SystemMessage(msg.Content))
|
||||||
|
default:
|
||||||
|
messages = append(messages, openai.UserMessage(msg.Content))
|
||||||
|
}
|
||||||
|
}
|
||||||
if proto.Config.LOG_OUTPUT {
|
if proto.Config.LOG_OUTPUT {
|
||||||
messagesBytes, _ := json.Marshal(messages)
|
messagesBytes, _ := json.Marshal(messages)
|
||||||
log.Println("openai messages:", string(messagesBytes))
|
log.Println("openai messages:", string(messagesBytes))
|
||||||
|
|
@ -315,13 +330,13 @@ func SendOpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContex
|
||||||
req.IMParamContext = *imCtx
|
req.IMParamContext = *imCtx
|
||||||
req.ModelParam = modelParam
|
req.ModelParam = modelParam
|
||||||
req.PlatForm = platform
|
req.PlatForm = platform
|
||||||
var messages []openai.ChatCompletionMessageParamUnion
|
var messages []proto.MessageContent
|
||||||
err1 := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
|
err1 := GetGeneralSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
log.Println("GetOpenAIHistoryMessages error:", err1)
|
log.Println("GetOpenAIHistoryMessages error:", err1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.ChatCompletionMessageParamUnion = messages
|
req.MessageContent = messages
|
||||||
reqBytes, err := json.Marshal(req)
|
reqBytes, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("send openai foreign ai request error:", err)
|
log.Println("send openai foreign ai request error:", err)
|
||||||
|
|
@ -354,3 +369,47 @@ func SendOpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContex
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取spark模型会话的历史消息
|
||||||
|
func GetGeneralSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]proto.MessageContent, userID int) error {
|
||||||
|
//获取会话的历史消息
|
||||||
|
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
||||||
|
if len(msgs) > proto.SparkContextLength {
|
||||||
|
//只取最后的 SparkContextLength 条消息
|
||||||
|
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
|
||||||
|
}
|
||||||
|
//添加系统消息
|
||||||
|
if systemPrompt != "" {
|
||||||
|
var message proto.MessageContent
|
||||||
|
message.Role = "system"
|
||||||
|
message.Content = systemPrompt
|
||||||
|
*messages = append(*messages, message)
|
||||||
|
}
|
||||||
|
for _, msg := range msgs {
|
||||||
|
var message proto.MessageContent
|
||||||
|
if msg.Type == proto.UserToModelMsgType {
|
||||||
|
message.Role = proto.SparkRoleUser
|
||||||
|
if msg.Status == proto.UserToModelFileMsgType {
|
||||||
|
err2 := service.DecodeFileMessageToGeneralMessage(msg.Msg, userID, messages)
|
||||||
|
if err2 != nil {
|
||||||
|
log.Println("spark decode file message error:", err2)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
message.Content = msg.Msg
|
||||||
|
*messages = append(*messages, message)
|
||||||
|
}
|
||||||
|
} else if msg.Type == proto.ModelToUserMsgType {
|
||||||
|
message.Role = proto.SparkRoleAssistant
|
||||||
|
message.Content = msg.Msg
|
||||||
|
*messages = append(*messages, message)
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//添加本次请求消息(上面获取到的已包含本次请求消息)
|
||||||
|
//var message MessageContent
|
||||||
|
//message.Role = "user"
|
||||||
|
//message.Content = prompt
|
||||||
|
//*messages = append(*messages, message)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -465,7 +465,7 @@ func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[
|
||||||
if msg.Type == proto.UserToModelMsgType {
|
if msg.Type == proto.UserToModelMsgType {
|
||||||
message.Role = proto.SparkRoleUser
|
message.Role = proto.SparkRoleUser
|
||||||
if msg.Status == proto.UserToModelFileMsgType {
|
if msg.Status == proto.UserToModelFileMsgType {
|
||||||
err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
|
err2 := DecodeFileMessageToSparkMessage(msg.Msg, userID, messages)
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
log.Println("spark decode file message error:", err2)
|
log.Println("spark decode file message error:", err2)
|
||||||
}
|
}
|
||||||
|
|
@ -489,7 +489,7 @@ func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error {
|
func DecodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error {
|
||||||
var fileMsgContent proto.UserFileMessage
|
var fileMsgContent proto.UserFileMessage
|
||||||
err := json.Unmarshal([]byte(msg), &fileMsgContent)
|
err := json.Unmarshal([]byte(msg), &fileMsgContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -540,3 +540,55 @@ func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageC
|
||||||
*message = append(*message, messageContent)
|
*message = append(*message, messageContent)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DecodeFileMessageToGeneralMessage(msg string, userID int, message *[]proto.MessageContent) error {
|
||||||
|
var fileMsgContent proto.UserFileMessage
|
||||||
|
err := json.Unmarshal([]byte(msg), &fileMsgContent)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("general unmarshal file message error:", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var isRemoveSystem bool
|
||||||
|
log.Println("general,decode file message:", fileMsgContent)
|
||||||
|
for _, file := range fileMsgContent.FileContent {
|
||||||
|
//文件消息为文本类型,则获取文本内容,若为图片则添加图片链接
|
||||||
|
if file.FileType == proto.FileTypeText {
|
||||||
|
fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID)
|
||||||
|
if err2 != nil {
|
||||||
|
log.Println("doubao find file content error:", err2, "\t,file id:", file.FileContent.FileID, "\t,file name:", file.FileContent.UserFileName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fileContent := fileContents[0].FileContent
|
||||||
|
//文件文本内容
|
||||||
|
var messageContent proto.MessageContent
|
||||||
|
messageContent.Role = "user"
|
||||||
|
messageContent.Content = fileContent
|
||||||
|
messageContent.ContentType = "text"
|
||||||
|
*message = append(*message, messageContent)
|
||||||
|
} else if file.FileType == proto.FileTypeImage {
|
||||||
|
imgBase64, err2 := FindImageFileContent(file.FileContent.FileID, userID)
|
||||||
|
if err2 != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
//图片文件内容
|
||||||
|
var messageContent proto.MessageContent
|
||||||
|
messageContent.Role = "user"
|
||||||
|
messageContent.Content = imgBase64
|
||||||
|
messageContent.ContentType = "image"
|
||||||
|
*message = append(*message, messageContent)
|
||||||
|
if !isRemoveSystem {
|
||||||
|
//移除*message数组首个元素
|
||||||
|
*message = (*message)[1:]
|
||||||
|
isRemoveSystem = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var messageContent proto.MessageContent
|
||||||
|
messageContent.Role = "user"
|
||||||
|
messageContent.Content = fileMsgContent.Text
|
||||||
|
messageContent.ContentType = "text"
|
||||||
|
*message = append(*message, messageContent)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue