From 269bdea1c72986f1cb2cb55ca7dd6a62e76b585d Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Mon, 24 Mar 2025 14:47:53 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84ws=E6=B6=88=E6=81=AF=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=A0=B9=E6=8D=AE=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E6=9F=A5=E6=89=BE=E5=AF=B9=E5=BA=94=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/model.go | 10 ++++++ handler/im.go | 67 ++++++++++++++++++++++++++--------------- proto/im.go | 12 +++++--- proto/status.go | 6 ++++ service/imService.go | 12 ++++---- service/modelService.go | 12 ++++++++ worker/spark.go | 8 ++--- 7 files changed, 88 insertions(+), 39 deletions(-) diff --git a/dao/model.go b/dao/model.go index 8b6a415..75c6fb4 100644 --- a/dao/model.go +++ b/dao/model.go @@ -135,3 +135,13 @@ func DeleteFunctionModelByID(id int) error { } return res.Error } + +func FindFunctionModelByFunction(function string) []FunctionModel { + var models []FunctionModel + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("function = ?", function).Find(&models) + } else { + DB.Where("function = ?", function).Find(&models) + } + return models +} diff --git a/handler/im.go b/handler/im.go index ec2aa0f..cce9bc3 100644 --- a/handler/im.go +++ b/handler/im.go @@ -1,11 +1,13 @@ package handler import ( + "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/service" "StuAcaWorksAI/worker" "context" "encoding/json" + "errors" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "log" @@ -105,7 +107,6 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { //defer ticker.Stop() var sessionID uint - var userMsgID uint //接收客户端消息,发送消息 done := make(chan struct{}) @@ -123,31 +124,9 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { if err1 != nil { log.Println("Error parsing JSON:", err1) } - if sessionID == 0 { - //创建会话 - //会话名字为前10个字符,如果 - //var sessionName string - //if len(data.Msg) > 10 { - // sessionName = data.Msg[:10] - //} else { - // sessionName = data.Msg - //} - err, sessionID = service.CreateSession(userId, data.Msg) - log.Println("create session id:", sessionID) - if err != nil { - log.Println("Create session error:", err) - return - } - } - //将消息存入数据库 - err, userMsgID = service.CreateMessage(1, int(sessionID), userId, 2, data.Msg, 1) - log.Println("create user message id:", userMsgID) - if err != nil { - return - } //机器人id为2,消息类型为1,状态为1 - //需要后续处理 - err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data) + err = doReceiveGenChatMessage(userId, sessionID, &data) if err != nil { + log.Println("do receive gen chat message error:", err) return } } @@ -202,3 +181,41 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { } } + +func doReceiveGenChatMessage(userId int, sessionID uint, data *proto.WSMessageReq) error { + var err error + chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg" + if sessionID == 0 { + //创建会话 + err, sessionID = service.CreateSession(userId, data.Msg) + log.Println("create session id:", sessionID) + if err != nil { + log.Println("Create session error:", err) + } + } + var model dao.Model + //查看请求功能类型 + if data.Function == "gen-ai-chat" { + model = service.FindFuncModelByFunction(data.Function, userId)[0] + } else { + return errors.New("function not exist") + } + //解析模型参数 + var modelParam proto.ModelParam + err = json.Unmarshal([]byte(model.Parameter), &modelParam) + if err != nil { + return err + } + modelParam.Url = model.Url + + var userMsgID uint + //将消息存入数据库 + err, userMsgID = service.CreateMessage(proto.UserAndModelMsgType, int(sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead) + log.Println("create user message id:", userMsgID) + if err != nil { + return err + } //机器人id为2,消息类型为1,状态为1 + //需要后续处理 + err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, *data, model, modelParam) + return err +} diff --git a/proto/im.go b/proto/im.go index a21c301..da2ec67 100644 --- a/proto/im.go +++ b/proto/im.go @@ -15,10 +15,6 @@ type WSMessage struct { ToID int `json:"to_id"` //接收者id } -type ModelParam struct { - Model string `json:"model"` //模型名称 -} - type WSMessageReq struct { Type string `json:"type"` //接收及发送消息类型 Function string `json:"function"` //功能名称 @@ -26,3 +22,11 @@ type WSMessageReq struct { SessionID int `json:"session_id"` //应用层会话id ToID int `json:"to_id"` //接收者id } + +type ModelParam struct { + Model string `json:"model"` //模型名称 + Url string `json:"url"` //模型地址 + APPID string `json:"appid"` //应用id + APISecret string `json:"apiSecret"` //应用密钥 + APIKey string `json:"apiKey"` //应用key +} diff --git a/proto/status.go b/proto/status.go index e7adbf6..8377717 100644 --- a/proto/status.go +++ b/proto/status.go @@ -89,3 +89,9 @@ const ( FuncModelDeleteFailed = 132 // 删除功能模型失败 FuncModelSearchFailed = 133 // 获取功能模型失败 ) + +const ( + // 代码中使用常量定义 + UserAndModelMsgType = 2 // 用户与模型消息类型 + MsgHasRead = 1 // 消息已读 +) diff --git a/service/imService.go b/service/imService.go index c1fcee2..ec0d797 100644 --- a/service/imService.go +++ b/service/imService.go @@ -79,15 +79,16 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin var client *api.Client var SessionContextMap map[int][]int -func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { +func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) { var resErr error var resID uint //处理消息 if msg.Type == "ollama" { resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg) } else if msg.Type == "spark" { - resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg) - + resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam) + } else { + return errors.New("model not exist"), 0 } return resErr, resID } @@ -150,10 +151,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes return resErr, resID } -func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { +func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) { var resErr error var resID uint - //查找模型 - + worker.Spark(modelParam, msg.Msg, channel, sessionID, userID) return resErr, resID } diff --git a/service/modelService.go b/service/modelService.go index f31a47a..e5038e3 100644 --- a/service/modelService.go +++ b/service/modelService.go @@ -106,3 +106,15 @@ func UpdateFuncModelByID(id int, userID, modelID uint, name, info, function stri } return dao.UpdateFunctionModelByID(id, userID, modelID, name, info, function) } + +// 根据功能查找对应功能 +func FindFuncModelByFunction(function string, userID int) []dao.Model { + //先查找对应功能 + funcModels := dao.FindFunctionModelByFunction(function) + if len(funcModels) == 0 { + return nil + } + modelID := funcModels[0].ModelID + //再查找对应模型 + return dao.FindModelByID(int(modelID), userID) +} diff --git a/worker/spark.go b/worker/spark.go index 614bd5a..503c60f 100644 --- a/worker/spark.go +++ b/worker/spark.go @@ -105,13 +105,13 @@ type SparkResponse struct { /***************** Spark响应结构 ***********************/ -func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel string, SessionID, UserID int) { +func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) { //建立websocket连接 dialer := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } //建立连接 - conn, resp, err := dialer.Dial(assembleAuthUrl(hostUrl, apiKey, apiSecret), nil) + conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil) if err != nil { panic(readResp(resp) + err.Error()) return @@ -121,7 +121,7 @@ func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel st //发送消息 go func() { - data := genSparkParams(appid, question) + data := genSparkParams(modelParam.APPID, question) //将数据转换为json //dataByte, err := json.Marshal(data) if err != nil { @@ -172,7 +172,7 @@ func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel st } } -func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponsem, channel string, SessionID int, UserID int) { +func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) { var aiMsg proto.AIQueueMessage var wsMsg proto.WSMessage var apiMsg api.GenerateResponse