saw-go/service/doubao.go

350 lines
11 KiB
Go

package service
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"fmt"
"github.com/ollama/ollama/api"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils"
"github.com/volcengine/volcengine-go-sdk/volcengine"
"io"
"log"
"time"
)
func doubao(modelParam proto.ModelParam, question, channel string, SessionID, UserID int, modelID int) {
doubaoClient := arkruntime.NewClientWithApiKey(
modelParam.APIKey,
arkruntime.WithBaseUrl(modelParam.Url),
)
ctx := context.Background()
messages := make([]*model.ChatCompletionMessage, 0)
err := GetDouBaoSessionHistoryMsg(SessionID, modelParam.System, &messages, question)
if err != nil {
log.Println("get doubao session history message error:", err)
}
//messagesStr, _ := json.Marshal(messages)
//log.Println("doubao messages:", string(messagesStr))
req := model.CreateChatCompletionRequest{
Model: modelParam.Model,
Messages: messages,
StreamOptions: &model.StreamOptions{
IncludeUsage: true,
},
}
stream, err := doubaoClient.CreateChatCompletionStream(ctx, req)
if err != nil {
fmt.Printf("stream chat error: %v\n", err)
return
}
defer func(stream *utils.ChatCompletionStreamReader) {
err2 := stream.Close()
if err2 != nil {
log.Println("close stream error:", err2)
}
}(stream)
answer := ""
for {
recv, err3 := stream.Recv()
if err3 == io.EOF {
log.Println("doubao stream end:", err3)
return
}
if err3 != nil {
fmt.Printf("doubao Stream chat error: %v\n", err)
return
}
go CreateAIStreamMsg(UserID, modelID, SessionID, recv, modelParam.Model)
doubaoToGeneralMassageAndSendMsgQueue(&recv, channel, SessionID, UserID)
choices := recv.Choices[0]
if choices.FinishReason == proto.FinishReasonStop {
answer += choices.Delta.Content
break
} else {
answer += choices.Delta.Content
}
}
//将消息存入数据库
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2)
if err != nil {
log.Println("spark create message error:", err)
}
}
func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
doubaoClient := arkruntime.NewClientWithApiKey(
modelParam.APIKey,
arkruntime.WithBaseUrl(modelParam.Url),
)
ctx := context.Background()
messages := make([]*model.ChatCompletionMessage, 0)
err := GetDouBaoSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.Question)
if err != nil {
log.Println("get doubao session history message error:", err)
}
req := model.CreateChatCompletionRequest{
Model: modelParam.Model,
Messages: messages,
StreamOptions: &model.StreamOptions{
IncludeUsage: true,
},
}
messagesStr, _ := json.Marshal(req)
log.Println("doubao req messages:", string(messagesStr))
stream, err := doubaoClient.CreateChatCompletionStream(ctx, req)
if err != nil {
fmt.Printf("stream chat error: %v\n", err)
return
}
defer func(stream *utils.ChatCompletionStreamReader) {
err2 := stream.Close()
if err2 != nil {
log.Println("close stream error:", err2)
}
}(stream)
answer := ""
tokens := 0
for {
recv, err3 := stream.Recv()
if err3 == io.EOF {
log.Println("doubao stream end:", err3)
return
}
if err3 != nil {
fmt.Printf("doubao Stream chat error: %v\n", err)
return
}
aistreamId, err3 := CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, recv, modelParam.Model)
if err3 != nil {
log.Println("create ai stream message error:", err3)
}
doubaoToGeneralMassageAndSendMsgQueue(&recv, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
if recv.Choices == nil {
log.Println("doubao stream recv choices is nil:", recv, "\t aiStream_id:", aistreamId)
continue
}
choices := recv.Choices[0]
if choices.FinishReason == proto.FinishReasonStop {
answer += choices.Delta.Content
if recv.Usage != nil {
tokens += recv.Usage.TotalTokens
}
break
} else {
answer += choices.Delta.Content
}
}
//将消息存入数据库
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID)
if err != nil {
log.Println("spark create message error:", err)
}
//将本次请求的tokens存入数据库
err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
if err2 != nil {
log.Println("create model token error:", err2)
}
}
func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamResponse, channel string, SessionID, UserID int) {
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
apiMsg.Model = proto.ModelTypeDouBao
choices := data.Choices[0]
apiMsg.Response = choices.Delta.Content
if choices.FinishReason == proto.FinishReasonStop {
apiMsg.Done = true
} else {
apiMsg.Done = false
}
//消息队列部分
aiMsg.Type = proto.ModelTypeDouBao
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
wsMsg.Type = proto.ModelTypeDouBao
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
worker.Publish(channel, string(wsMsgStr), time.Second*60)
}
// 当前消息包含在messages中
func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, prompt string) error {
var err error
//添加系统
if systemPrompt != "" {
var message model.ChatCompletionMessage
message.Role = model.ChatMessageRoleSystem
message.Content = &model.ChatCompletionMessageContent{
StringValue: volcengine.String(systemPrompt),
}
*messages = append(*messages, &message)
}
//获取该会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.DouBaoContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:]
}
for _, v := range msgs {
var message model.ChatCompletionMessage
if v.Type == proto.UserToModelMsgType {
message.Role = model.ChatMessageRoleUser
//用户消息为有图片类型
if v.Status == proto.UserToModelImageMsgType {
var imgMsg proto.UserImageMsg
err2 := json.Unmarshal([]byte(v.Msg), &imgMsg)
log.Println("user image message:", imgMsg)
if err2 != nil {
log.Println("unmarshal user image message error:", err2)
return err2
}
//用户消息为图片类型
var listValue []*model.ChatCompletionMessageContentPart
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: imgMsg.Text,
})
//图片链接是可以是多个
for _, img := range imgMsg.ImageContent {
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeImageURL,
ImageURL: &model.ChatMessageImageURL{
URL: img.ImgUrl,
},
})
}
message.Content = &model.ChatCompletionMessageContent{
ListValue: listValue,
}
} else {
//用户消息为文本类型
message.Content = &model.ChatCompletionMessageContent{
StringValue: volcengine.String(v.Msg),
}
}
} else if v.Type == proto.ModelToUserMsgType {
message.Role = model.ChatMessageRoleAssistant
message.Content = &model.ChatCompletionMessageContent{
StringValue: volcengine.String(v.Msg),
}
} else {
continue
}
*messages = append(*messages, &message)
}
//添加本次请求消息(本次消息已在上面添加)
return err
}
// 当前消息包含在messages中
func GetDouBaoSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, userID int) error {
var err error
//添加系统
if systemPrompt != "" {
var message model.ChatCompletionMessage
message.Role = model.ChatMessageRoleSystem
message.Content = &model.ChatCompletionMessageContent{
StringValue: volcengine.String(systemPrompt),
}
*messages = append(*messages, &message)
}
//获取该会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.DouBaoContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:]
}
for _, v := range msgs {
var message model.ChatCompletionMessage
if v.Type == proto.UserToModelMsgType {
message.Role = model.ChatMessageRoleUser
//用户消息为有图片类型
if v.Status == proto.UserToModelFileMsgType {
listValue, err2 := decodeFileMessageToDouBaoMessage(v.Msg, userID)
if err2 != nil {
log.Println("unmarshal user image message error:", err2)
return err2
}
message.Content = &model.ChatCompletionMessageContent{
ListValue: listValue,
}
} else {
//用户消息为文本类型
message.Content = &model.ChatCompletionMessageContent{
StringValue: volcengine.String(v.Msg),
}
}
} else if v.Type == proto.ModelToUserMsgType {
message.Role = model.ChatMessageRoleAssistant
message.Content = &model.ChatCompletionMessageContent{
StringValue: volcengine.String(v.Msg),
}
} else {
continue
}
*messages = append(*messages, &message)
}
//添加本次请求消息(本次消息已在上面添加)
return err
}
func decodeFileMessageToDouBaoMessage(msg string, userID int) ([]*model.ChatCompletionMessageContentPart, error) {
var fileMsgContent proto.UserFileMessage
err := json.Unmarshal([]byte(msg), &fileMsgContent)
if err != nil {
log.Println("DouBao unmarshal file message error:", err)
return nil, err
}
var listValue []*model.ChatCompletionMessageContentPart
for _, file := range fileMsgContent.FileContent {
//文件消息为文本类型,则获取文本内容,若为图片则添加图片链接
if file.FileType == proto.FileTypeText {
fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID)
if err2 != nil {
log.Println("doubao find file content error:", err2)
return nil, err2
}
fileContent := fileContents[0].FileContent
//文件文本内容
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: "文件" + file.FileContent.UserFileName + ",内容:" + fileContent,
})
} else if file.FileType == proto.FileTypeImage {
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeImageURL,
ImageURL: &model.ChatMessageImageURL{
URL: proto.DownloadFileBaseURL + file.FileContent.FileStoreName,
},
})
} else {
log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID)
}
}
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: fileMsgContent.Text,
})
return listValue, nil
}