From 4f71c12e2d97bdd7fad0c371bca55a3dddd48394 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Wed, 9 Apr 2025 14:53:16 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E7=9A=84=E6=B6=88=E6=81=AF=E6=B5=81=E5=A4=84=E7=90=86=E5=8D=B3?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E4=B8=8E=E5=A4=84=E7=90=86=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- handler/im.go | 83 +++++++++++++++++++++++++++++++++ main.go | 1 + proto/conf.go | 45 ++++++++++-------- proto/im.go | 2 + proto/kbase.go | 43 +++++++++++++++++ proto/status.go | 7 +-- service/imService.go | 64 +++++++++++++++++++++++++ service/knowledgeBaseService.go | 12 +++++ worker/redis.go | 11 +++++ 9 files changed, 245 insertions(+), 23 deletions(-) diff --git a/handler/im.go b/handler/im.go index 4fccc5f..ad198a0 100644 --- a/handler/im.go +++ b/handler/im.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/gorilla/websocket" @@ -46,6 +47,7 @@ func SetUpIMGroup(router *gin.Engine) { imGroup.POST("/send_message", SendMessage) imGroup.GET("/ai_chat_ws", SRMessage) imGroup.POST("/get_message", GetMessage) + imGroup.GET("/kbase_text_ws", KBaseTextWS) } func GetMessage(c *gin.Context) { @@ -261,6 +263,14 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR if data.IsImage || data.IsFile { status = proto.UserToModelImageMsgType //图片类型 } + //若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成 + if data.IsKBase { + err2 := service.SelectKBaseContentByQuery(&imContext, data) + if err2 != nil { + log.Println("select kbase content error:", err2) + return err2 + } + } //将消息存入数据库 err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID)) log.Println("create user message id:", userMsgID) @@ -272,3 +282,76 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext) return err } + +func KBaseTextWS(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + //query参数 + kBaseSercerID := c.Query("kbase_server_id") + // 升级HTTP连接为WebSocket连接 + ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil) + if err1 != nil { + log.Println("升级为WebSocket时发生错误:", err1) + return + } + subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID) +} + +func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSercerID string) { + ctx := context.Background() + chanel := fmt.Sprintf("kbase_server_id_%s", kBaseSercerID) // 频道名称 + pubsub := worker.RedisClient.Subscribe(ctx, chanel) + defer pubsub.Close() + defer ws.Close() + ch := pubsub.Channel() + + //接收客户端消息,发送消息 + done := make(chan struct{}) + go func() { + defer close(done) + for { + _, message, err := ws.ReadMessage() + var data proto.KnowledgeBaseServerResponse + //log.Println("rec message:", string(message)) + if err != nil { + log.Println("Read error:", err) + break + } + err1 := json.Unmarshal(message, &data) + if err1 != nil { + log.Println("Error parsing JSON:", err1) + continue + } + service.CreateAIStreamMsg(userId, data.IMContext.ModelID, data.IMContext.SessionID, message, proto.KnowledgeBaseServerResponseType) //将消息存入数据库 + //这里主要将接收到的消息转到对应协程 + //目标chanel + targetChanel := fmt.Sprintf("kbase_%s", data.KnowledgeBase.UUID) + //将消息转入 + err = worker.RedisClient.Publish(ctx, targetChanel, message).Err() + if err != nil { + log.Println("KBase Text WS Publish error:", err) + } + } + }() + + go func() { + //defer close(done) + //从redis订阅消息 + for m := range ch { + msg := m.Payload // 获取消息,消息格式为json + if msg != "" { + err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg)) + if err2 != nil { + break + } + } + } + }() + + for { + select { + case <-done: + return + } + } +} diff --git a/main.go b/main.go index 6f5c354..791cb71 100644 --- a/main.go +++ b/main.go @@ -211,6 +211,7 @@ func myTask() { } //其它定时任务-通用 RunGeneralCron() + service.AddKnowledgeBaseServer() //将配置文件中支持的知识库处理服务器添加到集合 } func ReadConfigToSetSystem() { diff --git a/proto/conf.go b/proto/conf.go index 404ff6f..b3e0f65 100644 --- a/proto/conf.go +++ b/proto/conf.go @@ -64,26 +64,31 @@ type User struct { } type ConfigStruct struct { - DB int `json:"db"` // 0: mysql, 1: pg - MYSQL_DSN string `json:"mysql_dsn"` - PG_DSN string `json:"pg_dsn"` - REDIS_ADDR string `json:"redis_addr"` - TOKEN_USE_REDIS bool `json:"token_use_redis"` - REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码 - REDIS_PASSWORD string `json:"redis_password"` - REDIS_DB int `json:"redis_db"` - TOKEN_SECRET string `json:"token_secret"` - CID_BASE_DIR string `json:"cid_base_dir"` - FILE_BASE_DIR string `json:"file_base_dir"` - MONITOR bool `json:"monitor"` // 状态监控及邮件通知 - SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志 - SERVER_PORT string `json:"server_port"` // 服务端口 - LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存,0表示永久保存 - SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型,master: 主服务器,slave: 从服务器,从服务器会定时同步数据 - MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名 - USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒 - SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器 - SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能 + DB int `json:"db"` // 0: mysql, 1: pg + MYSQL_DSN string `json:"mysql_dsn"` + PG_DSN string `json:"pg_dsn"` + REDIS_ADDR string `json:"redis_addr"` + TOKEN_USE_REDIS bool `json:"token_use_redis"` + REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码 + REDIS_PASSWORD string `json:"redis_password"` + REDIS_DB int `json:"redis_db"` + TOKEN_SECRET string `json:"token_secret"` + CID_BASE_DIR string `json:"cid_base_dir"` + FILE_BASE_DIR string `json:"file_base_dir"` + MONITOR bool `json:"monitor"` // 状态监控及邮件通知 + SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志 + SERVER_PORT string `json:"server_port"` // 服务端口 + LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存,0表示永久保存 + SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型,master: 主服务器,slave: 从服务器,从服务器会定时同步数据 + MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名 + USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒 + SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器 + SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能 + KBASE_SERVER []KBaseServer `json:"kbase_server"` // 知识库服务器列表 +} + +type KBaseServer struct { + ServerID string `json:"server_id"` // 服务器ID } // 读取配置文件 diff --git a/proto/im.go b/proto/im.go index 7df3b52..0b2fd4b 100644 --- a/proto/im.go +++ b/proto/im.go @@ -27,6 +27,8 @@ type WSMessageReq struct { ModelID uint `json:"model_id"` //模型id IsImage bool `json:"is_image"` //是否为图片(图片消息),如果是图片消息,则msg为图片消息结构为{"img_url":"","text":""} IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""} + IsKBase bool `json:"is_kbase"` //是否为知识库消息 + KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传 } type ModelParam struct { diff --git a/proto/kbase.go b/proto/kbase.go index 0e2eec1..c32e6e7 100644 --- a/proto/kbase.go +++ b/proto/kbase.go @@ -7,3 +7,46 @@ type KnowledgeBaseReq struct { Description string `json:"description" form:"description"` // 知识库描述 FileIDS string `json:"file_ids" form:"file_ids"` // 文件ID列表,逗号分隔 } + +type KnowledgeBaseServerResponse struct { + QuerySelect KnowledgeBaseServerResponseSelect `json:"query_select"` //查询结果 + KnowledgeBase KBase `json:"knowledge_base"` //知识库信息 + IMContext IMParamContext `json:"im_context"` //消息上下文 +} + +// KBase 定义与 TypeScript 接口对应的 Go 结构体 +type KBase struct { + ID uint `json:"ID"` + Name string `json:"Name"` + Description string `json:"Description"` + UUID string `json:"UUID"` + SessionID uint `json:"SessionID"` + CreatedAt string `json:"CreatedAt"` + UpdatedAt string `json:"UpdatedAt"` + DeletedAt string `json:"DeletedAt"` + AuthID int `json:"AuthID"` + FileNameList string `json:"FileNameList"` + FileIDs string `json:"FileIDs"` +} + +// 知识库向量查询返回结果Response 定义 JSON 数据对应的结构体 +type KnowledgeBaseServerResponseSelect struct { + IDs [][]string `json:"ids"` + Embeddings interface{} `json:"embeddings"` + Documents [][]string `json:"documents"` + URIs interface{} `json:"uris"` + Included []string `json:"included"` + Data interface{} `json:"data"` + Metadatas [][]Metadata `json:"metadatas"` + Distances [][]float64 `json:"distances"` +} + +// Metadata 定义元数据结构体 +type Metadata struct { + Source string `json:"source"` +} + +type KnowledgeBaseServerRequest struct { + IMContext IMParamContext `json:"im_context"` //消息上下文 + KnowledgeBase KBase `json:"knowledge_base"` //知识库信息 +} diff --git a/proto/status.go b/proto/status.go index a8f9c0a..f747ee4 100644 --- a/proto/status.go +++ b/proto/status.go @@ -130,9 +130,10 @@ const ( // 支持模型类型 const ( - ModelTypeSpark = "spark" - ModelTypeDouBao = "doubao" - ModelTypeOllama = "ollama" + ModelTypeSpark = "spark" + ModelTypeDouBao = "doubao" + ModelTypeOllama = "ollama" + KnowledgeBaseServerResponseType = "kbase_query_resp" ) // 其它 diff --git a/service/imService.go b/service/imService.go index a830010..e7ff739 100644 --- a/service/imService.go +++ b/service/imService.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/ollama/ollama/api" "log" "time" @@ -275,3 +276,66 @@ func CheckUserCreatePPTSessionPermission(userID int) error { } return err } + +// 知识库消息,从原始请求数据,上下文数据 +func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageReq) error { + var req proto.KnowledgeBaseServerRequest + req.IMContext = *ctx + kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID) + if err != nil { + return err + } + req.KnowledgeBase = proto.KBase{ID: kbase.ID, UUID: kbase.UUID, Name: kbase.Name, Description: kbase.Description, FileIDs: kbase.FileIDs, SessionID: kbase.SessionID} + //选取合适的服务器进行查询 + //先查看改知识库是否已经创建 + key := fmt.Sprintf("kbase_%s", kbase.UUID) + KBaseServerID := worker.GetRedis(key) + //若未分配 + if KBaseServerID == "" { + KBaseServerSet := "kbase_server_id_set" + //进行分配 + KBaseServerID = worker.GetRedisSetRandomMember(KBaseServerSet) + if KBaseServerID == "" { + return errors.New("no kbase server id") + } + //记录分配的知识库服务器 + worker.SetRedis(key, KBaseServerID) + } + chanel := fmt.Sprintf("kbase_%s", kbase.UUID) + pubsub := worker.RedisClient.Subscribe(context.Background(), chanel) + defer pubsub.Close() + ch := pubsub.Channel() + //1min定时器,超时则返回错误 + timer := time.NewTimer(time.Minute) + var resp proto.KnowledgeBaseServerResponse + select { + case m := <-ch: + msg := m.Payload // 获取消息,消息格式为json + err2 := json.Unmarshal([]byte(msg), &resp) + if err2 != nil { + log.Println("Error parsing JSON:", err2) + } + break + case <-timer.C: + log.Println("kbase query timeout") + return errors.New("kbase query timeout") + } + //解析返回的查询消息 + text := "根据下面文档片段:\n" + querySelect := resp.QuerySelect + for i := 0; i < len(querySelect.Documents); i++ { + text += "文件:" + for j := 0; j < len(querySelect.Metadatas); j++ { + text += querySelect.Metadatas[i][j].Source + "," + } + text += ":\n内容:" + for j := 0; j < len(querySelect.Documents[i]); j++ { + text += querySelect.Documents[i][j] + "," + } + text += "\n" + } + //用户问题 + text += "回答下面用户问题:" + data.Msg + data.Msg = text + return nil +} diff --git a/service/knowledgeBaseService.go b/service/knowledgeBaseService.go index 419c028..42d6768 100644 --- a/service/knowledgeBaseService.go +++ b/service/knowledgeBaseService.go @@ -3,8 +3,10 @@ package service import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" + "StuAcaWorksAI/worker" "errors" "github.com/google/uuid" + "time" ) func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) { @@ -68,3 +70,13 @@ func DeleteKnowledgeBase(userID int, req proto.KnowledgeBaseReq) error { return dao.DeleteKnowledgeBase(knowledgeBase.ID) } + +// 将配置文件中支持的知识库处理服务器添加到集合 +func AddKnowledgeBaseServer() { + set := "kbase_server_id_set" + var serverIDs []string + for _, server := range proto.Config.KBASE_SERVER { + serverIDs = append(serverIDs, server.ServerID) + } + worker.SetRedisSet(set, serverIDs, time.Hour) +} diff --git a/worker/redis.go b/worker/redis.go index 5b0ac7c..2568d33 100644 --- a/worker/redis.go +++ b/worker/redis.go @@ -542,3 +542,14 @@ func getKeyTypeAndData(key string) (string, string, error) { } return key_type, fmt.Sprintf("%v", val), err } + +// 随机获取集合中的一个元素 +func GetRedisSetRandomMember(key string) string { + ctx := context.Background() + val, err := RedisClient.SRandMember(ctx, key).Result() + if err != nil { + fmt.Println("Error getting key: %v", err) + return "" + } + return val +}