From 1960155b1952e777a4dd67dcdf16523801cc7f7e Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Wed, 9 Apr 2025 17:37:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E7=9A=84=E6=B6=88=E6=81=AF=E6=B5=81=E5=A4=84=E7=90=86=E5=8D=B3?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E4=B8=8E=E5=A4=84=E7=90=86=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=99=A8,=E7=94=A8=E6=88=B7=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=8D=95=E4=B8=AA=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/kbase.go | 17 +++++++++++++++++ handler/file.go | 12 ++++++++++++ handler/im.go | 14 +++++++++++++- proto/status.go | 2 ++ service/fileService.go | 14 ++++++++++++++ service/imService.go | 13 ++++++++++--- service/knowledgeBaseService.go | 12 ++++++++++++ 7 files changed, 80 insertions(+), 4 deletions(-) diff --git a/dao/kbase.go b/dao/kbase.go index 8f50a2c..2cd13f9 100644 --- a/dao/kbase.go +++ b/dao/kbase.go @@ -97,3 +97,20 @@ func DeleteKnowledgeBase(id uint) error { } return nil } + +// 根据会话id查找知识库 +func GetKnowledgeBaseBySessionID(sessionID uint) (KnowledgeBase, error) { + var db2 *gorm.DB + var err error + if proto.Config.SERVER_SQL_LOG { + db2 = DB.Debug() + } else { + db2 = DB + } + var base KnowledgeBase + err = db2.Where("session_id = ?", sessionID).First(&base).Error + if err != nil { + return KnowledgeBase{}, err + } + return base, nil +} diff --git a/handler/file.go b/handler/file.go index 1251bff..65fb733 100644 --- a/handler/file.go +++ b/handler/file.go @@ -44,6 +44,18 @@ func GetUserFileList(c *gin.Context) { } else if req.Type == "search" { fileList := dao.FindFileByUserFileName(req.FileName, userId) c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "msg": "success", "data": fileList}) + } else if req.Type == "one" { + if req.FileID == 0 { + c.JSON(http.StatusOK, gin.H{"error": "file_id is empty", "code": proto.ParameterError, "message": "failed"}) + return + } + userFile, err2 := service.FindUserFileByID(req.FileID, userId) + if err2 != nil { + c.JSON(http.StatusOK, gin.H{"error": "find user file failed", "code": proto.ParameterError, "message": "failed"}) + return + } else { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "msg": "success", "data": userFile}) + } } else { c.JSON(http.StatusOK, gin.H{"error": "search file type error", "code": proto.ParameterError, "message": "failed"}) } diff --git a/handler/im.go b/handler/im.go index ad198a0..b8333a8 100644 --- a/handler/im.go +++ b/handler/im.go @@ -265,11 +265,19 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR } //若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成 if data.IsKBase { - err2 := service.SelectKBaseContentByQuery(&imContext, data) + //根据会话ID获取对应知识库消息 + kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId)) + if err2 != nil { + log.Println("find knowledge base by session id error:", err2) + return err2 + } + data.KBaseID = kbaseMsg.ID + err2 = service.SelectKBaseContentByQuery(&imContext, data) if err2 != nil { log.Println("select kbase content error:", err2) return err2 } + return nil } //将消息存入数据库 err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID)) @@ -288,6 +296,10 @@ func KBaseTextWS(c *gin.Context) { userID := int(id.(float64)) //query参数 kBaseSercerID := c.Query("kbase_server_id") + if kBaseSercerID == "" { + c.JSON(http.StatusOK, gin.H{"code": proto.ParameterError, "message": "kbase_server_id不能为空"}) + return + } // 升级HTTP连接为WebSocket连接 ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil) if err1 != nil { diff --git a/proto/status.go b/proto/status.go index f747ee4..1b21046 100644 --- a/proto/status.go +++ b/proto/status.go @@ -157,6 +157,8 @@ const ( UserFileTypeFile = "file" // 通用文件 UserFileTypeConfig = "config" // 配置文件 UserMaxUploadSize = 1024 * 1024 * 100 + + KnowledgeBaseFunction = "kbase-chat" // 知识库功能 ) // 会话类型 diff --git a/service/fileService.go b/service/fileService.go index eb667f6..15d3ec6 100644 --- a/service/fileService.go +++ b/service/fileService.go @@ -279,6 +279,20 @@ func FindUserFileList(userID int, uploadType string) []dao.UserFileListResp { return res } +func FindUserFileByID(id, userID int) (dao.UserFileListResp, error) { + fileAuth := dao.FindFileAuthByID(id) + if fileAuth.ID == 0 { + return dao.UserFileListResp{}, errors.New("file auth not found") + } + if fileAuth.AuthID != userID { + return dao.UserFileListResp{}, errors.New("no permission") + } + file := dao.FindFileByID(fileAuth.FileID, userID) + fileStoreName := file.FileStoreName + res := dao.UserFileListResp{FileAuth: fileAuth, FileStoreName: fileStoreName} + return res, nil +} + func CreateUserFileAfterUnique() { } diff --git a/service/imService.go b/service/imService.go index e7ff739..0da3b65 100644 --- a/service/imService.go +++ b/service/imService.go @@ -299,14 +299,21 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR return errors.New("no kbase server id") } //记录分配的知识库服务器 - worker.SetRedis(key, KBaseServerID) + res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7) + log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res) } - chanel := fmt.Sprintf("kbase_%s", kbase.UUID) + log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name) + reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称 + reqStr, _ := json.Marshal(req) + //将消息转发到对应频道 + worker.Publish(reqChanel, string(reqStr), time.Second*60) + + chanel := fmt.Sprintf("kbase_%s", kbase.UUID) //获取返回消息频道 pubsub := worker.RedisClient.Subscribe(context.Background(), chanel) defer pubsub.Close() ch := pubsub.Channel() //1min定时器,超时则返回错误 - timer := time.NewTimer(time.Minute) + timer := time.NewTimer(time.Second * 10) var resp proto.KnowledgeBaseServerResponse select { case m := <-ch: diff --git a/service/knowledgeBaseService.go b/service/knowledgeBaseService.go index 42d6768..df1cd42 100644 --- a/service/knowledgeBaseService.go +++ b/service/knowledgeBaseService.go @@ -80,3 +80,15 @@ func AddKnowledgeBaseServer() { } worker.SetRedisSet(set, serverIDs, time.Hour) } + +// 根据会话id查找知识库 +func FindKnowledgeBaseBySessionID(sessionID uint, userID uint) (dao.KnowledgeBase, error) { + knowledgeBase, err := dao.GetKnowledgeBaseBySessionID(sessionID) + if err != nil { + return dao.KnowledgeBase{}, err + } + if knowledgeBase.AuthID != userID { + return dao.KnowledgeBase{}, errors.New("not authorized to access this knowledge base") + } + return knowledgeBase, nil +}