From 349ee3af05891bade835a856a6fbe5cc1ce04c3c Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Tue, 8 Apr 2025 15:58:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E5=A2=9E=E5=88=A0=E6=94=B9=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/db.go | 6 ++ dao/im.go | 2 +- dao/kbase.go | 99 +++++++++++++++++++++++++++++++ handler/knowledgeBase.go | 101 ++++++++++++++++++++++++++++++++ main.go | 17 +++--- proto/kbase.go | 9 +++ proto/status.go | 1 + proto/tool.go | 6 ++ service/knowledgeBaseService.go | 70 ++++++++++++++++++++++ 9 files changed, 302 insertions(+), 9 deletions(-) create mode 100644 dao/kbase.go create mode 100644 handler/knowledgeBase.go create mode 100644 proto/kbase.go create mode 100644 service/knowledgeBaseService.go diff --git a/dao/db.go b/dao/db.go index 7561ab5..d7d6b42 100644 --- a/dao/db.go +++ b/dao/db.go @@ -61,6 +61,12 @@ func Init() error { return err } + err = db.AutoMigrate(&KnowledgeBase{}) + if err != nil { + fmt.Println("knowledge base table:", err) + return err + } + err = db.AutoMigrate(&Model{}) if err != nil { fmt.Println("model table:", err) diff --git a/dao/im.go b/dao/im.go index 7edfc49..29e02d3 100644 --- a/dao/im.go +++ b/dao/im.go @@ -12,7 +12,7 @@ type Session struct { gorm.Model UserID int `gorm:"column:user_id"` //只能由用户创建 Name string `gorm:"column:name"` - Type int `gorm:"column:type"` //会话类型:1为通用聊天会话,2为提示词会话,3为制作ppt会话 + Type int `gorm:"column:type"` //会话类型:1为通用聊天会话,2为提示词会话,3为制作ppt会话,4为知识库会话 Context json.RawMessage `gorm:"column:context;type:json"` //会话上下文 } diff --git a/dao/kbase.go b/dao/kbase.go new file mode 100644 index 0000000..8f50a2c --- /dev/null +++ b/dao/kbase.go @@ -0,0 +1,99 @@ +package dao + +import ( + "StuAcaWorksAI/proto" + "gorm.io/gorm" +) + +// 知识库表结构,存储知识库的基本信息及文件信息 +type KnowledgeBase struct { + gorm.Model + AuthID uint `gorm:"column:auth_id"` // 知识库创建者ID + Name string `gorm:"column:name"` // 知识库名称 + Description string `gorm:"column:description"` // 知识库描述 + SessionID uint `gorm:"column:session_id"` // 会话ID,创建知识库自动生成会话 + UUID string `gorm:"column:uuid"` // 知识库UUID,用于向量数据库唯一标识 + FileIDs string `gorm:"column:file_ids"` // 文件ID列表,逗号分隔json '[{'file_id';1}]' +} + +// 创建知识库 +func CreateKnowledgeBase(base KnowledgeBase) (KnowledgeBase, error) { + var db2 *gorm.DB + var err error + if proto.Config.SERVER_SQL_LOG { + db2 = DB.Debug() + } else { + db2 = DB + } + err = db2.Create(&base).Error + if err != nil { + return KnowledgeBase{}, err + } + return base, nil +} + +// 更新知识库 +func UpdateKnowledgeBase(base KnowledgeBase) (KnowledgeBase, error) { + var db2 *gorm.DB + var err error + if proto.Config.SERVER_SQL_LOG { + db2 = DB.Debug() + } else { + db2 = DB + } + err = db2.Save(&base).Error + if err != nil { + return KnowledgeBase{}, err + } + return base, nil +} + +// 获取知识库列表 +func GetKnowledgeBaseList(authID int) ([]KnowledgeBase, error) { + var db2 *gorm.DB + var err error + if proto.Config.SERVER_SQL_LOG { + db2 = DB.Debug() + } else { + db2 = DB + } + var bases []KnowledgeBase + err = db2.Where("auth_id = ?", authID).Find(&bases).Error + if err != nil { + return nil, err + } + return bases, nil +} + +// 根据id获取知识库 +func GetKnowledgeBaseByID(id uint) (KnowledgeBase, error) { + var db2 *gorm.DB + var err error + if proto.Config.SERVER_SQL_LOG { + db2 = DB.Debug() + } else { + db2 = DB + } + var base KnowledgeBase + err = db2.Where("id = ?", id).First(&base).Error + if err != nil { + return KnowledgeBase{}, err + } + return base, nil +} + +// 删除知识库 +func DeleteKnowledgeBase(id uint) error { + var db2 *gorm.DB + var err error + if proto.Config.SERVER_SQL_LOG { + db2 = DB.Debug() + } else { + db2 = DB + } + err = db2.Where("id = ?", id).Delete(&KnowledgeBase{}).Error + if err != nil { + return err + } + return nil +} diff --git a/handler/knowledgeBase.go b/handler/knowledgeBase.go new file mode 100644 index 0000000..ad23c7d --- /dev/null +++ b/handler/knowledgeBase.go @@ -0,0 +1,101 @@ +package handler + +import ( + "StuAcaWorksAI/proto" + "StuAcaWorksAI/service" + "github.com/gin-gonic/gin" + "net/http" +) + +func SetUpKnowledgeBaseGroup(router *gin.Engine) { + knowledgeBaseGroup := router.Group("/knowledgeBase") + knowledgeBaseGroup.POST("/create", CreateKnowledgeBase) // 创建知识库 + knowledgeBaseGroup.POST("/update", UpdateKnowledgeBase) // 更新知识库 + knowledgeBaseGroup.POST("/list", GetKnowledgeBaseList) // 获取知识库列表 + knowledgeBaseGroup.POST("/delete", DeleteKnowledgeBase) // 删除知识库 +} + +func CreateKnowledgeBase(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + // 处理创建知识库的逻辑 + var req proto.KnowledgeBaseReq + var resp proto.GenerateResp + if err := c.ShouldBind(&req); err == nil { + knowledgeBase, err2 := service.CreateKnowledgeBase(userID, req) + if err2 == nil { + resp.Code = proto.SuccessCode + resp.Message = "success" + resp.Data = knowledgeBase + } else { + resp.Code = proto.ParameterError + resp.Message = "内部错误:" + err2.Error() + } + } else { + resp.Code = proto.ParameterError + resp.Message = "参数错误:" + err.Error() + } + c.JSON(http.StatusOK, resp) +} + +func UpdateKnowledgeBase(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + // 处理更新知识库的逻辑 + var req proto.KnowledgeBaseReq + var resp proto.GenerateResp + if err := c.ShouldBind(&req); err == nil { + knowledgeBase, err2 := service.UpdateKnowledgeBase(userID, req) + if err2 == nil { + resp.Code = proto.SuccessCode + resp.Message = "success" + resp.Data = knowledgeBase + } else { + resp.Code = proto.ParameterError + resp.Message = "内部错误:" + err2.Error() + } + } else { + resp.Code = proto.ParameterError + resp.Message = "参数错误:" + err.Error() + } + c.JSON(http.StatusOK, resp) +} + +func GetKnowledgeBaseList(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + // 处理获取知识库列表的逻辑 + var resp proto.GenerateResp + knowledgeBases, err2 := service.GetKnowledgeBaseList(userID) + if err2 == nil { + resp.Code = proto.SuccessCode + resp.Message = "success" + resp.Data = knowledgeBases + } else { + resp.Code = proto.ParameterError + resp.Message = "内部错误:" + err2.Error() + } + c.JSON(http.StatusOK, resp) +} + +func DeleteKnowledgeBase(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + // 处理更新知识库的逻辑 + var req proto.KnowledgeBaseReq + var resp proto.GenerateResp + if err := c.ShouldBind(&req); err == nil { + err2 := service.DeleteKnowledgeBase(userID, req) + if err2 == nil { + resp.Code = proto.SuccessCode + resp.Message = "success" + } else { + resp.Code = proto.ParameterError + resp.Message = "内部处理错误:" + err2.Error() + } + } else { + resp.Code = proto.ParameterError + resp.Message = "参数错误:" + err.Error() + } + +} diff --git a/main.go b/main.go index 3612b2d..6f5c354 100644 --- a/main.go +++ b/main.go @@ -31,14 +31,15 @@ func main() { panic("failed to connect redis:" + err.Error()) } r.Use(handler.CrosHandler()) - r.Use(JWTAuthMiddleware()) // 使用 JWT 认证中间件 - handler.SetUpUserGroup(r) // User - handler.SetUpToolGroup(r) // Tool - handler.SetUpFileGroup(r) // File - handler.SetUpSessionGroup(r) // Session - handler.SetUpModelGroup(r) // Model - handler.SetUpFuncModelGroup(r) // FuncModel - handler.SetUpIMGroup(r) // IM + r.Use(JWTAuthMiddleware()) // 使用 JWT 认证中间件 + handler.SetUpUserGroup(r) // User + handler.SetUpToolGroup(r) // Tool + handler.SetUpFileGroup(r) // File + handler.SetUpKnowledgeBaseGroup(r) // KnowledgeBase + handler.SetUpSessionGroup(r) // Session + handler.SetUpModelGroup(r) // Model + handler.SetUpFuncModelGroup(r) // FuncModel + handler.SetUpIMGroup(r) // IM defer dao.Close() defer worker.CloseRedis() //定时任务 diff --git a/proto/kbase.go b/proto/kbase.go new file mode 100644 index 0000000..f1b645e --- /dev/null +++ b/proto/kbase.go @@ -0,0 +1,9 @@ +package proto + +type KnowledgeBaseReq struct { + ID uint `json:"id"` // 知识库ID + UUID string `json:"uuid"` // 知识库UUID + Name string `json:"name"` // 知识库名称 + Description string `json:"description"` // 知识库描述 + FileIDS string `json:"file_ids"` // 文件ID列表,逗号分隔 +} diff --git a/proto/status.go b/proto/status.go index 9d9dfbd..a8f9c0a 100644 --- a/proto/status.go +++ b/proto/status.go @@ -163,6 +163,7 @@ const ( SessionTypeUserWithModelGeneration = 1 // 用户与模型通用会话 SessionTypeUserPrompt = 2 // 用户与模型提示词 SessionTypeUserCreatePPT = 3 // 用户与模型制作PPT + SessionTypeKnowledgeBase = 4 // 用户与知识库 ) // 文件类型(文本、图片) diff --git a/proto/tool.go b/proto/tool.go index 9d29023..aeb8627 100644 --- a/proto/tool.go +++ b/proto/tool.go @@ -42,3 +42,9 @@ type FileContentResp struct { Message string `json:"message"` Data any `json:"data"` } + +type GenerateResp struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data"` +} diff --git a/service/knowledgeBaseService.go b/service/knowledgeBaseService.go new file mode 100644 index 0000000..419c028 --- /dev/null +++ b/service/knowledgeBaseService.go @@ -0,0 +1,70 @@ +package service + +import ( + "StuAcaWorksAI/dao" + "StuAcaWorksAI/proto" + "errors" + "github.com/google/uuid" +) + +func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) { + //生成uuid + uuid := uuid.New().String() + sessionName := "知识库会话:" + req.Name + //创建会话 + err, sessionID := CreateSession(userID, proto.SessionTypeKnowledgeBase, sessionName) + if err != nil { + return dao.KnowledgeBase{}, err + } + + // 创建知识库 + base := dao.KnowledgeBase{AuthID: uint(userID), Name: req.Name, Description: req.Description, UUID: uuid, FileIDs: req.FileIDS, SessionID: sessionID} + knowledgeBase, err := dao.CreateKnowledgeBase(base) + if err != nil { + return dao.KnowledgeBase{}, err + } + // 创建向量数据库,后面完成 + + return knowledgeBase, err +} + +// 更新知识库 +func UpdateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) { + //查找知识库 + knowledgeBase, err := dao.GetKnowledgeBaseByID(req.ID) + if err != nil { + return dao.KnowledgeBase{}, err + } + if knowledgeBase.AuthID != uint(userID) { + return dao.KnowledgeBase{}, errors.New("not authorized to update this knowledge base") + } + + knowledgeBase.Name = req.Name + knowledgeBase.Description = req.Description + knowledgeBase.FileIDs = req.FileIDS + + return dao.UpdateKnowledgeBase(knowledgeBase) +} + +// 获取知识库列表 +func GetKnowledgeBaseList(userID int) ([]dao.KnowledgeBase, error) { + knowledgeBases, err := dao.GetKnowledgeBaseList(userID) + if err != nil { + return nil, err + } + return knowledgeBases, nil +} + +// 删除知识库 +func DeleteKnowledgeBase(userID int, req proto.KnowledgeBaseReq) error { + //查找知识库 + knowledgeBase, err := dao.GetKnowledgeBaseByID(req.ID) + if err != nil { + return err + } + if knowledgeBase.AuthID != uint(userID) { + return errors.New("not authorized to delete this knowledge base") + } + + return dao.DeleteKnowledgeBase(knowledgeBase.ID) +}