diff --git a/handler/im.go b/handler/im.go index b8333a8..44056da 100644 --- a/handler/im.go +++ b/handler/im.go @@ -263,28 +263,34 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR if data.IsImage || data.IsFile { status = proto.UserToModelImageMsgType //图片类型 } - //若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成 - if data.IsKBase { - //根据会话ID获取对应知识库消息 - kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId)) - if err2 != nil { - log.Println("find knowledge base by session id error:", err2) - return err2 - } - data.KBaseID = kbaseMsg.ID - err2 = service.SelectKBaseContentByQuery(&imContext, data) - if err2 != nil { - log.Println("select kbase content error:", err2) - return err2 - } - return nil - } //将消息存入数据库 err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID)) log.Println("create user message id:", userMsgID) if err != nil { return err } //机器人id为2,消息类型为1,状态为1 + + //若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成 + if data.IsKBase { + //根据会话ID获取对应知识库消息 + kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId)) + if err2 != nil { + log.Println("find knowledge base by session id error:", err2) + service.WSReturnErrorMessage(imContext.Channel, "find knowledge base error:"+err2.Error(), proto.WSKBaseSessionError) + return err2 + } + data.KBaseID = kbaseMsg.ID + err2 = service.SelectKBaseContentByQuery(&imContext, data) + if err2 != nil { + log.Println("select kbase content error:", err2) + service.WSReturnErrorMessage(imContext.Channel, "select kbase error:"+err2.Error(), proto.WSKBaseServerError) + return err2 + } + if data.KBaseType == "query" { + //若为查询直接返回 + return nil + } + } //需要后续处理 //err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam) err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext) @@ -309,12 +315,20 @@ func KBaseTextWS(c *gin.Context) { subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID) } -func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSercerID string) { +func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseServerID string) { ctx := context.Background() - chanel := fmt.Sprintf("kbase_server_id_%s", kBaseSercerID) // 频道名称 + chanel := fmt.Sprintf("kbase_server_id_%s", kBaseServerID) // 频道名称 + //设置KBaseServerID可用状态,加入可用集合 + setKey := "kbase_server_id_set_is_available" + worker.SetRedisSetAdd(setKey, kBaseServerID) pubsub := worker.RedisClient.Subscribe(ctx, chanel) - defer pubsub.Close() - defer ws.Close() + defer func() { + //设置KBaseServerID不可用状态,删除可用集合 + worker.SetRedisSetRemove(setKey, kBaseServerID) + pubsub.Close() + ws.Close() + }() + ch := pubsub.Channel() //接收客户端消息,发送消息 diff --git a/proto/im.go b/proto/im.go index 0b2fd4b..5762408 100644 --- a/proto/im.go +++ b/proto/im.go @@ -12,6 +12,7 @@ type AIQueueMessage struct { } type WSMessage struct { + Code int `json:"code"` //返回码,0表示成功, 其它状态码 Type string `json:"type"` //接收及发送消息类型 Msg AIQueueMessage `json:"msg"` //消息内容,只进行转发,不做处理 SessionID int `json:"session_id"` //应用层会话id @@ -29,6 +30,7 @@ type WSMessageReq struct { IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""} IsKBase bool `json:"is_kbase"` //是否为知识库消息 KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传 + KBaseType string `json:"kbase_type"` //知识库模式,1为query 2为chat } type ModelParam struct { diff --git a/proto/status.go b/proto/status.go index 1b21046..8cc8a78 100644 --- a/proto/status.go +++ b/proto/status.go @@ -90,6 +90,10 @@ const ( FuncModelSearchFailed = 133 // 获取功能模型失败 GetSparkCreatePPTStatusFailed = 140 // 获取spark创建ppt状态失败 + + //下面是ws消息错误码 + WSKBaseServerError = 150 // ws知识库服务器错误 + WSKBaseSessionError = 151 // ws知识库会话错误 ) const ( diff --git a/service/doubao.go b/service/doubao.go index 0109480..7a68dff 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -148,7 +148,7 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { //将消息存入数据库 err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID) if err != nil { - log.Println("spark create message error:", err) + log.Println("doubao create message error:", err) } //将本次请求的tokens存入数据库 err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) diff --git a/service/imService.go b/service/imService.go index 0da3b65..84484e4 100644 --- a/service/imService.go +++ b/service/imService.go @@ -278,9 +278,9 @@ func CheckUserCreatePPTSessionPermission(userID int) error { } // 知识库消息,从原始请求数据,上下文数据 -func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageReq) error { +func SelectKBaseContentByQuery(imCtx *proto.IMParamContext, data *proto.WSMessageReq) error { var req proto.KnowledgeBaseServerRequest - req.IMContext = *ctx + req.IMContext = *imCtx kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID) if err != nil { return err @@ -302,6 +302,13 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7) log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res) } + setKey := "kbase_server_id_set_is_available" //可用状态集合 + //查看KBaseServerID是否可用 + isAvailable := worker.SetRedisSetIsMember(setKey, KBaseServerID) + if isAvailable == false { + log.Println("当前知识库服务器不可用,重新分配") + return errors.New("kbase server id is not available") + } log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name) reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称 reqStr, _ := json.Marshal(req) @@ -313,11 +320,12 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR defer pubsub.Close() ch := pubsub.Channel() //1min定时器,超时则返回错误 - timer := time.NewTimer(time.Second * 10) + timer := time.NewTimer(time.Second * 30) var resp proto.KnowledgeBaseServerResponse select { case m := <-ch: msg := m.Payload // 获取消息,消息格式为json + //log.Println("resp:", msg) err2 := json.Unmarshal([]byte(msg), &resp) if err2 != nil { log.Println("Error parsing JSON:", err2) @@ -330,7 +338,34 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR //解析返回的查询消息 text := "根据下面文档片段:\n" querySelect := resp.QuerySelect + qText := getSelectedDocument(querySelect) + log.Println("kbase type:", data.KBaseType) + if data.KBaseType == "query" { + //则将query为模型返回结果 + //将消息存入数据库 + err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, qText, imCtx.FunctionID) + if err != nil { + log.Println("kbase query message save error:", err) + } + kbaseQueryToGeneralMassageAndSendMsgQueue(qText, imCtx.Channel, imCtx.SessionID, imCtx.UserID) //将消息发送到消息队列 + } else if data.KBaseType == "chat" { + //用户问题 + text += qText + text += "回答用户问题:" + data.Msg + data.Msg = text + log.Println("text:", text) + } else { + return errors.New("unknow kbase type") + } + return nil +} + +func getSelectedDocument(querySelect proto.KnowledgeBaseServerResponseSelect) string { + text := "" for i := 0; i < len(querySelect.Documents); i++ { + if len(querySelect.Documents[i]) == 0 { + break + } text += "文件:" for j := 0; j < len(querySelect.Metadatas); j++ { text += querySelect.Metadatas[i][j].Source + "," @@ -341,8 +376,38 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR } text += "\n" } - //用户问题 - text += "回答下面用户问题:" + data.Msg - data.Msg = text - return nil + return text +} + +func WSReturnErrorMessage(channel string, msg string, code int) { + var data proto.WSMessage + data.Code = code + data.Type = msg //错误时type为错误信息 + data.SessionID = 0 + data.ToID = 0 + errStr, _ := json.Marshal(data) + worker.Publish(channel, string(errStr), time.Second*60) +} + +func kbaseQueryToGeneralMassageAndSendMsgQueue(answer, channel string, SessionID, UserID int) { + var aiMsg proto.AIQueueMessage + var wsMsg proto.WSMessage + var apiMsg api.GenerateResponse + + apiMsg.Done = true + + apiMsg.Model = proto.KnowledgeBaseServerResponseType + apiMsg.Response = answer + //消息队列部分 + aiMsg.Type = proto.KnowledgeBaseServerResponseType + aiMsg.Msg = apiMsg + //ws发送消息部分 + wsMsg.Msg = aiMsg + wsMsg.SessionID = SessionID + wsMsg.ToID = UserID + wsMsg.Type = proto.KnowledgeBaseServerResponseType + + //发送消息 + wsMsgStr, _ := json.Marshal(wsMsg) + worker.Publish(channel, string(wsMsgStr), time.Second*60) } diff --git a/worker/redis.go b/worker/redis.go index 2568d33..28349b3 100644 --- a/worker/redis.go +++ b/worker/redis.go @@ -553,3 +553,13 @@ func GetRedisSetRandomMember(key string) string { } return val } + +func SetRedisSetIsMember(setKey, memberKey string) bool { + ctx := context.Background() + val, err := RedisClient.SIsMember(ctx, setKey, memberKey).Result() + if err != nil { + fmt.Println("Error getting key: %v", err) + return false + } + return val +}