saw-go/service/generalModel/openai.go

189 lines
6.3 KiB
Go

package generalModel
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"github.com/ollama/ollama/api"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"log"
"time"
)
// 通义模型
func OpenAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
apiKey := modelParam.APIKey
if apiKey == "" {
log.Printf("openai_%s apiKey is empty\n", platform)
return
}
openaiClient := openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL(modelParam.Url),
)
var messages []openai.ChatCompletionMessageParamUnion
err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
if err != nil {
log.Println("GetOpenAIHistoryMessages error:", err)
return
}
if proto.Config.LOG_OUTPUT {
messagesBytes, _ := json.Marshal(messages)
log.Println("openai messages:", string(messagesBytes))
}
stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
Messages: messages,
Model: modelParam.Model,
})
acc := openai.ChatCompletionAccumulator{}
answer := ""
var tokens int64
var PromptTokens int64 //用户问题token
for stream.Next() {
chunk := stream.Current()
msgStr, err3 := json.Marshal(chunk)
if err3 != nil {
log.Println("openai stream msg Error parsing JSON:", err)
}
//go service.CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model)
go dao.InsertAIStreamMsgToDB(imCtx.UserID, imCtx.SessionID, uint(imCtx.ModelID), string(msgStr), modelParam.Model) //存入数据库
acc.AddChunk(chunk)
if _, ok := acc.JustFinishedToolCall(); ok {
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
}
if len(chunk.Choices) > 0 {
if proto.Config.LOG_OUTPUT {
log.Printf("openai_%s stream msg:%s\n", platform, chunk.Choices[0].Delta.Content)
}
answer += chunk.Choices[0].Delta.Content
tokens = chunk.Usage.TotalTokens
PromptTokens = chunk.Usage.PromptTokens
}
OpenAIToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
}
if stream.Err() != nil {
log.Printf("openai_%s stream err:%v\n", platform, stream.Err())
}
//将消息存入数据库
DOSaveMessageAndToken(imCtx, answer, "openai_"+platform, tokens, PromptTokens)
}
func DOSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) {
msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType}
err2, msgID := dao.CreateMessageV2(msg)
if err2 != nil {
log.Printf("%s create message error:%v\n", platform, err2)
} else {
log.Printf("%s create message success,message id:%d\n", platform, msgID)
}
//保存用户问题消耗token
err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens))
if err != nil {
log.Printf("%s user prompt create message token error:%v\n", platform, err)
} else {
log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID)
}
//保存返回的消息token消耗
err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens))
if err != nil {
log.Printf("%s create message token error:%v\n", platform, err)
} else {
log.Printf("%s create message token id:%d\n", platform, messageTokenID)
}
//存入tokens消耗信息
err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
if err3 != nil {
log.Printf("%s create model tokens error:%v\n", platform, err3)
} else {
log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID)
}
}
func GetOpenAIHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error {
//获取会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.SparkContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
}
//添加系统消息
if systemPrompt != "" {
*messages = append(*messages, openai.SystemMessage(systemPrompt))
}
for _, msg := range msgs {
if msg.Type == proto.UserToModelMsgType {
if msg.Status == proto.UserToModelFileMsgType {
//err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
//if err2 != nil {
// log.Println("spark decode file message error:", err2)
//}
} else {
*messages = append(*messages, openai.UserMessage(msg.Msg))
}
} else if msg.Type == proto.ModelToUserMsgType {
*messages = append(*messages, openai.AssistantMessage(msg.Msg))
} else {
continue
}
}
return nil
}
func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) {
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
apiMsg.Model = proto.ModelTypeSpark
apiMsg.Response = data.Choices[0].Delta.Content
if data.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断
apiMsg.Done = true
} else {
apiMsg.Done = false
}
//消息队列部分
aiMsg.Type = proto.ModelTypeSpark
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
wsMsg.Type = proto.ModelTypeSpark
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
worker.Publish(channel, string(wsMsgStr), time.Second*60)
}
// openai 单轮对话,文本
func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, error) {
var answer string
var err error
apiKey := param.APIKey
if apiKey == "" || param.Url == "" {
log.Printf("openai apiKey or url is empty\n")
return "", nil
}
openaiClient := openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL(param.Url),
)
chatCompletion, err := openaiClient.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
},
Model: param.Model,
})
if err != nil {
log.Printf("openai single turn conversation error:%v\n", err)
return "", err
}
answer = chatCompletion.Choices[0].Message.Content
return answer, nil
}