完成知识库的消息流处理即知识库与处理服务器
This commit is contained in:
parent
497424cb95
commit
4f71c12e2d
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
|
@ -46,6 +47,7 @@ func SetUpIMGroup(router *gin.Engine) {
|
||||||
imGroup.POST("/send_message", SendMessage)
|
imGroup.POST("/send_message", SendMessage)
|
||||||
imGroup.GET("/ai_chat_ws", SRMessage)
|
imGroup.GET("/ai_chat_ws", SRMessage)
|
||||||
imGroup.POST("/get_message", GetMessage)
|
imGroup.POST("/get_message", GetMessage)
|
||||||
|
imGroup.GET("/kbase_text_ws", KBaseTextWS)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMessage(c *gin.Context) {
|
func GetMessage(c *gin.Context) {
|
||||||
|
|
@ -261,6 +263,14 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
||||||
if data.IsImage || data.IsFile {
|
if data.IsImage || data.IsFile {
|
||||||
status = proto.UserToModelImageMsgType //图片类型
|
status = proto.UserToModelImageMsgType //图片类型
|
||||||
}
|
}
|
||||||
|
//若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
|
||||||
|
if data.IsKBase {
|
||||||
|
err2 := service.SelectKBaseContentByQuery(&imContext, data)
|
||||||
|
if err2 != nil {
|
||||||
|
log.Println("select kbase content error:", err2)
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
}
|
||||||
//将消息存入数据库
|
//将消息存入数据库
|
||||||
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
|
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
|
||||||
log.Println("create user message id:", userMsgID)
|
log.Println("create user message id:", userMsgID)
|
||||||
|
|
@ -272,3 +282,76 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
||||||
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
|
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func KBaseTextWS(c *gin.Context) {
|
||||||
|
id, _ := c.Get("id")
|
||||||
|
userID := int(id.(float64))
|
||||||
|
//query参数
|
||||||
|
kBaseSercerID := c.Query("kbase_server_id")
|
||||||
|
// 升级HTTP连接为WebSocket连接
|
||||||
|
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err1 != nil {
|
||||||
|
log.Println("升级为WebSocket时发生错误:", err1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSercerID string) {
|
||||||
|
ctx := context.Background()
|
||||||
|
chanel := fmt.Sprintf("kbase_server_id_%s", kBaseSercerID) // 频道名称
|
||||||
|
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
|
||||||
|
defer pubsub.Close()
|
||||||
|
defer ws.Close()
|
||||||
|
ch := pubsub.Channel()
|
||||||
|
|
||||||
|
//接收客户端消息,发送消息
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
_, message, err := ws.ReadMessage()
|
||||||
|
var data proto.KnowledgeBaseServerResponse
|
||||||
|
//log.Println("rec message:", string(message))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Read error:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err1 := json.Unmarshal(message, &data)
|
||||||
|
if err1 != nil {
|
||||||
|
log.Println("Error parsing JSON:", err1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
service.CreateAIStreamMsg(userId, data.IMContext.ModelID, data.IMContext.SessionID, message, proto.KnowledgeBaseServerResponseType) //将消息存入数据库
|
||||||
|
//这里主要将接收到的消息转到对应协程
|
||||||
|
//目标chanel
|
||||||
|
targetChanel := fmt.Sprintf("kbase_%s", data.KnowledgeBase.UUID)
|
||||||
|
//将消息转入
|
||||||
|
err = worker.RedisClient.Publish(ctx, targetChanel, message).Err()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("KBase Text WS Publish error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
//defer close(done)
|
||||||
|
//从redis订阅消息
|
||||||
|
for m := range ch {
|
||||||
|
msg := m.Payload // 获取消息,消息格式为json
|
||||||
|
if msg != "" {
|
||||||
|
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
|
||||||
|
if err2 != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
1
main.go
1
main.go
|
|
@ -211,6 +211,7 @@ func myTask() {
|
||||||
}
|
}
|
||||||
//其它定时任务-通用
|
//其它定时任务-通用
|
||||||
RunGeneralCron()
|
RunGeneralCron()
|
||||||
|
service.AddKnowledgeBaseServer() //将配置文件中支持的知识库处理服务器添加到集合
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadConfigToSetSystem() {
|
func ReadConfigToSetSystem() {
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,11 @@ type ConfigStruct struct {
|
||||||
USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒
|
USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒
|
||||||
SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器
|
SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器
|
||||||
SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能
|
SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能
|
||||||
|
KBASE_SERVER []KBaseServer `json:"kbase_server"` // 知识库服务器列表
|
||||||
|
}
|
||||||
|
|
||||||
|
type KBaseServer struct {
|
||||||
|
ServerID string `json:"server_id"` // 服务器ID
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读取配置文件
|
// 读取配置文件
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ type WSMessageReq struct {
|
||||||
ModelID uint `json:"model_id"` //模型id
|
ModelID uint `json:"model_id"` //模型id
|
||||||
IsImage bool `json:"is_image"` //是否为图片(图片消息),如果是图片消息,则msg为图片消息结构为{"img_url":"","text":""}
|
IsImage bool `json:"is_image"` //是否为图片(图片消息),如果是图片消息,则msg为图片消息结构为{"img_url":"","text":""}
|
||||||
IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""}
|
IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""}
|
||||||
|
IsKBase bool `json:"is_kbase"` //是否为知识库消息
|
||||||
|
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelParam struct {
|
type ModelParam struct {
|
||||||
|
|
|
||||||
|
|
@ -7,3 +7,46 @@ type KnowledgeBaseReq struct {
|
||||||
Description string `json:"description" form:"description"` // 知识库描述
|
Description string `json:"description" form:"description"` // 知识库描述
|
||||||
FileIDS string `json:"file_ids" form:"file_ids"` // 文件ID列表,逗号分隔
|
FileIDS string `json:"file_ids" form:"file_ids"` // 文件ID列表,逗号分隔
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type KnowledgeBaseServerResponse struct {
|
||||||
|
QuerySelect KnowledgeBaseServerResponseSelect `json:"query_select"` //查询结果
|
||||||
|
KnowledgeBase KBase `json:"knowledge_base"` //知识库信息
|
||||||
|
IMContext IMParamContext `json:"im_context"` //消息上下文
|
||||||
|
}
|
||||||
|
|
||||||
|
// KBase 定义与 TypeScript 接口对应的 Go 结构体
|
||||||
|
type KBase struct {
|
||||||
|
ID uint `json:"ID"`
|
||||||
|
Name string `json:"Name"`
|
||||||
|
Description string `json:"Description"`
|
||||||
|
UUID string `json:"UUID"`
|
||||||
|
SessionID uint `json:"SessionID"`
|
||||||
|
CreatedAt string `json:"CreatedAt"`
|
||||||
|
UpdatedAt string `json:"UpdatedAt"`
|
||||||
|
DeletedAt string `json:"DeletedAt"`
|
||||||
|
AuthID int `json:"AuthID"`
|
||||||
|
FileNameList string `json:"FileNameList"`
|
||||||
|
FileIDs string `json:"FileIDs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 知识库向量查询返回结果Response 定义 JSON 数据对应的结构体
|
||||||
|
type KnowledgeBaseServerResponseSelect struct {
|
||||||
|
IDs [][]string `json:"ids"`
|
||||||
|
Embeddings interface{} `json:"embeddings"`
|
||||||
|
Documents [][]string `json:"documents"`
|
||||||
|
URIs interface{} `json:"uris"`
|
||||||
|
Included []string `json:"included"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
Metadatas [][]Metadata `json:"metadatas"`
|
||||||
|
Distances [][]float64 `json:"distances"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metadata 定义元数据结构体
|
||||||
|
type Metadata struct {
|
||||||
|
Source string `json:"source"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KnowledgeBaseServerRequest struct {
|
||||||
|
IMContext IMParamContext `json:"im_context"` //消息上下文
|
||||||
|
KnowledgeBase KBase `json:"knowledge_base"` //知识库信息
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,7 @@ const (
|
||||||
ModelTypeSpark = "spark"
|
ModelTypeSpark = "spark"
|
||||||
ModelTypeDouBao = "doubao"
|
ModelTypeDouBao = "doubao"
|
||||||
ModelTypeOllama = "ollama"
|
ModelTypeOllama = "ollama"
|
||||||
|
KnowledgeBaseServerResponseType = "kbase_query_resp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 其它
|
// 其它
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -275,3 +276,66 @@ func CheckUserCreatePPTSessionPermission(userID int) error {
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 知识库消息,从原始请求数据,上下文数据
|
||||||
|
func SelectKBaseContentByQuery(ctx *proto.IMParamContext, data *proto.WSMessageReq) error {
|
||||||
|
var req proto.KnowledgeBaseServerRequest
|
||||||
|
req.IMContext = *ctx
|
||||||
|
kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.KnowledgeBase = proto.KBase{ID: kbase.ID, UUID: kbase.UUID, Name: kbase.Name, Description: kbase.Description, FileIDs: kbase.FileIDs, SessionID: kbase.SessionID}
|
||||||
|
//选取合适的服务器进行查询
|
||||||
|
//先查看改知识库是否已经创建
|
||||||
|
key := fmt.Sprintf("kbase_%s", kbase.UUID)
|
||||||
|
KBaseServerID := worker.GetRedis(key)
|
||||||
|
//若未分配
|
||||||
|
if KBaseServerID == "" {
|
||||||
|
KBaseServerSet := "kbase_server_id_set"
|
||||||
|
//进行分配
|
||||||
|
KBaseServerID = worker.GetRedisSetRandomMember(KBaseServerSet)
|
||||||
|
if KBaseServerID == "" {
|
||||||
|
return errors.New("no kbase server id")
|
||||||
|
}
|
||||||
|
//记录分配的知识库服务器
|
||||||
|
worker.SetRedis(key, KBaseServerID)
|
||||||
|
}
|
||||||
|
chanel := fmt.Sprintf("kbase_%s", kbase.UUID)
|
||||||
|
pubsub := worker.RedisClient.Subscribe(context.Background(), chanel)
|
||||||
|
defer pubsub.Close()
|
||||||
|
ch := pubsub.Channel()
|
||||||
|
//1min定时器,超时则返回错误
|
||||||
|
timer := time.NewTimer(time.Minute)
|
||||||
|
var resp proto.KnowledgeBaseServerResponse
|
||||||
|
select {
|
||||||
|
case m := <-ch:
|
||||||
|
msg := m.Payload // 获取消息,消息格式为json
|
||||||
|
err2 := json.Unmarshal([]byte(msg), &resp)
|
||||||
|
if err2 != nil {
|
||||||
|
log.Println("Error parsing JSON:", err2)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case <-timer.C:
|
||||||
|
log.Println("kbase query timeout")
|
||||||
|
return errors.New("kbase query timeout")
|
||||||
|
}
|
||||||
|
//解析返回的查询消息
|
||||||
|
text := "根据下面文档片段:\n"
|
||||||
|
querySelect := resp.QuerySelect
|
||||||
|
for i := 0; i < len(querySelect.Documents); i++ {
|
||||||
|
text += "文件:"
|
||||||
|
for j := 0; j < len(querySelect.Metadatas); j++ {
|
||||||
|
text += querySelect.Metadatas[i][j].Source + ","
|
||||||
|
}
|
||||||
|
text += ":\n内容:"
|
||||||
|
for j := 0; j < len(querySelect.Documents[i]); j++ {
|
||||||
|
text += querySelect.Documents[i][j] + ","
|
||||||
|
}
|
||||||
|
text += "\n"
|
||||||
|
}
|
||||||
|
//用户问题
|
||||||
|
text += "回答下面用户问题:" + data.Msg
|
||||||
|
data.Msg = text
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ package service
|
||||||
import (
|
import (
|
||||||
"StuAcaWorksAI/dao"
|
"StuAcaWorksAI/dao"
|
||||||
"StuAcaWorksAI/proto"
|
"StuAcaWorksAI/proto"
|
||||||
|
"StuAcaWorksAI/worker"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) {
|
func CreateKnowledgeBase(userID int, req proto.KnowledgeBaseReq) (dao.KnowledgeBase, error) {
|
||||||
|
|
@ -68,3 +70,13 @@ func DeleteKnowledgeBase(userID int, req proto.KnowledgeBaseReq) error {
|
||||||
|
|
||||||
return dao.DeleteKnowledgeBase(knowledgeBase.ID)
|
return dao.DeleteKnowledgeBase(knowledgeBase.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 将配置文件中支持的知识库处理服务器添加到集合
|
||||||
|
func AddKnowledgeBaseServer() {
|
||||||
|
set := "kbase_server_id_set"
|
||||||
|
var serverIDs []string
|
||||||
|
for _, server := range proto.Config.KBASE_SERVER {
|
||||||
|
serverIDs = append(serverIDs, server.ServerID)
|
||||||
|
}
|
||||||
|
worker.SetRedisSet(set, serverIDs, time.Hour)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -542,3 +542,14 @@ func getKeyTypeAndData(key string) (string, string, error) {
|
||||||
}
|
}
|
||||||
return key_type, fmt.Sprintf("%v", val), err
|
return key_type, fmt.Sprintf("%v", val), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 随机获取集合中的一个元素
|
||||||
|
func GetRedisSetRandomMember(key string) string {
|
||||||
|
ctx := context.Background()
|
||||||
|
val, err := RedisClient.SRandMember(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error getting key: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue