添加知识库增删改查
This commit is contained in:
parent
6eee65aa32
commit
349ee3af05
|
|
@ -61,6 +61,12 @@ func Init() error {
|
|||
return err
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&KnowledgeBase{})
|
||||
if err != nil {
|
||||
fmt.Println("knowledge base table:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&Model{})
|
||||
if err != nil {
|
||||
fmt.Println("model table:", err)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ type Session struct {
|
|||
gorm.Model
|
||||
UserID int `gorm:"column:user_id"` //只能由用户创建
|
||||
Name string `gorm:"column:name"`
|
||||
Type int `gorm:"column:type"` //会话类型:1为通用聊天会话,2为提示词会话,3为制作ppt会话
|
||||
Type int `gorm:"column:type"` //会话类型:1为通用聊天会话,2为提示词会话,3为制作ppt会话,4为知识库会话
|
||||
Context json.RawMessage `gorm:"column:context;type:json"` //会话上下文
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
package dao
|
||||
|
||||
import (
|
||||
"StuAcaWorksAI/proto"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 知识库表结构,存储知识库的基本信息及文件信息
|
||||
type KnowledgeBase struct {
|
||||
gorm.Model
|
||||
AuthID uint `gorm:"column:auth_id"` // 知识库创建者ID
|
||||
Name string `gorm:"column:name"` // 知识库名称
|
||||
Description string `gorm:"column:description"` // 知识库描述
|
||||
SessionID uint `gorm:"column:session_id"` // 会话ID,创建知识库自动生成会话
|
||||
UUID string `gorm:"column:uuid"` // 知识库UUID,用于向量数据库唯一标识
|
||||
FileIDs string `gorm:"column:file_ids"` // 文件ID列表,逗号分隔json '[{'file_id';1}]'
|
||||
}
|
||||
|
||||
// 创建知识库
|
||||
func CreateKnowledgeBase(base KnowledgeBase) (KnowledgeBase, error) {
|
||||
var db2 *gorm.DB
|
||||
var err error
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
db2 = DB.Debug()
|
||||
} else {
|
||||
db2 = DB
|
||||
}
|
||||
err = db2.Create(&base).Error
|
||||
if err != nil {
|
||||
return KnowledgeBase{}, err
|
||||
}
|
||||
return base, nil
|
||||
}
|
||||
|
||||
// 更新知识库
|
||||
func UpdateKnowledgeBase(base KnowledgeBase) (KnowledgeBase, error) {
|
||||
var db2 *gorm.DB
|
||||
var err error
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
db2 = DB.Debug()
|
||||
} else {
|
||||
db2 = DB
|
||||
}
|
||||
err = db2.Save(&base).Error
|
||||
if err != nil {
|
||||
return KnowledgeBase{}, err
|
||||
}
|
||||
return base, nil
|
||||
}
|
||||
|
||||
// 获取知识库列表
|
||||
func GetKnowledgeBaseList(authID int) ([]KnowledgeBase, error) {
|
||||
var db2 *gorm.DB
|
||||
var err error
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
db2 = DB.Debug()
|
||||
} else {
|
||||
db2 = DB
|
||||
}
|
||||
var bases []KnowledgeBase
|
||||
err = db2.Where("auth_id = ?", authID).Find(&bases).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bases, nil
|
||||
}
|
||||
|
||||
// 根据id获取知识库
|
||||
func GetKnowledgeBaseByID(id uint) (KnowledgeBase, error) {
|
||||
var db2 *gorm.DB
|
||||
var err error
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
db2 = DB.Debug()
|
||||
} else {
|
||||
db2 = DB
|
||||
}
|
||||
var base KnowledgeBase
|
||||
err = db2.Where("id = ?", id).First(&base).Error
|
||||
if err != nil {
|
||||
return KnowledgeBase{}, err
|
||||
}
|
||||
return base, nil
|
||||
}
|
||||
|
||||
// 删除知识库
|
||||
func DeleteKnowledgeBase(id uint) error {
|
||||
var db2 *gorm.DB
|
||||
var err error
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
db2 = DB.Debug()
|
||||
} else {
|
||||
db2 = DB
|
||||
}
|
||||
err = db2.Where("id = ?", id).Delete(&KnowledgeBase{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"StuAcaWorksAI/proto"
|
||||
"StuAcaWorksAI/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func SetUpKnowledgeBaseGroup(router *gin.Engine) {
|
||||
knowledgeBaseGroup := router.Group("/knowledgeBase")
|
||||
knowledgeBaseGroup.POST("/create", CreateKnowledgeBase) // 创建知识库
|
||||
knowledgeBaseGroup.POST("/update", UpdateKnowledgeBase) // 更新知识库
|
||||
knowledgeBaseGroup.POST("/list", GetKnowledgeBaseList) // 获取知识库列表
|
||||
knowledgeBaseGroup.POST("/delete", DeleteKnowledgeBase) // 删除知识库
|
||||
}
|
||||
|
||||
func CreateKnowledgeBase(c *gin.Context) {
|
||||
id, _ := c.Get("id")
|
||||
userID := int(id.(float64))
|
||||
// 处理创建知识库的逻辑
|
||||
var req proto.KnowledgeBaseReq
|
||||
var resp proto.GenerateResp
|
||||
if err := c.ShouldBind(&req); err == nil {
|
||||
knowledgeBase, err2 := service.CreateKnowledgeBase(userID, req)
|
||||
if err2 == nil {
|
||||
resp.Code = proto.SuccessCode
|
||||
resp.Message = "success"
|
||||
resp.Data = knowledgeBase
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "内部错误:" + err2.Error()
|
||||
}
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "参数错误:" + err.Error()
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func UpdateKnowledgeBase(c *gin.Context) {
|
||||
id, _ := c.Get("id")
|
||||
userID := int(id.(float64))
|
||||
// 处理更新知识库的逻辑
|
||||
var req proto.KnowledgeBaseReq
|
||||
var resp proto.GenerateResp
|
||||
if err := c.ShouldBind(&req); err == nil {
|
||||
knowledgeBase, err2 := service.UpdateKnowledgeBase(userID, req)
|
||||
if err2 == nil {
|
||||
resp.Code = proto.SuccessCode
|
||||
resp.Message = "success"
|
||||
resp.Data = knowledgeBase
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "内部错误:" + err2.Error()
|
||||
}
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "参数错误:" + err.Error()
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func GetKnowledgeBaseList(c *gin.Context) {
|
||||
id, _ := c.Get("id")
|
||||
userID := int(id.(float64))
|
||||
// 处理获取知识库列表的逻辑
|
||||
var resp proto.GenerateResp
|
||||
knowledgeBases, err2 := service.GetKnowledgeBaseList(userID)
|
||||
if err2 == nil {
|
||||
resp.Code = proto.SuccessCode
|
||||
resp.Message = "success"
|
||||
resp.Data = knowledgeBases
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "内部错误:" + err2.Error()
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func DeleteKnowledgeBase(c *gin.Context) {
|
||||
id, _ := c.Get("id")
|
||||
userID := int(id.(float64))
|
||||
// 处理更新知识库的逻辑
|
||||
var req proto.KnowledgeBaseReq
|
||||
var resp proto.GenerateResp
|
||||
if err := c.ShouldBind(&req); err == nil {
|
||||
err2 := service.DeleteKnowledgeBase(userID, req)
|
||||
if err2 == nil {
|
||||
resp.Code = proto.SuccessCode
|
||||
resp.Message = "success"
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "内部处理错误:" + err2.Error()
|
||||
}
|
||||
} else {
|
||||
resp.Code = proto.ParameterError
|
||||
resp.Message = "参数错误:" + err.Error()
|
||||
}
|
||||
|
||||
}
|
||||
17
main.go
17
main.go
|
|
@ -31,14 +31,15 @@ func main() {
|
|||
panic("failed to connect redis:" + err.Error())
|
||||
}
|
||||
r.Use(handler.CrosHandler())
|
||||
r.Use(JWTAuthMiddleware()) // 使用 JWT 认证中间件
|
||||
handler.SetUpUserGroup(r) // User
|
||||
handler.SetUpToolGroup(r) // Tool
|
||||
handler.SetUpFileGroup(r) // File
|
||||
handler.SetUpSessionGroup(r) // Session
|
||||
handler.SetUpModelGroup(r) // Model
|
||||
handler.SetUpFuncModelGroup(r) // FuncModel
|
||||
handler.SetUpIMGroup(r) // IM
|
||||
r.Use(JWTAuthMiddleware()) // 使用 JWT 认证中间件
|
||||
handler.SetUpUserGroup(r) // User
|
||||
handler.SetUpToolGroup(r) // Tool
|
||||
handler.SetUpFileGroup(r) // File
|
||||
handler.SetUpKnowledgeBaseGroup(r) // KnowledgeBase
|
||||
handler.SetUpSessionGroup(r) // Session
|
||||
handler.SetUpModelGroup(r) // Model
|
||||
handler.SetUpFuncModelGroup(r) // FuncModel
|
||||
handler.SetUpIMGroup(r) // IM
|
||||
defer dao.Close()
|
||||
defer worker.CloseRedis()
|
||||
//定时任务
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
package proto
|
||||
|
||||
type KnowledgeBaseReq struct {
|
||||
ID uint `json:"id"` // 知识库ID
|
||||
UUID string `json:"uuid"` // 知识库UUID
|
||||
Name string `json:"name"` // 知识库名称
|
||||
Description string `json:"description"` // 知识库描述
|
||||
FileIDS string `json:"file_ids"` // 文件ID列表,逗号分隔
|
||||
}
|
||||
|
|
@ -163,6 +163,7 @@ const (
|
|||
SessionTypeUserWithModelGeneration = 1 // 用户与模型通用会话
|
||||
SessionTypeUserPrompt = 2 // 用户与模型提示词
|
||||
SessionTypeUserCreatePPT = 3 // 用户与模型制作PPT
|
||||
SessionTypeKnowledgeBase = 4 // 用户与知识库
|
||||
)
|
||||
|
||||
// 文件类型(文本、图片)
|
||||
|
|
|
|||
|
|
@ -42,3 +42,9 @@ type FileContentResp struct {
|
|||
Message string `json:"message"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
type GenerateResp struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"StuAcaWorksAI/dao"
|
||||
"StuAcaWorksAI/proto"
|
||||
"errors"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) {
|
||||
//生成uuid
|
||||
uuid := uuid.New().String()
|
||||
sessionName := "知识库会话:" + req.Name
|
||||
//创建会话
|
||||
err, sessionID := CreateSession(userID, proto.SessionTypeKnowledgeBase, sessionName)
|
||||
if err != nil {
|
||||
return dao.KnowledgeBase{}, err
|
||||
}
|
||||
|
||||
// 创建知识库
|
||||
base := dao.KnowledgeBase{AuthID: uint(userID), Name: req.Name, Description: req.Description, UUID: uuid, FileIDs: req.FileIDS, SessionID: sessionID}
|
||||
knowledgeBase, err := dao.CreateKnowledgeBase(base)
|
||||
if err != nil {
|
||||
return dao.KnowledgeBase{}, err
|
||||
}
|
||||
// 创建向量数据库,后面完成
|
||||
|
||||
return knowledgeBase, err
|
||||
}
|
||||
|
||||
// 更新知识库
|
||||
func UpdateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) {
|
||||
//查找知识库
|
||||
knowledgeBase, err := dao.GetKnowledgeBaseByID(req.ID)
|
||||
if err != nil {
|
||||
return dao.KnowledgeBase{}, err
|
||||
}
|
||||
if knowledgeBase.AuthID != uint(userID) {
|
||||
return dao.KnowledgeBase{}, errors.New("not authorized to update this knowledge base")
|
||||
}
|
||||
|
||||
knowledgeBase.Name = req.Name
|
||||
knowledgeBase.Description = req.Description
|
||||
knowledgeBase.FileIDs = req.FileIDS
|
||||
|
||||
return dao.UpdateKnowledgeBase(knowledgeBase)
|
||||
}
|
||||
|
||||
// 获取知识库列表
|
||||
func GetKnowledgeBaseList(userID int) ([]dao.KnowledgeBase, error) {
|
||||
knowledgeBases, err := dao.GetKnowledgeBaseList(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return knowledgeBases, nil
|
||||
}
|
||||
|
||||
// 删除知识库
|
||||
func DeleteKnowledgeBase(userID int, req proto.KnowledgeBaseReq) error {
|
||||
//查找知识库
|
||||
knowledgeBase, err := dao.GetKnowledgeBaseByID(req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if knowledgeBase.AuthID != uint(userID) {
|
||||
return errors.New("not authorized to delete this knowledge base")
|
||||
}
|
||||
|
||||
return dao.DeleteKnowledgeBase(knowledgeBase.ID)
|
||||
}
|
||||
Loading…
Reference in New Issue