diff --git a/proto/status.go b/proto/status.go index c87ef48..2567129 100644 --- a/proto/status.go +++ b/proto/status.go @@ -110,3 +110,10 @@ const ( FinishReasonContentFilter = "content_filter" FinishReasonNull = "null" ) + +// spark 角色 +const ( + SparkRoleUser = "user" + SparkRoleAssistant = "assistant" + SparRoleSystem = "system" +) diff --git a/service/doubao.go b/service/doubao.go index 5214aa5..b6802f1 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -28,7 +28,8 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us if err != nil { log.Println("get doubao session history message error:", err) } - log.Println("doubao messages:", messages) + messagesStr, _ := json.Marshal(messages) + log.Println("doubao messages:", string(messagesStr)) req := model.CreateChatCompletionRequest{ Model: modelParam.Model, diff --git a/service/spark.go b/service/spark.go index cc4862b..ec89783 100644 --- a/service/spark.go +++ b/service/spark.go @@ -124,7 +124,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use //发送消息 go func() { - data := genSparkParams(question, modelParam.APPID, modelParam.Domain) + data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, "") //将数据转换为json //dataByte, err := json.Marshal(data) if err != nil { @@ -202,10 +202,17 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S } // 生成参数 -func genSparkParams(question, appid string, domain string) SparkRequest { +func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest { messages := []MessageContent{ {Role: "user", Content: question}, } + messages2 := make([]MessageContent, 0) + err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question) + if err != nil { + log.Println("get spark session history message error:", err) + } + messages2Str, _ := json.Marshal(messages2) + log.Println("spark messages2:", string(messages2Str)) data := SparkRequest{ Header: SparkRequestHeader{ @@ -276,3 +283,33 @@ func readResp(resp *http.Response) string { } return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) } + +// 获取spark模型会话的历史消息 +func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]MessageContent, prompt string) error { + //获取会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + //添加系统消息 + if systemPrompt != "" { + var message MessageContent + message.Role = "system" + message.Content = systemPrompt + *messages = append(*messages, message) + } + for _, msg := range msgs { + var message MessageContent + if msg.Type == proto.UserToModelMsgType { + message.Role = "user" + message.Content = msg.Msg + } else { + message.Role = "assistant" + message.Content = msg.Msg + } + *messages = append(*messages, message) + } + //添加本次请求消息 + var message MessageContent + message.Role = "user" + message.Content = prompt + *messages = append(*messages, message) + return nil +}