完成知识库的消息流处理即知识库与处理服务器

This commit is contained in:
junleea 2025-04-09 14:53:16 +08:00
parent 497424cb95
commit 4f71c12e2d
9 changed files with 245 additions and 23 deletions

View File

@ -8,6 +8,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -46,6 +47,7 @@ func SetUpIMGroup(router *gin.Engine) {
imGroup.POST("/send_message", SendMessage) imGroup.POST("/send_message", SendMessage)
imGroup.GET("/ai_chat_ws", SRMessage) imGroup.GET("/ai_chat_ws", SRMessage)
imGroup.POST("/get_message", GetMessage) imGroup.POST("/get_message", GetMessage)
imGroup.GET("/kbase_text_ws", KBaseTextWS)
} }
func GetMessage(c *gin.Context) { func GetMessage(c *gin.Context) {
@ -261,6 +263,14 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
if data.IsImage || data.IsFile { if data.IsImage || data.IsFile {
status = proto.UserToModelImageMsgType //图片类型 status = proto.UserToModelImageMsgType //图片类型
} }
//若为知识库消息需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
if data.IsKBase {
err2 := service.SelectKBaseContentByQuery(&imContext, data)
if err2 != nil {
log.Println("select kbase content error:", err2)
return err2
}
}
//将消息存入数据库 //将消息存入数据库
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID)) err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
log.Println("create user message id:", userMsgID) log.Println("create user message id:", userMsgID)
@ -272,3 +282,76 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext) err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
return err return err
} }
func KBaseTextWS(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
//query参数
kBaseSercerID := c.Query("kbase_server_id")
// 升级HTTP连接为WebSocket连接
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
if err1 != nil {
log.Println("升级为WebSocket时发生错误:", err1)
return
}
subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID)
}
func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSercerID string) {
ctx := context.Background()
chanel := fmt.Sprintf("kbase_server_id_%s", kBaseSercerID) // 频道名称
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
defer pubsub.Close()
defer ws.Close()
ch := pubsub.Channel()
//接收客户端消息,发送消息
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := ws.ReadMessage()
var data proto.KnowledgeBaseServerResponse
//log.Println("rec message:", string(message))
if err != nil {
log.Println("Read error:", err)
break
}
err1 := json.Unmarshal(message, &data)
if err1 != nil {
log.Println("Error parsing JSON:", err1)
continue
}
service.CreateAIStreamMsg(userId, data.IMContext.ModelID, data.IMContext.SessionID, message, proto.KnowledgeBaseServerResponseType) //将消息存入数据库
//这里主要将接收到的消息转到对应协程
//目标chanel
targetChanel := fmt.Sprintf("kbase_%s", data.KnowledgeBase.UUID)
//将消息转入
err = worker.RedisClient.Publish(ctx, targetChanel, message).Err()
if err != nil {
log.Println("KBase Text WS Publish error:", err)
}
}
}()
go func() {
//defer close(done)
//从redis订阅消息
for m := range ch {
msg := m.Payload // 获取消息,消息格式为json
if msg != "" {
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
if err2 != nil {
break
}
}
}
}()
for {
select {
case <-done:
return
}
}
}

View File

