diff --git a/service/doubao.go b/service/doubao.go index b6802f1..68599b0 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -23,30 +23,17 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us arkruntime.WithBaseUrl(modelParam.Url), ) ctx := context.Background() - messages := make([]model.ChatCompletionMessage, 0) - err := GetDouBaoSessionHistoryMsg(SessionID, "", &messages, question) + messages := make([]*model.ChatCompletionMessage, 0) + err := GetDouBaoSessionHistoryMsg(SessionID, "你是豆包,是由字节跳动开发的 AI 人工智能助手", &messages, question) if err != nil { log.Println("get doubao session history message error:", err) } - messagesStr, _ := json.Marshal(messages) - log.Println("doubao messages:", string(messagesStr)) + //messagesStr, _ := json.Marshal(messages) + //log.Println("doubao messages:", string(messagesStr)) req := model.CreateChatCompletionRequest{ - Model: modelParam.Model, - Messages: []*model.ChatCompletionMessage{ - { - Role: model.ChatMessageRoleSystem, - Content: &model.ChatCompletionMessageContent{ - StringValue: volcengine.String("你是豆包,是由字节跳动开发的 AI 人工智能助手"), - }, - }, - { - Role: model.ChatMessageRoleUser, - Content: &model.ChatCompletionMessageContent{ - StringValue: volcengine.String(question), - }, - }, - }, + Model: modelParam.Model, + Messages: messages, } stream, err := doubaoClient.CreateChatCompletionStream(ctx, req) @@ -116,7 +103,7 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo worker.Publish(channel, string(wsMsgStr), time.Second*60) } -func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]model.ChatCompletionMessage, prompt string) error { +func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, prompt string) error { var err error //添加系统 if systemPrompt != "" { @@ -125,7 +112,7 @@ func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[] message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(systemPrompt), } - *messages = append(*messages, message) + *messages = append(*messages, &message) } //获取该会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) @@ -133,19 +120,16 @@ func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[] var message model.ChatCompletionMessage if v.Type == proto.UserToModelMsgType { message.Role = model.ChatMessageRoleUser - } else { + } else if v.Type == proto.ModelToUserMsgType { message.Role = model.ChatMessageRoleAssistant + } else { + continue } message.Content = &model.ChatCompletionMessageContent{ StringValue: volcengine.String(v.Msg), } - *messages = append(*messages, message) - } - //添加本次请求消息 - var message model.ChatCompletionMessage - message.Role = model.ChatMessageRoleUser - message.Content = &model.ChatCompletionMessageContent{ - StringValue: volcengine.String(prompt), + *messages = append(*messages, &message) } + //添加本次请求消息(本次消息已在上面添加) return err } diff --git a/service/spark.go b/service/spark.go index ec89783..caeff96 100644 --- a/service/spark.go +++ b/service/spark.go @@ -203,16 +203,16 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S // 生成参数 func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest { - messages := []MessageContent{ - {Role: "user", Content: question}, - } + //messages := []MessageContent{ + // {Role: "user", Content: question}, + //} messages2 := make([]MessageContent, 0) err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question) if err != nil { log.Println("get spark session history message error:", err) } - messages2Str, _ := json.Marshal(messages2) - log.Println("spark messages2:", string(messages2Str)) + //messages2Str, _ := json.Marshal(messages2) + //log.Println("spark messages2:", string(messages2Str)) data := SparkRequest{ Header: SparkRequestHeader{ @@ -227,7 +227,7 @@ func genSparkParams(question, appid string, domain string, sessionID int, system }, Payload: SparkRequestPayload{ Message: SparkMessage{ - Text: messages, + Text: messages2, }, }, } @@ -298,18 +298,19 @@ func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]M for _, msg := range msgs { var message MessageContent if msg.Type == proto.UserToModelMsgType { - message.Role = "user" - message.Content = msg.Msg + message.Role = proto.SparkRoleUser + } else if msg.Type == proto.ModelToUserMsgType { + message.Role = proto.SparkRoleAssistant } else { - message.Role = "assistant" - message.Content = msg.Msg + continue } + message.Content = msg.Msg *messages = append(*messages, message) } - //添加本次请求消息 - var message MessageContent - message.Role = "user" - message.Content = prompt - *messages = append(*messages, message) + //添加本次请求消息(上面获取到的已包含本次请求消息) + //var message MessageContent + //message.Role = "user" + //message.Content = prompt + //*messages = append(*messages, message) return nil }