2025-03-21 14:44:40 +08:00
|
|
|
|
package handler
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"StuAcaWorksAI/proto"
|
|
|
|
|
|
"StuAcaWorksAI/service"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"StuAcaWorksAI/worker"
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"encoding/json"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
"github.com/gin-gonic/gin"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
|
|
"log"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
"net/http"
|
2025-03-22 15:39:13 +08:00
|
|
|
|
"strconv"
|
|
|
|
|
|
"sync"
|
|
|
|
|
|
"time"
|
2025-03-21 14:44:40 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type SMessage struct {
|
|
|
|
|
|
Type int `json:"type" form:"type"`
|
|
|
|
|
|
Msg string `json:"msg" form:"msg"`
|
|
|
|
|
|
ToID int `json:"to_id" form:"to_id"`
|
|
|
|
|
|
SessionID int `json:"session_id" form:"session_id"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-22 15:39:13 +08:00
|
|
|
|
var (
|
|
|
|
|
|
upgrader = websocket.Upgrader{
|
|
|
|
|
|
ReadBufferSize: 1024,
|
|
|
|
|
|
WriteBufferSize: 1024,
|
|
|
|
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
|
|
|
|
// 允许所有来源的连接
|
|
|
|
|
|
return true
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// 创建一个用于存储WebSocket连接的map和互斥锁
|
|
|
|
|
|
var (
|
|
|
|
|
|
clients = make(map[*websocket.Conn]bool)
|
|
|
|
|
|
clientsMux sync.Mutex
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-03-21 14:44:40 +08:00
|
|
|
|
func SetUpIMGroup(router *gin.Engine) {
|
|
|
|
|
|
imGroup := router.Group("/im")
|
|
|
|
|
|
imGroup.POST("/send_message", SendMessage)
|
2025-03-22 15:39:13 +08:00
|
|
|
|
imGroup.GET("/ai_chat_ws", SRMessage)
|
2025-03-21 14:44:40 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func SendMessage(c *gin.Context) {
|
|
|
|
|
|
var req SMessage
|
|
|
|
|
|
id, _ := c.Get("id")
|
|
|
|
|
|
userID := int(id.(float64))
|
|
|
|
|
|
if err := c.ShouldBind(&req); err == nil {
|
|
|
|
|
|
var err2 error
|
|
|
|
|
|
var mid uint
|
|
|
|
|
|
err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg)
|
|
|
|
|
|
if err2 == nil {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid})
|
|
|
|
|
|
} else {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-22 15:39:13 +08:00
|
|
|
|
// ws长连接
|
|
|
|
|
|
func SRMessage(c *gin.Context) {
|
|
|
|
|
|
id, _ := c.Get("id")
|
|
|
|
|
|
userID := int(id.(float64))
|
|
|
|
|
|
// 升级HTTP连接为WebSocket连接
|
|
|
|
|
|
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
|
|
|
|
clients[ws] = true
|
|
|
|
|
|
if err1 != nil {
|
|
|
|
|
|
log.Println("升级为WebSocket时发生错误:", err1)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
subscribeAndHandleIMMessages(ws, userID)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
|
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg"
|
|
|
|
|
|
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
|
|
|
|
|
|
defer pubsub.Close()
|
|
|
|
|
|
defer ws.Close()
|
|
|
|
|
|
ch := pubsub.Channel()
|
|
|
|
|
|
//定时器,每3秒发送一次心跳
|
|
|
|
|
|
ticker := time.NewTicker(3 * time.Second)
|
|
|
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
|
|
|
|
|
|
var sessionID uint
|
|
|
|
|
|
|
|
|
|
|
|
//接收客户端消息,发送消息
|
|
|
|
|
|
done := make(chan struct{})
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
defer close(done)
|
|
|
|
|
|
for {
|
|
|
|
|
|
_, message, err := ws.ReadMessage()
|
2025-03-22 16:26:09 +08:00
|
|
|
|
log.Println("rec message:", string(message))
|
2025-03-22 15:39:13 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("Read error:", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
var data proto.WSMessage
|
|
|
|
|
|
err1 := json.Unmarshal(message, &data)
|
|
|
|
|
|
if err1 != nil {
|
|
|
|
|
|
log.Println("Error parsing JSON:", err1)
|
|
|
|
|
|
}
|
|
|
|
|
|
if sessionID == 0 {
|
|
|
|
|
|
//创建会话
|
2025-03-22 16:20:20 +08:00
|
|
|
|
//会话名字为前10个字符,如果
|
|
|
|
|
|
var sessionName string
|
|
|
|
|
|
if len(data.Msg) > 10 {
|
|
|
|
|
|
sessionName = data.Msg[:10]
|
|
|
|
|
|
} else {
|
|
|
|
|
|
sessionName = data.Msg
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-03-22 15:39:13 +08:00
|
|
|
|
err, sessionID = service.CreateSession(userId, sessionName)
|
2025-03-22 16:26:09 +08:00
|
|
|
|
log.Println("create session id:", sessionID)
|
2025-03-22 15:39:13 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("Create session error:", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
//需要后续处理
|
|
|
|
|
|
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
var data proto.WSMessage
|
|
|
|
|
|
//从redis订阅消息
|
|
|
|
|
|
for m := range ch {
|
|
|
|
|
|
msg := m.Payload // 获取消息,消息格式为json
|
|
|
|
|
|
if msg != "" {
|
|
|
|
|
|
data.Msg = msg
|
|
|
|
|
|
data.Type = "msg"
|
|
|
|
|
|
data.SessionID = int(sessionID)
|
|
|
|
|
|
data.ToID = 0
|
|
|
|
|
|
msg_, _ := json.Marshal(data)
|
|
|
|
|
|
err2 := ws.WriteMessage(websocket.TextMessage, msg_)
|
|
|
|
|
|
if err2 != nil {
|
|
|
|
|
|
// 发生错误,删除连接
|
|
|
|
|
|
clientsMux.Lock()
|
|
|
|
|
|
delete(clients, ws)
|
|
|
|
|
|
clientsMux.Unlock()
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
//发送心跳包
|
|
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-done:
|
|
|
|
|
|
//客户端断开连接
|
|
|
|
|
|
return
|
|
|
|
|
|
case t := <-ticker.C:
|
|
|
|
|
|
//发送心跳
|
|
|
|
|
|
err := ws.WriteMessage(websocket.PingMessage, nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Println("Ping error:", err, t)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|