@ -211,6 +211,7 @@ func myTask() {
} }
//其它定时任务-通用 //其它定时任务-通用
RunGeneralCron() RunGeneralCron()
service.AddKnowledgeBaseServer() //将配置文件中支持的知识库处理服务器添加到集合
} }
func ReadConfigToSetSystem() { func ReadConfigToSetSystem() {

View File

@ -64,26 +64,31 @@ type User struct {
} }
type ConfigStruct struct { type ConfigStruct struct {
DB int `json:"db"` // 0: mysql, 1: pg DB int `json:"db"` // 0: mysql, 1: pg
MYSQL_DSN string `json:"mysql_dsn"` MYSQL_DSN string `json:"mysql_dsn"`
PG_DSN string `json:"pg_dsn"` PG_DSN string `json:"pg_dsn"`
REDIS_ADDR string `json:"redis_addr"` REDIS_ADDR string `json:"redis_addr"`
TOKEN_USE_REDIS bool `json:"token_use_redis"` TOKEN_USE_REDIS bool `json:"token_use_redis"`
REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码 REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码
REDIS_PASSWORD string `json:"redis_password"` REDIS_PASSWORD string `json:"redis_password"`
REDIS_DB int `json:"redis_db"` REDIS_DB int `json:"redis_db"`
TOKEN_SECRET string `json:"token_secret"` TOKEN_SECRET string `json:"token_secret"`
CID_BASE_DIR string `json:"cid_base_dir"` CID_BASE_DIR string `json:"cid_base_dir"`
FILE_BASE_DIR string `json:"file_base_dir"` FILE_BASE_DIR string `json:"file_base_dir"`
MONITOR bool `json:"monitor"` // 状态监控及邮件通知 MONITOR bool `json:"monitor"` // 状态监控及邮件通知
SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志 SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志
SERVER_PORT string `json:"server_port"` // 服务端口 SERVER_PORT string `json:"server_port"` // 服务端口
LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存0表示永久保存 LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存0表示永久保存
SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型master: 主服务器slave: 从服务器,从服务器会定时同步数据 SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型master: 主服务器slave: 从服务器,从服务器会定时同步数据
MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名 MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名
USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒 USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒
SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器 SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器
SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能 SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能
KBASE_SERVER []KBaseServer `json:"kbase_server"` // 知识库服务器列表
}
type KBaseServer struct {
ServerID string `json:"server_id"` // 服务器ID
} }
// 读取配置文件 // 读取配置文件

View File

@ -27,6 +27,8 @@ type WSMessageReq struct {
ModelID uint `json:"model_id"` //模型id ModelID uint `json:"model_id"` //模型id
IsImage bool `json:"is_image"` //是否为图片(图片消息)如果是图片消息则msg为图片消息结构为{"img_url":"","text":""} IsImage bool `json:"is_image"` //是否为图片(图片消息)如果是图片消息则msg为图片消息结构为{"img_url":"","text":""}
IsFile bool `json:"is_file"` //是否为文件(文件消息)如果是文件消息则msg为文件消息结构为{"file_url":"","text":""} IsFile bool `json:"is_file"` //是否为文件(文件消息)如果是文件消息则msg为文件消息结构为{"file_url":"","text":""}
IsKBase bool `json:"is_kbase"` //是否为知识库消息
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
} }
type ModelParam struct { type ModelParam struct {

View File

@ -7,3 +7,46 @@ type KnowledgeBaseReq struct {
Description string `json:"description" form:"description"` // 知识库描述 Description string `json:"description" form:"description"` // 知识库描述
FileIDS string `json:"file_ids" form:"file_ids"` // 文件ID列表逗号分隔 FileIDS string `json:"file_ids" form:"file_ids"` // 文件ID列表逗号分隔
} }
type KnowledgeBaseServerResponse struct {
QuerySelect KnowledgeBaseServerResponseSelect `json:"query_select"` //查询结果
KnowledgeBase KBase `json:"knowledge_base"` //知识库信息
IMContext IMParamContext `json:"im_context"` //消息上下文
}
// KBase 定义与 TypeScript 接口对应的 Go 结构体
type KBase struct {
ID uint `json:"ID"`
Name string `json:"Name"`
Description string `json:"Description"`
UUID string `json:"UUID"`
SessionID uint `json:"SessionID"`
CreatedAt string `json:"CreatedAt"`
UpdatedAt string `json:"UpdatedAt"`
DeletedAt string `json:"DeletedAt"`
AuthID int `json:"AuthID"`
FileNameList string `json:"FileNameList"`
FileIDs string `json:"FileIDs"`
}
// 知识库向量查询返回结果Response 定义 JSON 数据对应的结构体
type KnowledgeBaseServerResponseSelect struct {
IDs [][]string `json:"ids"`
Embeddings interface{} `json:"embeddings"`
Documents [][]string `json:"documents"`
URIs interface{} `json:"uris"`
Included []string `json:"included"`
Data interface{} `json:"data"`
Metadatas [][]Metadata `json:"metadatas"`
Distances [][]float64 `json:"distances"`
}
// Metadata 定义元数据结构体
type Metadata struct {
Source string `json:"source"`
}
type KnowledgeBaseServerRequest struct {
IMContext IMParamContext `json:"im_context"` //消息上下文
KnowledgeBase KBase `json:"knowledge_base"` //知识库信息
}

View File

@ -130,9 +130,10 @@ const (
// 支持模型类型 // 支持模型类型
const ( const (
ModelTypeSpark = "spark" ModelTypeSpark = "spark"
ModelTypeDouBao = "doubao" ModelTypeDouBao = "doubao"
ModelTypeOllama = "ollama" ModelTypeOllama = "ollama"
KnowledgeBaseServerResponseType = "kbase_query_resp"
) )
// 其它 // 其它

View File

@ -7,6 +7,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"log" "log"
"time" "time"
@ -275,3 +276,66 @@ func CheckUserCreatePPTSessionPermission(userID int) error {
} }
return err return err
} }
// 知识库消息,从原始请求数据,上下文数据
func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageReq) error {
var req proto.KnowledgeBaseServerRequest
req.IMContext = *ctx
kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID)
if err != nil {
return err
}
req.KnowledgeBase = proto.KBase{ID: kbase.ID, UUID: kbase.UUID, Name: kbase.Name, Description: kbase.Description, FileIDs: kbase.FileIDs, SessionID: kbase.SessionID}
//选取合适的服务器进行查询
//先查看改知识库是否已经创建
key := fmt.Sprintf("kbase_%s", kbase.UUID)
KBaseServerID := worker.GetRedis(key)
//若未分配
if KBaseServerID == "" {
KBaseServerSet := "kbase_server_id_set"
//进行分配
KBaseServerID = worker.GetRedisSetRandomMember(KBaseServerSet)
if KBaseServerID == "" {
return errors.New("no kbase server id")
}
//记录分配的知识库服务器
worker.SetRedis(key, KBaseServerID)
}
chanel := fmt.Sprintf("kbase_%s", kbase.UUID)
pubsub := worker.RedisClient.Subscribe(context.Background(), chanel)
defer pubsub.Close()
ch := pubsub.Channel()
//1min定时器超时则返回错误
timer := time.NewTimer(time.Minute)
var resp proto.KnowledgeBaseServerResponse
select {
case m := <-ch:
msg := m.Payload // 获取消息,消息格式为json
err2 := json.Unmarshal([]byte(msg), &resp)
if err2 != nil {
log.Println("Error parsing JSON:", err2)
}
break
case <-timer.C:
log.Println("kbase query timeout")
return errors.New("kbase query timeout")
}
//解析返回的查询消息
text := "根据下面文档片段:\n"
querySelect := resp.QuerySelect
for i := 0; i < len(querySelect.Documents); i++ {
text += "文件:"
for j := 0; j < len(querySelect.Metadatas); j++ {
text += querySelect.Metadatas[i][j].Source + ""
}
text += ":\n内容"
for j := 0; j < len(querySelect.Documents[i]); j++ {
text += querySelect.Documents[i][j] + ""
}
text += "\n"
}
//用户问题
text += "回答下面用户问题:" + data.Msg
data.Msg = text
return nil
}

