saw-go/service/qianfan.go

149 lines
4.4 KiB
Go
Raw Permalink Normal View History

2025-05-10 14:05:22 +08:00
package service
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"fmt"
"github.com/baidubce/bce-qianfan-sdk/go/qianfan"
"github.com/ollama/ollama/api"
"log"
"time"
)
func Qianfan(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
2025-05-10 14:12:12 +08:00
log.Println("qianfan model param:", modelParam)
log.Println("qianfan question:", imCtx.Question)
2025-05-10 15:27:56 +08:00
if modelParam.AccessKey == "" || modelParam.SecretKey == "" {
log.Println("qianfan AccessKey or SecretKey is empty")
2025-05-10 14:12:12 +08:00
return
}
qianfan.GetConfig().AccessKey = modelParam.AccessKey
2025-05-10 15:27:56 +08:00
qianfan.GetConfig().SecretKey = modelParam.SecretKey
2025-05-10 14:05:22 +08:00
chat := qianfan.NewChatCompletion(qianfan.WithModel(modelParam.Model))
var messages []qianfan.ChatCompletionMessage
err := GetQianfanHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
if err != nil {
log.Println("GetQianfanHistoryMessages error:", err)
return
}
if proto.Config.LOG_OUTPUT {
messagesBytes, _ := json.Marshal(messages)
log.Println("qianfan messages:", string(messagesBytes))
}
2025-05-10 14:05:22 +08:00
answer := ""
tokens := 0
resp, err2 := chat.Stream( // Stream 启用流式返回
context.TODO(),
&qianfan.ChatCompletionRequest{
Messages: messages,
},
)
if err2 != nil {
fmt.Println("qianfan stream err2:", err2)
return
}
for {
r, err3 := resp.Recv()
if err3 != nil {
fmt.Print("get qianfan stream err3:", err3)
break
}
if resp.IsEnd { // 判断是否结束
break
}
answer += r.Result
tokens = r.Usage.TotalTokens
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, r, modelParam.Model)
QianfanToGeneralMassageAndSendMsgQueue(r, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
}
//将消息存入数据库
err2, msgID := dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID)
if err2 != nil {
log.Println("qianfan create message error:", err2)
} else {
log.Println("qianfan create message success,message id:", msgID)
}
//存入tokens消耗信息
err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
if err3 != nil {
log.Println("qianfan create message error:", err3)
} else {
log.Println("qianfan create model tokens id:", modelTokenID)
}
}
func GetQianfanHistoryMessages(sessionID int, systemPrompt string, messages *[]qianfan.ChatCompletionMessage, userID int) error {
//获取会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.SparkContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
}
//添加系统消息,千帆不支持系统参数
//if systemPrompt != "" {
// var message qianfan.ChatCompletionMessage
// message.Role = "system"
// message.Content = systemPrompt
// *messages = append(*messages, message)
//}
2025-05-10 14:05:22 +08:00
for _, msg := range msgs {
var message qianfan.ChatCompletionMessage
if msg.Type == proto.UserToModelMsgType {
message.Role = proto.QianfanRoleUser
if msg.Status == proto.UserToModelFileMsgType {
//err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
//if err2 != nil {
// log.Println("spark decode file message error:", err2)
//}
} else {
message.Content = msg.Msg
*messages = append(*messages, message)
}
} else if msg.Type == proto.ModelToUserMsgType {
message.Role = proto.QianfanAssistant
message.Content = msg.Msg
*messages = append(*messages, message)
} else {
continue
}
}
//添加本次请求消息(上面获取到的已包含本次请求消息)
//var message MessageContent
//message.Role = "user"
//message.Content = prompt
//*messages = append(*messages, message)
return nil
}
func QianfanToGeneralMassageAndSendMsgQueue(data *qianfan.ModelResponse, channel string, SessionID int, UserID int) {
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
apiMsg.Model = proto.ModelTypeSpark
apiMsg.Response = data.Result
if data.IsEnd == true {
apiMsg.Done = true
} else {
apiMsg.Done = false
}
//消息队列部分
aiMsg.Type = proto.ModelTypeSpark
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
wsMsg.Type = proto.ModelTypeSpark
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
worker.Publish(channel, string(wsMsgStr), time.Second*60)
}