修改知识库服务器id状态,添加知识库查询、聊天选择功能

This commit is contained in:
junleea 2025-04-10 15:06:26 +08:00
parent 1960155b19
commit 87e859cfca
6 changed files with 123 additions and 28 deletions

View File

@ -263,28 +263,34 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
if data.IsImage || data.IsFile { if data.IsImage || data.IsFile {
status = proto.UserToModelImageMsgType //图片类型 status = proto.UserToModelImageMsgType //图片类型
} }
//若为知识库消息需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
if data.IsKBase {
//根据会话ID获取对应知识库消息
kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId))
if err2 != nil {
log.Println("find knowledge base by session id error:", err2)
return err2
}
data.KBaseID = kbaseMsg.ID
err2 = service.SelectKBaseContentByQuery(&imContext, data)
if err2 != nil {
log.Println("select kbase content error:", err2)
return err2
}
return nil
}
//将消息存入数据库 //将消息存入数据库
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID)) err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
log.Println("create user message id:", userMsgID) log.Println("create user message id:", userMsgID)
if err != nil { if err != nil {
return err return err
} //机器人id为2,消息类型为1,状态为1 } //机器人id为2,消息类型为1,状态为1
//若为知识库消息需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
if data.IsKBase {
//根据会话ID获取对应知识库消息
kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId))
if err2 != nil {
log.Println("find knowledge base by session id error:", err2)
service.WSReturnErrorMessage(imContext.Channel, "find knowledge base error:"+err2.Error(), proto.WSKBaseSessionError)
return err2
}
data.KBaseID = kbaseMsg.ID
err2 = service.SelectKBaseContentByQuery(&imContext, data)
if err2 != nil {
log.Println("select kbase content error:", err2)
service.WSReturnErrorMessage(imContext.Channel, "select kbase error:"+err2.Error(), proto.WSKBaseServerError)
return err2
}
if data.KBaseType == "query" {
//若为查询直接返回
return nil
}
}
//需要后续处理 //需要后续处理
//err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam) //err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam)
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext) err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
@ -309,12 +315,20 @@ func KBaseTextWS(c *gin.Context) {
subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID) subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID)
} }
func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSercerID string) { func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseServerID string) {
ctx := context.Background() ctx := context.Background()
chanel := fmt.Sprintf("kbase_server_id_%s", kBaseSercerID) // 频道名称 chanel := fmt.Sprintf("kbase_server_id_%s", kBaseServerID) // 频道名称
//设置KBaseServerID可用状态加入可用集合
setKey := "kbase_server_id_set_is_available"
worker.SetRedisSetAdd(setKey, kBaseServerID)
pubsub := worker.RedisClient.Subscribe(ctx, chanel) pubsub := worker.RedisClient.Subscribe(ctx, chanel)
defer pubsub.Close() defer func() {
defer ws.Close() //设置KBaseServerID不可用状态删除可用集合
worker.SetRedisSetRemove(setKey, kBaseServerID)
pubsub.Close()
ws.Close()
}()
ch := pubsub.Channel() ch := pubsub.Channel()
//接收客户端消息,发送消息 //接收客户端消息,发送消息

View File

@ -12,6 +12,7 @@ type AIQueueMessage struct {
} }
type WSMessage struct { type WSMessage struct {
Code int `json:"code"` //返回码,0表示成功 其它状态码
Type string `json:"type"` //接收及发送消息类型 Type string `json:"type"` //接收及发送消息类型
Msg AIQueueMessage `json:"msg"` //消息内容,只进行转发,不做处理 Msg AIQueueMessage `json:"msg"` //消息内容,只进行转发,不做处理
SessionID int `json:"session_id"` //应用层会话id SessionID int `json:"session_id"` //应用层会话id
@ -29,6 +30,7 @@ type WSMessageReq struct {
IsFile bool `json:"is_file"` //是否为文件(文件消息)如果是文件消息则msg为文件消息结构为{"file_url":"","text":""} IsFile bool `json:"is_file"` //是否为文件(文件消息)如果是文件消息则msg为文件消息结构为{"file_url":"","text":""}
IsKBase bool `json:"is_kbase"` //是否为知识库消息 IsKBase bool `json:"is_kbase"` //是否为知识库消息
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传 KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
KBaseType string `json:"kbase_type"` //知识库模式1为query 2为chat
} }
type ModelParam struct { type ModelParam struct {

View File

@ -90,6 +90,10 @@ const (
FuncModelSearchFailed = 133 // 获取功能模型失败 FuncModelSearchFailed = 133 // 获取功能模型失败
GetSparkCreatePPTStatusFailed = 140 // 获取spark创建ppt状态失败 GetSparkCreatePPTStatusFailed = 140 // 获取spark创建ppt状态失败
//下面是ws消息错误码
WSKBaseServerError = 150 // ws知识库服务器错误
WSKBaseSessionError = 151 // ws知识库会话错误
) )
const ( const (

View File

@ -148,7 +148,7 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
//将消息存入数据库 //将消息存入数据库
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID) err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID)
if err != nil { if err != nil {
log.Println("spark create message error:", err) log.Println("doubao create message error:", err)
} }
//将本次请求的tokens存入数据库 //将本次请求的tokens存入数据库
err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))

