完成知识库的消息流处理即知识库与处理服务器,用户文件支持单个查询

This commit is contained in:
junleea 2025-04-09 17:37:24 +08:00
parent 4f71c12e2d
commit 1960155b19
7 changed files with 80 additions and 4 deletions

View File

@ -97,3 +97,20 @@ func DeleteKnowledgeBase(id uint) error {
}
return nil
}
// 根据会话id查找知识库
func GetKnowledgeBaseBySessionID(sessionID uint) (KnowledgeBase, error) {
var db2 *gorm.DB
var err error
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
var base KnowledgeBase
err = db2.Where("session_id = ?", sessionID).First(&base).Error
if err != nil {
return KnowledgeBase{}, err
}
return base, nil
}

View File

@ -44,6 +44,18 @@ func GetUserFileList(c *gin.Context) {
} else if req.Type == "search" {
fileList := dao.FindFileByUserFileName(req.FileName, userId)
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "msg": "success", "data": fileList})
} else if req.Type == "one" {
if req.FileID == 0 {
c.JSON(http.StatusOK, gin.H{"error": "file_id is empty", "code": proto.ParameterError, "message": "failed"})
return
}
userFile, err2 := service.FindUserFileByID(req.FileID, userId)
if err2 != nil {
c.JSON(http.StatusOK, gin.H{"error": "find user file failed", "code": proto.ParameterError, "message": "failed"})
return
} else {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "msg": "success", "data": userFile})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": "search file type error", "code": proto.ParameterError, "message": "failed"})
}

View File

@ -265,11 +265,19 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
}
//若为知识库消息需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
if data.IsKBase {
err2 := service.SelectKBaseContentByQuery(&imContext, data)
//根据会话ID获取对应知识库消息
kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId))
if err2 != nil {
log.Println("find knowledge base by session id error:", err2)
return err2
}
data.KBaseID = kbaseMsg.ID
err2 = service.SelectKBaseContentByQuery(&imContext, data)
if err2 != nil {
log.Println("select kbase content error:", err2)
return err2
}
return nil
}
//将消息存入数据库
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
@ -288,6 +296,10 @@ func KBaseTextWS(c *gin.Context) {
userID := int(id.(float64))
//query参数
kBaseSercerID := c.Query("kbase_server_id")
if kBaseSercerID == "" {
c.JSON(http.StatusOK, gin.H{"code": proto.ParameterError, "message": "kbase_server_id不能为空"})
return
}
// 升级HTTP连接为WebSocket连接
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
if err1 != nil {

View File

@ -157,6 +157,8 @@ const (
UserFileTypeFile = "file" // 通用文件
UserFileTypeConfig = "config" // 配置文件
UserMaxUploadSize = 1024 * 1024 * 100
KnowledgeBaseFunction = "kbase-chat" // 知识库功能
)
// 会话类型

View File

@ -279,6 +279,20 @@ func FindUserFileList(userID int, uploadType string) []dao.UserFileListResp {
return res
}
func FindUserFileByID(id, userID int) (dao.UserFileListResp, error) {
fileAuth := dao.FindFileAuthByID(id)
if fileAuth.ID == 0 {
return dao.UserFileListResp{}, errors.New("file auth not found")
}
if fileAuth.AuthID != userID {
return dao.UserFileListResp{}, errors.New("no permission")
}
file := dao.FindFileByID(fileAuth.FileID, userID)
fileStoreName := file.FileStoreName
res := dao.UserFileListResp{FileAuth: fileAuth, FileStoreName: fileStoreName}
return res, nil
}
func CreateUserFileAfterUnique() {
}

View File

@ -299,14 +299,21 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
return errors.New("no kbase server id")
}
//记录分配的知识库服务器
worker.SetRedis(key, KBaseServerID)
res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7)
log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res)
}
chanel := fmt.Sprintf("kbase_%s", kbase.UUID)
log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name)
reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称
reqStr, _ := json.Marshal(req)
//将消息转发到对应频道
worker.Publish(reqChanel, string(reqStr), time.Second*60)
chanel := fmt.Sprintf("kbase_%s", kbase.UUID) //获取返回消息频道
pubsub := worker.RedisClient.Subscribe(context.Background(), chanel)
defer pubsub.Close()
ch := pubsub.Channel()
//1min定时器超时则返回错误
timer := time.NewTimer(time.Minute)
timer := time.NewTimer(time.Second * 10)
var resp proto.KnowledgeBaseServerResponse
select {
case m := <-ch:

View File

@ -80,3 +80,15 @@ func AddKnowledgeBaseServer() {
}
worker.SetRedisSet(set, serverIDs, time.Hour)
}
// 根据会话id查找知识库
func FindKnowledgeBaseBySessionID(sessionID uint, userID uint) (dao.KnowledgeBase, error) {
knowledgeBase, err := dao.GetKnowledgeBaseBySessionID(sessionID)
if err != nil {
return dao.KnowledgeBase{}, err
}
if knowledgeBase.AuthID != userID {
return dao.KnowledgeBase{}, errors.New("not authorized to access this knowledge base")
}
return knowledgeBase, nil
}