2025-03-21 14:44:40 +08:00
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"StuAcaWorksAI/dao"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"StuAcaWorksAI/proto"
|
2025-05-12 12:11:44 +08:00
|
|
|
|
"StuAcaWorksAI/service/generalModel"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"StuAcaWorksAI/worker"
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"encoding/json"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
"errors"
|
2025-04-09 14:53:16 +08:00
|
|
|
|
"fmt"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
|
|
|
"log"
|
|
|
|
|
|
"time"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// 创建会话服务
|
2025-04-03 16:13:22 +08:00
|
|
|
|
func CreateSession(userID, Type int, name string) (error, uint) {
|
|
|
|
|
|
err, id := dao.CreateSession(userID, Type, name)
|
2025-03-21 14:44:40 +08:00
|
|
|
|
return err, id
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-05-12 12:57:40 +08:00
|
|
|
|
// 将用户的提问总计为会话名称并更新会话
|
|
|
|
|
|
func UpdateSessionNameByAIAndID(id, userID int, originSessionName string) error {
|
|
|
|
|
|
var err error
|
|
|
|
|
|
//获取ai总结的会话名称
|
|
|
|
|
|
prompt := "请根据下面的对话内容,提取为20字以内的主题作为会话名称:\n" + originSessionName
|
|
|
|
|
|
//ai总结模型id
|
2025-05-12 12:59:57 +08:00
|
|
|
|
modelID := proto.Config.AISessionNameModelID //固定为混元lite模型,免费
|
2025-05-12 12:57:40 +08:00
|
|
|
|
models := FindModelByID(modelID, 1)
|
|
|
|
|
|
if models == nil {
|
|
|
|
|
|
log.Println("UpdateSessionNameByAIAndID model not exist")
|
|
|
|
|
|
return errors.New("session name ai model not exist")
|
|
|
|
|
|
}
|
|
|
|
|
|
model := models[0]
|
|
|
|
|
|
var modelParam proto.ModelParam
|
|
|
|
|
|
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("UpdateSessionNameByAIAndID model param error:", err)
|
|
|
|
|
|
return errors.New("session name ai model param error")
|
|
|
|
|
|
}
|
|
|
|
|
|
newSessionName, err2 := generalModel.SingleTurnConversationText(prompt, modelParam)
|
|
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
//更新会话名称
|
|
|
|
|
|
err = UpdateSessionByID(id, userID, newSessionName, nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("UpdateSessionNameByAIAndAI update session name error:", err, "session id:", id, "user id:", userID, "name:", newSessionName)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-21 14:44:40 +08:00
|
|
|
|
// 查找用户的会话列表
|
2025-04-03 19:16:14 +08:00
|
|
|
|
func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse {
|
2025-04-06 16:16:57 +08:00
|
|
|
|
var sessions []dao.Session
|
|
|
|
|
|
if sessionType == -1 {
|
|
|
|
|
|
sessions = dao.FindSessionByUserIDV2(userID)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
sessions = dao.FindSessionByUserID(userID, sessionType)
|
|
|
|
|
|
}
|
2025-03-25 13:26:51 +08:00
|
|
|
|
sessionsMsgCounts := dao.FindSessionMessageCountByUserID(userID)
|
2025-03-30 14:33:52 +08:00
|
|
|
|
sessionsTokens := dao.FindModelTotalTokensBySessionID(userID)
|
2025-03-25 13:26:51 +08:00
|
|
|
|
smap := make(map[int]int)
|
2025-03-30 14:33:52 +08:00
|
|
|
|
st := make(map[int]uint)
|
2025-03-25 13:26:51 +08:00
|
|
|
|
for _, v := range sessionsMsgCounts {
|
|
|
|
|
|
smap[v.SessionID] = v.Count
|
|
|
|
|
|
}
|
2025-03-30 14:33:52 +08:00
|
|
|
|
for _, v := range sessionsTokens {
|
|
|
|
|
|
st[int(v.SessionID)] = v.Token
|
|
|
|
|
|
}
|
2025-03-25 13:26:51 +08:00
|
|
|
|
var res []proto.SessionResponse
|
|
|
|
|
|
for _, v := range sessions {
|
|
|
|
|
|
var session proto.SessionResponse
|
|
|
|
|
|
session.ID = v.ID
|
2025-03-25 13:30:57 +08:00
|
|
|
|
session.CreatedAt = v.CreatedAt
|
|
|
|
|
|
session.UpdatedAt = v.UpdatedAt
|
2025-03-25 13:26:51 +08:00
|
|
|
|
session.Name = v.Name
|
|
|
|
|
|
session.UserID = v.UserID
|
|
|
|
|
|
session.MsgCount = smap[int(v.ID)]
|
2025-03-30 14:33:52 +08:00
|
|
|
|
session.TokenUsage = st[int(v.ID)]
|
2025-03-25 13:26:51 +08:00
|
|
|
|
res = append(res, session)
|
|
|
|
|
|
}
|
|
|
|
|
|
return res
|
2025-03-21 14:44:40 +08:00
|
|
|
|
}
|
2025-04-03 19:16:14 +08:00
|
|
|
|
func FindSessionByID(id, userID, sessionType int) (error, []dao.Session) {
|
2025-03-21 16:20:24 +08:00
|
|
|
|
session := dao.FindSessionByID(id)
|
|
|
|
|
|
var res []dao.Session
|
|
|
|
|
|
if session.UserID != userID {
|
|
|
|
|
|
return errors.New("session not belong to user"), res
|
|
|
|
|
|
}
|
|
|
|
|
|
res = append(res, session)
|
|
|
|
|
|
return nil, res
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func DeleteSessionByID(id, userID int) error {
|
|
|
|
|
|
session := dao.FindSessionByID(id)
|
|
|
|
|
|
if session.UserID != userID {
|
|
|
|
|
|
return errors.New("session not belong to user")
|
|
|
|
|
|
}
|
|
|
|
|
|
err := dao.DeleteSessionByID(id)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
2025-03-21 14:44:40 +08:00
|
|
|
|
|
|
|
|
|
|
// 更新会话名字
|
2025-03-23 13:17:03 +08:00
|
|
|
|
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
|
|
|
|
|
err := dao.UpdateSessionByID(id, userId, name, context)
|
2025-03-21 14:44:40 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 创建消息服务
|
2025-03-28 11:23:45 +08:00
|
|
|
|
func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int, functionId int) (error, uint) {
|
2025-03-21 14:44:40 +08:00
|
|
|
|
//sessionID, fromID, toID, msgType, status int, msg string
|
2025-03-28 11:23:45 +08:00
|
|
|
|
err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg, functionId)
|
2025-03-21 14:44:40 +08:00
|
|
|
|
return err, id
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 通过会话id查询消息
|
|
|
|
|
|
func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) {
|
|
|
|
|
|
//查看会话是否属于该用户
|
|
|
|
|
|
session := dao.FindSessionByID(sessionID)
|
|
|
|
|
|
if session.UserID != userID {
|
|
|
|
|
|
return errors.New("session not belong to user"), nil
|
|
|
|
|
|
}
|
2025-03-26 17:34:45 +08:00
|
|
|
|
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
2025-03-21 14:44:40 +08:00
|
|
|
|
return nil, msgs
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) {
|
|
|
|
|
|
//查看会话是否属于该用户
|
|
|
|
|
|
session := dao.FindSessionByID(sessionID)
|
|
|
|
|
|
if session.UserID != fromID {
|
|
|
|
|
|
return errors.New("session not belong to user"), 0
|
|
|
|
|
|
}
|
2025-03-28 11:23:45 +08:00
|
|
|
|
err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0, 0)
|
2025-03-21 14:44:40 +08:00
|
|
|
|
return err, mid
|
|
|
|
|
|
}
|
2025-03-22 15:39:13 +08:00
|
|
|
|
|
|
|
|
|
|
var client *api.Client
|
2025-03-23 13:17:03 +08:00
|
|
|
|
var SessionContextMap map[int][]int
|
2025-03-22 15:39:13 +08:00
|
|
|
|
|
2025-03-24 14:47:53 +08:00
|
|
|
|
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) {
|
2025-03-22 15:39:13 +08:00
|
|
|
|
var resErr error
|
|
|
|
|
|
var resID uint
|
|
|
|
|
|
//处理消息
|
2025-03-26 20:28:22 +08:00
|
|
|
|
if model.Type == proto.ModelTypeOllama {
|
2025-03-25 14:34:50 +08:00
|
|
|
|
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
|
2025-03-26 20:28:22 +08:00
|
|
|
|
} else if model.Type == proto.ModelTypeSpark {
|
2025-03-25 15:51:05 +08:00
|
|
|
|
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model)
|
2025-03-26 20:28:22 +08:00
|
|
|
|
} else if model.Type == proto.ModelTypeDouBao {
|
2025-03-25 15:48:56 +08:00
|
|
|
|
doubao(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
|
2025-03-24 14:47:53 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
return errors.New("model not exist"), 0
|
2025-03-23 16:03:17 +08:00
|
|
|
|
}
|
|
|
|
|
|
return resErr, resID
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-28 11:23:45 +08:00
|
|
|
|
func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamContext) (error, uint) {
|
|
|
|
|
|
var resErr error
|
|
|
|
|
|
var resID uint
|
|
|
|
|
|
//处理消息
|
2025-05-10 14:05:22 +08:00
|
|
|
|
switch ctx.ModelType {
|
|
|
|
|
|
case proto.ModelTypeOllama:
|
2025-03-28 11:23:45 +08:00
|
|
|
|
resErr, resID = ReceiveOllamaSessionV2(modelParam, ctx)
|
2025-05-10 14:05:22 +08:00
|
|
|
|
case proto.ModelTypeSpark:
|
2025-03-28 11:23:45 +08:00
|
|
|
|
SparkV2(modelParam, ctx)
|
2025-05-10 14:05:22 +08:00
|
|
|
|
case proto.ModelTypeDouBao:
|
2025-03-28 11:23:45 +08:00
|
|
|
|
DouBaoV2(modelParam, ctx)
|
2025-05-10 14:05:22 +08:00
|
|
|
|
case proto.ModelTypeQianfan:
|
|
|
|
|
|
Qianfan(modelParam, ctx)
|
2025-05-11 20:31:39 +08:00
|
|
|
|
case proto.ModelTypeTongyi:
|
|
|
|
|
|
Tongyi(modelParam, ctx)
|
2025-05-12 12:11:44 +08:00
|
|
|
|
case proto.ModelTypeHuanyun:
|
|
|
|
|
|
generalModel.OpenAI(modelParam, ctx, ctx.ModelType)
|
2025-05-10 14:05:22 +08:00
|
|
|
|
default:
|
2025-05-10 15:37:53 +08:00
|
|
|
|
resErr, resID = errors.New("model type not exist:"+ctx.ModelType), 0
|
2025-03-28 11:23:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
return resErr, resID
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-25 14:34:50 +08:00
|
|
|
|
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) {
|
2025-03-23 16:03:17 +08:00
|
|
|
|
var resErr error
|
|
|
|
|
|
var resID uint
|
|
|
|
|
|
var err error
|
|
|
|
|
|
if client == nil {
|
|
|
|
|
|
client, err = api.ClientFromEnvironment()
|
|
|
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("get ollama client error:", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
prompt := msg.Msg
|
|
|
|
|
|
actx, ok := SessionContextMap[sessionID]
|
|
|
|
|
|
if ok {
|
|
|
|
|
|
log.Println("actx has get will use:", len(actx))
|
|
|
|
|
|
}
|
|
|
|
|
|
req := &api.GenerateRequest{
|
2025-03-24 21:45:16 +08:00
|
|
|
|
Model: modelParam.Model,
|
2025-03-23 16:03:17 +08:00
|
|
|
|
Context: actx,
|
|
|
|
|
|
Prompt: prompt,
|
|
|
|
|
|
}
|
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
|
robotMsg := ""
|
|
|
|
|
|
var aiMsg proto.AIQueueMessage
|
|
|
|
|
|
var data proto.WSMessage
|
|
|
|
|
|
respFunc := func(resp api.GenerateResponse) error {
|
2025-03-26 20:28:22 +08:00
|
|
|
|
aiMsg.Type = proto.ModelTypeOllama
|
2025-03-23 16:03:17 +08:00
|
|
|
|
aiMsg.Msg = resp
|
|
|
|
|
|
robotMsg += resp.Response
|
|
|
|
|
|
if resp.Done { //该消息完成
|
|
|
|
|
|
actx = resp.Context
|
|
|
|
|
|
if SessionContextMap == nil {
|
|
|
|
|
|
SessionContextMap = make(map[int][]int)
|
2025-03-22 15:39:13 +08:00
|
|
|
|
}
|
2025-03-23 16:03:17 +08:00
|
|
|
|
SessionContextMap[sessionID] = actx
|
|
|
|
|
|
log.Println("actx has save:", len(SessionContextMap[sessionID]))
|
|
|
|
|
|
//创建消息
|
2025-03-28 11:23:45 +08:00
|
|
|
|
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1, 0) //机器人id为2,消息类型为2,状态为1
|
2025-03-23 16:03:17 +08:00
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
log.Println("create robot message error:", err2)
|
|
|
|
|
|
}
|
|
|
|
|
|
aiMsg.ID = msgID
|
|
|
|
|
|
resID = msgID
|
2025-03-22 15:39:13 +08:00
|
|
|
|
}
|
2025-03-23 16:03:17 +08:00
|
|
|
|
//发送消息
|
|
|
|
|
|
data.Msg = aiMsg
|
|
|
|
|
|
data.SessionID = sessionID
|
|
|
|
|
|
data.ToID = userID
|
|
|
|
|
|
aiMsgStr, _ := json.Marshal(data)
|
|
|
|
|
|
worker.Publish(channel, string(aiMsgStr), time.Second*60)
|
2025-03-25 14:34:50 +08:00
|
|
|
|
go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model)
|
2025-03-23 16:03:17 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
err = client.Generate(ctx, req, respFunc)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Fatal(err)
|
2025-03-22 15:39:13 +08:00
|
|
|
|
}
|
|
|
|
|
|
return resErr, resID
|
|
|
|
|
|
}
|
2025-03-23 16:03:17 +08:00
|
|
|
|
|
2025-03-28 11:23:45 +08:00
|
|
|
|
func ReceiveOllamaSessionV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) (error, uint) {
|
|
|
|
|
|
var resErr error
|
|
|
|
|
|
var resID uint
|
|
|
|
|
|
var err error
|
|
|
|
|
|
if client == nil {
|
|
|
|
|
|
client, err = api.ClientFromEnvironment()
|
|
|
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("get ollama client error:", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
prompt := imCtx.Question
|
|
|
|
|
|
actx, ok := SessionContextMap[int(imCtx.SessionID)]
|
|
|
|
|
|
if ok {
|
|
|
|
|
|
log.Println("actx has get will use:", len(actx))
|
|
|
|
|
|
}
|
|
|
|
|
|
req := &api.GenerateRequest{
|
|
|
|
|
|
Model: modelParam.Model,
|
|
|
|
|
|
Context: actx,
|
|
|
|
|
|
Prompt: prompt,
|
|
|
|
|
|
}
|
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
|
robotMsg := ""
|
|
|
|
|
|
var aiMsg proto.AIQueueMessage
|
|
|
|
|
|
var data proto.WSMessage
|
|
|
|
|
|
respFunc := func(resp api.GenerateResponse) error {
|
|
|
|
|
|
aiMsg.Type = proto.ModelTypeOllama
|
|
|
|
|
|
aiMsg.Msg = resp
|
|
|
|
|
|
robotMsg += resp.Response
|
|
|
|
|
|
if resp.Done { //该消息完成
|
|
|
|
|
|
actx = resp.Context
|
|
|
|
|
|
if SessionContextMap == nil {
|
|
|
|
|
|
SessionContextMap = make(map[int][]int)
|
|
|
|
|
|
}
|
|
|
|
|
|
SessionContextMap[imCtx.SessionID] = actx
|
|
|
|
|
|
log.Println("actx has save:", len(SessionContextMap[imCtx.SessionID]))
|
|
|
|
|
|
//创建消息
|
2025-04-10 15:27:56 +08:00
|
|
|
|
err2, msgID := CreateMessage(proto.ModelToUserMsgType, imCtx.SessionID, imCtx.ModelID, imCtx.UserID, robotMsg, 1, imCtx.FunctionID) //机器人id为2,消息类型为2,状态为1
|
2025-03-28 11:23:45 +08:00
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
log.Println("create robot message error:", err2)
|
|
|
|
|
|
}
|
|
|
|
|
|
aiMsg.ID = msgID
|
|
|
|
|
|
resID = msgID
|
|
|
|
|
|
}
|
|
|
|
|
|
//发送消息
|
|
|
|
|
|
data.Msg = aiMsg
|
|
|
|
|
|
data.SessionID = int(imCtx.SessionID)
|
|
|
|
|
|
data.ToID = imCtx.UserID
|
|
|
|
|
|
aiMsgStr, _ := json.Marshal(data)
|
|
|
|
|
|
worker.Publish(imCtx.Channel, string(aiMsgStr), time.Second*60)
|
|
|
|
|
|
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, resp, modelParam.Model)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
err = client.Generate(ctx, req, respFunc)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
|
}
|
|
|
|
|
|
return resErr, resID
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-25 15:51:05 +08:00
|
|
|
|
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) {
|
2025-03-24 13:36:42 +08:00
|
|
|
|
var resErr error
|
|
|
|
|
|
var resID uint
|
2025-03-25 15:51:05 +08:00
|
|
|
|
Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
|
2025-03-24 13:36:42 +08:00
|
|
|
|
return resErr, resID
|
2025-03-23 16:03:17 +08:00
|
|
|
|
}
|
2025-04-05 19:57:50 +08:00
|
|
|
|
|
|
|
|
|
|
func CheckUserCreatePPTSessionPermission(userID int) error {
|
|
|
|
|
|
sessionCount := dao.FindUserSessionCount(userID, proto.SessionTypeUserCreatePPT)
|
|
|
|
|
|
var err error
|
|
|
|
|
|
user := GetUserByIDWithCache(userID)
|
|
|
|
|
|
if user.Role == "admin" {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
if sessionCount > 3 {
|
|
|
|
|
|
err = errors.New("create ppt session count exceed limit")
|
|
|
|
|
|
}
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
2025-04-09 14:53:16 +08:00
|
|
|
|
|
|
|
|
|
|
// 知识库消息,从原始请求数据,上下文数据
|
2025-04-10 15:06:26 +08:00
|
|
|
|
func SelectKBaseContentByQuery(imCtx *proto.IMParamContext, data *proto.WSMessageReq) error {
|
2025-04-09 14:53:16 +08:00
|
|
|
|
var req proto.KnowledgeBaseServerRequest
|
2025-04-10 15:06:26 +08:00
|
|
|
|
req.IMContext = *imCtx
|
2025-04-09 14:53:16 +08:00
|
|
|
|
kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
req.KnowledgeBase = proto.KBase{ID: kbase.ID, UUID: kbase.UUID, Name: kbase.Name, Description: kbase.Description, FileIDs: kbase.FileIDs, SessionID: kbase.SessionID}
|
|
|
|
|
|
//选取合适的服务器进行查询
|
|
|
|
|
|
//先查看改知识库是否已经创建
|
|
|
|
|
|
key := fmt.Sprintf("kbase_%s", kbase.UUID)
|
|
|
|
|
|
KBaseServerID := worker.GetRedis(key)
|
2025-04-11 14:24:01 +08:00
|
|
|
|
//若未分配,从可用的知识库服务器中进行分配
|
|
|
|
|
|
setKey := "kbase_server_id_set_is_available" //可用状态集合
|
2025-04-09 14:53:16 +08:00
|
|
|
|
if KBaseServerID == "" {
|
2025-04-11 14:24:01 +08:00
|
|
|
|
//KBaseServerSet := "kbase_server_id_set"
|
2025-04-09 14:53:16 +08:00
|
|
|
|
//进行分配
|
2025-04-11 14:24:01 +08:00
|
|
|
|
KBaseServerID = worker.GetRedisSetRandomMember(setKey)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
if KBaseServerID == "" {
|
|
|
|
|
|
return errors.New("no kbase server id")
|
|
|
|
|
|
}
|
|
|
|
|
|
//记录分配的知识库服务器
|
2025-04-09 17:37:24 +08:00
|
|
|
|
res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7)
|
|
|
|
|
|
log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
}
|
2025-04-10 15:06:26 +08:00
|
|
|
|
//查看KBaseServerID是否可用
|
|
|
|
|
|
isAvailable := worker.SetRedisSetIsMember(setKey, KBaseServerID)
|
|
|
|
|
|
if isAvailable == false {
|
|
|
|
|
|
log.Println("当前知识库服务器不可用,重新分配")
|
|
|
|
|
|
return errors.New("kbase server id is not available")
|
|
|
|
|
|
}
|
2025-04-09 17:37:24 +08:00
|
|
|
|
log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name)
|
|
|
|
|
|
reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称
|
|
|
|
|
|
reqStr, _ := json.Marshal(req)
|
|
|
|
|
|
//将消息转发到对应频道
|
2025-04-11 14:24:01 +08:00
|
|
|
|
worker.Publish(reqChanel, string(reqStr), time.Minute*3)
|
2025-04-09 17:37:24 +08:00
|
|
|
|
|
|
|
|
|
|
chanel := fmt.Sprintf("kbase_%s", kbase.UUID) //获取返回消息频道
|
2025-04-09 14:53:16 +08:00
|
|
|
|
pubsub := worker.RedisClient.Subscribe(context.Background(), chanel)
|
|
|
|
|
|
defer pubsub.Close()
|
|
|
|
|
|
ch := pubsub.Channel()
|
|
|
|
|
|
//1min定时器,超时则返回错误
|
2025-04-11 14:24:01 +08:00
|
|
|
|
timer := time.NewTimer(time.Minute * 3)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
var resp proto.KnowledgeBaseServerResponse
|
|
|
|
|
|
select {
|
|
|
|
|
|
case m := <-ch:
|
|
|
|
|
|
msg := m.Payload // 获取消息,消息格式为json
|
2025-04-10 15:06:26 +08:00
|
|
|
|
//log.Println("resp:", msg)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
err2 := json.Unmarshal([]byte(msg), &resp)
|
|
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
log.Println("Error parsing JSON:", err2)
|
|
|
|
|
|
}
|
|
|
|
|
|
break
|
|
|
|
|
|
case <-timer.C:
|
|
|
|
|
|
log.Println("kbase query timeout")
|
|
|
|
|
|
return errors.New("kbase query timeout")
|
|
|
|
|
|
}
|
|
|
|
|
|
//解析返回的查询消息
|
|
|
|
|
|
text := "根据下面文档片段:\n"
|
|
|
|
|
|
querySelect := resp.QuerySelect
|
2025-04-10 15:06:26 +08:00
|
|
|
|
qText := getSelectedDocument(querySelect)
|
|
|
|
|
|
log.Println("kbase type:", data.KBaseType)
|
|
|
|
|
|
if data.KBaseType == "query" {
|
|
|
|
|
|
//则将query为模型返回结果
|
|
|
|
|
|
//将消息存入数据库
|
|
|
|
|
|
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, qText, imCtx.FunctionID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("kbase query message save error:", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
kbaseQueryToGeneralMassageAndSendMsgQueue(qText, imCtx.Channel, imCtx.SessionID, imCtx.UserID) //将消息发送到消息队列
|
|
|
|
|
|
} else if data.KBaseType == "chat" {
|
|
|
|
|
|
//用户问题
|
|
|
|
|
|
text += qText
|
|
|
|
|
|
text += "回答用户问题:" + data.Msg
|
|
|
|
|
|
data.Msg = text
|
|
|
|
|
|
log.Println("text:", text)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
return errors.New("unknow kbase type")
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func getSelectedDocument(querySelect proto.KnowledgeBaseServerResponseSelect) string {
|
|
|
|
|
|
text := ""
|
2025-04-09 14:53:16 +08:00
|
|
|
|
for i := 0; i < len(querySelect.Documents); i++ {
|
2025-04-10 15:06:26 +08:00
|
|
|
|
if len(querySelect.Documents[i]) == 0 {
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
2025-04-09 14:53:16 +08:00
|
|
|
|
text += "文件:"
|
|
|
|
|
|
for j := 0; j < len(querySelect.Metadatas); j++ {
|
|
|
|
|
|
text += querySelect.Metadatas[i][j].Source + ","
|
|
|
|
|
|
}
|
|
|
|
|
|
text += ":\n内容:"
|
|
|
|
|
|
for j := 0; j < len(querySelect.Documents[i]); j++ {
|
|
|
|
|
|
text += querySelect.Documents[i][j] + ","
|
|
|
|
|
|
}
|
|
|
|
|
|
text += "\n"
|
|
|
|
|
|
}
|
2025-04-10 15:06:26 +08:00
|
|
|
|
return text
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func WSReturnErrorMessage(channel string, msg string, code int) {
|
|
|
|
|
|
var data proto.WSMessage
|
|
|
|
|
|
data.Code = code
|
|
|
|
|
|
data.Type = msg //错误时type为错误信息
|
|
|
|
|
|
data.SessionID = 0
|
|
|
|
|
|
data.ToID = 0
|
|
|
|
|
|
errStr, _ := json.Marshal(data)
|
|
|
|
|
|
worker.Publish(channel, string(errStr), time.Second*60)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func kbaseQueryToGeneralMassageAndSendMsgQueue(answer, channel string, SessionID, UserID int) {
|
|
|
|
|
|
var aiMsg proto.AIQueueMessage
|
|
|
|
|
|
var wsMsg proto.WSMessage
|
|
|
|
|
|
var apiMsg api.GenerateResponse
|
|
|
|
|
|
|
|
|
|
|
|
apiMsg.Done = true
|
|
|
|
|
|
|
|
|
|
|
|
apiMsg.Model = proto.KnowledgeBaseServerResponseType
|
|
|
|
|
|
apiMsg.Response = answer
|
|
|
|
|
|
//消息队列部分
|
|
|
|
|
|
aiMsg.Type = proto.KnowledgeBaseServerResponseType
|
|
|
|
|
|
aiMsg.Msg = apiMsg
|
|
|
|
|
|
//ws发送消息部分
|
|
|
|
|
|
wsMsg.Msg = aiMsg
|
|
|
|
|
|
wsMsg.SessionID = SessionID
|
|
|
|
|
|
wsMsg.ToID = UserID
|
|
|
|
|
|
wsMsg.Type = proto.KnowledgeBaseServerResponseType
|
|
|
|
|
|
|
|
|
|
|
|
//发送消息
|
|
|
|
|
|
wsMsgStr, _ := json.Marshal(wsMsg)
|
|
|
|
|
|
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
}
|