package service import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/worker" "context" "encoding/json" "fmt" "github.com/ollama/ollama/api" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils" "github.com/volcengine/volcengine-go-sdk/volcengine" "io" "log" "time" ) func doubao(modelParam proto.ModelParam, question, channel string, SessionID, UserID int, modelID int) { doubaoClient := arkruntime.NewClientWithApiKey( modelParam.APIKey, arkruntime.WithBaseUrl(modelParam.Url), ) ctx := context.Background() messages := make([]*model.ChatCompletionMessage, 0) err := GetDouBaoSessionHistoryMsg(SessionID, modelParam.System, &messages, question) if err != nil { log.Println("get doubao session history message error:", err) } //messagesStr, _ := json.Marshal(messages) //log.Println("doubao messages:", string(messagesStr)) req := model.CreateChatCompletionRequest{ Model: modelParam.Model, Messages: messages, StreamOptions: &model.StreamOptions{ IncludeUsage: true, }, } stream, err := doubaoClient.CreateChatCompletionStream(ctx, req) if err != nil { fmt.Printf("stream chat error: %v\n", err) return } defer func(stream *utils.ChatCompletionStreamReader) { err2 := stream.Close() if err2 != nil { log.Println("close stream error:", err2) } }(stream) answer := "" for { recv, err3 := stream.Recv() if err3 == io.EOF { log.Println("doubao stream end:", err3) return } if err3 != nil { fmt.Printf("doubao Stream chat error: %v\n", err) return } go CreateAIStreamMsg(UserID, modelID, SessionID, recv, modelParam.Model) doubaoToGeneralMassageAndSendMsgQueue(&recv, channel, SessionID, UserID) choices := recv.Choices[0] if choices.FinishReason == proto.FinishReasonStop { answer += choices.Delta.Content break } else { answer += choices.Delta.Content } } //将消息存入数据库 err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2) if err != nil { log.Println("spark create message error:", err) } } func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { doubaoClient := arkruntime.NewClientWithApiKey( modelParam.APIKey, arkruntime.WithBaseUrl(modelParam.Url), ) ctx := context.Background() messages := make([]*model.ChatCompletionMessage, 0) err := GetDouBaoSessionHistoryMsgV2(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) if err != nil { log.Println("get doubao session history message error:", err) } maxTokens := 10240 req := model.CreateChatCompletionRequest{ Model: modelParam.Model, Temperature: &modelParam.Temperature, TopP: &modelParam.TopP, MaxTokens: &maxTokens, //设置最大token数设置为10k Messages: messages, StreamOptions: &model.StreamOptions{ IncludeUsage: true, }, } messagesStr, _ := json.Marshal(req) log.Println("doubao req messages:", string(messagesStr)) stream, err := doubaoClient.CreateChatCompletionStream(ctx, req) if err != nil { fmt.Printf("stream chat error: %v\n", err) return } defer func(stream *utils.ChatCompletionStreamReader) { err2 := stream.Close() if err2 != nil { log.Println("close stream error:", err2) } }(stream) answer := "" tokens := 0 for { recv, err3 := stream.Recv() if err3 == io.EOF { log.Println("doubao stream end:", err3) return } if err3 != nil { fmt.Printf("doubao Stream chat error: %v\n", err) return } aistreamId, err3 := CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, recv, modelParam.Model) if err3 != nil { log.Println("create ai stream message error:", err3) } doubaoToGeneralMassageAndSendMsgQueue(&recv, imCtx.Channel, imCtx.SessionID, imCtx.UserID) if recv.Choices == nil { log.Println("doubao stream recv choices is nil:", recv, "\t aiStream_id:", aistreamId) continue } choices := recv.Choices[0] if choices.FinishReason == proto.FinishReasonStop { answer += choices.Delta.Content if recv.Usage != nil { tokens += recv.Usage.TotalTokens } break } else { answer += choices.Delta.Content } } //将消息存入数据库 err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID) if err != nil { log.Println("doubao create message error:", err) } //将本次请求的tokens存入数据库 err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) if err2 != nil { log.Println("create model token error:", err2) } } func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamResponse, channel string, SessionID, UserID int) { var aiMsg proto.AIQueueMessage var wsMsg proto.WSMessage var apiMsg api.GenerateResponse apiMsg.Model = proto.ModelTypeDouBao choices := data.Choices[0] if choices.Delta.Content != "" { apiMsg.Response = choices.Delta.Content } else if choices.Delta.ReasoningContent != nil { apiMsg.Response = *choices.Delta.ReasoningContent //深度思考部分 } if choices.FinishReason == proto.FinishReasonStop || choices.FinishReason == proto.FinishReasonLength { //停止原因是stop或长度超限 apiMsg.Done = true } else { apiMsg.Done = false } //消息队列部分 aiMsg.Type = proto.ModelTypeDouBao aiMsg.Msg = apiMsg //ws发送消息部分 wsMsg.Msg = aiMsg wsMsg.SessionID = SessionID wsMsg.ToID = UserID wsMsg.Type = proto.ModelTypeDouBao //发送消息 wsMsgStr, _ := json.Marshal(wsMsg) worker.Publish(channel, string(wsMsgStr), time.Second*60) } // 当前消息包含在messages中 func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, prompt string) error { var err error //添加系统 if systemPrompt != "" { var message model.ChatCompletionMessage message.Role = model.ChatMessageRoleSystem message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(systemPrompt), } *messages = append(*messages, &message) } //获取该会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) if len(msgs) > proto.DouBaoContextLength { //只取最后的 SparkContextLength 条消息 msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:] } for _, v := range msgs { var message model.ChatCompletionMessage if v.Type == proto.UserToModelMsgType { message.Role = model.ChatMessageRoleUser //用户消息为有图片类型 if v.Status == proto.UserToModelImageMsgType { var imgMsg proto.UserImageMsg err2 := json.Unmarshal([]byte(v.Msg), &imgMsg) log.Println("user image message:", imgMsg) if err2 != nil { log.Println("unmarshal user image message error:", err2) return err2 } //用户消息为图片类型 var listValue []*model.ChatCompletionMessageContentPart listValue = append(listValue, &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeText, Text: imgMsg.Text, }) //图片链接是可以是多个 for _, img := range imgMsg.ImageContent { listValue = append(listValue, &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeImageURL, ImageURL: &model.ChatMessageImageURL{ URL: img.ImgUrl, }, }) } message.Content = &model.ChatCompletionMessageContent{ ListValue: listValue, } } else { //用户消息为文本类型 message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(v.Msg), } } } else if v.Type == proto.ModelToUserMsgType { message.Role = model.ChatMessageRoleAssistant message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(v.Msg), } } else { continue } *messages = append(*messages, &message) } //添加本次请求消息(本次消息已在上面添加) return err } // 当前消息包含在messages中 func GetDouBaoSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, userID int) error { var err error //添加系统 if systemPrompt != "" { var message model.ChatCompletionMessage message.Role = model.ChatMessageRoleSystem message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(systemPrompt), } *messages = append(*messages, &message) } //获取该会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) if len(msgs) > proto.DouBaoContextLength { //只取最后的 SparkContextLength 条消息 msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:] } //log.Println("get doubao session history message:", msgs) for i, v := range msgs { var message model.ChatCompletionMessage if v.Type == proto.UserToModelMsgType { message.Role = model.ChatMessageRoleUser //用户消息为有图片类型且是最后一条消息 if v.Status == proto.UserToModelFileMsgType && i == len(msgs)-1 { listValue, err2 := decodeFileMessageToDouBaoMessage(v.Msg, userID) if err2 != nil { log.Println("unmarshal user image message error:", err2) return err2 } message.Content = &model.ChatCompletionMessageContent{ ListValue: listValue, } } else { //用户消息为文本类型 message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(v.Msg), } } } else if v.Type == proto.ModelToUserMsgType { message.Role = model.ChatMessageRoleAssistant message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(v.Msg), } } else { continue } *messages = append(*messages, &message) } //添加本次请求消息(本次消息已在上面添加) return err } func decodeFileMessageToDouBaoMessage(msg string, userID int) ([]*model.ChatCompletionMessageContentPart, error) { var fileMsgContent proto.UserFileMessage err := json.Unmarshal([]byte(msg), &fileMsgContent) if err != nil { log.Println("DouBao unmarshal file message error:", err) return nil, err } log.Println("doubaov2,decode file message:", fileMsgContent) var listValue []*model.ChatCompletionMessageContentPart for _, file := range fileMsgContent.FileContent { //文件消息为文本类型,则获取文本内容,若为图片则添加图片链接 if file.FileType == proto.FileTypeText { fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID) if err2 != nil { log.Println("doubao find file content error:", err2, "\t,file id:", file.FileContent.FileID, "\t,file name:", file.FileContent.UserFileName) continue } fileContent := fileContents[0].FileContent //文件文本内容 listValue = append(listValue, &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeText, Text: "文件" + file.FileContent.UserFileName + ",内容:" + fileContent, }) } else if file.FileType == proto.FileTypeImage { listValue = append(listValue, &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeImageURL, ImageURL: &model.ChatMessageImageURL{ URL: proto.DownloadFileBaseURL + file.FileContent.FileStoreName, }, }) } else { log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID) } } listValue = append(listValue, &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeText, Text: fileMsgContent.Text, }) return listValue, nil }