View File

@ -3,8 +3,10 @@ package service
import ( import (
"StuAcaWorksAI/dao" "StuAcaWorksAI/dao"
"StuAcaWorksAI/proto" "StuAcaWorksAI/proto"
"StuAcaWorksAI/worker"
"errors" "errors"
"github.com/google/uuid" "github.com/google/uuid"
"time"
) )
func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) { func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) {
@ -68,3 +70,13 @@ func DeleteKnowledgeBase(userID int, req proto.KnowledgeBaseReq) error {
return dao.DeleteKnowledgeBase(knowledgeBase.ID) return dao.DeleteKnowledgeBase(knowledgeBase.ID)
} }
// 将配置文件中支持的知识库处理服务器添加到集合
func AddKnowledgeBaseServer() {
set := "kbase_server_id_set"
var serverIDs []string
for _, server := range proto.Config.KBASE_SERVER {
serverIDs = append(serverIDs, server.ServerID)
}
worker.SetRedisSet(set, serverIDs, time.Hour)
}

View File

@ -542,3 +542,14 @@ func getKeyTypeAndData(key string) (string, string, error) {
} }
return key_type, fmt.Sprintf("%v", val), err return key_type, fmt.Sprintf("%v", val), err
} }
// 随机获取集合中的一个元素
func GetRedisSetRandomMember(key string) string {
ctx := context.Background()
val, err := RedisClient.SRandMember(ctx, key).Result()
if err != nil {
fmt.Println("Error getting key: %v", err)
return ""
}
return val
}