2025-03-21 14:44:40 +08:00
|
|
|
|
package handler
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
2025-03-24 14:47:53 +08:00
|
|
|
|
"StuAcaWorksAI/dao"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
"StuAcaWorksAI/proto"
|
|
|
|
|
|
"StuAcaWorksAI/service"
|
2025-05-13 11:02:09 +08:00
|
|
|
|
"StuAcaWorksAI/service/generalModel"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"StuAcaWorksAI/worker"
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"encoding/json"
|
2025-03-24 14:47:53 +08:00
|
|
|
|
"errors"
|
2025-04-09 14:53:16 +08:00
|
|
|
|
"fmt"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
"github.com/gin-gonic/gin"
|
2025-03-28 11:23:45 +08:00
|
|
|
|
"github.com/google/uuid"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
|
|
"log"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
"net/http"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"strconv"
|
|
|
|
|
|
"sync"
|
2025-05-12 13:57:38 +08:00
|
|
|
|
"time"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type SMessage struct {
|
|
|
|
|
|
Type int `json:"type" form:"type"`
|
|
|
|
|
|
Msg string `json:"msg" form:"msg"`
|
|
|
|
|
|
ToID int `json:"to_id" form:"to_id"`
|
|
|
|
|
|
SessionID int `json:"session_id" form:"session_id"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-22 15:39:13 +08:00
|
|
|
|
var (
|
|
|
|
|
|
upgrader = websocket.Upgrader{
|
|
|
|
|
|
ReadBufferSize: 1024,
|
|
|
|
|
|
WriteBufferSize: 1024,
|
|
|
|
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
|
|
|
|
// 允许所有来源的连接
|
|
|
|
|
|
return true
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// 创建一个用于存储WebSocket连接的map和互斥锁
|
|
|
|
|
|
var (
|
|
|
|
|
|
clients = make(map[*websocket.Conn]bool)
|
|
|
|
|
|
clientsMux sync.Mutex
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-03-21 14:44:40 +08:00
|
|
|
|
func SetUpIMGroup(router *gin.Engine) {
|
|
|
|
|
|
imGroup := router.Group("/im")
|
|
|
|
|
|
imGroup.POST("/send_message", SendMessage)
|
2025-03-22 15:39:13 +08:00
|
|
|
|
imGroup.GET("/ai_chat_ws", SRMessage)
|
2025-03-22 17:50:41 +08:00
|
|
|
|
imGroup.POST("/get_message", GetMessage)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
imGroup.GET("/kbase_text_ws", KBaseTextWS)
|
2025-05-13 11:02:09 +08:00
|
|
|
|
|
|
|
|
|
|
//添加国外ai处理接口,服务之间调用
|
|
|
|
|
|
imGroup.POST("/send_message_foreign_ai", SendMessageForeignAI)
|
|
|
|
|
|
imGroup.POST("/receive_message_foreign_ai", ReceiveMessageForeignAI)
|
2025-03-22 17:50:41 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func GetMessage(c *gin.Context) {
|
|
|
|
|
|
var req SMessage
|
|
|
|
|
|
id, _ := c.Get("id")
|
|
|
|
|
|
userID := int(id.(float64))
|
|
|
|
|
|
if err := c.ShouldBind(&req); err == nil {
|
|
|
|
|
|
err2, messages := service.FindMessageBySessionID(req.SessionID, userID)
|
|
|
|
|
|
if err2 == nil {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": messages})
|
|
|
|
|
|
} else {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
|
|
|
|
|
|
}
|
2025-03-21 14:44:40 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func SendMessage(c *gin.Context) {
|
|
|
|
|
|
var req SMessage
|
|
|
|
|
|
id, _ := c.Get("id")
|
|
|
|
|
|
userID := int(id.(float64))
|
|
|
|
|
|
if err := c.ShouldBind(&req); err == nil {
|
|
|
|
|
|
var err2 error
|
|
|
|
|
|
var mid uint
|
|
|
|
|
|
err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg)
|
|
|
|
|
|
if err2 == nil {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid})
|
|
|
|
|
|
} else {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-22 15:39:13 +08:00
|
|
|
|
// ws长连接
|
|
|
|
|
|
func SRMessage(c *gin.Context) {
|
|
|
|
|
|
id, _ := c.Get("id")
|
|
|
|
|
|
userID := int(id.(float64))
|
|
|
|
|
|
// 升级HTTP连接为WebSocket连接
|
|
|
|
|
|
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
|
|
|
|
clients[ws] = true
|
|
|
|
|
|
if err1 != nil {
|
|
|
|
|
|
log.Println("升级为WebSocket时发生错误:", err1)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
subscribeAndHandleIMMessages(ws, userID)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|
|
|
|
|
ctx := context.Background()
|
2025-03-28 11:23:45 +08:00
|
|
|
|
//随机字符串
|
|
|
|
|
|
strRandom := uuid.NewString()
|
|
|
|
|
|
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg_" + strRandom
|
2025-03-22 15:39:13 +08:00
|
|
|
|
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
|
|
|
|
|
|
defer pubsub.Close()
|
|
|
|
|
|
defer ws.Close()
|
|
|
|
|
|
ch := pubsub.Channel()
|
|
|
|
|
|
//定时器,每3秒发送一次心跳
|
2025-03-22 17:38:44 +08:00
|
|
|
|
//ticker := time.NewTicker(3 * time.Second)
|
|
|
|
|
|
//defer ticker.Stop()
|
2025-03-22 15:39:13 +08:00
|
|
|
|
|
|
|
|
|
|
var sessionID uint
|
2025-03-25 14:55:33 +08:00
|
|
|
|
sessionID = 0
|
2025-03-22 15:39:13 +08:00
|
|
|
|
|
|
|
|
|
|
//接收客户端消息,发送消息
|
|
|
|
|
|
done := make(chan struct{})
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
defer close(done)
|
|
|
|
|
|
for {
|
|
|
|
|
|
_, message, err := ws.ReadMessage()
|
2025-03-22 16:26:09 +08:00
|
|
|
|
log.Println("rec message:", string(message))
|
2025-03-22 15:39:13 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("Read error:", err)
|
2025-03-22 17:00:43 +08:00
|
|
|
|
break
|
2025-03-22 15:39:13 +08:00
|
|
|
|
}
|
2025-03-22 17:11:42 +08:00
|
|
|
|
var data proto.WSMessageReq
|
2025-03-22 15:39:13 +08:00
|
|
|
|
err1 := json.Unmarshal(message, &data)
|
|
|
|
|
|
if err1 != nil {
|
|
|
|
|
|
log.Println("Error parsing JSON:", err1)
|
|
|
|
|
|
}
|
2025-03-28 11:23:45 +08:00
|
|
|
|
err = doReceiveGenChatMessage(userId, &sessionID, &data, chanel)
|
2025-03-22 15:39:13 +08:00
|
|
|
|
if err != nil {
|
2025-03-24 14:47:53 +08:00
|
|
|
|
log.Println("do receive gen chat message error:", err)
|
2025-05-12 13:57:38 +08:00
|
|
|
|
//返回错误消息
|
|
|
|
|
|
errMsg := proto.WSMessage{Code: proto.OperationFailed, ErrorMsg: "服务器处理错误", SessionID: int(sessionID)}
|
|
|
|
|
|
errMsgStr, _ := json.Marshal(errMsg)
|
|
|
|
|
|
worker.Publish(chanel, string(errMsgStr), time.Second*60)
|
2025-03-22 15:39:13 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
go func() {
|
2025-03-22 17:56:39 +08:00
|
|
|
|
//defer close(done)
|
2025-03-22 15:39:13 +08:00
|
|
|
|
//从redis订阅消息
|
|
|
|
|
|
for m := range ch {
|
|
|
|
|
|
msg := m.Payload // 获取消息,消息格式为json
|
|
|
|
|
|
if msg != "" {
|
2025-03-22 17:11:42 +08:00
|
|
|
|
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
|
2025-03-22 15:39:13 +08:00
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
// 发生错误,删除连接
|
|
|
|
|
|
clientsMux.Lock()
|
|
|
|
|
|
delete(clients, ws)
|
|
|
|
|
|
clientsMux.Unlock()
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-03-22 17:56:39 +08:00
|
|
|
|
//查看done是否关闭
|
2025-03-22 15:39:13 +08:00
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
//发送心跳包
|
2025-03-22 17:42:57 +08:00
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-done:
|
2025-03-23 13:17:03 +08:00
|
|
|
|
//结束会话将context保存
|
|
|
|
|
|
ctx2, ok := service.SessionContextMap[int(sessionID)]
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
log.Println("session context not found")
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("update session context error:", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
//保存的会话
|
|
|
|
|
|
delete(service.SessionContextMap, int(sessionID))
|
2025-03-22 17:42:57 +08:00
|
|
|
|
//客户端断开连接
|
|
|
|
|
|
return
|
|
|
|
|
|
//case t := <-ticker.C:
|
|
|
|
|
|
// //发送心跳
|
|
|
|
|
|
// err := ws.WriteMessage(websocket.PingMessage, nil)
|
|
|
|
|
|
// if err != nil {
|
|
|
|
|
|
// log.Println("Ping error:", err, t)
|
|
|
|
|
|
// return
|
|
|
|
|
|
// }
|
|
|
|
|
|
//}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-03-23 13:17:03 +08:00
|
|
|
|
|
2025-03-22 15:39:13 +08:00
|
|
|
|
}
|
2025-03-24 14:47:53 +08:00
|
|
|
|
|
2025-03-28 11:23:45 +08:00
|
|
|
|
func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error {
|
2025-03-24 14:47:53 +08:00
|
|
|
|
var model dao.Model
|
|
|
|
|
|
//查看请求功能类型
|
2025-03-31 13:26:06 +08:00
|
|
|
|
//models, funcs := service.FindFuncModelByFunction(data.Function, userId)
|
|
|
|
|
|
models, funcs, mferr := service.FindFuncModelListByFunctionV2(data.Function)
|
|
|
|
|
|
if mferr != nil {
|
|
|
|
|
|
return mferr
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-28 15:33:01 +08:00
|
|
|
|
//log.Println("find function model by function:", models, funcs)
|
|
|
|
|
|
//log.Println("funcs:", len(funcs), "\tmodels:", len(models))
|
2025-03-28 15:21:57 +08:00
|
|
|
|
if len(funcs) == 0 {
|
2025-03-24 14:47:53 +08:00
|
|
|
|
return errors.New("function not exist")
|
|
|
|
|
|
}
|
2025-03-28 15:21:57 +08:00
|
|
|
|
if len(models) == 0 {
|
2025-03-28 15:33:01 +08:00
|
|
|
|
//log.Println("doReceiveGenChatMessage models:", models)
|
2025-03-27 14:36:21 +08:00
|
|
|
|
return errors.New("model not exist")
|
|
|
|
|
|
}
|
2025-03-28 15:33:01 +08:00
|
|
|
|
var err error
|
2025-04-01 15:00:59 +08:00
|
|
|
|
sessionName := ""
|
2025-04-07 18:00:24 +08:00
|
|
|
|
if data.IsImage || data.IsFile {
|
|
|
|
|
|
var fileMsg proto.UserFileMessage
|
|
|
|
|
|
err = json.Unmarshal([]byte(data.Msg), &fileMsg)
|
2025-04-01 15:00:59 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("Error parsing JSON:", err)
|
|
|
|
|
|
}
|
2025-04-07 18:00:24 +08:00
|
|
|
|
sessionName = fileMsg.Text
|
2025-04-01 15:00:59 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
sessionName = data.Msg
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-28 17:11:49 +08:00
|
|
|
|
//前端传入的sessionID为0,则新会话
|
|
|
|
|
|
*sessionID = uint(data.SessionID)
|
2025-03-28 15:33:01 +08:00
|
|
|
|
if *sessionID == 0 {
|
|
|
|
|
|
//创建会话
|
2025-04-01 19:36:29 +08:00
|
|
|
|
if sessionName == "" {
|
|
|
|
|
|
sessionName = "新会话"
|
|
|
|
|
|
}
|
2025-04-03 16:13:22 +08:00
|
|
|
|
sessionType := proto.SessionTypeUserWithModelGeneration //默认通用聊天
|
|
|
|
|
|
if data.Function == "gen-prompt" {
|
|
|
|
|
|
sessionType = proto.SessionTypeUserPrompt //提示词会话
|
|
|
|
|
|
}
|
|
|
|
|
|
err, *sessionID = service.CreateSession(userId, sessionType, sessionName)
|
2025-03-28 15:33:01 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("Create session error:", err)
|
|
|
|
|
|
}
|
2025-05-12 12:57:40 +08:00
|
|
|
|
log.Println("create session id:", *sessionID)
|
|
|
|
|
|
//会话名称过长,使用ai进行总结
|
2025-05-12 13:57:38 +08:00
|
|
|
|
if len(sessionName) > 30 && data.Function == "gen-chat" { //只有在通用聊天时才进行ai总结
|
2025-05-12 12:57:40 +08:00
|
|
|
|
//通过ai总结会话名称
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
sessionErr := service.UpdateSessionNameByAIAndID(int(*sessionID), userId, sessionName)
|
|
|
|
|
|
if sessionErr != nil {
|
|
|
|
|
|
log.Println("update ai session name error:", sessionErr)
|
|
|
|
|
|
return
|
|
|
|
|
|
} //更新会话名称
|
|
|
|
|
|
}()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-28 15:33:01 +08:00
|
|
|
|
}
|
2025-03-31 13:26:06 +08:00
|
|
|
|
//获取模型
|
|
|
|
|
|
for _, v := range models {
|
|
|
|
|
|
if v.ID == data.ModelID {
|
|
|
|
|
|
model = v
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
if model.ID == 0 {
|
|
|
|
|
|
//说明前端传入的模型id不正确,取第一个模型
|
|
|
|
|
|
model = models[0]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-24 14:47:53 +08:00
|
|
|
|
//解析模型参数
|
|
|
|
|
|
var modelParam proto.ModelParam
|
|
|
|
|
|
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
modelParam.Url = model.Url
|
2025-03-27 14:36:21 +08:00
|
|
|
|
modelParam.System = funcs[0].Info //系统功能
|
2025-04-22 16:46:19 +08:00
|
|
|
|
//设置模型temperature及其他参数
|
|
|
|
|
|
checkAndSetModelTemperatureParam(&modelParam, data, &model)
|
2025-03-28 15:33:01 +08:00
|
|
|
|
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type}
|
2025-03-24 14:47:53 +08:00
|
|
|
|
|
|
|
|
|
|
var userMsgID uint
|
2025-03-31 19:30:46 +08:00
|
|
|
|
status := proto.MsgHasRead
|
2025-04-07 18:00:24 +08:00
|
|
|
|
if data.IsImage || data.IsFile {
|
2025-03-31 19:30:46 +08:00
|
|
|
|
status = proto.UserToModelImageMsgType //图片类型
|
|
|
|
|
|
}
|
2025-04-10 15:06:26 +08:00
|
|
|
|
//将消息存入数据库
|
|
|
|
|
|
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
|
|
|
|
|
|
log.Println("create user message id:", userMsgID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
} //机器人id为2,消息类型为1,状态为1
|
2025-05-11 20:31:39 +08:00
|
|
|
|
imContext.QuestionMessageID = int(userMsgID)
|
2025-04-10 15:06:26 +08:00
|
|
|
|
|
2025-04-09 14:53:16 +08:00
|
|
|
|
//若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
|
|
|
|
|
|
if data.IsKBase {
|
2025-04-09 17:37:24 +08:00
|
|
|
|
//根据会话ID获取对应知识库消息
|
|
|
|
|
|
kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId))
|
|
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
log.Println("find knowledge base by session id error:", err2)
|
2025-04-10 15:06:26 +08:00
|
|
|
|
service.WSReturnErrorMessage(imContext.Channel, "find knowledge base error:"+err2.Error(), proto.WSKBaseSessionError)
|
2025-04-09 17:37:24 +08:00
|
|
|
|
return err2
|
|
|
|
|
|
}
|
|
|
|
|
|
data.KBaseID = kbaseMsg.ID
|
|
|
|
|
|
err2 = service.SelectKBaseContentByQuery(&imContext, data)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
log.Println("select kbase content error:", err2)
|
2025-04-10 15:06:26 +08:00
|
|
|
|
service.WSReturnErrorMessage(imContext.Channel, "select kbase error:"+err2.Error(), proto.WSKBaseServerError)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
return err2
|
|
|
|
|
|
}
|
2025-04-10 15:06:26 +08:00
|
|
|
|
if data.KBaseType == "query" {
|
|
|
|
|
|
//若为查询直接返回
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2025-04-09 14:53:16 +08:00
|
|
|
|
}
|
2025-03-24 14:47:53 +08:00
|
|
|
|
//需要后续处理
|
2025-03-28 11:23:45 +08:00
|
|
|
|
//err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam)
|
|
|
|
|
|
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
|
2025-03-24 14:47:53 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
2025-04-09 14:53:16 +08:00
|
|
|
|
|
|
|
|
|
|
func KBaseTextWS(c *gin.Context) {
|
|
|
|
|
|
id, _ := c.Get("id")
|
|
|
|
|
|
userID := int(id.(float64))
|
|
|
|
|
|
//query参数
|
|
|
|
|
|
kBaseSercerID := c.Query("kbase_server_id")
|
2025-04-09 17:37:24 +08:00
|
|
|
|
if kBaseSercerID == "" {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"code": proto.ParameterError, "message": "kbase_server_id不能为空"})
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
2025-04-09 14:53:16 +08:00
|
|
|
|
// 升级HTTP连接为WebSocket连接
|
|
|
|
|
|
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
|
|
|
|
if err1 != nil {
|
|
|
|
|
|
log.Println("升级为WebSocket时发生错误:", err1)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-04-10 15:06:26 +08:00
|
|
|
|
func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseServerID string) {
|
2025-04-09 14:53:16 +08:00
|
|
|
|
ctx := context.Background()
|
2025-04-10 15:06:26 +08:00
|
|
|
|
chanel := fmt.Sprintf("kbase_server_id_%s", kBaseServerID) // 频道名称
|
|
|
|
|
|
//设置KBaseServerID可用状态,加入可用集合
|
|
|
|
|
|
setKey := "kbase_server_id_set_is_available"
|
|
|
|
|
|
worker.SetRedisSetAdd(setKey, kBaseServerID)
|
2025-04-09 14:53:16 +08:00
|
|
|
|
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
|
2025-04-10 15:06:26 +08:00
|
|
|
|
defer func() {
|
|
|
|
|
|
//设置KBaseServerID不可用状态,删除可用集合
|
|
|
|
|
|
worker.SetRedisSetRemove(setKey, kBaseServerID)
|
|
|
|
|
|
pubsub.Close()
|
|
|
|
|
|
ws.Close()
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
2025-04-09 14:53:16 +08:00
|
|
|
|
ch := pubsub.Channel()
|
|
|
|
|
|
|
|
|
|
|
|
//接收客户端消息,发送消息
|
|
|
|
|
|
done := make(chan struct{})
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
defer close(done)
|
|
|
|
|
|
for {
|
|
|
|
|
|
_, message, err := ws.ReadMessage()
|
|
|
|
|
|
var data proto.KnowledgeBaseServerResponse
|
|
|
|
|
|
//log.Println("rec message:", string(message))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("Read error:", err)
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
err1 := json.Unmarshal(message, &data)
|
|
|
|
|
|
if err1 != nil {
|
|
|
|
|
|
log.Println("Error parsing JSON:", err1)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2025-04-11 16:35:36 +08:00
|
|
|
|
service.CreateAIStreamMsg(userId, data.IMContext.ModelID, data.IMContext.SessionID, data, proto.KnowledgeBaseServerResponseType) //将消息存入数据库
|
2025-04-09 14:53:16 +08:00
|
|
|
|
//这里主要将接收到的消息转到对应协程
|
|
|
|
|
|
//目标chanel
|
|
|
|
|
|
targetChanel := fmt.Sprintf("kbase_%s", data.KnowledgeBase.UUID)
|
|
|
|
|
|
//将消息转入
|
|
|
|
|
|
err = worker.RedisClient.Publish(ctx, targetChanel, message).Err()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("KBase Text WS Publish error:", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
//defer close(done)
|
|
|
|
|
|
//从redis订阅消息
|
|
|
|
|
|
for m := range ch {
|
|
|
|
|
|
msg := m.Payload // 获取消息,消息格式为json
|
|
|
|
|
|
if msg != "" {
|
|
|
|
|
|
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-done:
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-04-22 16:46:19 +08:00
|
|
|
|
|
|
|
|
|
|
func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto.WSMessageReq, model *dao.Model) {
|
|
|
|
|
|
isErrParam := false
|
|
|
|
|
|
if data.Temperature < 0 || data.Temperature > 2 {
|
|
|
|
|
|
isErrParam = true
|
|
|
|
|
|
} else {
|
|
|
|
|
|
modelParam.Temperature = data.Temperature
|
|
|
|
|
|
}
|
|
|
|
|
|
if model.Type == proto.ModelTypeDouBao {
|
|
|
|
|
|
if data.TopP < 0 || data.TopP > 1 {
|
|
|
|
|
|
isErrParam = true
|
|
|
|
|
|
} else {
|
|
|
|
|
|
modelParam.TopP = data.TopP
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if model.Type == proto.ModelTypeSpark {
|
|
|
|
|
|
if data.TopK < 0 || data.TopK > 1 {
|
|
|
|
|
|
isErrParam = true
|
|
|
|
|
|
} else {
|
|
|
|
|
|
modelParam.TopK = data.TopK
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
isErrParam = true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if isErrParam {
|
|
|
|
|
|
modelParam.SetDefaultParams()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-05-13 11:02:09 +08:00
|
|
|
|
|
|
|
|
|
|
func SendMessageForeignAI(c *gin.Context) {
|
|
|
|
|
|
id, _ := c.Get("user_id")
|
|
|
|
|
|
userID := id.(int)
|
|
|
|
|
|
user := service.GetUserByIDWithCache(userID)
|
|
|
|
|
|
var resp proto.GenerateResp
|
|
|
|
|
|
var req proto.SendMessageForeignAIRRequest
|
|
|
|
|
|
if user.Role == "admin" {
|
|
|
|
|
|
//解析请求参数
|
|
|
|
|
|
if err := c.ShouldBind(&req); err == nil {
|
2025-05-13 12:05:53 +08:00
|
|
|
|
reqBytes, _ := json.Marshal(req)
|
|
|
|
|
|
log.Println("SendMessageForeignAI req:", string(reqBytes))
|
2025-05-13 11:02:09 +08:00
|
|
|
|
//校验参数
|
|
|
|
|
|
if req.ModelParam.Model == "" || req.ModelParam.Url == "" || req.ModelParam.APIKey == "" {
|
|
|
|
|
|
resp.Code = proto.ParameterError
|
|
|
|
|
|
resp.Message = "模型必要参数不能为空"
|
|
|
|
|
|
} else {
|
|
|
|
|
|
//处理
|
2025-05-13 13:14:20 +08:00
|
|
|
|
generalModel.OpenAIForeignAI(req.ModelParam, &req.IMParamContext, req.PlatForm, req.MessageContent)
|
2025-05-13 11:02:09 +08:00
|
|
|
|
//返回结果
|
|
|
|
|
|
resp.Code = proto.SuccessCode
|
|
|
|
|
|
resp.Message = "success"
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
resp.Code = proto.ParameterError
|
|
|
|
|
|
resp.Message = "参数错误"
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
resp.Code = proto.PermissionDenied
|
|
|
|
|
|
resp.Message = "没有权限"
|
|
|
|
|
|
}
|
|
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func ReceiveMessageForeignAI(c *gin.Context) {
|
|
|
|
|
|
id, _ := c.Get("user_id")
|
|
|
|
|
|
userID := id.(int)
|
|
|
|
|
|
user := service.GetUserByIDWithCache(userID)
|
|
|
|
|
|
var resp proto.GenerateResp
|
|
|
|
|
|
var req proto.ReceiveMessageForeignAIRRequest
|
|
|
|
|
|
if user.Role == "admin" {
|
|
|
|
|
|
//解析请求参数
|
|
|
|
|
|
if err := c.ShouldBind(&req); err == nil {
|
|
|
|
|
|
//校验参数
|
|
|
|
|
|
if req.IMParamContext.SessionID == 0 || req.ChatCompletionChunk.Choices == nil {
|
|
|
|
|
|
resp.Code = proto.ParameterError
|
|
|
|
|
|
resp.Message = "模型必要参数不能为空"
|
|
|
|
|
|
} else {
|
|
|
|
|
|
generalModel.OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(&req)
|
|
|
|
|
|
resp.Code = proto.SuccessCode
|
|
|
|
|
|
resp.Message = "success"
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
resp.Code = proto.ParameterError
|
|
|
|
|
|
resp.Message = "参数错误"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
resp.Code = proto.PermissionDenied
|
|
|
|
|
|
resp.Message = "没有权限"
|
|
|
|
|
|
}
|
|
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
|
|
|
|
}
|