diff --git a/handler/im.go b/handler/im.go index 248c2cd..8fc6514 100644 --- a/handler/im.go +++ b/handler/im.go @@ -189,7 +189,12 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error { var model dao.Model //查看请求功能类型 - models, funcs := service.FindFuncModelByFunction(data.Function, userId) + //models, funcs := service.FindFuncModelByFunction(data.Function, userId) + models, funcs, mferr := service.FindFuncModelListByFunctionV2(data.Function) + if mferr != nil { + return mferr + } + //log.Println("find function model by function:", models, funcs) //log.Println("funcs:", len(funcs), "\tmodels:", len(models)) if len(funcs) == 0 { @@ -210,7 +215,18 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR log.Println("Create session error:", err) } } - model = models[0] + //获取模型 + for _, v := range models { + if v.ID == data.ModelID { + model = v + break + } + } + if model.ID == 0 { + //说明前端传入的模型id不正确,取第一个模型 + model = models[0] + } + //解析模型参数 var modelParam proto.ModelParam err = json.Unmarshal([]byte(model.Parameter), &modelParam) diff --git a/proto/im.go b/proto/im.go index 0e942aa..67cb94e 100644 --- a/proto/im.go +++ b/proto/im.go @@ -24,6 +24,7 @@ type WSMessageReq struct { Msg string `json:"msg"` //消息内容,只进行转发,不做处理 SessionID int `json:"session_id"` //应用层会话id ToID int `json:"to_id"` //接收者id + ModelID int `json:"model_id"` //模型id } type ModelParam struct { diff --git a/service/modelService.go b/service/modelService.go index e83f23b..6b8df79 100644 --- a/service/modelService.go +++ b/service/modelService.go @@ -174,3 +174,24 @@ func FindFuncModelListByFunction(function string) ([]dao.Model, error) { } return models, nil } + +func FindFuncModelListByFunctionV2(function string) ([]dao.Model, []dao.FunctionModel, error) { + funcModels := dao.FindFunctionModelByFunction(function) + + var functionModelIDs []proto.FunctionModelIDs + err := json.Unmarshal([]byte(funcModels[0].ModelIDS), &functionModelIDs) + if err != nil { + log.Println("FindFuncModelListByFunction json unmarshal error:", err) + return nil, funcModels, err + } + modelIDs := map[int]bool{} + for _, v := range functionModelIDs { + modelIDs[v.ID] = true + } + var models []dao.Model + for k := range modelIDs { + models_ := dao.FindModelByIDV2(k) + models = append(models, models_...) + } + return models, funcModels, nil +}