From 5e8be228cfedcca297c9e91ff4b835e1d2c6ae62 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Wed, 26 Mar 2025 14:07:27 +0800 Subject: [PATCH] =?UTF-8?q?spark=E3=80=81=E8=B1=86=E5=8C=85=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E6=B5=8B=E8=AF=95=E5=8E=86=E5=8F=B2=E6=B6=88=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- proto/status.go | 7 +++++++ service/doubao.go | 3 ++- service/spark.go | 41 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/proto/status.go b/proto/status.go index c87ef48..2567129 100644 --- a/proto/status.go +++ b/proto/status.go @@ -110,3 +110,10 @@ const ( FinishReasonContentFilter = "content_filter" FinishReasonNull = "null" ) + +// spark 角色 +const ( + SparkRoleUser = "user" + SparkRoleAssistant = "assistant" + SparRoleSystem = "system" +) diff --git a/service/doubao.go b/service/doubao.go index 5214aa5..b6802f1 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -28,7 +28,8 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us if err != nil { log.Println("get doubao session history message error:", err) } - log.Println("doubao messages:", messages) + messagesStr, _ := json.Marshal(messages) + log.Println("doubao messages:", string(messagesStr)) req := model.CreateChatCompletionRequest{ Model: modelParam.Model, diff --git a/service/spark.go b/service/spark.go index cc4862b..ec89783 100644 --- a/service/spark.go +++ b/service/spark.go @@ -124,7 +124,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use //发送消息 go func() { - data := genSparkParams(question, modelParam.APPID, modelParam.Domain) + data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, "") //将数据转换为json //dataByte, err := json.Marshal(data) if err != nil { @@ -202,10 +202,17 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S } // 生成参数 -func genSparkParams(question, appid string, domain string) SparkRequest { +func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest { messages := []MessageContent{ {Role: "user", Content: question}, } + messages2 := make([]MessageContent, 0) + err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question) + if err != nil { + log.Println("get spark session history message error:", err) + } + messages2Str, _ := json.Marshal(messages2) + log.Println("spark messages2:", string(messages2Str)) data := SparkRequest{ Header: SparkRequestHeader{ @@ -276,3 +283,33 @@ func readResp(resp *http.Response) string { } return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) } + +// 获取spark模型会话的历史消息 +func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]MessageContent, prompt string) error { + //获取会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + //添加系统消息 + if systemPrompt != "" { + var message MessageContent + message.Role = "system" + message.Content = systemPrompt + *messages = append(*messages, message) + } + for _, msg := range msgs { + var message MessageContent + if msg.Type == proto.UserToModelMsgType { + message.Role = "user" + message.Content = msg.Msg + } else { + message.Role = "assistant" + message.Content = msg.Msg + } + *messages = append(*messages, message) + } + //添加本次请求消息 + var message MessageContent + message.Role = "user" + message.Content = prompt + *messages = append(*messages, message) + return nil +}