256 lines
7.5 KiB
Go
256 lines
7.5 KiB
Go
package service
|
|
|
|
import (
|
|
"StuAcaWorksAI/dao"
|
|
"StuAcaWorksAI/proto"
|
|
"StuAcaWorksAI/worker"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils"
|
|
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
|
"io"
|
|
"log"
|
|
"time"
|
|
)
|
|
|
|
func doubao(modelParam proto.ModelParam, question, channel string, SessionID, UserID int, modelID int) {
|
|
doubaoClient := arkruntime.NewClientWithApiKey(
|
|
modelParam.APIKey,
|
|
arkruntime.WithBaseUrl(modelParam.Url),
|
|
)
|
|
ctx := context.Background()
|
|
messages := make([]*model.ChatCompletionMessage, 0)
|
|
err := GetDouBaoSessionHistoryMsg(SessionID, modelParam.System, &messages, question)
|
|
if err != nil {
|
|
log.Println("get doubao session history message error:", err)
|
|
}
|
|
//messagesStr, _ := json.Marshal(messages)
|
|
//log.Println("doubao messages:", string(messagesStr))
|
|
|
|
req := model.CreateChatCompletionRequest{
|
|
Model: modelParam.Model,
|
|
Messages: messages,
|
|
StreamOptions: &model.StreamOptions{
|
|
IncludeUsage: true,
|
|
},
|
|
}
|
|
|
|
stream, err := doubaoClient.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
fmt.Printf("stream chat error: %v\n", err)
|
|
return
|
|
}
|
|
defer func(stream *utils.ChatCompletionStreamReader) {
|
|
err2 := stream.Close()
|
|
if err2 != nil {
|
|
log.Println("close stream error:", err2)
|
|
}
|
|
}(stream)
|
|
|
|
answer := ""
|
|
for {
|
|
recv, err3 := stream.Recv()
|
|
if err3 == io.EOF {
|
|
log.Println("doubao stream end:", err3)
|
|
return
|
|
}
|
|
if err3 != nil {
|
|
fmt.Printf("doubao Stream chat error: %v\n", err)
|
|
return
|
|
}
|
|
go CreateAIStreamMsg(UserID, modelID, SessionID, recv, modelParam.Model)
|
|
doubaoToGeneralMassageAndSendMsgQueue(&recv, channel, SessionID, UserID)
|
|
choices := recv.Choices[0]
|
|
if choices.FinishReason == proto.FinishReasonStop {
|
|
answer += choices.Delta.Content
|
|
break
|
|
} else {
|
|
answer += choices.Delta.Content
|
|
}
|
|
}
|
|
//将消息存入数据库
|
|
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2)
|
|
if err != nil {
|
|
log.Println("spark create message error:", err)
|
|
}
|
|
}
|
|
|
|
func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
|
|
doubaoClient := arkruntime.NewClientWithApiKey(
|
|
modelParam.APIKey,
|
|
arkruntime.WithBaseUrl(modelParam.Url),
|
|
)
|
|
ctx := context.Background()
|
|
messages := make([]*model.ChatCompletionMessage, 0)
|
|
err := GetDouBaoSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.Question)
|
|
if err != nil {
|
|
log.Println("get doubao session history message error:", err)
|
|
}
|
|
messagesStr, _ := json.Marshal(messages)
|
|
log.Println("doubao messages:", string(messagesStr))
|
|
|
|
req := model.CreateChatCompletionRequest{
|
|
Model: modelParam.Model,
|
|
Messages: messages,
|
|
StreamOptions: &model.StreamOptions{
|
|
IncludeUsage: true,
|
|
},
|
|
}
|
|
|
|
stream, err := doubaoClient.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
fmt.Printf("stream chat error: %v\n", err)
|
|
return
|
|
}
|
|
defer func(stream *utils.ChatCompletionStreamReader) {
|
|
err2 := stream.Close()
|
|
if err2 != nil {
|
|
log.Println("close stream error:", err2)
|
|
}
|
|
}(stream)
|
|
|
|
answer := ""
|
|
tokens := 0
|
|
for {
|
|
recv, err3 := stream.Recv()
|
|
if err3 == io.EOF {
|
|
log.Println("doubao stream end:", err3)
|
|
return
|
|
}
|
|
if err3 != nil {
|
|
fmt.Printf("doubao Stream chat error: %v\n", err)
|
|
return
|
|
}
|
|
aiStream_id, err3 := CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, recv, modelParam.Model)
|
|
if err3 != nil {
|
|
log.Println("create ai stream message error:", err3)
|
|
}
|
|
doubaoToGeneralMassageAndSendMsgQueue(&recv, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
|
|
if recv.Choices == nil {
|
|
log.Println("doubao stream recv choices is nil:", recv, "\t aiStream_id:", aiStream_id)
|
|
continue
|
|
}
|
|
choices := recv.Choices[0]
|
|
if choices.FinishReason == proto.FinishReasonStop {
|
|
answer += choices.Delta.Content
|
|
if recv.Usage != nil {
|
|
tokens += recv.Usage.TotalTokens
|
|
}
|
|
break
|
|
} else {
|
|
answer += choices.Delta.Content
|
|
}
|
|
}
|
|
//将消息存入数据库
|
|
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID)
|
|
if err != nil {
|
|
log.Println("spark create message error:", err)
|
|
}
|
|
//将本次请求的tokens存入数据库
|
|
err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
|
|
if err2 != nil {
|
|
log.Println("create model token error:", err2)
|
|
}
|
|
|
|
}
|
|
|
|
func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamResponse, channel string, SessionID, UserID int) {
|
|
var aiMsg proto.AIQueueMessage
|
|
var wsMsg proto.WSMessage
|
|
var apiMsg api.GenerateResponse
|
|
|
|
apiMsg.Model = proto.ModelTypeDouBao
|
|
choices := data.Choices[0]
|
|
apiMsg.Response = choices.Delta.Content
|
|
if choices.FinishReason == proto.FinishReasonStop {
|
|
apiMsg.Done = true
|
|
} else {
|
|
apiMsg.Done = false
|
|
}
|
|
//消息队列部分
|
|
aiMsg.Type = proto.ModelTypeDouBao
|
|
aiMsg.Msg = apiMsg
|
|
//ws发送消息部分
|
|
wsMsg.Msg = aiMsg
|
|
wsMsg.SessionID = SessionID
|
|
wsMsg.ToID = UserID
|
|
wsMsg.Type = proto.ModelTypeDouBao
|
|
|
|
//发送消息
|
|
wsMsgStr, _ := json.Marshal(wsMsg)
|
|
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
|
}
|
|
|
|
// 当前消息包含在messages中
|
|
func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, prompt string) error {
|
|
var err error
|
|
//添加系统
|
|
if systemPrompt != "" {
|
|
var message model.ChatCompletionMessage
|
|
message.Role = model.ChatMessageRoleSystem
|
|
message.Content = &model.ChatCompletionMessageContent{
|
|
StringValue: volcengine.String(systemPrompt),
|
|
}
|
|
*messages = append(*messages, &message)
|
|
}
|
|
//获取该会话的历史消息
|
|
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
|
if len(msgs) > proto.DouBaoContextLength {
|
|
//只取最后的 SparkContextLength 条消息
|
|
msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:]
|
|
}
|
|
for _, v := range msgs {
|
|
var message model.ChatCompletionMessage
|
|
if v.Type == proto.UserToModelMsgType {
|
|
message.Role = model.ChatMessageRoleUser
|
|
//用户消息为有图片类型
|
|
if v.Status == proto.UserToModelImageMsgType {
|
|
var imgMsg proto.UserImageMsg
|
|
err2 := json.Unmarshal([]byte(v.Msg), &imgMsg)
|
|
if err2 != nil {
|
|
log.Println("unmarshal user image message error:", err2)
|
|
return err2
|
|
}
|
|
//用户消息为图片类型
|
|
listValue := []*model.ChatCompletionMessageContentPart{}
|
|
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
|
|
Type: model.ChatCompletionMessageContentPartTypeText,
|
|
Text: imgMsg.Text,
|
|
})
|
|
//图片链接是可以是多个
|
|
for _, img := range imgMsg.ImageContent {
|
|
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
|
|
Type: model.ChatCompletionMessageContentPartTypeImageURL,
|
|
ImageURL: &model.ChatMessageImageURL{
|
|
URL: img.ImgUrl,
|
|
},
|
|
})
|
|
}
|
|
message.Content = &model.ChatCompletionMessageContent{
|
|
ListValue: listValue,
|
|
}
|
|
} else {
|
|
//用户消息为文本类型
|
|
message.Content = &model.ChatCompletionMessageContent{
|
|
StringValue: volcengine.String(v.Msg),
|
|
}
|
|
}
|
|
} else if v.Type == proto.ModelToUserMsgType {
|
|
message.Role = model.ChatMessageRoleAssistant
|
|
message.Content = &model.ChatCompletionMessageContent{
|
|
StringValue: volcengine.String(v.Msg),
|
|
}
|
|
} else {
|
|
continue
|
|
}
|
|
|
|
*messages = append(*messages, &message)
|
|
}
|
|
//添加本次请求消息(本次消息已在上面添加)
|
|
return err
|
|
}
|