spark、豆包模型完成历史消息
This commit is contained in:
parent
5e8be228cf
commit
3433a964b0
|
|
@ -23,30 +23,17 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
|
||||||
arkruntime.WithBaseUrl(modelParam.Url),
|
arkruntime.WithBaseUrl(modelParam.Url),
|
||||||
)
|
)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
messages := make([]model.ChatCompletionMessage, 0)
|
messages := make([]*model.ChatCompletionMessage, 0)
|
||||||
err := GetDouBaoSessionHistoryMsg(SessionID, "", &messages, question)
|
err := GetDouBaoSessionHistoryMsg(SessionID, "你是豆包,是由字节跳动开发的 AI 人工智能助手", &messages, question)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("get doubao session history message error:", err)
|
log.Println("get doubao session history message error:", err)
|
||||||
}
|
}
|
||||||
messagesStr, _ := json.Marshal(messages)
|
//messagesStr, _ := json.Marshal(messages)
|
||||||
log.Println("doubao messages:", string(messagesStr))
|
//log.Println("doubao messages:", string(messagesStr))
|
||||||
|
|
||||||
req := model.CreateChatCompletionRequest{
|
req := model.CreateChatCompletionRequest{
|
||||||
Model: modelParam.Model,
|
Model: modelParam.Model,
|
||||||
Messages: []*model.ChatCompletionMessage{
|
Messages: messages,
|
||||||
{
|
|
||||||
Role: model.ChatMessageRoleSystem,
|
|
||||||
Content: &model.ChatCompletionMessageContent{
|
|
||||||
StringValue: volcengine.String("你是豆包,是由字节跳动开发的 AI 人工智能助手"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: model.ChatMessageRoleUser,
|
|
||||||
Content: &model.ChatCompletionMessageContent{
|
|
||||||
StringValue: volcengine.String(question),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := doubaoClient.CreateChatCompletionStream(ctx, req)
|
stream, err := doubaoClient.CreateChatCompletionStream(ctx, req)
|
||||||
|
|
@ -116,7 +103,7 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo
|
||||||
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]model.ChatCompletionMessage, prompt string) error {
|
func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, prompt string) error {
|
||||||
var err error
|
var err error
|
||||||
//添加系统
|
//添加系统
|
||||||
if systemPrompt != "" {
|
if systemPrompt != "" {
|
||||||
|
|
@ -125,7 +112,7 @@ func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]
|
||||||
message.Content = &model.ChatCompletionMessageContent{
|
message.Content = &model.ChatCompletionMessageContent{
|
||||||
StringValue: volcengine.String(systemPrompt),
|
StringValue: volcengine.String(systemPrompt),
|
||||||
}
|
}
|
||||||
*messages = append(*messages, message)
|
*messages = append(*messages, &message)
|
||||||
}
|
}
|
||||||
//获取该会话的历史消息
|
//获取该会话的历史消息
|
||||||
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
||||||
|
|
@ -133,19 +120,16 @@ func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]
|
||||||
var message model.ChatCompletionMessage
|
var message model.ChatCompletionMessage
|
||||||
if v.Type == proto.UserToModelMsgType {
|
if v.Type == proto.UserToModelMsgType {
|
||||||
message.Role = model.ChatMessageRoleUser
|
message.Role = model.ChatMessageRoleUser
|
||||||
} else {
|
} else if v.Type == proto.ModelToUserMsgType {
|
||||||
message.Role = model.ChatMessageRoleAssistant
|
message.Role = model.ChatMessageRoleAssistant
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
message.Content = &model.ChatCompletionMessageContent{
|
message.Content = &model.ChatCompletionMessageContent{
|
||||||
StringValue: volcengine.String(v.Msg),
|
StringValue: volcengine.String(v.Msg),
|
||||||
}
|
}
|
||||||
*messages = append(*messages, message)
|
*messages = append(*messages, &message)
|
||||||
}
|
|
||||||
//添加本次请求消息
|
|
||||||
var message model.ChatCompletionMessage
|
|
||||||
message.Role = model.ChatMessageRoleUser
|
|
||||||
message.Content = &model.ChatCompletionMessageContent{
|
|
||||||
StringValue: volcengine.String(prompt),
|
|
||||||
}
|
}
|
||||||
|
//添加本次请求消息(本次消息已在上面添加)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -203,16 +203,16 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S
|
||||||
|
|
||||||
// 生成参数
|
// 生成参数
|
||||||
func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest {
|
func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest {
|
||||||
messages := []MessageContent{
|
//messages := []MessageContent{
|
||||||
{Role: "user", Content: question},
|
// {Role: "user", Content: question},
|
||||||
}
|
//}
|
||||||
messages2 := make([]MessageContent, 0)
|
messages2 := make([]MessageContent, 0)
|
||||||
err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
|
err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("get spark session history message error:", err)
|
log.Println("get spark session history message error:", err)
|
||||||
}
|
}
|
||||||
messages2Str, _ := json.Marshal(messages2)
|
//messages2Str, _ := json.Marshal(messages2)
|
||||||
log.Println("spark messages2:", string(messages2Str))
|
//log.Println("spark messages2:", string(messages2Str))
|
||||||
|
|
||||||
data := SparkRequest{
|
data := SparkRequest{
|
||||||
Header: SparkRequestHeader{
|
Header: SparkRequestHeader{
|
||||||
|
|
@ -227,7 +227,7 @@ func genSparkParams(question, appid string, domain string, sessionID int, system
|
||||||
},
|
},
|
||||||
Payload: SparkRequestPayload{
|
Payload: SparkRequestPayload{
|
||||||
Message: SparkMessage{
|
Message: SparkMessage{
|
||||||
Text: messages,
|
Text: messages2,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -298,18 +298,19 @@ func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]M
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
var message MessageContent
|
var message MessageContent
|
||||||
if msg.Type == proto.UserToModelMsgType {
|
if msg.Type == proto.UserToModelMsgType {
|
||||||
message.Role = "user"
|
message.Role = proto.SparkRoleUser
|
||||||
message.Content = msg.Msg
|
} else if msg.Type == proto.ModelToUserMsgType {
|
||||||
|
message.Role = proto.SparkRoleAssistant
|
||||||
} else {
|
} else {
|
||||||
message.Role = "assistant"
|
continue
|
||||||
message.Content = msg.Msg
|
|
||||||
}
|
}
|
||||||
|
message.Content = msg.Msg
|
||||||
*messages = append(*messages, message)
|
*messages = append(*messages, message)
|
||||||
}
|
}
|
||||||
//添加本次请求消息
|
//添加本次请求消息(上面获取到的已包含本次请求消息)
|
||||||
var message MessageContent
|
//var message MessageContent
|
||||||
message.Role = "user"
|
//message.Role = "user"
|
||||||
message.Content = prompt
|
//message.Content = prompt
|
||||||
*messages = append(*messages, message)
|
//*messages = append(*messages, message)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue