saw-go/service/generalModel/openai.go

419 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package generalModel
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"github.com/ollama/ollama/api"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"log"
"time"
)
// 通用openai模型
func OpenAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
apiKey := modelParam.APIKey
if apiKey == "" {
log.Printf("openai_%s apiKey is empty\n", platform)
return
}
openaiClient := openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL(modelParam.Url),
)
var messages []openai.ChatCompletionMessageParamUnion
err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
if err != nil {
log.Println("GetOpenAIHistoryMessages error:", err)
return
}
if proto.Config.LOG_OUTPUT {
messagesBytes, _ := json.Marshal(messages)
log.Println("openai messages:", string(messagesBytes))
}
stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
Messages: messages,
Model: modelParam.Model,
})
acc := openai.ChatCompletionAccumulator{}
answer := ""
var tokens int64
var PromptTokens int64 //用户问题token
for stream.Next() {
chunk := stream.Current()
msgStr, err3 := json.Marshal(chunk)
if err3 != nil {
log.Println("openai stream msg Error parsing JSON:", err)
}
//go service.CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model)
go dao.InsertAIStreamMsgToDB(imCtx.UserID, imCtx.SessionID, uint(imCtx.ModelID), string(msgStr), modelParam.Model) //存入数据库
acc.AddChunk(chunk)
if _, ok := acc.JustFinishedToolCall(); ok {
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
}
if len(chunk.Choices) > 0 {
if proto.Config.LOG_OUTPUT {
log.Printf("openai_%s stream msg:%s\n", platform, chunk.Choices[0].Delta.Content)
}
answer += chunk.Choices[0].Delta.Content
tokens = chunk.Usage.TotalTokens
PromptTokens = chunk.Usage.PromptTokens
}
OpenAIToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
}
if stream.Err() != nil {
log.Printf("openai_%s stream err:%v\n", platform, stream.Err())
}
//将消息存入数据库
DOSaveMessageAndToken(imCtx, answer, "openai_"+platform, tokens, PromptTokens)
}
func DOSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) {
msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType}
err2, msgID := dao.CreateMessageV2(msg)
if err2 != nil {
log.Printf("%s create message error:%v\n", platform, err2)
} else {
log.Printf("%s create message success,message id:%d\n", platform, msgID)
}
//保存用户问题消耗token
err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens))
if err != nil {
log.Printf("%s user prompt create message token error:%v\n", platform, err)
} else {
log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID)
}
//保存返回的消息token消耗
err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens))
if err != nil {
log.Printf("%s create message token error:%v\n", platform, err)
} else {
log.Printf("%s create message token id:%d\n", platform, messageTokenID)
}
//存入tokens消耗信息
err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
if err3 != nil {
log.Printf("%s create model tokens error:%v\n", platform, err3)
} else {
log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID)
}
}
func GetOpenAIHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error {
//获取会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.SparkContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
}
//添加系统消息
if systemPrompt != "" {
*messages = append(*messages, openai.SystemMessage(systemPrompt))
}
for _, msg := range msgs {
if msg.Type == proto.UserToModelMsgType {
if msg.Status == proto.UserToModelFileMsgType {
//err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
//if err2 != nil {
// log.Println("spark decode file message error:", err2)
//}
} else {
*messages = append(*messages, openai.UserMessage(msg.Msg))
}
} else if msg.Type == proto.ModelToUserMsgType {
*messages = append(*messages, openai.AssistantMessage(msg.Msg))
} else {
continue
}
}
return nil
}
func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) {
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
apiMsg.Model = proto.ModelTypeSpark
apiMsg.Response = data.Choices[0].Delta.Content
if data.Created < 0 || data.Choices[0].FinishReason != "" { //这里由于没有返回结束标志,所以用该字段来判断
apiMsg.Done = true
} else {
apiMsg.Done = false
}
//消息队列部分
aiMsg.Type = proto.ModelTypeSpark
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
wsMsg.Type = proto.ModelTypeSpark
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
worker.Publish(channel, string(wsMsgStr), time.Second*60)
}
// openai 单轮对话,文本
func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, error) {
var answer string
var err error
apiKey := param.APIKey
if apiKey == "" || param.Url == "" {
log.Printf("openai apiKey or url is empty\n")
return "", nil
}
openaiClient := openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL(param.Url),
)
chatCompletion, err := openaiClient.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
},
Model: param.Model,
})
if err != nil {
log.Printf("openai single turn conversation error:%v\n", err)
return "", err
}
answer = chatCompletion.Choices[0].Message.Content
return answer, nil
}
// 只进行与模型信息交互不进行保存,返回消息给国内服务器
func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string, originMessages []proto.MessageContent) {
apiKey := modelParam.APIKey
if apiKey == "" {
log.Printf("openai_%s apiKey is empty\n", platform)
return
}
openaiClient := openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL(modelParam.Url),
)
//var messages []openai.ChatCompletionMessageParamUnion
//err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
//if err != nil {
// log.Println("GetOpenAIHistoryMessages error:", err)
// return
//}
var messages []openai.ChatCompletionMessageParamUnion
for _, msg := range originMessages {
switch msg.Role {
case "user":
messages = append(messages, openai.UserMessage(msg.Content))
case "assistant":
messages = append(messages, openai.AssistantMessage(msg.Content))
case "system":
messages = append(messages, openai.SystemMessage(msg.Content))
default:
messages = append(messages, openai.UserMessage(msg.Content))
}
}
if proto.Config.LOG_OUTPUT {
messagesBytes, _ := json.Marshal(messages)
log.Println("openai messages:", string(messagesBytes))
}
stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
Messages: messages,
Model: modelParam.Model,
})
acc := openai.ChatCompletionAccumulator{}
for stream.Next() {
chunk := stream.Current()
acc.AddChunk(chunk)
if _, ok := acc.JustFinishedToolCall(); ok {
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
}
DoOpenAIForeignAICallBack(chunk, modelParam, imCtx, platform)
}
if stream.Err() != nil {
log.Printf("openai_%s stream err:%v\n", platform, stream.Err())
}
}
// 返回消息给国内服务器
func DoOpenAIForeignAICallBack(chunk openai.ChatCompletionChunk, modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
//post请求
url := "https://pm.ljsea.top/im/receive_message_foreign_ai?super_id=1"
var req proto.ReceiveMessageForeignAIRRequest
req.IMParamContext = *imCtx
req.ModelParam = modelParam
req.ChatCompletionChunk = chunk
req.PlatForm = platform
reqBytes, err := json.Marshal(req)
if err != nil {
log.Println("openai foreign ai request error:", err)
return
}
superTokens := worker.GetRedisSetMembers("super_permission_tokens")
if len(superTokens) == 0 {
log.Println("openai foreign ai super permission tokens is empty")
return
}
//设置请求头
headers := map[string]string{
"token": superTokens[0],
"super_id": "1",
}
err, respBytes := worker.DoPostRequestJSON(url, reqBytes, headers)
if err != nil {
log.Println("openai foreign ai request error:", err)
return
}
var resp proto.GenerateResp
err = json.Unmarshal(respBytes, &resp)
if err != nil {
log.Println("openai foreign ai response error:", err)
return
}
if resp.Code != proto.SuccessCode {
log.Println("openai foreign ai response error:", resp.Message)
return
}
}
// 处理国外服务器返回的模型消息
func OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(data *proto.ReceiveMessageForeignAIRRequest) {
chunk := data.ChatCompletionChunk
imCtx := &data.IMParamContext
msgStr, _ := json.Marshal(chunk)
go dao.InsertAIStreamMsgToDB(imCtx.UserID, imCtx.SessionID, uint(imCtx.ModelID), string(msgStr), data.ModelParam.Model) //存入数据库
//转入队列
OpenAIToGeneralMassageAndSendMsgQueue(&data.ChatCompletionChunk, data.IMParamContext.Channel, data.IMParamContext.SessionID, data.IMParamContext.UserID)
//保存消息流式需要暂存消息使用redis
answer := worker.GetRedis(data.IMParamContext.Channel + "_stream_msg")
//var tokens int64
//var PromptTokens int64 //用户问题token
//tokensStr := worker.GetRedis(data.IMParamContext.Channel + "_total_tokens")
//if tokensStr != "" {
// tokens, _ = strconv.ParseInt(tokensStr, 10, 64)
//}
//PromptTokensStr := worker.GetRedis(data.IMParamContext.Channel + "_prompt_tokens")
//if PromptTokensStr != "" {
// PromptTokens, _ = strconv.ParseInt(PromptTokensStr, 10, 64)
//}
if len(chunk.Choices) > 0 {
if proto.Config.LOG_OUTPUT {
log.Printf("openai_%s stream msg:%s\n", data.PlatForm, chunk.Choices[0].Delta.Content)
}
answer += chunk.Choices[0].Delta.Content
//tokens = chunk.Usage.TotalTokens
//PromptTokens = chunk.Usage.PromptTokens
}
//将消息存入数据库
if chunk.Created < 0 || chunk.Choices[0].FinishReason != "" { //这里由于没有返回结束标志,所以用该字段来判断
tokens := chunk.Usage.TotalTokens
PromptTokens := chunk.Usage.PromptTokens
//将消息存入数据库
DOSaveMessageAndToken(&data.IMParamContext, answer, "openai_"+data.PlatForm, tokens, PromptTokens)
} else {
//流式消息暂存到redis
worker.SetRedisWithExpire(data.IMParamContext.Channel+"_stream_msg", answer, time.Second*30)
}
}
// 将任务发送到国外服务器
func SendOpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
//post请求
url := "https://vis.ljsea.top/im/send_message_foreign_ai?super_id=1"
var req proto.SendMessageForeignAIRRequest
req.IMParamContext = *imCtx
req.ModelParam = modelParam
req.PlatForm = platform
var messages []proto.MessageContent
err1 := GetGeneralSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
if err1 != nil {
log.Println("GetOpenAIHistoryMessages error:", err1)
return
}
req.MessageContent = messages
reqBytes, err := json.Marshal(req)
if err != nil {
log.Println("send openai foreign ai request error:", err)
return
}
superTokens := worker.GetRedisSetMembers("super_permission_tokens")
if len(superTokens) == 0 {
log.Println("send openai foreign ai super permission tokens is empty")
return
}
//设置请求头
headers := map[string]string{
"token": superTokens[0],
"super_id": "1",
}
err, respBytes := worker.DoPostRequestJSON(url, reqBytes, headers)
if err != nil {
log.Println("send openai foreign ai request error:", err)
return
}
var resp proto.GenerateResp
err = json.Unmarshal(respBytes, &resp)
if err != nil {
log.Println("send openai foreign ai response error:", err)
return
}
if resp.Code != proto.SuccessCode {
log.Println("send openai foreign ai response error:", resp.Message)
return
}
}
// 获取spark模型会话的历史消息
func GetGeneralSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]proto.MessageContent, userID int) error {
//获取会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.SparkContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
}
//添加系统消息
if systemPrompt != "" {
var message proto.MessageContent
message.Role = "system"
message.Content = systemPrompt
*messages = append(*messages, message)
}
for _, msg := range msgs {
var message proto.MessageContent
if msg.Type == proto.UserToModelMsgType {
message.Role = proto.SparkRoleUser
if msg.Status == proto.UserToModelFileMsgType {
//err2 := service.DecodeFileMessageToGeneralMessage(msg.Msg, userID, messages)
//if err2 != nil {
// log.Println("spark decode file message error:", err2)
//}
} else {
message.Content = msg.Msg
*messages = append(*messages, message)
}
} else if msg.Type == proto.ModelToUserMsgType {
message.Role = proto.SparkRoleAssistant
message.Content = msg.Msg
*messages = append(*messages, message)
} else {
continue
}
}
//添加本次请求消息(上面获取到的已包含本次请求消息)
//var message MessageContent
//message.Role = "user"
//message.Content = prompt
//*messages = append(*messages, message)
return nil
}