From b334b8185debbe208189c4329015bc193da634fe Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Tue, 13 May 2025 11:02:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=9B=BD=E5=A4=96=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=99=A8=E4=B8=AD=E8=BD=ACai=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- handler/im.go | 66 ++++++++++++++++++ main.go | 1 + proto/im.go | 14 ++++ service/generalModel/openai.go | 123 ++++++++++++++++++++++++++++++++- 4 files changed, 203 insertions(+), 1 deletion(-) diff --git a/handler/im.go b/handler/im.go index 3cda178..4be2917 100644 --- a/handler/im.go +++ b/handler/im.go @@ -4,6 +4,7 @@ import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/service" + "StuAcaWorksAI/service/generalModel" "StuAcaWorksAI/worker" "context" "encoding/json" @@ -49,6 +50,10 @@ func SetUpIMGroup(router *gin.Engine) { imGroup.GET("/ai_chat_ws", SRMessage) imGroup.POST("/get_message", GetMessage) imGroup.GET("/kbase_text_ws", KBaseTextWS) + + //添加国外ai处理接口,服务之间调用 + imGroup.POST("/send_message_foreign_ai", SendMessageForeignAI) + imGroup.POST("/receive_message_foreign_ai", ReceiveMessageForeignAI) } func GetMessage(c *gin.Context) { @@ -429,3 +434,64 @@ func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto. modelParam.SetDefaultParams() } } + +func SendMessageForeignAI(c *gin.Context) { + id, _ := c.Get("user_id") + userID := id.(int) + user := service.GetUserByIDWithCache(userID) + var resp proto.GenerateResp + var req proto.SendMessageForeignAIRRequest + if user.Role == "admin" { + //解析请求参数 + if err := c.ShouldBind(&req); err == nil { + //校验参数 + if req.ModelParam.Model == "" || req.ModelParam.Url == "" || req.ModelParam.APIKey == "" { + resp.Code = proto.ParameterError + resp.Message = "模型必要参数不能为空" + } else { + //处理 + generalModel.OpenAIForeignAI(req.ModelParam, &req.IMParamContext, req.PlatForm) + //返回结果 + resp.Code = proto.SuccessCode + resp.Message = "success" + } + } else { + resp.Code = proto.ParameterError + resp.Message = "参数错误" + } + } else { + resp.Code = proto.PermissionDenied + resp.Message = "没有权限" + } + c.JSON(http.StatusOK, resp) +} + +func ReceiveMessageForeignAI(c *gin.Context) { + id, _ := c.Get("user_id") + userID := id.(int) + user := service.GetUserByIDWithCache(userID) + var resp proto.GenerateResp + var req proto.ReceiveMessageForeignAIRRequest + if user.Role == "admin" { + //解析请求参数 + if err := c.ShouldBind(&req); err == nil { + //校验参数 + if req.IMParamContext.SessionID == 0 || req.ChatCompletionChunk.Choices == nil { + resp.Code = proto.ParameterError + resp.Message = "模型必要参数不能为空" + } else { + generalModel.OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(&req) + resp.Code = proto.SuccessCode + resp.Message = "success" + } + } else { + resp.Code = proto.ParameterError + resp.Message = "参数错误" + } + + } else { + resp.Code = proto.PermissionDenied + resp.Message = "没有权限" + } + c.JSON(http.StatusOK, resp) +} diff --git a/main.go b/main.go index db92a05..7e70fbc 100644 --- a/main.go +++ b/main.go @@ -161,6 +161,7 @@ func JWTAuthMiddleware() gin.HandlerFunc { idFloat64 := float64(id) //查看s_id类型 c.Set("id", idFloat64) + c.Set("user_id", id) c.Next() return } diff --git a/proto/im.go b/proto/im.go index d50eea3..23845b5 100644 --- a/proto/im.go +++ b/proto/im.go @@ -2,6 +2,7 @@ package proto import ( "github.com/ollama/ollama/api" + "github.com/openai/openai-go" "gorm.io/gorm" ) @@ -117,3 +118,16 @@ type UserFileMessage struct { FileContent []FileMessageContent `json:"file_content"` //文件内容 Text string `json:"text"` //问题文本 } + +type SendMessageForeignAIRRequest struct { + ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数 + IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数 + PlatForm string `json:"platform" form:"platform"` //平台名称 +} + +type ReceiveMessageForeignAIRRequest struct { + ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数 + IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数 + PlatForm string `json:"platform" form:"platform"` //平台名称 + ChatCompletionChunk openai.ChatCompletionChunk `json:"chat_completion_chunk" form:"chat_completion_chunk"` //消息内容 +} diff --git a/service/generalModel/openai.go b/service/generalModel/openai.go index 568c1cb..6445a5b 100644 --- a/service/generalModel/openai.go +++ b/service/generalModel/openai.go @@ -13,7 +13,7 @@ import ( "time" ) -// 通义模型 +// 通用openai模型 func OpenAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) { apiKey := modelParam.APIKey if apiKey == "" { @@ -186,3 +186,124 @@ func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, answer = chatCompletion.Choices[0].Message.Content return answer, nil } + +// 只进行与模型信息交互不进行保存,返回消息给国内服务器 +func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) { + apiKey := modelParam.APIKey + if apiKey == "" { + log.Printf("openai_%s apiKey is empty\n", platform) + return + } + openaiClient := openai.NewClient( + option.WithAPIKey(apiKey), + option.WithBaseURL(modelParam.Url), + ) + var messages []openai.ChatCompletionMessageParamUnion + err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) + if err != nil { + log.Println("GetOpenAIHistoryMessages error:", err) + return + } + if proto.Config.LOG_OUTPUT { + messagesBytes, _ := json.Marshal(messages) + log.Println("openai messages:", string(messagesBytes)) + } + + stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{ + Messages: messages, + Model: modelParam.Model, + }) + acc := openai.ChatCompletionAccumulator{} + + for stream.Next() { + chunk := stream.Current() + acc.AddChunk(chunk) + if _, ok := acc.JustFinishedToolCall(); ok { + chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断 + } + DoOpenAIForeignAICallBack(chunk, modelParam, imCtx, platform) + } + if stream.Err() != nil { + log.Printf("openai_%s stream err:%v\n", platform, stream.Err()) + } +} + +// 返回消息给国内服务器 +func DoOpenAIForeignAICallBack(chunk openai.ChatCompletionChunk, modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) { + //post请求 + url := "https://pm.ljsea.top/im/receive_message_foreign_ai?super_id=1" + var req proto.ReceiveMessageForeignAIRRequest + req.IMParamContext = *imCtx + req.ModelParam = modelParam + req.ChatCompletionChunk = chunk + req.PlatForm = platform + reqBytes, err := json.Marshal(req) + if err != nil { + log.Println("openai foreign ai request error:", err) + return + } + superTokens := worker.GetRedisSetMembers("super_permission_tokens") + if len(superTokens) == 0 { + log.Println("openai foreign ai super permission tokens is empty") + return + } + //设置请求头 + headers := map[string]string{ + "token": superTokens[0], + "super_id": "1", + } + + err, respBytes := worker.DoPostRequestJSON(url, reqBytes, headers) + if err != nil { + log.Println("openai foreign ai request error:", err) + return + } + var resp proto.GenerateResp + err = json.Unmarshal(respBytes, &resp) + if err != nil { + log.Println("openai foreign ai response error:", err) + return + } + if resp.Code != proto.SuccessCode { + log.Println("openai foreign ai response error:", resp.Message) + return + } +} + +// 处理国外服务器返回的模型消息 +func OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(data *proto.ReceiveMessageForeignAIRRequest) { + //转入队列 + OpenAIToGeneralMassageAndSendMsgQueue(&data.ChatCompletionChunk, data.IMParamContext.Channel, data.IMParamContext.SessionID, data.IMParamContext.UserID) + //保存消息,流式需要暂存消息使用redis + answer := worker.GetRedis(data.IMParamContext.Channel + "_stream_msg") + //var tokens int64 + //var PromptTokens int64 //用户问题token + //tokensStr := worker.GetRedis(data.IMParamContext.Channel + "_total_tokens") + //if tokensStr != "" { + // tokens, _ = strconv.ParseInt(tokensStr, 10, 64) + //} + //PromptTokensStr := worker.GetRedis(data.IMParamContext.Channel + "_prompt_tokens") + //if PromptTokensStr != "" { + // PromptTokens, _ = strconv.ParseInt(PromptTokensStr, 10, 64) + //} + chunk := data.ChatCompletionChunk + if len(chunk.Choices) > 0 { + if proto.Config.LOG_OUTPUT { + log.Printf("openai_%s stream msg:%s\n", data.PlatForm, chunk.Choices[0].Delta.Content) + } + answer += chunk.Choices[0].Delta.Content + //tokens = chunk.Usage.TotalTokens + //PromptTokens = chunk.Usage.PromptTokens + } + //将消息存入数据库 + + if chunk.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断 + tokens := chunk.Usage.TotalTokens + PromptTokens := chunk.Usage.PromptTokens + //将消息存入数据库 + DOSaveMessageAndToken(&data.IMParamContext, answer, "openai_"+data.PlatForm, tokens, PromptTokens) + } else { + //流式消息,暂存到redis + worker.SetRedis(data.IMParamContext.Channel+"_stream_msg", answer) + } +}