修改知识库服务器id状态,添加知识库查询、聊天选择功能
This commit is contained in:
parent
1960155b19
commit
87e859cfca
|
|
@ -263,28 +263,34 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
if data.IsImage || data.IsFile {
|
||||
status = proto.UserToModelImageMsgType //图片类型
|
||||
}
|
||||
//若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
|
||||
if data.IsKBase {
|
||||
//根据会话ID获取对应知识库消息
|
||||
kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId))
|
||||
if err2 != nil {
|
||||
log.Println("find knowledge base by session id error:", err2)
|
||||
return err2
|
||||
}
|
||||
data.KBaseID = kbaseMsg.ID
|
||||
err2 = service.SelectKBaseContentByQuery(&imContext, data)
|
||||
if err2 != nil {
|
||||
log.Println("select kbase content error:", err2)
|
||||
return err2
|
||||
}
|
||||
return nil
|
||||
}
|
||||
//将消息存入数据库
|
||||
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
|
||||
log.Println("create user message id:", userMsgID)
|
||||
if err != nil {
|
||||
return err
|
||||
} //机器人id为2,消息类型为1,状态为1
|
||||
|
||||
//若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
|
||||
if data.IsKBase {
|
||||
//根据会话ID获取对应知识库消息
|
||||
kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId))
|
||||
if err2 != nil {
|
||||
log.Println("find knowledge base by session id error:", err2)
|
||||
service.WSReturnErrorMessage(imContext.Channel, "find knowledge base error:"+err2.Error(), proto.WSKBaseSessionError)
|
||||
return err2
|
||||
}
|
||||
data.KBaseID = kbaseMsg.ID
|
||||
err2 = service.SelectKBaseContentByQuery(&imContext, data)
|
||||
if err2 != nil {
|
||||
log.Println("select kbase content error:", err2)
|
||||
service.WSReturnErrorMessage(imContext.Channel, "select kbase error:"+err2.Error(), proto.WSKBaseServerError)
|
||||
return err2
|
||||
}
|
||||
if data.KBaseType == "query" {
|
||||
//若为查询直接返回
|
||||
return nil
|
||||
}
|
||||
}
|
||||
//需要后续处理
|
||||
//err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam)
|
||||
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
|
||||
|
|
@ -309,12 +315,20 @@ func KBaseTextWS(c *gin.Context) {
|
|||
subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID)
|
||||
}
|
||||
|
||||
func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSercerID string) {
|
||||
func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseServerID string) {
|
||||
ctx := context.Background()
|
||||
chanel := fmt.Sprintf("kbase_server_id_%s", kBaseSercerID) // 频道名称
|
||||
chanel := fmt.Sprintf("kbase_server_id_%s", kBaseServerID) // 频道名称
|
||||
//设置KBaseServerID可用状态,加入可用集合
|
||||
setKey := "kbase_server_id_set_is_available"
|
||||
worker.SetRedisSetAdd(setKey, kBaseServerID)
|
||||
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
|
||||
defer pubsub.Close()
|
||||
defer ws.Close()
|
||||
defer func() {
|
||||
//设置KBaseServerID不可用状态,删除可用集合
|
||||
worker.SetRedisSetRemove(setKey, kBaseServerID)
|
||||
pubsub.Close()
|
||||
ws.Close()
|
||||
}()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
|
||||
//接收客户端消息,发送消息
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ type AIQueueMessage struct {
|
|||
}
|
||||
|
||||
type WSMessage struct {
|
||||
Code int `json:"code"` //返回码,0表示成功, 其它状态码
|
||||
Type string `json:"type"` //接收及发送消息类型
|
||||
Msg AIQueueMessage `json:"msg"` //消息内容,只进行转发,不做处理
|
||||
SessionID int `json:"session_id"` //应用层会话id
|
||||
|
|
@ -29,6 +30,7 @@ type WSMessageReq struct {
|
|||
IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""}
|
||||
IsKBase bool `json:"is_kbase"` //是否为知识库消息
|
||||
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
|
||||
KBaseType string `json:"kbase_type"` //知识库模式,1为query 2为chat
|
||||
}
|
||||
|
||||
type ModelParam struct {
|
||||
|
|
|
|||
|
|
@ -90,6 +90,10 @@ const (
|
|||
FuncModelSearchFailed = 133 // 获取功能模型失败
|
||||
|
||||
GetSparkCreatePPTStatusFailed = 140 // 获取spark创建ppt状态失败
|
||||
|
||||
//下面是ws消息错误码
|
||||
WSKBaseServerError = 150 // ws知识库服务器错误
|
||||
WSKBaseSessionError = 151 // ws知识库会话错误
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
|
|||
//将消息存入数据库
|
||||
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID)
|
||||
if err != nil {
|
||||
log.Println("spark create message error:", err)
|
||||
log.Println("doubao create message error:", err)
|
||||
}
|
||||
//将本次请求的tokens存入数据库
|
||||
err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
|
||||
|
|
|
|||
|
|
@ -278,9 +278,9 @@ func CheckUserCreatePPTSessionPermission(userID int) error {
|
|||
}
|
||||
|
||||
// 知识库消息,从原始请求数据,上下文数据
|
||||
func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageReq) error {
|
||||
func SelectKBaseContentByQuery(imCtx *proto.IMParamContext, data *proto.WSMessageReq) error {
|
||||
var req proto.KnowledgeBaseServerRequest
|
||||
req.IMContext = *ctx
|
||||
req.IMContext = *imCtx
|
||||
kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -302,6 +302,13 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
|
|||
res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7)
|
||||
log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res)
|
||||
}
|
||||
setKey := "kbase_server_id_set_is_available" //可用状态集合
|
||||
//查看KBaseServerID是否可用
|
||||
isAvailable := worker.SetRedisSetIsMember(setKey, KBaseServerID)
|
||||
if isAvailable == false {
|
||||
log.Println("当前知识库服务器不可用,重新分配")
|
||||
return errors.New("kbase server id is not available")
|
||||
}
|
||||
log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name)
|
||||
reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称
|
||||
reqStr, _ := json.Marshal(req)
|
||||
|
|
@ -313,11 +320,12 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
|
|||
defer pubsub.Close()
|
||||
ch := pubsub.Channel()
|
||||
//1min定时器,超时则返回错误
|
||||
timer := time.NewTimer(time.Second * 10)
|
||||
timer := time.NewTimer(time.Second * 30)
|
||||
var resp proto.KnowledgeBaseServerResponse
|
||||
select {
|
||||
case m := <-ch:
|
||||
msg := m.Payload // 获取消息,消息格式为json
|
||||
//log.Println("resp:", msg)
|
||||
err2 := json.Unmarshal([]byte(msg), &resp)
|
||||
if err2 != nil {
|
||||
log.Println("Error parsing JSON:", err2)
|
||||
|
|
@ -330,7 +338,34 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
|
|||
//解析返回的查询消息
|
||||
text := "根据下面文档片段:\n"
|
||||
querySelect := resp.QuerySelect
|
||||
qText := getSelectedDocument(querySelect)
|
||||
log.Println("kbase type:", data.KBaseType)
|
||||
if data.KBaseType == "query" {
|
||||
//则将query为模型返回结果
|
||||
//将消息存入数据库
|
||||
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, qText, imCtx.FunctionID)
|
||||
if err != nil {
|
||||
log.Println("kbase query message save error:", err)
|
||||
}
|
||||
kbaseQueryToGeneralMassageAndSendMsgQueue(qText, imCtx.Channel, imCtx.SessionID, imCtx.UserID) //将消息发送到消息队列
|
||||
} else if data.KBaseType == "chat" {
|
||||
//用户问题
|
||||
text += qText
|
||||
text += "回答用户问题:" + data.Msg
|
||||
data.Msg = text
|
||||
log.Println("text:", text)
|
||||
} else {
|
||||
return errors.New("unknow kbase type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSelectedDocument(querySelect proto.KnowledgeBaseServerResponseSelect) string {
|
||||
text := ""
|
||||
for i := 0; i < len(querySelect.Documents); i++ {
|
||||
if len(querySelect.Documents[i]) == 0 {
|
||||
break
|
||||
}
|
||||
text += "文件:"
|
||||
for j := 0; j < len(querySelect.Metadatas); j++ {
|
||||
text += querySelect.Metadatas[i][j].Source + ","
|
||||
|
|
@ -341,8 +376,38 @@ func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageR
|
|||
}
|
||||
text += "\n"
|
||||
}
|
||||
//用户问题
|
||||
text += "回答下面用户问题:" + data.Msg
|
||||
data.Msg = text
|
||||
return nil
|
||||
return text
|
||||
}
|
||||
|
||||
func WSReturnErrorMessage(channel string, msg string, code int) {
|
||||
var data proto.WSMessage
|
||||
data.Code = code
|
||||
data.Type = msg //错误时type为错误信息
|
||||
data.SessionID = 0
|
||||
data.ToID = 0
|
||||
errStr, _ := json.Marshal(data)
|
||||
worker.Publish(channel, string(errStr), time.Second*60)
|
||||
}
|
||||
|
||||
func kbaseQueryToGeneralMassageAndSendMsgQueue(answer, channel string, SessionID, UserID int) {
|
||||
var aiMsg proto.AIQueueMessage
|
||||
var wsMsg proto.WSMessage
|
||||
var apiMsg api.GenerateResponse
|
||||
|
||||
apiMsg.Done = true
|
||||
|
||||
apiMsg.Model = proto.KnowledgeBaseServerResponseType
|
||||
apiMsg.Response = answer
|
||||
//消息队列部分
|
||||
aiMsg.Type = proto.KnowledgeBaseServerResponseType
|
||||
aiMsg.Msg = apiMsg
|
||||
//ws发送消息部分
|
||||
wsMsg.Msg = aiMsg
|
||||
wsMsg.SessionID = SessionID
|
||||
wsMsg.ToID = UserID
|
||||
wsMsg.Type = proto.KnowledgeBaseServerResponseType
|
||||
|
||||
//发送消息
|
||||
wsMsgStr, _ := json.Marshal(wsMsg)
|
||||
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -553,3 +553,13 @@ func GetRedisSetRandomMember(key string) string {
|
|||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func SetRedisSetIsMember(setKey, memberKey string) bool {
|
||||
ctx := context.Background()
|
||||
val, err := RedisClient.SIsMember(ctx, setKey, memberKey).Result()
|
||||
if err != nil {
|
||||
fmt.Println("Error getting key: %v", err)
|
||||
return false
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue