package generalModel import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/worker" "context" "encoding/json" "github.com/ollama/ollama/api" "github.com/openai/openai-go" "github.com/openai/openai-go/option" "log" "time" ) // 通用openai模型 func OpenAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) { apiKey := modelParam.APIKey if apiKey == "" { log.Printf("openai_%s apiKey is empty\n", platform) return } openaiClient := openai.NewClient( option.WithAPIKey(apiKey), option.WithBaseURL(modelParam.Url), ) var messages []openai.ChatCompletionMessageParamUnion err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) if err != nil { log.Println("GetOpenAIHistoryMessages error:", err) return } if proto.Config.LOG_OUTPUT { messagesBytes, _ := json.Marshal(messages) log.Println("openai messages:", string(messagesBytes)) } stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{ Messages: messages, Model: modelParam.Model, }) acc := openai.ChatCompletionAccumulator{} answer := "" var tokens int64 var PromptTokens int64 //用户问题token for stream.Next() { chunk := stream.Current() msgStr, err3 := json.Marshal(chunk) if err3 != nil { log.Println("openai stream msg Error parsing JSON:", err) } //go service.CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model) go dao.InsertAIStreamMsgToDB(imCtx.UserID, imCtx.SessionID, uint(imCtx.ModelID), string(msgStr), modelParam.Model) //存入数据库 acc.AddChunk(chunk) if _, ok := acc.JustFinishedToolCall(); ok { chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断 } if len(chunk.Choices) > 0 { if proto.Config.LOG_OUTPUT { log.Printf("openai_%s stream msg:%s\n", platform, chunk.Choices[0].Delta.Content) } answer += chunk.Choices[0].Delta.Content tokens = chunk.Usage.TotalTokens PromptTokens = chunk.Usage.PromptTokens } OpenAIToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID) } if stream.Err() != nil { log.Printf("openai_%s stream err:%v\n", platform, stream.Err()) } //将消息存入数据库 DOSaveMessageAndToken(imCtx, answer, "openai_"+platform, tokens, PromptTokens) } func DOSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) { msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType} err2, msgID := dao.CreateMessageV2(msg) if err2 != nil { log.Printf("%s create message error:%v\n", platform, err2) } else { log.Printf("%s create message success,message id:%d\n", platform, msgID) } //保存用户问题消耗token err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens)) if err != nil { log.Printf("%s user prompt create message token error:%v\n", platform, err) } else { log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID) } //保存返回的消息token消耗 err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens)) if err != nil { log.Printf("%s create message token error:%v\n", platform, err) } else { log.Printf("%s create message token id:%d\n", platform, messageTokenID) } //存入tokens消耗信息 err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) if err3 != nil { log.Printf("%s create model tokens error:%v\n", platform, err3) } else { log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID) } } func GetOpenAIHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error { //获取会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) if len(msgs) > proto.SparkContextLength { //只取最后的 SparkContextLength 条消息 msgs = msgs[len(msgs)-proto.SparkContextLength-1:] } //添加系统消息 if systemPrompt != "" { *messages = append(*messages, openai.SystemMessage(systemPrompt)) } for _, msg := range msgs { if msg.Type == proto.UserToModelMsgType { if msg.Status == proto.UserToModelFileMsgType { //err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages) //if err2 != nil { // log.Println("spark decode file message error:", err2) //} } else { *messages = append(*messages, openai.UserMessage(msg.Msg)) } } else if msg.Type == proto.ModelToUserMsgType { *messages = append(*messages, openai.AssistantMessage(msg.Msg)) } else { continue } } return nil } func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) { var aiMsg proto.AIQueueMessage var wsMsg proto.WSMessage var apiMsg api.GenerateResponse apiMsg.Model = proto.ModelTypeSpark apiMsg.Response = data.Choices[0].Delta.Content if data.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断 apiMsg.Done = true } else { apiMsg.Done = false } //消息队列部分 aiMsg.Type = proto.ModelTypeSpark aiMsg.Msg = apiMsg //ws发送消息部分 wsMsg.Msg = aiMsg wsMsg.SessionID = SessionID wsMsg.ToID = UserID wsMsg.Type = proto.ModelTypeSpark //发送消息 wsMsgStr, _ := json.Marshal(wsMsg) worker.Publish(channel, string(wsMsgStr), time.Second*60) } // openai 单轮对话,文本 func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, error) { var answer string var err error apiKey := param.APIKey if apiKey == "" || param.Url == "" { log.Printf("openai apiKey or url is empty\n") return "", nil } openaiClient := openai.NewClient( option.WithAPIKey(apiKey), option.WithBaseURL(param.Url), ) chatCompletion, err := openaiClient.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ Messages: []openai.ChatCompletionMessageParamUnion{ openai.UserMessage(prompt), }, Model: param.Model, }) if err != nil { log.Printf("openai single turn conversation error:%v\n", err) return "", err } answer = chatCompletion.Choices[0].Message.Content return answer, nil } // 只进行与模型信息交互不进行保存,返回消息给国内服务器 func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string, originMessages []proto.MessageContent) { apiKey := modelParam.APIKey if apiKey == "" { log.Printf("openai_%s apiKey is empty\n", platform) return } openaiClient := openai.NewClient( option.WithAPIKey(apiKey), option.WithBaseURL(modelParam.Url), ) //var messages []openai.ChatCompletionMessageParamUnion //err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) //if err != nil { // log.Println("GetOpenAIHistoryMessages error:", err) // return //} var messages []openai.ChatCompletionMessageParamUnion for _, msg := range originMessages { switch msg.Role { case "user": messages = append(messages, openai.UserMessage(msg.Content)) case "assistant": messages = append(messages, openai.AssistantMessage(msg.Content)) case "system": messages = append(messages, openai.SystemMessage(msg.Content)) default: messages = append(messages, openai.UserMessage(msg.Content)) } } if proto.Config.LOG_OUTPUT { messagesBytes, _ := json.Marshal(messages) log.Println("openai messages:", string(messagesBytes)) } stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{ Messages: messages, Model: modelParam.Model, }) acc := openai.ChatCompletionAccumulator{} for stream.Next() { chunk := stream.Current() acc.AddChunk(chunk) if _, ok := acc.JustFinishedToolCall(); ok { chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断 } DoOpenAIForeignAICallBack(chunk, modelParam, imCtx, platform) } if stream.Err() != nil { log.Printf("openai_%s stream err:%v\n", platform, stream.Err()) } } // 返回消息给国内服务器 func DoOpenAIForeignAICallBack(chunk openai.ChatCompletionChunk, modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) { //post请求 url := "https://pm.ljsea.top/im/receive_message_foreign_ai?super_id=1" var req proto.ReceiveMessageForeignAIRRequest req.IMParamContext = *imCtx req.ModelParam = modelParam req.ChatCompletionChunk = chunk req.PlatForm = platform reqBytes, err := json.Marshal(req) if err != nil { log.Println("openai foreign ai request error:", err) return } superTokens := worker.GetRedisSetMembers("super_permission_tokens") if len(superTokens) == 0 { log.Println("openai foreign ai super permission tokens is empty") return } //设置请求头 headers := map[string]string{ "token": superTokens[0], "super_id": "1", } err, respBytes := worker.DoPostRequestJSON(url, reqBytes, headers) if err != nil { log.Println("openai foreign ai request error:", err) return } var resp proto.GenerateResp err = json.Unmarshal(respBytes, &resp) if err != nil { log.Println("openai foreign ai response error:", err) return } if resp.Code != proto.SuccessCode { log.Println("openai foreign ai response error:", resp.Message) return } } // 处理国外服务器返回的模型消息 func OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(data *proto.ReceiveMessageForeignAIRRequest) { //转入队列 OpenAIToGeneralMassageAndSendMsgQueue(&data.ChatCompletionChunk, data.IMParamContext.Channel, data.IMParamContext.SessionID, data.IMParamContext.UserID) //保存消息,流式需要暂存消息使用redis answer := worker.GetRedis(data.IMParamContext.Channel + "_stream_msg") //var tokens int64 //var PromptTokens int64 //用户问题token //tokensStr := worker.GetRedis(data.IMParamContext.Channel + "_total_tokens") //if tokensStr != "" { // tokens, _ = strconv.ParseInt(tokensStr, 10, 64) //} //PromptTokensStr := worker.GetRedis(data.IMParamContext.Channel + "_prompt_tokens") //if PromptTokensStr != "" { // PromptTokens, _ = strconv.ParseInt(PromptTokensStr, 10, 64) //} chunk := data.ChatCompletionChunk if len(chunk.Choices) > 0 { if proto.Config.LOG_OUTPUT { log.Printf("openai_%s stream msg:%s\n", data.PlatForm, chunk.Choices[0].Delta.Content) } answer += chunk.Choices[0].Delta.Content //tokens = chunk.Usage.TotalTokens //PromptTokens = chunk.Usage.PromptTokens } //将消息存入数据库 if chunk.Created < 0 || chunk.Choices[0].FinishReason != "" { //这里由于没有返回结束标志,所以用该字段来判断 tokens := chunk.Usage.TotalTokens PromptTokens := chunk.Usage.PromptTokens //将消息存入数据库 DOSaveMessageAndToken(&data.IMParamContext, answer, "openai_"+data.PlatForm, tokens, PromptTokens) } else { //流式消息,暂存到redis worker.SetRedisWithExpire(data.IMParamContext.Channel+"_stream_msg", answer, time.Second*30) } } // 将任务发送到国外服务器 func SendOpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) { //post请求 url := "https://vis.ljsea.top/im/send_message_foreign_ai?super_id=1" var req proto.SendMessageForeignAIRRequest req.IMParamContext = *imCtx req.ModelParam = modelParam req.PlatForm = platform var messages []proto.MessageContent err1 := GetGeneralSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) if err1 != nil { log.Println("GetOpenAIHistoryMessages error:", err1) return } req.MessageContent = messages reqBytes, err := json.Marshal(req) if err != nil { log.Println("send openai foreign ai request error:", err) return } superTokens := worker.GetRedisSetMembers("super_permission_tokens") if len(superTokens) == 0 { log.Println("send openai foreign ai super permission tokens is empty") return } //设置请求头 headers := map[string]string{ "token": superTokens[0], "super_id": "1", } err, respBytes := worker.DoPostRequestJSON(url, reqBytes, headers) if err != nil { log.Println("send openai foreign ai request error:", err) return } var resp proto.GenerateResp err = json.Unmarshal(respBytes, &resp) if err != nil { log.Println("send openai foreign ai response error:", err) return } if resp.Code != proto.SuccessCode { log.Println("send openai foreign ai response error:", resp.Message) return } } // 获取spark模型会话的历史消息 func GetGeneralSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]proto.MessageContent, userID int) error { //获取会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) if len(msgs) > proto.SparkContextLength { //只取最后的 SparkContextLength 条消息 msgs = msgs[len(msgs)-proto.SparkContextLength-1:] } //添加系统消息 if systemPrompt != "" { var message proto.MessageContent message.Role = "system" message.Content = systemPrompt *messages = append(*messages, message) } for _, msg := range msgs { var message proto.MessageContent if msg.Type == proto.UserToModelMsgType { message.Role = proto.SparkRoleUser if msg.Status == proto.UserToModelFileMsgType { //err2 := service.DecodeFileMessageToGeneralMessage(msg.Msg, userID, messages) //if err2 != nil { // log.Println("spark decode file message error:", err2) //} } else { message.Content = msg.Msg *messages = append(*messages, message) } } else if msg.Type == proto.ModelToUserMsgType { message.Role = proto.SparkRoleAssistant message.Content = msg.Msg *messages = append(*messages, message) } else { continue } } //添加本次请求消息(上面获取到的已包含本次请求消息) //var message MessageContent //message.Role = "user" //message.Content = prompt //*messages = append(*messages, message) return nil }