完成知识库的消息流处理即知识库与处理服务器
This commit is contained in:
parent
497424cb95
commit
4f71c12e2d
|
|
@ -8,6 +8,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
|
|
@ -46,6 +47,7 @@ func SetUpIMGroup(router *gin.Engine) {
|
|||
imGroup.POST("/send_message", SendMessage)
|
||||
imGroup.GET("/ai_chat_ws", SRMessage)
|
||||
imGroup.POST("/get_message", GetMessage)
|
||||
imGroup.GET("/kbase_text_ws", KBaseTextWS)
|
||||
}
|
||||
|
||||
func GetMessage(c *gin.Context) {
|
||||
|
|
@ -261,6 +263,14 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
if data.IsImage || data.IsFile {
|
||||
status = proto.UserToModelImageMsgType //图片类型
|
||||
}
|
||||
//若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
|
||||
if data.IsKBase {
|
||||
err2 := service.SelectKBaseContentByQuery(&imContext, data)
|
||||
if err2 != nil {
|
||||
log.Println("select kbase content error:", err2)
|
||||
return err2
|
||||
}
|
||||
}
|
||||
//将消息存入数据库
|
||||
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
|
||||
log.Println("create user message id:", userMsgID)
|
||||
|
|
@ -272,3 +282,76 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
|
||||
return err
|
||||
}
|
||||
|
||||
func KBaseTextWS(c *gin.Context) {
|
||||
id, _ := c.Get("id")
|
||||
userID := int(id.(float64))
|
||||
//query参数
|
||||
kBaseSercerID := c.Query("kbase_server_id")
|
||||
// 升级HTTP连接为WebSocket连接
|
||||
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err1 != nil {
|
||||
log.Println("升级为WebSocket时发生错误:", err1)
|
||||
return
|
||||
}
|
||||
subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID)
|
||||
}
|
||||
|
||||
func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSercerID string) {
|
||||
ctx := context.Background()
|
||||
chanel := fmt.Sprintf("kbase_server_id_%s", kBaseSercerID) // 频道名称
|
||||
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
|
||||
defer pubsub.Close()
|
||||
defer ws.Close()
|
||||
ch := pubsub.Channel()
|
||||
|
||||
//接收客户端消息,发送消息
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
_, message, err := ws.ReadMessage()
|
||||
var data proto.KnowledgeBaseServerResponse
|
||||
//log.Println("rec message:", string(message))
|
||||
if err != nil {
|
||||
log.Println("Read error:", err)
|
||||
break
|
||||
}
|
||||
err1 := json.Unmarshal(message, &data)
|
||||
if err1 != nil {
|
||||
log.Println("Error parsing JSON:", err1)
|
||||
continue
|
||||
}
|
||||
service.CreateAIStreamMsg(userId, data.IMContext.ModelID, data.IMContext.SessionID, message, proto.KnowledgeBaseServerResponseType) //将消息存入数据库
|
||||
//这里主要将接收到的消息转到对应协程
|
||||
//目标chanel
|
||||
targetChanel := fmt.Sprintf("kbase_%s", data.KnowledgeBase.UUID)
|
||||
//将消息转入
|
||||
err = worker.RedisClient.Publish(ctx, targetChanel, message).Err()
|
||||
if err != nil {
|
||||
log.Println("KBase Text WS Publish error:", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
//defer close(done)
|
||||
//从redis订阅消息
|
||||
for m := range ch {
|
||||
msg := m.Payload // 获取消息,消息格式为json
|
||||
if msg != "" {
|
||||
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
|
||||
if err2 != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
1
main.go
1
main.go
|
|
@ -211,6 +211,7 @@ func myTask() {
|
|||
}
|
||||
//其它定时任务-通用
|
||||
RunGeneralCron()
|
||||
service.AddKnowledgeBaseServer() //将配置文件中支持的知识库处理服务器添加到集合
|
||||
}
|
||||
|
||||
func ReadConfigToSetSystem() {
|
||||
|
|
|
|||
|
|
@ -64,26 +64,31 @@ type User struct {
|
|||
}
|
||||
|
||||
type ConfigStruct struct {
|
||||
DB int `json:"db"` // 0: mysql, 1: pg
|
||||
MYSQL_DSN string `json:"mysql_dsn"`
|
||||
PG_DSN string `json:"pg_dsn"`
|
||||
REDIS_ADDR string `json:"redis_addr"`
|
||||
TOKEN_USE_REDIS bool `json:"token_use_redis"`
|
||||
REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码
|
||||
REDIS_PASSWORD string `json:"redis_password"`
|
||||
REDIS_DB int `json:"redis_db"`
|
||||
TOKEN_SECRET string `json:"token_secret"`
|
||||
CID_BASE_DIR string `json:"cid_base_dir"`
|
||||
FILE_BASE_DIR string `json:"file_base_dir"`
|
||||
MONITOR bool `json:"monitor"` // 状态监控及邮件通知
|
||||
SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志
|
||||
SERVER_PORT string `json:"server_port"` // 服务端口
|
||||
LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存,0表示永久保存
|
||||
SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型,master: 主服务器,slave: 从服务器,从服务器会定时同步数据
|
||||
MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名
|
||||
USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒
|
||||
SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器
|
||||
SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能
|
||||
DB int `json:"db"` // 0: mysql, 1: pg
|
||||
MYSQL_DSN string `json:"mysql_dsn"`
|
||||
PG_DSN string `json:"pg_dsn"`
|
||||
REDIS_ADDR string `json:"redis_addr"`
|
||||
TOKEN_USE_REDIS bool `json:"token_use_redis"`
|
||||
REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码
|
||||
REDIS_PASSWORD string `json:"redis_password"`
|
||||
REDIS_DB int `json:"redis_db"`
|
||||
TOKEN_SECRET string `json:"token_secret"`
|
||||
CID_BASE_DIR string `json:"cid_base_dir"`
|
||||
FILE_BASE_DIR string `json:"file_base_dir"`
|
||||
MONITOR bool `json:"monitor"` // 状态监控及邮件通知
|
||||
SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志
|
||||
SERVER_PORT string `json:"server_port"` // 服务端口
|
||||
LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存,0表示永久保存
|
||||
SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型,master: 主服务器,slave: 从服务器,从服务器会定时同步数据
|
||||
MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名
|
||||
USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒
|
||||
SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器
|
||||
SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能
|
||||
KBASE_SERVER []KBaseServer `json:"kbase_server"` // 知识库服务器列表
|
||||
}
|
||||
|
||||
type KBaseServer struct {
|
||||
ServerID string `json:"server_id"` // 服务器ID
|
||||
}
|
||||
|
||||
// 读取配置文件
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ type WSMessageReq struct {
|
|||
ModelID uint `json:"model_id"` //模型id
|
||||
IsImage bool `json:"is_image"` //是否为图片(图片消息),如果是图片消息,则msg为图片消息结构为{"img_url":"","text":""}
|
||||
IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""}
|
||||
IsKBase bool `json:"is_kbase"` //是否为知识库消息
|
||||
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
|
||||
}
|
||||
|
||||
type ModelParam struct {
|
||||
|
|
|
|||
|
|
@ -7,3 +7,46 @@ type KnowledgeBaseReq struct {
|
|||
Description string `json:"description" form:"description"` // 知识库描述
|
||||
FileIDS string `json:"file_ids" form:"file_ids"` // 文件ID列表,逗号分隔
|
||||
}
|
||||
|
||||
type KnowledgeBaseServerResponse struct {
|
||||
QuerySelect KnowledgeBaseServerResponseSelect `json:"query_select"` //查询结果
|
||||
KnowledgeBase KBase `json:"knowledge_base"` //知识库信息
|
||||
IMContext IMParamContext `json:"im_context"` //消息上下文
|
||||
}
|
||||
|
||||
// KBase 定义与 TypeScript 接口对应的 Go 结构体
|
||||
type KBase struct {
|
||||
ID uint `json:"ID"`
|
||||
Name string `json:"Name"`
|
||||
Description string `json:"Description"`
|
||||
UUID string `json:"UUID"`
|
||||
SessionID uint `json:"SessionID"`
|
||||
CreatedAt string `json:"CreatedAt"`
|
||||
UpdatedAt string `json:"UpdatedAt"`
|
||||
DeletedAt string `json:"DeletedAt"`
|
||||
AuthID int `json:"AuthID"`
|
||||
FileNameList string `json:"FileNameList"`
|
||||
FileIDs string `json:"FileIDs"`
|
||||
}
|
||||
|
||||
// 知识库向量查询返回结果Response 定义 JSON 数据对应的结构体
|
||||
type KnowledgeBaseServerResponseSelect struct {
|
||||
IDs [][]string `json:"ids"`
|
||||
Embeddings interface{} `json:"embeddings"`
|
||||
Documents [][]string `json:"documents"`
|
||||
URIs interface{} `json:"uris"`
|
||||
Included []string `json:"included"`
|
||||
Data interface{} `json:"data"`
|
||||
Metadatas [][]Metadata `json:"metadatas"`
|
||||
Distances [][]float64 `json:"distances"`
|
||||
}
|
||||
|
||||
// Metadata 定义元数据结构体
|
||||
type Metadata struct {
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
type KnowledgeBaseServerRequest struct {
|
||||
IMContext IMParamContext `json:"im_context"` //消息上下文
|
||||
KnowledgeBase KBase `json:"knowledge_base"` //知识库信息
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,9 +130,10 @@ const (
|
|||
|
||||
// 支持模型类型
|
||||
const (
|
||||
ModelTypeSpark = "spark"
|
||||
ModelTypeDouBao = "doubao"
|
||||
ModelTypeOllama = "ollama"
|
||||
ModelTypeSpark = "spark"
|
||||
ModelTypeDouBao = "doubao"
|
||||
ModelTypeOllama = "ollama"
|
||||
KnowledgeBaseServerResponseType = "kbase_query_resp"
|
||||
)
|
||||
|
||||
// 其它
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ollama/ollama/api"
|
||||
"log"
|
||||
"time"
|
||||
|
|
@ -275,3 +276,66 @@ func CheckUserCreatePPTSessionPermission(userID int) error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 知识库消息,从原始请求数据,上下文数据
|
||||
func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageReq) error {
|
||||
var req proto.KnowledgeBaseServerRequest
|
||||
req.IMContext = *ctx
|
||||
kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.KnowledgeBase = proto.KBase{ID: kbase.ID, UUID: kbase.UUID, Name: kbase.Name, Description: kbase.Description, FileIDs: kbase.FileIDs, SessionID: kbase.SessionID}
|
||||
//选取合适的服务器进行查询
|
||||
//先查看改知识库是否已经创建
|
||||
key := fmt.Sprintf("kbase_%s", kbase.UUID)
|
||||
KBaseServerID := worker.GetRedis(key)
|
||||
//若未分配
|
||||
if KBaseServerID == "" {
|
||||
KBaseServerSet := "kbase_server_id_set"
|
||||
//进行分配
|
||||
KBaseServerID = worker.GetRedisSetRandomMember(KBaseServerSet)
|
||||
if KBaseServerID == "" {
|
||||
return errors.New("no kbase server id")
|
||||
}
|
||||
//记录分配的知识库服务器
|
||||
worker.SetRedis(key, KBaseServerID)
|
||||
}
|
||||
chanel := fmt.Sprintf("kbase_%s", kbase.UUID)
|
||||
pubsub := worker.RedisClient.Subscribe(context.Background(), chanel)
|
||||
defer pubsub.Close()
|
||||
ch := pubsub.Channel()
|
||||
//1min定时器,超时则返回错误
|
||||
timer := time.NewTimer(time.Minute)
|
||||
var resp proto.KnowledgeBaseServerResponse
|
||||
select {
|
||||
case m := <-ch:
|
||||
msg := m.Payload // 获取消息,消息格式为json
|
||||
err2 := json.Unmarshal([]byte(msg), &resp)
|
||||
if err2 != nil {
|
||||
log.Println("Error parsing JSON:", err2)
|
||||
}
|
||||
break
|
||||
case <-timer.C:
|
||||
log.Println("kbase query timeout")
|
||||
return errors.New("kbase query timeout")
|
||||
}
|
||||
//解析返回的查询消息
|
||||
text := "根据下面文档片段:\n"
|
||||
querySelect := resp.QuerySelect
|
||||
for i := 0; i < len(querySelect.Documents); i++ {
|
||||
text += "文件:"
|
||||
for j := 0; j < len(querySelect.Metadatas); j++ {
|
||||
text += querySelect.Metadatas[i][j].Source + ","
|
||||
}
|
||||
text += ":\n内容:"
|
||||
for j := 0; j < len(querySelect.Documents[i]); j++ {
|
||||
text += querySelect.Documents[i][j] + ","
|
||||
}
|
||||
text += "\n"
|
||||
}
|
||||
//用户问题
|
||||
text += "回答下面用户问题:" + data.Msg
|
||||
data.Msg = text
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ package service
|
|||
import (
|
||||
"StuAcaWorksAI/dao"
|
||||
"StuAcaWorksAI/proto"
|
||||
"StuAcaWorksAI/worker"
|
||||
"errors"
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) {
|
||||
|
|
@ -68,3 +70,13 @@ func DeleteKnowledgeBase(userID int, req proto.KnowledgeBaseReq) error {
|
|||
|
||||
return dao.DeleteKnowledgeBase(knowledgeBase.ID)
|
||||
}
|
||||
|
||||
// 将配置文件中支持的知识库处理服务器添加到集合
|
||||
func AddKnowledgeBaseServer() {
|
||||
set := "kbase_server_id_set"
|
||||
var serverIDs []string
|
||||
for _, server := range proto.Config.KBASE_SERVER {
|
||||
serverIDs = append(serverIDs, server.ServerID)
|
||||
}
|
||||
worker.SetRedisSet(set, serverIDs, time.Hour)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -542,3 +542,14 @@ func getKeyTypeAndData(key string) (string, string, error) {
|
|||
}
|
||||
return key_type, fmt.Sprintf("%v", val), err
|
||||
}
|
||||
|
||||
// 随机获取集合中的一个元素
|
||||
func GetRedisSetRandomMember(key string) string {
|
||||
ctx := context.Background()
|
||||
val, err := RedisClient.SRandMember(ctx, key).Result()
|
||||
if err != nil {
|
||||
fmt.Println("Error getting key: %v", err)
|
||||
return ""
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue