From acc286d00d05fb126ee9dc1a91fb62596b7bf58e Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Mon, 31 Mar 2025 13:26:06 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9ws=E9=95=BF=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E4=B8=AD=E6=8E=A5=E6=94=B6=E6=B6=88=E6=81=AF=E5=90=8E=E9=80=89?= =?UTF-8?q?=E6=8B=A9=E6=A8=A1=E5=9E=8B=E7=9A=84=E5=8A=9F=E8=83=BD=E5=8F=8A?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=8A=9F=E8=83=BD=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- handler/im.go | 20 ++++++++++++++++++-- proto/im.go | 1 + service/modelService.go | 21 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/handler/im.go b/handler/im.go index 248c2cd..8fc6514 100644 --- a/handler/im.go +++ b/handler/im.go @@ -189,7 +189,12 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error { var model dao.Model //查看请求功能类型 - models, funcs := service.FindFuncModelByFunction(data.Function, userId) + //models, funcs := service.FindFuncModelByFunction(data.Function, userId) + models, funcs, mferr := service.FindFuncModelListByFunctionV2(data.Function) + if mferr != nil { + return mferr + } + //log.Println("find function model by function:", models, funcs) //log.Println("funcs:", len(funcs), "\tmodels:", len(models)) if len(funcs) == 0 { @@ -210,7 +215,18 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR log.Println("Create session error:", err) } } - model = models[0] + //获取模型 + for _, v := range models { + if v.ID == data.ModelID { + model = v + break + } + } + if model.ID == 0 { + //说明前端传入的模型id不正确,取第一个模型 + model = models[0] + } + //解析模型参数 var modelParam proto.ModelParam err = json.Unmarshal([]byte(model.Parameter), &modelParam) diff --git a/proto/im.go b/proto/im.go index 0e942aa..67cb94e 100644 --- a/proto/im.go +++ b/proto/im.go @@ -24,6 +24,7 @@ type WSMessageReq struct { Msg string `json:"msg"` //消息内容,只进行转发,不做处理 SessionID int `json:"session_id"` //应用层会话id ToID int `json:"to_id"` //接收者id + ModelID int `json:"model_id"` //模型id } type ModelParam struct { diff --git a/service/modelService.go b/service/modelService.go index e83f23b..6b8df79 100644 --- a/service/modelService.go +++ b/service/modelService.go @@ -174,3 +174,24 @@ func FindFuncModelListByFunction(function string) ([]dao.Model, error) { } return models, nil } + +func FindFuncModelListByFunctionV2(function string) ([]dao.Model, []dao.FunctionModel, error) { + funcModels := dao.FindFunctionModelByFunction(function) + + var functionModelIDs []proto.FunctionModelIDs + err := json.Unmarshal([]byte(funcModels[0].ModelIDS), &functionModelIDs) + if err != nil { + log.Println("FindFuncModelListByFunction json unmarshal error:", err) + return nil, funcModels, err + } + modelIDs := map[int]bool{} + for _, v := range functionModelIDs { + modelIDs[v.ID] = true + } + var models []dao.Model + for k := range modelIDs { + models_ := dao.FindModelByIDV2(k) + models = append(models, models_...) + } + return models, funcModels, nil +}