View File

@ -278,9 +278,9 @@ func CheckUserCreatePPTSessionPermission(userID int) error {
} }
// 知识库消息,从原始请求数据,上下文数据 // 知识库消息,从原始请求数据,上下文数据
func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageReq) error { func SelectKBaseContentByQuery(imCtx *proto.IMParamContext, data *proto.WSMessageReq) error {
var req proto.KnowledgeBaseServerRequest var req proto.KnowledgeBaseServerRequest
req.IMContext = *ctx req.IMContext = *imCtx
kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID) kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID)
if err != nil { if err != nil {
return err return err
@ -302,6 +302,13 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7) res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7)
log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res) log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res)
} }
setKey := "kbase_server_id_set_is_available" //可用状态集合
//查看KBaseServerID是否可用
isAvailable := worker.SetRedisSetIsMember(setKey, KBaseServerID)
if isAvailable == false {
log.Println("当前知识库服务器不可用,重新分配")
return errors.New("kbase server id is not available")
}
log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name) log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name)
reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称 reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称
reqStr, _ := json.Marshal(req) reqStr, _ := json.Marshal(req)
@ -313,11 +320,12 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
defer pubsub.Close() defer pubsub.Close()
ch := pubsub.Channel() ch := pubsub.Channel()
//1min定时器超时则返回错误 //1min定时器超时则返回错误
timer := time.NewTimer(time.Second * 10) timer := time.NewTimer(time.Second * 30)
var resp proto.KnowledgeBaseServerResponse var resp proto.KnowledgeBaseServerResponse
select { select {
case m := <-ch: case m := <-ch:
msg := m.Payload // 获取消息,消息格式为json msg := m.Payload // 获取消息,消息格式为json
//log.Println("resp:", msg)
err2 := json.Unmarshal([]byte(msg), &resp) err2 := json.Unmarshal([]byte(msg), &resp)
if err2 != nil { if err2 != nil {
log.Println("Error parsing JSON:", err2) log.Println("Error parsing JSON:", err2)
@ -330,7 +338,34 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
//解析返回的查询消息 //解析返回的查询消息
text := "根据下面文档片段:\n" text := "根据下面文档片段:\n"
querySelect := resp.QuerySelect querySelect := resp.QuerySelect
qText := getSelectedDocument(querySelect)
log.Println("kbase type:", data.KBaseType)
if data.KBaseType == "query" {
//则将query为模型返回结果
//将消息存入数据库
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, qText, imCtx.FunctionID)
if err != nil {
log.Println("kbase query message save error:", err)
}
kbaseQueryToGeneralMassageAndSendMsgQueue(qText, imCtx.Channel, imCtx.SessionID, imCtx.UserID) //将消息发送到消息队列
} else if data.KBaseType == "chat" {
//用户问题
text += qText
text += "回答用户问题:" + data.Msg
data.Msg = text
log.Println("text:", text)
} else {
return errors.New("unknow kbase type")
}
return nil
}
func getSelectedDocument(querySelect proto.KnowledgeBaseServerResponseSelect) string {
text := ""
for i := 0; i < len(querySelect.Documents); i++ { for i := 0; i < len(querySelect.Documents); i++ {
if len(querySelect.Documents[i]) == 0 {
break
}
text += "文件:" text += "文件:"
for j := 0; j < len(querySelect.Metadatas); j++ { for j := 0; j < len(querySelect.Metadatas); j++ {
text += querySelect.Metadatas[i][j].Source + "" text += querySelect.Metadatas[i][j].Source + ""
@ -341,8 +376,38 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
} }
text += "\n" text += "\n"
} }
//用户问题 return text
text += "回答下面用户问题:" + data.Msg }
data.Msg = text
return nil func WSReturnErrorMessage(channel string, msg string, code int) {
var data proto.WSMessage
data.Code = code
data.Type = msg //错误时type为错误信息
data.SessionID = 0
data.ToID = 0
errStr, _ := json.Marshal(data)
worker.Publish(channel, string(errStr), time.Second*60)
}
func kbaseQueryToGeneralMassageAndSendMsgQueue(answer, channel string, SessionID, UserID int) {
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
apiMsg.Done = true
apiMsg.Model = proto.KnowledgeBaseServerResponseType
apiMsg.Response = answer
//消息队列部分
aiMsg.Type = proto.KnowledgeBaseServerResponseType
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
wsMsg.Type = proto.KnowledgeBaseServerResponseType
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
worker.Publish(channel, string(wsMsgStr), time.Second*60)
} }

View File

@ -553,3 +553,13 @@ func GetRedisSetRandomMember(key string) string {
} }
return val return val
} }
func SetRedisSetIsMember(setKey, memberKey string) bool {
ctx := context.Background()
val, err := RedisClient.SIsMember(ctx, setKey, memberKey).Result()
if err != nil {
fmt.Println("Error getting key: %v", err)
return false
}
return val
}