2025-05-11 20:31:39 +08:00
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"StuAcaWorksAI/dao"
|
|
|
|
|
"StuAcaWorksAI/proto"
|
|
|
|
|
"StuAcaWorksAI/worker"
|
|
|
|
|
"context"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
|
|
"github.com/openai/openai-go"
|
|
|
|
|
"github.com/openai/openai-go/option"
|
|
|
|
|
"log"
|
|
|
|
|
"time"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// 通义模型
|
|
|
|
|
func Tongyi(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
|
|
|
|
|
tongyiApiKey := modelParam.APIKey
|
|
|
|
|
tongyiClient := openai.NewClient(
|
|
|
|
|
option.WithAPIKey(tongyiApiKey),
|
|
|
|
|
option.WithBaseURL(modelParam.Url),
|
|
|
|
|
)
|
|
|
|
|
var messages []openai.ChatCompletionMessageParamUnion
|
|
|
|
|
err := GetTongyiHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
println("GetTongyiHistoryMessages error:", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if proto.Config.LOG_OUTPUT {
|
|
|
|
|
messagesBytes, _ := json.Marshal(messages)
|
|
|
|
|
println("tongyi messages:", string(messagesBytes))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
stream := tongyiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
|
|
|
|
|
Messages: messages,
|
|
|
|
|
Model: modelParam.Model,
|
|
|
|
|
})
|
|
|
|
|
acc := openai.ChatCompletionAccumulator{}
|
|
|
|
|
answer := ""
|
|
|
|
|
var tokens int64
|
|
|
|
|
var PromptTokens int64 //用户问题token
|
|
|
|
|
|
|
|
|
|
for stream.Next() {
|
|
|
|
|
chunk := stream.Current()
|
|
|
|
|
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model)
|
|
|
|
|
acc.AddChunk(chunk)
|
|
|
|
|
if _, ok := acc.JustFinishedToolCall(); ok {
|
|
|
|
|
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
|
|
|
|
|
}
|
|
|
|
|
if len(chunk.Choices) > 0 {
|
2025-05-12 12:22:31 +08:00
|
|
|
if proto.Config.LOG_OUTPUT {
|
|
|
|
|
log.Printf("tongyi stream msg:%s\n", chunk.Choices[0].Delta.Content)
|
|
|
|
|
}
|
2025-05-11 20:31:39 +08:00
|
|
|
answer += chunk.Choices[0].Delta.Content
|
|
|
|
|
tokens = chunk.Usage.TotalTokens
|
|
|
|
|
PromptTokens = chunk.Usage.PromptTokens
|
|
|
|
|
}
|
|
|
|
|
TongyiToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
|
|
|
|
|
}
|
|
|
|
|
if stream.Err() != nil {
|
|
|
|
|
log.Println("tongyi stream err:", stream.Err())
|
|
|
|
|
}
|
|
|
|
|
//将消息存入数据库
|
2025-06-04 23:36:25 +08:00
|
|
|
DoSaveMessageAndToken(imCtx, answer, "tongyi", tokens, PromptTokens)
|
2025-05-11 20:31:39 +08:00
|
|
|
}
|
|
|
|
|
|
2025-06-04 23:36:25 +08:00
|
|
|
func DoSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) {
|
2025-05-11 20:31:39 +08:00
|
|
|
msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType}
|
|
|
|
|
err2, msgID := dao.CreateMessageV2(msg)
|
|
|
|
|
if err2 != nil {
|
|
|
|
|
log.Printf("%s create message error:%v\n", platform, err2)
|
|
|
|
|
} else {
|
|
|
|
|
log.Printf("%s create message success,message id:%d\n", platform, msgID)
|
|
|
|
|
}
|
|
|
|
|
//保存用户问题消耗token
|
|
|
|
|
err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens))
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("%s user prompt create message token error:%v\n", platform, err)
|
|
|
|
|
} else {
|
|
|
|
|
log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//保存返回的消息token消耗
|
|
|
|
|
err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens))
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("%s create message token error:%v\n", platform, err)
|
|
|
|
|
} else {
|
|
|
|
|
log.Printf("%s create message token id:%d\n", platform, messageTokenID)
|
|
|
|
|
}
|
|
|
|
|
//存入tokens消耗信息
|
|
|
|
|
err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
|
|
|
|
|
if err3 != nil {
|
|
|
|
|
log.Printf("%s create model tokens error:%v\n", platform, err3)
|
|
|
|
|
} else {
|
|
|
|
|
log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func GetTongyiHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error {
|
|
|
|
|
//获取会话的历史消息
|
|
|
|
|
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
|
|
|
|
if len(msgs) > proto.SparkContextLength {
|
|
|
|
|
//只取最后的 SparkContextLength 条消息
|
|
|
|
|
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
|
|
|
|
|
}
|
|
|
|
|
//添加系统消息
|
|
|
|
|
if systemPrompt != "" {
|
|
|
|
|
*messages = append(*messages, openai.SystemMessage(systemPrompt))
|
|
|
|
|
}
|
|
|
|
|
for _, msg := range msgs {
|
|
|
|
|
if msg.Type == proto.UserToModelMsgType {
|
|
|
|
|
if msg.Status == proto.UserToModelFileMsgType {
|
|
|
|
|
//err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
|
|
|
|
|
//if err2 != nil {
|
|
|
|
|
// log.Println("spark decode file message error:", err2)
|
|
|
|
|
//}
|
|
|
|
|
} else {
|
|
|
|
|
*messages = append(*messages, openai.UserMessage(msg.Msg))
|
|
|
|
|
}
|
|
|
|
|
} else if msg.Type == proto.ModelToUserMsgType {
|
|
|
|
|
*messages = append(*messages, openai.AssistantMessage(msg.Msg))
|
|
|
|
|
} else {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TongyiToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) {
|
|
|
|
|
var aiMsg proto.AIQueueMessage
|
|
|
|
|
var wsMsg proto.WSMessage
|
|
|
|
|
var apiMsg api.GenerateResponse
|
|
|
|
|
apiMsg.Model = proto.ModelTypeSpark
|
|
|
|
|
apiMsg.Response = data.Choices[0].Delta.Content
|
2025-06-04 23:36:25 +08:00
|
|
|
if data.Created < 0 || data.Choices[0].FinishReason != "" { //这里由于没有返回结束标志,所以用该字段来判断
|
2025-05-11 20:31:39 +08:00
|
|
|
apiMsg.Done = true
|
|
|
|
|
} else {
|
|
|
|
|
apiMsg.Done = false
|
|
|
|
|
}
|
|
|
|
|
//消息队列部分
|
|
|
|
|
aiMsg.Type = proto.ModelTypeSpark
|
|
|
|
|
aiMsg.Msg = apiMsg
|
|
|
|
|
//ws发送消息部分
|
|
|
|
|
wsMsg.Msg = aiMsg
|
|
|
|
|
wsMsg.SessionID = SessionID
|
|
|
|
|
wsMsg.ToID = UserID
|
|
|
|
|
wsMsg.Type = proto.ModelTypeSpark
|
|
|
|
|
|
|
|
|
|
//发送消息
|
|
|
|
|
wsMsgStr, _ := json.Marshal(wsMsg)
|
|
|
|
|
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
|
|
|
|
}
|