完成知识库的消息流处理即知识库与处理服务器,用户文件支持单个查询
This commit is contained in:
parent
4f71c12e2d
commit
1960155b19
17
dao/kbase.go
17
dao/kbase.go
|
|
@ -97,3 +97,20 @@ func DeleteKnowledgeBase(id uint) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 根据会话id查找知识库
|
||||
func GetKnowledgeBaseBySessionID(sessionID uint) (KnowledgeBase, error) {
|
||||
var db2 *gorm.DB
|
||||
var err error
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
db2 = DB.Debug()
|
||||
} else {
|
||||
db2 = DB
|
||||
}
|
||||
var base KnowledgeBase
|
||||
err = db2.Where("session_id = ?", sessionID).First(&base).Error
|
||||
if err != nil {
|
||||
return KnowledgeBase{}, err
|
||||
}
|
||||
return base, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,18 @@ func GetUserFileList(c *gin.Context) {
|
|||
} else if req.Type == "search" {
|
||||
fileList := dao.FindFileByUserFileName(req.FileName, userId)
|
||||
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "msg": "success", "data": fileList})
|
||||
} else if req.Type == "one" {
|
||||
if req.FileID == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "file_id is empty", "code": proto.ParameterError, "message": "failed"})
|
||||
return
|
||||
}
|
||||
userFile, err2 := service.FindUserFileByID(req.FileID, userId)
|
||||
if err2 != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "find user file failed", "code": proto.ParameterError, "message": "failed"})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "msg": "success", "data": userFile})
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "search file type error", "code": proto.ParameterError, "message": "failed"})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -265,11 +265,19 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
}
|
||||
//若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
|
||||
if data.IsKBase {
|
||||
err2 := service.SelectKBaseContentByQuery(&imContext, data)
|
||||
//根据会话ID获取对应知识库消息
|
||||
kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId))
|
||||
if err2 != nil {
|
||||
log.Println("find knowledge base by session id error:", err2)
|
||||
return err2
|
||||
}
|
||||
data.KBaseID = kbaseMsg.ID
|
||||
err2 = service.SelectKBaseContentByQuery(&imContext, data)
|
||||
if err2 != nil {
|
||||
log.Println("select kbase content error:", err2)
|
||||
return err2
|
||||
}
|
||||
return nil
|
||||
}
|
||||
//将消息存入数据库
|
||||
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
|
||||
|
|
@ -288,6 +296,10 @@ func KBaseTextWS(c *gin.Context) {
|
|||
userID := int(id.(float64))
|
||||
//query参数
|
||||
kBaseSercerID := c.Query("kbase_server_id")
|
||||
if kBaseSercerID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"code": proto.ParameterError, "message": "kbase_server_id不能为空"})
|
||||
return
|
||||
}
|
||||
// 升级HTTP连接为WebSocket连接
|
||||
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err1 != nil {
|
||||
|
|
|
|||
|
|
@ -157,6 +157,8 @@ const (
|
|||
UserFileTypeFile = "file" // 通用文件
|
||||
UserFileTypeConfig = "config" // 配置文件
|
||||
UserMaxUploadSize = 1024 * 1024 * 100
|
||||
|
||||
KnowledgeBaseFunction = "kbase-chat" // 知识库功能
|
||||
)
|
||||
|
||||
// 会话类型
|
||||
|
|
|
|||
|
|
@ -279,6 +279,20 @@ func FindUserFileList(userID int, uploadType string) []dao.UserFileListResp {
|
|||
return res
|
||||
}
|
||||
|
||||
func FindUserFileByID(id, userID int) (dao.UserFileListResp, error) {
|
||||
fileAuth := dao.FindFileAuthByID(id)
|
||||
if fileAuth.ID == 0 {
|
||||
return dao.UserFileListResp{}, errors.New("file auth not found")
|
||||
}
|
||||
if fileAuth.AuthID != userID {
|
||||
return dao.UserFileListResp{}, errors.New("no permission")
|
||||
}
|
||||
file := dao.FindFileByID(fileAuth.FileID, userID)
|
||||
fileStoreName := file.FileStoreName
|
||||
res := dao.UserFileListResp{FileAuth: fileAuth, FileStoreName: fileStoreName}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func CreateUserFileAfterUnique() {
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -299,14 +299,21 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
|
|||
return errors.New("no kbase server id")
|
||||
}
|
||||
//记录分配的知识库服务器
|
||||
worker.SetRedis(key, KBaseServerID)
|
||||
res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7)
|
||||
log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res)
|
||||
}
|
||||
chanel := fmt.Sprintf("kbase_%s", kbase.UUID)
|
||||
log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name)
|
||||
reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称
|
||||
reqStr, _ := json.Marshal(req)
|
||||
//将消息转发到对应频道
|
||||
worker.Publish(reqChanel, string(reqStr), time.Second*60)
|
||||
|
||||
chanel := fmt.Sprintf("kbase_%s", kbase.UUID) //获取返回消息频道
|
||||
pubsub := worker.RedisClient.Subscribe(context.Background(), chanel)
|
||||
defer pubsub.Close()
|
||||
ch := pubsub.Channel()
|
||||
//1min定时器,超时则返回错误
|
||||
timer := time.NewTimer(time.Minute)
|
||||
timer := time.NewTimer(time.Second * 10)
|
||||
var resp proto.KnowledgeBaseServerResponse
|
||||
select {
|
||||
case m := <-ch:
|
||||
|
|
|
|||
|
|
@ -80,3 +80,15 @@ func AddKnowledgeBaseServer() {
|
|||
}
|
||||
worker.SetRedisSet(set, serverIDs, time.Hour)
|
||||
}
|
||||
|
||||
// 根据会话id查找知识库
|
||||
func FindKnowledgeBaseBySessionID(sessionID uint, userID uint) (dao.KnowledgeBase, error) {
|
||||
knowledgeBase, err := dao.GetKnowledgeBaseBySessionID(sessionID)
|
||||
if err != nil {
|
||||
return dao.KnowledgeBase{}, err
|
||||
}
|
||||
if knowledgeBase.AuthID != userID {
|
||||
return dao.KnowledgeBase{}, errors.New("not authorized to access this knowledge base")
|
||||
}
|
||||
return knowledgeBase, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue