spark、豆包模型测试历史消息
This commit is contained in:
parent
370caca0ae
commit
5e8be228cf
|
|
@ -110,3 +110,10 @@ const (
|
|||
FinishReasonContentFilter = "content_filter"
|
||||
FinishReasonNull = "null"
|
||||
)
|
||||
|
||||
// spark 角色
|
||||
const (
|
||||
SparkRoleUser = "user"
|
||||
SparkRoleAssistant = "assistant"
|
||||
SparRoleSystem = "system"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
|
|||
if err != nil {
|
||||
log.Println("get doubao session history message error:", err)
|
||||
}
|
||||
log.Println("doubao messages:", messages)
|
||||
messagesStr, _ := json.Marshal(messages)
|
||||
log.Println("doubao messages:", string(messagesStr))
|
||||
|
||||
req := model.CreateChatCompletionRequest{
|
||||
Model: modelParam.Model,
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
|
|||
|
||||
//发送消息
|
||||
go func() {
|
||||
data := genSparkParams(question, modelParam.APPID, modelParam.Domain)
|
||||
data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, "")
|
||||
//将数据转换为json
|
||||
//dataByte, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
|
@ -202,10 +202,17 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S
|
|||
}
|
||||
|
||||
// 生成参数
|
||||
func genSparkParams(question, appid string, domain string) SparkRequest {
|
||||
func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest {
|
||||
messages := []MessageContent{
|
||||
{Role: "user", Content: question},
|
||||
}
|
||||
messages2 := make([]MessageContent, 0)
|
||||
err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
|
||||
if err != nil {
|
||||
log.Println("get spark session history message error:", err)
|
||||
}
|
||||
messages2Str, _ := json.Marshal(messages2)
|
||||
log.Println("spark messages2:", string(messages2Str))
|
||||
|
||||
data := SparkRequest{
|
||||
Header: SparkRequestHeader{
|
||||
|
|
@ -276,3 +283,33 @@ func readResp(resp *http.Response) string {
|
|||
}
|
||||
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
|
||||
}
|
||||
|
||||
// 获取spark模型会话的历史消息
|
||||
func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]MessageContent, prompt string) error {
|
||||
//获取会话的历史消息
|
||||
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
||||
//添加系统消息
|
||||
if systemPrompt != "" {
|
||||
var message MessageContent
|
||||
message.Role = "system"
|
||||
message.Content = systemPrompt
|
||||
*messages = append(*messages, message)
|
||||
}
|
||||
for _, msg := range msgs {
|
||||
var message MessageContent
|
||||
if msg.Type == proto.UserToModelMsgType {
|
||||
message.Role = "user"
|
||||
message.Content = msg.Msg
|
||||
} else {
|
||||
message.Role = "assistant"
|
||||
message.Content = msg.Msg
|
||||
}
|
||||
*messages = append(*messages, message)
|
||||
}
|
||||
//添加本次请求消息
|
||||
var message MessageContent
|
||||
message.Role = "user"
|
||||
message.Content = prompt
|
||||
*messages = append(*messages, message)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue