添加知识库增删改查

This commit is contained in:
junleea 2025-04-08 15:58:29 +08:00
parent 6eee65aa32
commit 349ee3af05
9 changed files with 302 additions and 9 deletions

View File

@ -61,6 +61,12 @@ func Init() error {
return err
}
err = db.AutoMigrate(&KnowledgeBase{})
if err != nil {
fmt.Println("knowledge base table:", err)
return err
}
err = db.AutoMigrate(&Model{})
if err != nil {
fmt.Println("model table:", err)

View File

@ -12,7 +12,7 @@ type Session struct {
gorm.Model
UserID int `gorm:"column:user_id"` //只能由用户创建
Name string `gorm:"column:name"`
Type int `gorm:"column:type"` //会话类型:1为通用聊天会话2为提示词会话3为制作ppt会话
Type int `gorm:"column:type"` //会话类型:1为通用聊天会话2为提示词会话3为制作ppt会话,4为知识库会话
Context json.RawMessage `gorm:"column:context;type:json"` //会话上下文
}

99
dao/kbase.go Normal file
View File

@ -0,0 +1,99 @@
package dao
import (
"StuAcaWorksAI/proto"
"gorm.io/gorm"
)
// 知识库表结构,存储知识库的基本信息及文件信息
type KnowledgeBase struct {
gorm.Model
AuthID uint `gorm:"column:auth_id"` // 知识库创建者ID
Name string `gorm:"column:name"` // 知识库名称
Description string `gorm:"column:description"` // 知识库描述
SessionID uint `gorm:"column:session_id"` // 会话ID,创建知识库自动生成会话
UUID string `gorm:"column:uuid"` // 知识库UUID,用于向量数据库唯一标识
FileIDs string `gorm:"column:file_ids"` // 文件ID列表逗号分隔json '[{'file_id';1}]'
}
// 创建知识库
func CreateKnowledgeBase(base KnowledgeBase) (KnowledgeBase, error) {
var db2 *gorm.DB
var err error
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
err = db2.Create(&base).Error
if err != nil {
return KnowledgeBase{}, err
}
return base, nil
}
// 更新知识库
func UpdateKnowledgeBase(base KnowledgeBase) (KnowledgeBase, error) {
var db2 *gorm.DB
var err error
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
err = db2.Save(&base).Error
if err != nil {
return KnowledgeBase{}, err
}
return base, nil
}
// 获取知识库列表
func GetKnowledgeBaseList(authID int) ([]KnowledgeBase, error) {
var db2 *gorm.DB
var err error
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
var bases []KnowledgeBase
err = db2.Where("auth_id = ?", authID).Find(&bases).Error
if err != nil {
return nil, err
}
return bases, nil
}
// 根据id获取知识库
func GetKnowledgeBaseByID(id uint) (KnowledgeBase, error) {
var db2 *gorm.DB
var err error
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
var base KnowledgeBase
err = db2.Where("id = ?", id).First(&base).Error
if err != nil {
return KnowledgeBase{}, err
}
return base, nil
}
// 删除知识库
func DeleteKnowledgeBase(id uint) error {
var db2 *gorm.DB
var err error
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
err = db2.Where("id = ?", id).Delete(&KnowledgeBase{}).Error
if err != nil {
return err
}
return nil
}

101
handler/knowledgeBase.go Normal file
View File

@ -0,0 +1,101 @@
package handler
import (
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service"
"github.com/gin-gonic/gin"
"net/http"
)
func SetUpKnowledgeBaseGroup(router *gin.Engine) {
knowledgeBaseGroup := router.Group("/knowledgeBase")
knowledgeBaseGroup.POST("/create", CreateKnowledgeBase) // 创建知识库
knowledgeBaseGroup.POST("/update", UpdateKnowledgeBase) // 更新知识库
knowledgeBaseGroup.POST("/list", GetKnowledgeBaseList) // 获取知识库列表
knowledgeBaseGroup.POST("/delete", DeleteKnowledgeBase) // 删除知识库
}
func CreateKnowledgeBase(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
// 处理创建知识库的逻辑
var req proto.KnowledgeBaseReq
var resp proto.GenerateResp
if err := c.ShouldBind(&req); err == nil {
knowledgeBase, err2 := service.CreateKnowledgeBase(userID, req)
if err2 == nil {
resp.Code = proto.SuccessCode
resp.Message = "success"
resp.Data = knowledgeBase
} else {
resp.Code = proto.ParameterError
resp.Message = "内部错误:" + err2.Error()
}
} else {
resp.Code = proto.ParameterError
resp.Message = "参数错误:" + err.Error()
}
c.JSON(http.StatusOK, resp)
}
func UpdateKnowledgeBase(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
// 处理更新知识库的逻辑
var req proto.KnowledgeBaseReq
var resp proto.GenerateResp
if err := c.ShouldBind(&req); err == nil {
knowledgeBase, err2 := service.UpdateKnowledgeBase(userID, req)
if err2 == nil {
resp.Code = proto.SuccessCode
resp.Message = "success"
resp.Data = knowledgeBase
} else {
resp.Code = proto.ParameterError
resp.Message = "内部错误:" + err2.Error()
}
} else {
resp.Code = proto.ParameterError
resp.Message = "参数错误:" + err.Error()
}
c.JSON(http.StatusOK, resp)
}
func GetKnowledgeBaseList(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
// 处理获取知识库列表的逻辑
var resp proto.GenerateResp
knowledgeBases, err2 := service.GetKnowledgeBaseList(userID)
if err2 == nil {
resp.Code = proto.SuccessCode
resp.Message = "success"
resp.Data = knowledgeBases
} else {
resp.Code = proto.ParameterError
resp.Message = "内部错误:" + err2.Error()
}
c.JSON(http.StatusOK, resp)
}
func DeleteKnowledgeBase(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
// 处理更新知识库的逻辑
var req proto.KnowledgeBaseReq
var resp proto.GenerateResp
if err := c.ShouldBind(&req); err == nil {
err2 := service.DeleteKnowledgeBase(userID, req)
if err2 == nil {
resp.Code = proto.SuccessCode
resp.Message = "success"
} else {
resp.Code = proto.ParameterError
resp.Message = "内部处理错误:" + err2.Error()
}
} else {
resp.Code = proto.ParameterError
resp.Message = "参数错误:" + err.Error()
}
}

17
main.go
View File

@ -31,14 +31,15 @@ func main() {
panic("failed to connect redis:" + err.Error())
}
r.Use(handler.CrosHandler())
r.Use(JWTAuthMiddleware()) // 使用 JWT 认证中间件
handler.SetUpUserGroup(r) // User
handler.SetUpToolGroup(r) // Tool
handler.SetUpFileGroup(r) // File
handler.SetUpSessionGroup(r) // Session
handler.SetUpModelGroup(r) // Model
handler.SetUpFuncModelGroup(r) // FuncModel
handler.SetUpIMGroup(r) // IM
r.Use(JWTAuthMiddleware()) // 使用 JWT 认证中间件
handler.SetUpUserGroup(r) // User
handler.SetUpToolGroup(r) // Tool
handler.SetUpFileGroup(r) // File
handler.SetUpKnowledgeBaseGroup(r) // KnowledgeBase
handler.SetUpSessionGroup(r) // Session
handler.SetUpModelGroup(r) // Model
handler.SetUpFuncModelGroup(r) // FuncModel
handler.SetUpIMGroup(r) // IM
defer dao.Close()
defer worker.CloseRedis()
//定时任务

9
proto/kbase.go Normal file
View File

@ -0,0 +1,9 @@
package proto
type KnowledgeBaseReq struct {
ID uint `json:"id"` // 知识库ID
UUID string `json:"uuid"` // 知识库UUID
Name string `json:"name"` // 知识库名称
Description string `json:"description"` // 知识库描述
FileIDS string `json:"file_ids"` // 文件ID列表逗号分隔
}

View File

@ -163,6 +163,7 @@ const (
SessionTypeUserWithModelGeneration = 1 // 用户与模型通用会话
SessionTypeUserPrompt = 2 // 用户与模型提示词
SessionTypeUserCreatePPT = 3 // 用户与模型制作PPT
SessionTypeKnowledgeBase = 4 // 用户与知识库
)
// 文件类型(文本、图片)

View File

@ -42,3 +42,9 @@ type FileContentResp struct {
Message string `json:"message"`
Data any `json:"data"`
}
type GenerateResp struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
}

View File

@ -0,0 +1,70 @@
package service
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"errors"
"github.com/google/uuid"
)
func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) {
//生成uuid
uuid := uuid.New().String()
sessionName := "知识库会话:" + req.Name
//创建会话
err, sessionID := CreateSession(userID, proto.SessionTypeKnowledgeBase, sessionName)
if err != nil {
return dao.KnowledgeBase{}, err
}
// 创建知识库
base := dao.KnowledgeBase{AuthID: uint(userID), Name: req.Name, Description: req.Description, UUID: uuid, FileIDs: req.FileIDS, SessionID: sessionID}
knowledgeBase, err := dao.CreateKnowledgeBase(base)
if err != nil {
return dao.KnowledgeBase{}, err
}
// 创建向量数据库,后面完成
return knowledgeBase, err
}
// 更新知识库
func UpdateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) {
//查找知识库
knowledgeBase, err := dao.GetKnowledgeBaseByID(req.ID)
if err != nil {
return dao.KnowledgeBase{}, err
}
if knowledgeBase.AuthID != uint(userID) {
return dao.KnowledgeBase{}, errors.New("not authorized to update this knowledge base")
}
knowledgeBase.Name = req.Name
knowledgeBase.Description = req.Description
knowledgeBase.FileIDs = req.FileIDS
return dao.UpdateKnowledgeBase(knowledgeBase)
}
// 获取知识库列表
func GetKnowledgeBaseList(userID int) ([]dao.KnowledgeBase, error) {
knowledgeBases, err := dao.GetKnowledgeBaseList(userID)
if err != nil {
return nil, err
}
return knowledgeBases, nil
}
// 删除知识库
func DeleteKnowledgeBase(userID int, req proto.KnowledgeBaseReq) error {
//查找知识库
knowledgeBase, err := dao.GetKnowledgeBaseByID(req.ID)
if err != nil {
return err
}
if knowledgeBase.AuthID != uint(userID) {
return errors.New("not authorized to delete this knowledge base")
}
return dao.DeleteKnowledgeBase(knowledgeBase.ID)
}