package service import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/worker" "context" "encoding/json" "fmt" "github.com/ollama/ollama/api" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils" "github.com/volcengine/volcengine-go-sdk/volcengine" "io" "log" "time" ) func doubao(modelParam proto.ModelParam, question, channel string, SessionID, UserID int, modelID int) { doubaoClient := arkruntime.NewClientWithApiKey( modelParam.APIKey, arkruntime.WithBaseUrl(modelParam.Url), ) ctx := context.Background() messages := make([]*model.ChatCompletionMessage, 0) err := GetDouBaoSessionHistoryMsg(SessionID, modelParam.System, &messages, question) if err != nil { log.Println("get doubao session history message error:", err) } //messagesStr, _ := json.Marshal(messages) //log.Println("doubao messages:", string(messagesStr)) req := model.CreateChatCompletionRequest{ Model: modelParam.Model, Messages: messages, } stream, err := doubaoClient.CreateChatCompletionStream(ctx, req) if err != nil { fmt.Printf("stream chat error: %v\n", err) return } defer func(stream *utils.ChatCompletionStreamReader) { err2 := stream.Close() if err2 != nil { log.Println("close stream error:", err2) } }(stream) answer := "" for { recv, err3 := stream.Recv() if err3 == io.EOF { log.Println("doubao stream end:", err3) return } if err3 != nil { fmt.Printf("doubao Stream chat error: %v\n", err) return } go CreateAIStreamMsg(UserID, modelID, SessionID, recv, modelParam.Model) doubaoToGeneralMassageAndSendMsgQueue(&recv, channel, SessionID, UserID) choices := recv.Choices[0] if choices.FinishReason == proto.FinishReasonStop { answer += choices.Delta.Content break } else { answer += choices.Delta.Content } } //将消息存入数据库 err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2) if err != nil { log.Println("spark create message error:", err) } } func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { doubaoClient := arkruntime.NewClientWithApiKey( modelParam.APIKey, arkruntime.WithBaseUrl(modelParam.Url), ) ctx := context.Background() messages := make([]*model.ChatCompletionMessage, 0) err := GetDouBaoSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.Question) if err != nil { log.Println("get doubao session history message error:", err) } //messagesStr, _ := json.Marshal(messages) //log.Println("doubao messages:", string(messagesStr)) req := model.CreateChatCompletionRequest{ Model: modelParam.Model, Messages: messages, StreamOptions: &model.StreamOptions{ IncludeUsage: true, }, } stream, err := doubaoClient.CreateChatCompletionStream(ctx, req) if err != nil { fmt.Printf("stream chat error: %v\n", err) return } defer func(stream *utils.ChatCompletionStreamReader) { err2 := stream.Close() if err2 != nil { log.Println("close stream error:", err2) } }(stream) answer := "" tokens := 0 for { recv, err3 := stream.Recv() if err3 == io.EOF { log.Println("doubao stream end:", err3) return } if err3 != nil { fmt.Printf("doubao Stream chat error: %v\n", err) return } go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, recv, modelParam.Model) doubaoToGeneralMassageAndSendMsgQueue(&recv, imCtx.Channel, imCtx.SessionID, imCtx.UserID) choices := recv.Choices[0] if choices.FinishReason == proto.FinishReasonStop { answer += choices.Delta.Content tokens += recv.Usage.TotalTokens break } else { answer += choices.Delta.Content } } //将消息存入数据库 err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID) if err != nil { log.Println("spark create message error:", err) } //将本次请求的tokens存入数据库 err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) if err2 != nil { log.Println("create model token error:", err2) } } func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamResponse, channel string, SessionID, UserID int) { var aiMsg proto.AIQueueMessage var wsMsg proto.WSMessage var apiMsg api.GenerateResponse apiMsg.Model = proto.ModelTypeDouBao choices := data.Choices[0] apiMsg.Response = choices.Delta.Content if choices.FinishReason == proto.FinishReasonStop { apiMsg.Done = true } else { apiMsg.Done = false } //消息队列部分 aiMsg.Type = proto.ModelTypeDouBao aiMsg.Msg = apiMsg //ws发送消息部分 wsMsg.Msg = aiMsg wsMsg.SessionID = SessionID wsMsg.ToID = UserID wsMsg.Type = proto.ModelTypeDouBao //发送消息 wsMsgStr, _ := json.Marshal(wsMsg) worker.Publish(channel, string(wsMsgStr), time.Second*60) } func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, prompt string) error { var err error //添加系统 if systemPrompt != "" { var message model.ChatCompletionMessage message.Role = model.ChatMessageRoleSystem message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(systemPrompt), } *messages = append(*messages, &message) } //获取该会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) if len(msgs) > proto.DouBaoContextLength { //只取最后的 SparkContextLength 条消息 msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:] } for _, v := range msgs { var message model.ChatCompletionMessage if v.Type == proto.UserToModelMsgType { message.Role = model.ChatMessageRoleUser } else if v.Type == proto.ModelToUserMsgType { message.Role = model.ChatMessageRoleAssistant } else { continue } message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(v.Msg), } *messages = append(*messages, &message) } //添加本次请求消息(本次消息已在上面添加) return err }