144 lines
4.2 KiB
Go
144 lines
4.2 KiB
Go
package service
|
|
|
|
import (
|
|
"StuAcaWorksAI/dao"
|
|
"StuAcaWorksAI/proto"
|
|
"StuAcaWorksAI/worker"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/baidubce/bce-qianfan-sdk/go/qianfan"
|
|
"github.com/ollama/ollama/api"
|
|
"log"
|
|
"time"
|
|
)
|
|
|
|
func Qianfan(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
|
|
log.Println("qianfan model param:", modelParam)
|
|
log.Println("qianfan question:", imCtx.Question)
|
|
if modelParam.AccessKey == "" || modelParam.SecretKey == "" {
|
|
log.Println("qianfan AccessKey or SecretKey is empty")
|
|
return
|
|
}
|
|
qianfan.GetConfig().AccessKey = modelParam.AccessKey
|
|
qianfan.GetConfig().SecretKey = modelParam.SecretKey
|
|
|
|
chat := qianfan.NewChatCompletion(qianfan.WithModel(modelParam.Model))
|
|
var messages []qianfan.ChatCompletionMessage
|
|
err := GetQianfanHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
|
|
if err != nil {
|
|
log.Println("GetQianfanHistoryMessages error:", err)
|
|
return
|
|
}
|
|
answer := ""
|
|
tokens := 0
|
|
|
|
resp, err2 := chat.Stream( // Stream 启用流式返回
|
|
context.TODO(),
|
|
&qianfan.ChatCompletionRequest{
|
|
Messages: messages,
|
|
},
|
|
)
|
|
if err2 != nil {
|
|
fmt.Println("qianfan stream err2:", err2)
|
|
return
|
|
}
|
|
for {
|
|
r, err3 := resp.Recv()
|
|
if err3 != nil {
|
|
fmt.Print("get qianfan stream err3:", err3)
|
|
break
|
|
}
|
|
if resp.IsEnd { // 判断是否结束
|
|
break
|
|
}
|
|
answer += r.Result
|
|
tokens = r.Usage.TotalTokens
|
|
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, r, modelParam.Model)
|
|
QianfanToGeneralMassageAndSendMsgQueue(r, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
|
|
}
|
|
|
|
//将消息存入数据库
|
|
err2, msgID := dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID)
|
|
if err2 != nil {
|
|
log.Println("qianfan create message error:", err2)
|
|
} else {
|
|
log.Println("qianfan create message success,message id:", msgID)
|
|
}
|
|
//存入tokens消耗信息
|
|
err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
|
|
if err3 != nil {
|
|
log.Println("qianfan create message error:", err3)
|
|
} else {
|
|
log.Println("qianfan create model tokens id:", modelTokenID)
|
|
}
|
|
}
|
|
|
|
func GetQianfanHistoryMessages(sessionID int, systemPrompt string, messages *[]qianfan.ChatCompletionMessage, userID int) error {
|
|
//获取会话的历史消息
|
|
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
|
if len(msgs) > proto.SparkContextLength {
|
|
//只取最后的 SparkContextLength 条消息
|
|
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
|
|
}
|
|
//添加系统消息
|
|
if systemPrompt != "" {
|
|
var message qianfan.ChatCompletionMessage
|
|
message.Role = "system"
|
|
message.Content = systemPrompt
|
|
*messages = append(*messages, message)
|
|
}
|
|
for _, msg := range msgs {
|
|
var message qianfan.ChatCompletionMessage
|
|
if msg.Type == proto.UserToModelMsgType {
|
|
message.Role = proto.QianfanRoleUser
|
|
if msg.Status == proto.UserToModelFileMsgType {
|
|
//err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
|
|
//if err2 != nil {
|
|
// log.Println("spark decode file message error:", err2)
|
|
//}
|
|
} else {
|
|
message.Content = msg.Msg
|
|
*messages = append(*messages, message)
|
|
}
|
|
} else if msg.Type == proto.ModelToUserMsgType {
|
|
message.Role = proto.QianfanAssistant
|
|
message.Content = msg.Msg
|
|
*messages = append(*messages, message)
|
|
} else {
|
|
continue
|
|
}
|
|
}
|
|
//添加本次请求消息(上面获取到的已包含本次请求消息)
|
|
//var message MessageContent
|
|
//message.Role = "user"
|
|
//message.Content = prompt
|
|
//*messages = append(*messages, message)
|
|
return nil
|
|
}
|
|
|
|
func QianfanToGeneralMassageAndSendMsgQueue(data *qianfan.ModelResponse, channel string, SessionID int, UserID int) {
|
|
var aiMsg proto.AIQueueMessage
|
|
var wsMsg proto.WSMessage
|
|
var apiMsg api.GenerateResponse
|
|
apiMsg.Model = proto.ModelTypeSpark
|
|
apiMsg.Response = data.Result
|
|
if data.IsEnd == true {
|
|
apiMsg.Done = true
|
|
} else {
|
|
apiMsg.Done = false
|
|
}
|
|
//消息队列部分
|
|
aiMsg.Type = proto.ModelTypeSpark
|
|
aiMsg.Msg = apiMsg
|
|
//ws发送消息部分
|
|
wsMsg.Msg = aiMsg
|
|
wsMsg.SessionID = SessionID
|
|
wsMsg.ToID = UserID
|
|
wsMsg.Type = proto.ModelTypeSpark
|
|
|
|
//发送消息
|
|
wsMsgStr, _ := json.Marshal(wsMsg)
|
|
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
|
}
|