添加国外服务器中转ai请求数据
This commit is contained in:
parent
623a590b2d
commit
b334b8185d
|
|
@ -4,6 +4,7 @@ import (
|
|||
"StuAcaWorksAI/dao"
|
||||
"StuAcaWorksAI/proto"
|
||||
"StuAcaWorksAI/service"
|
||||
"StuAcaWorksAI/service/generalModel"
|
||||
"StuAcaWorksAI/worker"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
|
@ -49,6 +50,10 @@ func SetUpIMGroup(router *gin.Engine) {
|
|||
imGroup.GET("/ai_chat_ws", SRMessage)
|
||||
imGroup.POST("/get_message", GetMessage)
|
||||
imGroup.GET("/kbase_text_ws", KBaseTextWS)
|
||||
|
||||
//添加国外ai处理接口,服务之间调用
|
||||
imGroup.POST("/send_message_foreign_ai", SendMessageForeignAI)
|
||||
imGroup.POST("/receive_message_foreign_ai", ReceiveMessageForeignAI)
|
||||
}
|
||||
|
||||
func GetMessage(c *gin.Context) {
|
||||
|
|
@ -429,3 +434,64 @@ func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto.
|
|||
modelParam.SetDefaultParams()
|
||||
}
|
||||
}
|
||||
|
||||
func SendMessageForeignAI(c *gin.Context) {
|
||||
id, _ := c.Get("user_id")
|
||||
userID := id.(int)
|
||||
user := service.GetUserByIDWithCache(userID)
|
||||
var resp proto.GenerateResp
|
||||
var req proto.SendMessageForeignAIRRequest
|
||||
if user.Role == "admin" {
|
||||
//解析请求参数
|
||||
if err := c.ShouldBind(&req); err == nil {
|
||||
//校验参数
|
||||
if req.ModelParam.Model == "" || req.ModelParam.Url == "" || req.ModelParam.APIKey == "" {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "模型必要参数不能为空"
|
||||
} else {
|
||||
//处理
|
||||
generalModel.OpenAIForeignAI(req.ModelParam, &req.IMParamContext, req.PlatForm)
|
||||
//返回结果
|
||||
resp.Code = proto.SuccessCode
|
||||
resp.Message = "success"
|
||||
}
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "参数错误"
|
||||
}
|
||||
} else {
|
||||
resp.Code = proto.PermissionDenied
|
||||
resp.Message = "没有权限"
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func ReceiveMessageForeignAI(c *gin.Context) {
|
||||
id, _ := c.Get("user_id")
|
||||
userID := id.(int)
|
||||
user := service.GetUserByIDWithCache(userID)
|
||||
var resp proto.GenerateResp
|
||||
var req proto.ReceiveMessageForeignAIRRequest
|
||||
if user.Role == "admin" {
|
||||
//解析请求参数
|
||||
if err := c.ShouldBind(&req); err == nil {
|
||||
//校验参数
|
||||
if req.IMParamContext.SessionID == 0 || req.ChatCompletionChunk.Choices == nil {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "模型必要参数不能为空"
|
||||
} else {
|
||||
generalModel.OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(&req)
|
||||
resp.Code = proto.SuccessCode
|
||||
resp.Message = "success"
|
||||
}
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "参数错误"
|
||||
}
|
||||
|
||||
} else {
|
||||
resp.Code = proto.PermissionDenied
|
||||
resp.Message = "没有权限"
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
|
|
|||
1
main.go
1
main.go
|
|
@ -161,6 +161,7 @@ func JWTAuthMiddleware() gin.HandlerFunc {
|
|||
idFloat64 := float64(id)
|
||||
//查看s_id类型
|
||||
c.Set("id", idFloat64)
|
||||
c.Set("user_id", id)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
|
|
|||
14
proto/im.go
14
proto/im.go
|
|
@ -2,6 +2,7 @@ package proto
|
|||
|
||||
import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/openai/openai-go"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
|
|
@ -117,3 +118,16 @@ type UserFileMessage struct {
|
|||
FileContent []FileMessageContent `json:"file_content"` //文件内容
|
||||
Text string `json:"text"` //问题文本
|
||||
}
|
||||
|
||||
type SendMessageForeignAIRRequest struct {
|
||||
ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数
|
||||
IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数
|
||||
PlatForm string `json:"platform" form:"platform"` //平台名称
|
||||
}
|
||||
|
||||
type ReceiveMessageForeignAIRRequest struct {
|
||||
ModelParam ModelParam `json:"model_param" form:"model_param"` //模型参数
|
||||
IMParamContext IMParamContext `json:"im_param_context" form:"im_param_context"` //上下文参数
|
||||
PlatForm string `json:"platform" form:"platform"` //平台名称
|
||||
ChatCompletionChunk openai.ChatCompletionChunk `json:"chat_completion_chunk" form:"chat_completion_chunk"` //消息内容
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// 通义模型
|
||||
// 通用openai模型
|
||||
func OpenAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
|
||||
apiKey := modelParam.APIKey
|
||||
if apiKey == "" {
|
||||
|
|
@ -186,3 +186,124 @@ func SingleTurnConversationText(prompt string, param proto.ModelParam) (string,
|
|||
answer = chatCompletion.Choices[0].Message.Content
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
// 只进行与模型信息交互不进行保存,返回消息给国内服务器
|
||||
func OpenAIForeignAI(modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
|
||||
apiKey := modelParam.APIKey
|
||||
if apiKey == "" {
|
||||
log.Printf("openai_%s apiKey is empty\n", platform)
|
||||
return
|
||||
}
|
||||
openaiClient := openai.NewClient(
|
||||
option.WithAPIKey(apiKey),
|
||||
option.WithBaseURL(modelParam.Url),
|
||||
)
|
||||
var messages []openai.ChatCompletionMessageParamUnion
|
||||
err := GetOpenAIHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
|
||||
if err != nil {
|
||||
log.Println("GetOpenAIHistoryMessages error:", err)
|
||||
return
|
||||
}
|
||||
if proto.Config.LOG_OUTPUT {
|
||||
messagesBytes, _ := json.Marshal(messages)
|
||||
log.Println("openai messages:", string(messagesBytes))
|
||||
}
|
||||
|
||||
stream := openaiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
|
||||
Messages: messages,
|
||||
Model: modelParam.Model,
|
||||
})
|
||||
acc := openai.ChatCompletionAccumulator{}
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
acc.AddChunk(chunk)
|
||||
if _, ok := acc.JustFinishedToolCall(); ok {
|
||||
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
|
||||
}
|
||||
DoOpenAIForeignAICallBack(chunk, modelParam, imCtx, platform)
|
||||
}
|
||||
if stream.Err() != nil {
|
||||
log.Printf("openai_%s stream err:%v\n", platform, stream.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// 返回消息给国内服务器
|
||||
func DoOpenAIForeignAICallBack(chunk openai.ChatCompletionChunk, modelParam proto.ModelParam, imCtx *proto.IMParamContext, platform string) {
|
||||
//post请求
|
||||
url := "https://pm.ljsea.top/im/receive_message_foreign_ai?super_id=1"
|
||||
var req proto.ReceiveMessageForeignAIRRequest
|
||||
req.IMParamContext = *imCtx
|
||||
req.ModelParam = modelParam
|
||||
req.ChatCompletionChunk = chunk
|
||||
req.PlatForm = platform
|
||||
reqBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
log.Println("openai foreign ai request error:", err)
|
||||
return
|
||||
}
|
||||
superTokens := worker.GetRedisSetMembers("super_permission_tokens")
|
||||
if len(superTokens) == 0 {
|
||||
log.Println("openai foreign ai super permission tokens is empty")
|
||||
return
|
||||
}
|
||||
//设置请求头
|
||||
headers := map[string]string{
|
||||
"token": superTokens[0],
|
||||
"super_id": "1",
|
||||
}
|
||||
|
||||
err, respBytes := worker.DoPostRequestJSON(url, reqBytes, headers)
|
||||
if err != nil {
|
||||
log.Println("openai foreign ai request error:", err)
|
||||
return
|
||||
}
|
||||
var resp proto.GenerateResp
|
||||
err = json.Unmarshal(respBytes, &resp)
|
||||
if err != nil {
|
||||
log.Println("openai foreign ai response error:", err)
|
||||
return
|
||||
}
|
||||
if resp.Code != proto.SuccessCode {
|
||||
log.Println("openai foreign ai response error:", resp.Message)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 处理国外服务器返回的模型消息
|
||||
func OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(data *proto.ReceiveMessageForeignAIRRequest) {
|
||||
//转入队列
|
||||
OpenAIToGeneralMassageAndSendMsgQueue(&data.ChatCompletionChunk, data.IMParamContext.Channel, data.IMParamContext.SessionID, data.IMParamContext.UserID)
|
||||
//保存消息,流式需要暂存消息使用redis
|
||||
answer := worker.GetRedis(data.IMParamContext.Channel + "_stream_msg")
|
||||
//var tokens int64
|
||||
//var PromptTokens int64 //用户问题token
|
||||
//tokensStr := worker.GetRedis(data.IMParamContext.Channel + "_total_tokens")
|
||||
//if tokensStr != "" {
|
||||
// tokens, _ = strconv.ParseInt(tokensStr, 10, 64)
|
||||
//}
|
||||
//PromptTokensStr := worker.GetRedis(data.IMParamContext.Channel + "_prompt_tokens")
|
||||
//if PromptTokensStr != "" {
|
||||
// PromptTokens, _ = strconv.ParseInt(PromptTokensStr, 10, 64)
|
||||
//}
|
||||
chunk := data.ChatCompletionChunk
|
||||
if len(chunk.Choices) > 0 {
|
||||
if proto.Config.LOG_OUTPUT {
|
||||
log.Printf("openai_%s stream msg:%s\n", data.PlatForm, chunk.Choices[0].Delta.Content)
|
||||
}
|
||||
answer += chunk.Choices[0].Delta.Content
|
||||
//tokens = chunk.Usage.TotalTokens
|
||||
//PromptTokens = chunk.Usage.PromptTokens
|
||||
}
|
||||
//将消息存入数据库
|
||||
|
||||
if chunk.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断
|
||||
tokens := chunk.Usage.TotalTokens
|
||||
PromptTokens := chunk.Usage.PromptTokens
|
||||
//将消息存入数据库
|
||||
DOSaveMessageAndToken(&data.IMParamContext, answer, "openai_"+data.PlatForm, tokens, PromptTokens)
|
||||
} else {
|
||||
//流式消息,暂存到redis
|
||||
worker.SetRedis(data.IMParamContext.Channel+"_stream_msg", answer)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue