saw-go/handler/im.go

275 lines
7.4 KiB
Go
Raw Normal View History

package handler
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"errors"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"log"
"net/http"
"strconv"
"sync"
)
type SMessage struct {
Type int `json:"type" form:"type"`
Msg string `json:"msg" form:"msg"`
ToID int `json:"to_id" form:"to_id"`
SessionID int `json:"session_id" form:"session_id"`
}
var (
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// 允许所有来源的连接
return true
},
}
)
// 创建一个用于存储WebSocket连接的map和互斥锁
var (
clients = make(map[*websocket.Conn]bool)
clientsMux sync.Mutex
)
func SetUpIMGroup(router *gin.Engine) {
imGroup := router.Group("/im")
imGroup.POST("/send_message", SendMessage)
imGroup.GET("/ai_chat_ws", SRMessage)
imGroup.POST("/get_message", GetMessage)
}
func GetMessage(c *gin.Context) {
var req SMessage
id, _ := c.Get("id")
userID := int(id.(float64))
if err := c.ShouldBind(&req); err == nil {
err2, messages := service.FindMessageBySessionID(req.SessionID, userID)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": messages})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
func SendMessage(c *gin.Context) {
var req SMessage
id, _ := c.Get("id")
userID := int(id.(float64))
if err := c.ShouldBind(&req); err == nil {
var err2 error
var mid uint
err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
// ws长连接
func SRMessage(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
// 升级HTTP连接为WebSocket连接
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
clients[ws] = true
if err1 != nil {
log.Println("升级为WebSocket时发生错误:", err1)
return
}
subscribeAndHandleIMMessages(ws, userID)
}
func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
ctx := context.Background()
//随机字符串
strRandom := uuid.NewString()
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg_" + strRandom
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
defer pubsub.Close()
defer ws.Close()
ch := pubsub.Channel()
//定时器每3秒发送一次心跳
2025-03-22 17:38:44 +08:00
//ticker := time.NewTicker(3 * time.Second)
//defer ticker.Stop()
var sessionID uint
sessionID = 0
//接收客户端消息,发送消息
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := ws.ReadMessage()
2025-03-22 16:26:09 +08:00
log.Println("rec message:", string(message))
if err != nil {
log.Println("Read error:", err)
2025-03-22 17:00:43 +08:00
break
}
var data proto.WSMessageReq
err1 := json.Unmarshal(message, &data)
if err1 != nil {
log.Println("Error parsing JSON:", err1)
}
err = doReceiveGenChatMessage(userId, &sessionID, &data, chanel)
if err != nil {
log.Println("do receive gen chat message error:", err)
return
}
}
}()
go func() {
//defer close(done)
//从redis订阅消息
for m := range ch {
msg := m.Payload // 获取消息,消息格式为json
if msg != "" {
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
if err2 != nil {
// 发生错误,删除连接
clientsMux.Lock()
delete(clients, ws)
clientsMux.Unlock()
break
}
}
}
//查看done是否关闭
}()
//发送心跳包
2025-03-22 17:42:57 +08:00
for {
select {
case <-done:
2025-03-23 13:17:03 +08:00
//结束会话将context保存
ctx2, ok := service.SessionContextMap[int(sessionID)]
if !ok {
log.Println("session context not found")
return
}
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
if err != nil {
log.Println("update session context error:", err)
}
//保存的会话
delete(service.SessionContextMap, int(sessionID))
2025-03-22 17:42:57 +08:00
//客户端断开连接
return
//case t := <-ticker.C:
// //发送心跳
// err := ws.WriteMessage(websocket.PingMessage, nil)
// if err != nil {
// log.Println("Ping error:", err, t)
// return
// }
//}
}
}
2025-03-23 13:17:03 +08:00
}
func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error {
var model dao.Model
//查看请求功能类型
//models, funcs := service.FindFuncModelByFunction(data.Function, userId)
models, funcs, mferr := service.FindFuncModelListByFunctionV2(data.Function)
if mferr != nil {
return mferr
}
2025-03-28 15:33:01 +08:00
//log.Println("find function model by function:", models, funcs)
//log.Println("funcs:", len(funcs), "\tmodels:", len(models))
if len(funcs) == 0 {
return errors.New("function not exist")
}
if len(models) == 0 {
2025-03-28 15:33:01 +08:00
//log.Println("doReceiveGenChatMessage models:", models)
return errors.New("model not exist")
}
2025-03-28 15:33:01 +08:00
var err error
2025-04-01 15:00:59 +08:00
sessionName := ""
if data.IsImage || data.IsFile {
var fileMsg proto.UserFileMessage
err = json.Unmarshal([]byte(data.Msg), &fileMsg)
2025-04-01 15:00:59 +08:00
if err != nil {
log.Println("Error parsing JSON:", err)
}
sessionName = fileMsg.Text
2025-04-01 15:00:59 +08:00
} else {
sessionName = data.Msg
}
2025-03-28 17:11:49 +08:00
//前端传入的sessionID为0则新会话
*sessionID = uint(data.SessionID)
2025-03-28 15:33:01 +08:00
if *sessionID == 0 {
//创建会话
if sessionName == "" {
sessionName = "新会话"
}
sessionType := proto.SessionTypeUserWithModelGeneration //默认通用聊天
if data.Function == "gen-prompt" {
sessionType = proto.SessionTypeUserPrompt //提示词会话
}
err, *sessionID = service.CreateSession(userId, sessionType, sessionName)
2025-03-28 17:01:48 +08:00
log.Println("create session id:", *sessionID)
2025-03-28 15:33:01 +08:00
if err != nil {
log.Println("Create session error:", err)
}
}
//获取模型
for _, v := range models {
if v.ID == data.ModelID {
model = v
break
}
}
if model.ID == 0 {
//说明前端传入的模型id不正确取第一个模型
model = models[0]
}
//解析模型参数
var modelParam proto.ModelParam
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
if err != nil {
return err
}
modelParam.Url = model.Url
modelParam.System = funcs[0].Info //系统功能
2025-03-28 15:33:01 +08:00
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type}
var userMsgID uint
2025-03-31 19:30:46 +08:00
status := proto.MsgHasRead
if data.IsImage || data.IsFile {
2025-03-31 19:30:46 +08:00
status = proto.UserToModelImageMsgType //图片类型
}
//将消息存入数据库
2025-03-31 19:30:46 +08:00
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
log.Println("create user message id:", userMsgID)
if err != nil {
return err
} //机器人id为2,消息类型为1,状态为1
//需要后续处理
//err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam)
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
return err
}