309 lines
10 KiB
Go
309 lines
10 KiB
Go
package generalModel
|
||
|
||
import (
|
||
"StuAcaWorksAI/dao"
|
||
"StuAcaWorksAI/proto"
|
||
"StuAcaWorksAI/worker"
|
||
"context"
|
||
"encoding/json"
|
||
"github.com/ollama/ollama/api"
|
||
"github.com/openai/openai-go"
|
||
"github.com/openai/openai-go/option"
|
||
"log"
|
||
"time"
|
||
)
|
||
|
||
// 通用openai模型
|
||
func OpenAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
|
||
apiKey := modelParam.APIKey
|
||
if apiKey == "" {
|
||
log.Printf("openai_%s apiKey is empty\n", platform)
|
||
return
|
||
}
|
||
openaiClient := openai.NewClient(
|
||
option.WithAPIKey(apiKey),
|
||
option.WithBaseURL(modelParam.Url),
|
||
)
|
||
var messages []openai.ChatCompletionMessageParamUnion
|
||
err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
|
||
if err != nil {
|
||
log.Println("GetOpenAIHistoryMessages error:", err)
|
||
return
|
||
}
|
||
if proto.Config.LOG_OUTPUT {
|
||
messagesBytes, _ := json.Marshal(messages)
|
||
log.Println("openai messages:", string(messagesBytes))
|
||
}
|
||
|
||
stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
|
||
Messages: messages,
|
||
Model: modelParam.Model,
|
||
})
|
||
acc := openai.ChatCompletionAccumulator{}
|
||
answer := ""
|
||
var tokens int64
|
||
var PromptTokens int64 //用户问题token
|
||
|
||
for stream.Next() {
|
||
chunk := stream.Current()
|
||
msgStr, err3 := json.Marshal(chunk)
|
||
if err3 != nil {
|
||
log.Println("openai stream msg Error parsing JSON:", err)
|
||
}
|
||
//go service.CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model)
|
||
go dao.InsertAIStreamMsgToDB(imCtx.UserID, imCtx.SessionID, uint(imCtx.ModelID), string(msgStr), modelParam.Model) //存入数据库
|
||
acc.AddChunk(chunk)
|
||
if _, ok := acc.JustFinishedToolCall(); ok {
|
||
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
|
||
}
|
||
if len(chunk.Choices) > 0 {
|
||
if proto.Config.LOG_OUTPUT {
|
||
log.Printf("openai_%s stream msg:%s\n", platform, chunk.Choices[0].Delta.Content)
|
||
}
|
||
answer += chunk.Choices[0].Delta.Content
|
||
tokens = chunk.Usage.TotalTokens
|
||
PromptTokens = chunk.Usage.PromptTokens
|
||
}
|
||
OpenAIToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
|
||
}
|
||
if stream.Err() != nil {
|
||
log.Printf("openai_%s stream err:%v\n", platform, stream.Err())
|
||
}
|
||
//将消息存入数据库
|
||
DOSaveMessageAndToken(imCtx, answer, "openai_"+platform, tokens, PromptTokens)
|
||
}
|
||
|
||
func DOSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) {
|
||
msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType}
|
||
err2, msgID := dao.CreateMessageV2(msg)
|
||
if err2 != nil {
|
||
log.Printf("%s create message error:%v\n", platform, err2)
|
||
} else {
|
||
log.Printf("%s create message success,message id:%d\n", platform, msgID)
|
||
}
|
||
//保存用户问题消耗token
|
||
err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens))
|
||
if err != nil {
|
||
log.Printf("%s user prompt create message token error:%v\n", platform, err)
|
||
} else {
|
||
log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID)
|
||
}
|
||
|
||
//保存返回的消息token消耗
|
||
err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens))
|
||
if err != nil {
|
||
log.Printf("%s create message token error:%v\n", platform, err)
|
||
} else {
|
||
log.Printf("%s create message token id:%d\n", platform, messageTokenID)
|
||
}
|
||
//存入tokens消耗信息
|
||
err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
|
||
if err3 != nil {
|
||
log.Printf("%s create model tokens error:%v\n", platform, err3)
|
||
} else {
|
||
log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID)
|
||
}
|
||
}
|
||
|
||
func GetOpenAIHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error {
|
||
//获取会话的历史消息
|
||
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
||
if len(msgs) > proto.SparkContextLength {
|
||
//只取最后的 SparkContextLength 条消息
|
||
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
|
||
}
|
||
//添加系统消息
|
||
if systemPrompt != "" {
|
||
*messages = append(*messages, openai.SystemMessage(systemPrompt))
|
||
}
|
||
for _, msg := range msgs {
|
||
if msg.Type == proto.UserToModelMsgType {
|
||
if msg.Status == proto.UserToModelFileMsgType {
|
||
//err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
|
||
//if err2 != nil {
|
||
// log.Println("spark decode file message error:", err2)
|
||
//}
|
||
} else {
|
||
*messages = append(*messages, openai.UserMessage(msg.Msg))
|
||
}
|
||
} else if msg.Type == proto.ModelToUserMsgType {
|
||
*messages = append(*messages, openai.AssistantMessage(msg.Msg))
|
||
} else {
|
||
continue
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) {
|
||
var aiMsg proto.AIQueueMessage
|
||
var wsMsg proto.WSMessage
|
||
var apiMsg api.GenerateResponse
|
||
apiMsg.Model = proto.ModelTypeSpark
|
||
apiMsg.Response = data.Choices[0].Delta.Content
|
||
if data.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断
|
||
apiMsg.Done = true
|
||
} else {
|
||
apiMsg.Done = false
|
||
}
|
||
//消息队列部分
|
||
aiMsg.Type = proto.ModelTypeSpark
|
||
aiMsg.Msg = apiMsg
|
||
//ws发送消息部分
|
||
wsMsg.Msg = aiMsg
|
||
wsMsg.SessionID = SessionID
|
||
wsMsg.ToID = UserID
|
||
wsMsg.Type = proto.ModelTypeSpark
|
||
|
||
//发送消息
|
||
wsMsgStr, _ := json.Marshal(wsMsg)
|
||
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
||
}
|
||
|
||
// openai 单轮对话,文本
|
||
func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, error) {
|
||
var answer string
|
||
var err error
|
||
apiKey := param.APIKey
|
||
if apiKey == "" || param.Url == "" {
|
||
log.Printf("openai apiKey or url is empty\n")
|
||
return "", nil
|
||
}
|
||
openaiClient := openai.NewClient(
|
||
option.WithAPIKey(apiKey),
|
||
option.WithBaseURL(param.Url),
|
||
)
|
||
chatCompletion, err := openaiClient.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
|
||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||
openai.UserMessage(prompt),
|
||
},
|
||
Model: param.Model,
|
||
})
|
||
if err != nil {
|
||
log.Printf("openai single turn conversation error:%v\n", err)
|
||
return "", err
|
||
}
|
||
answer = chatCompletion.Choices[0].Message.Content
|
||
return answer, nil
|
||
}
|
||
|
||
// 只进行与模型信息交互不进行保存,返回消息给国内服务器
|
||
func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
|
||
apiKey := modelParam.APIKey
|
||
if apiKey == "" {
|
||
log.Printf("openai_%s apiKey is empty\n", platform)
|
||
return
|
||
}
|
||
openaiClient := openai.NewClient(
|
||
option.WithAPIKey(apiKey),
|
||
option.WithBaseURL(modelParam.Url),
|
||
)
|
||
var messages []openai.ChatCompletionMessageParamUnion
|
||
err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
|
||
if err != nil {
|
||
log.Println("GetOpenAIHistoryMessages error:", err)
|
||
return
|
||
}
|
||
if proto.Config.LOG_OUTPUT {
|
||
messagesBytes, _ := json.Marshal(messages)
|
||
log.Println("openai messages:", string(messagesBytes))
|
||
}
|
||
|
||
stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
|
||
Messages: messages,
|
||
Model: modelParam.Model,
|
||
})
|
||
acc := openai.ChatCompletionAccumulator{}
|
||
|
||
for stream.Next() {
|
||
chunk := stream.Current()
|
||
acc.AddChunk(chunk)
|
||
if _, ok := acc.JustFinishedToolCall(); ok {
|
||
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
|
||
}
|
||
DoOpenAIForeignAICallBack(chunk, modelParam, imCtx, platform)
|
||
}
|
||
if stream.Err() != nil {
|
||
log.Printf("openai_%s stream err:%v\n", platform, stream.Err())
|
||
}
|
||
}
|
||
|
||
// 返回消息给国内服务器
|
||
func DoOpenAIForeignAICallBack(chunk openai.ChatCompletionChunk, modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
|
||
//post请求
|
||
url := "https://pm.ljsea.top/im/receive_message_foreign_ai?super_id=1"
|
||
var req proto.ReceiveMessageForeignAIRRequest
|
||
req.IMParamContext = *imCtx
|
||
req.ModelParam = modelParam
|
||
req.ChatCompletionChunk = chunk
|
||
req.PlatForm = platform
|
||
reqBytes, err := json.Marshal(req)
|
||
if err != nil {
|
||
log.Println("openai foreign ai request error:", err)
|
||
return
|
||
}
|
||
superTokens := worker.GetRedisSetMembers("super_permission_tokens")
|
||
if len(superTokens) == 0 {
|
||
log.Println("openai foreign ai super permission tokens is empty")
|
||
return
|
||
}
|
||
//设置请求头
|
||
headers := map[string]string{
|
||
"token": superTokens[0],
|
||
"super_id": "1",
|
||
}
|
||
|
||
err, respBytes := worker.DoPostRequestJSON(url, reqBytes, headers)
|
||
if err != nil {
|
||
log.Println("openai foreign ai request error:", err)
|
||
return
|
||
}
|
||
var resp proto.GenerateResp
|
||
err = json.Unmarshal(respBytes, &resp)
|
||
if err != nil {
|
||
log.Println("openai foreign ai response error:", err)
|
||
return
|
||
}
|
||
if resp.Code != proto.SuccessCode {
|
||
log.Println("openai foreign ai response error:", resp.Message)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 处理国外服务器返回的模型消息
|
||
func OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(data *proto.ReceiveMessageForeignAIRRequest) {
|
||
//转入队列
|
||
OpenAIToGeneralMassageAndSendMsgQueue(&data.ChatCompletionChunk, data.IMParamContext.Channel, data.IMParamContext.SessionID, data.IMParamContext.UserID)
|
||
//保存消息,流式需要暂存消息使用redis
|
||
answer := worker.GetRedis(data.IMParamContext.Channel + "_stream_msg")
|
||
//var tokens int64
|
||
//var PromptTokens int64 //用户问题token
|
||
//tokensStr := worker.GetRedis(data.IMParamContext.Channel + "_total_tokens")
|
||
//if tokensStr != "" {
|
||
// tokens, _ = strconv.ParseInt(tokensStr, 10, 64)
|
||
//}
|
||
//PromptTokensStr := worker.GetRedis(data.IMParamContext.Channel + "_prompt_tokens")
|
||
//if PromptTokensStr != "" {
|
||
// PromptTokens, _ = strconv.ParseInt(PromptTokensStr, 10, 64)
|
||
//}
|
||
chunk := data.ChatCompletionChunk
|
||
if len(chunk.Choices) > 0 {
|
||
if proto.Config.LOG_OUTPUT {
|
||
log.Printf("openai_%s stream msg:%s\n", data.PlatForm, chunk.Choices[0].Delta.Content)
|
||
}
|
||
answer += chunk.Choices[0].Delta.Content
|
||
//tokens = chunk.Usage.TotalTokens
|
||
//PromptTokens = chunk.Usage.PromptTokens
|
||
}
|
||
//将消息存入数据库
|
||
if chunk.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断
|
||
tokens := chunk.Usage.TotalTokens
|
||
PromptTokens := chunk.Usage.PromptTokens
|
||
//将消息存入数据库
|
||
DOSaveMessageAndToken(&data.IMParamContext, answer, "openai_"+data.PlatForm, tokens, PromptTokens)
|
||
} else {
|
||
//流式消息,暂存到redis
|
||
worker.SetRedisWithExpire(data.IMParamContext.Channel+"_stream_msg", answer, time.Second*30)
|
||
}
|
||
}
|