diff --git a/service/generalModel/openai.go b/service/generalModel/openai.go new file mode 100644 index 0000000..07f4c04 --- /dev/null +++ b/service/generalModel/openai.go @@ -0,0 +1,155 @@ +package generalModel + +import ( + "StuAcaWorksAI/dao" + "StuAcaWorksAI/proto" + "StuAcaWorksAI/service" + "StuAcaWorksAI/worker" + "context" + "encoding/json" + "github.com/ollama/ollama/api" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "log" + "time" +) + +// 通义模型 +func OpenAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) { + apiKey := modelParam.APIKey + if apiKey == "" { + log.Printf("openai_%s apiKey is empty\n", platform) + return + } + openaiClient := openai.NewClient( + option.WithAPIKey(apiKey), + option.WithBaseURL(modelParam.Url), + ) + var messages []openai.ChatCompletionMessageParamUnion + err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) + if err != nil { + log.Println("GetOpenAIHistoryMessages error:", err) + return + } + if proto.Config.LOG_OUTPUT { + messagesBytes, _ := json.Marshal(messages) + log.Println("openai messages:", string(messagesBytes)) + } + + stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{ + Messages: messages, + Model: modelParam.Model, + }) + acc := openai.ChatCompletionAccumulator{} + answer := "" + var tokens int64 + var PromptTokens int64 //用户问题token + + for stream.Next() { + chunk := stream.Current() + go service.CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model) + acc.AddChunk(chunk) + if _, ok := acc.JustFinishedToolCall(); ok { + chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断 + } + if len(chunk.Choices) > 0 { + println(chunk.Choices[0].Delta.Content) + answer += chunk.Choices[0].Delta.Content + tokens = chunk.Usage.TotalTokens + PromptTokens = chunk.Usage.PromptTokens + } + OpenAIToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID) + } + if stream.Err() != nil { + log.Printf("openai_%s stream err:%v\n", platform, stream.Err()) + } + //将消息存入数据库 + DOSaveMessageAndToken(imCtx, answer, "openai_"+platform, tokens, PromptTokens) +} + +func DOSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) { + msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType} + err2, msgID := dao.CreateMessageV2(msg) + if err2 != nil { + log.Printf("%s create message error:%v\n", platform, err2) + } else { + log.Printf("%s create message success,message id:%d\n", platform, msgID) + } + //保存用户问题消耗token + err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens)) + if err != nil { + log.Printf("%s user prompt create message token error:%v\n", platform, err) + } else { + log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID) + } + + //保存返回的消息token消耗 + err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens)) + if err != nil { + log.Printf("%s create message token error:%v\n", platform, err) + } else { + log.Printf("%s create message token id:%d\n", platform, messageTokenID) + } + //存入tokens消耗信息 + err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) + if err3 != nil { + log.Printf("%s create model tokens error:%v\n", platform, err3) + } else { + log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID) + } +} + +func GetOpenAIHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error { + //获取会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + if len(msgs) > proto.SparkContextLength { + //只取最后的 SparkContextLength 条消息 + msgs = msgs[len(msgs)-proto.SparkContextLength-1:] + } + //添加系统消息 + if systemPrompt != "" { + *messages = append(*messages, openai.SystemMessage(systemPrompt)) + } + for _, msg := range msgs { + if msg.Type == proto.UserToModelMsgType { + if msg.Status == proto.UserToModelFileMsgType { + //err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages) + //if err2 != nil { + // log.Println("spark decode file message error:", err2) + //} + } else { + *messages = append(*messages, openai.UserMessage(msg.Msg)) + } + } else if msg.Type == proto.ModelToUserMsgType { + *messages = append(*messages, openai.AssistantMessage(msg.Msg)) + } else { + continue + } + } + return nil +} + +func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) { + var aiMsg proto.AIQueueMessage + var wsMsg proto.WSMessage + var apiMsg api.GenerateResponse + apiMsg.Model = proto.ModelTypeSpark + apiMsg.Response = data.Choices[0].Delta.Content + if data.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断 + apiMsg.Done = true + } else { + apiMsg.Done = false + } + //消息队列部分 + aiMsg.Type = proto.ModelTypeSpark + aiMsg.Msg = apiMsg + //ws发送消息部分 + wsMsg.Msg = aiMsg + wsMsg.SessionID = SessionID + wsMsg.ToID = UserID + wsMsg.Type = proto.ModelTypeSpark + + //发送消息 + wsMsgStr, _ := json.Marshal(wsMsg) + worker.Publish(channel, string(wsMsgStr), time.Second*60) +}