saw-go/handler/im.go

172 lines
4.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"log"
"net/http"
"strconv"
"sync"
"time"
)
type SMessage struct {
Type int `json:"type" form:"type"`
Msg string `json:"msg" form:"msg"`
ToID int `json:"to_id" form:"to_id"`
SessionID int `json:"session_id" form:"session_id"`
}
var (
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// 允许所有来源的连接
return true
},
}
)
// 创建一个用于存储WebSocket连接的map和互斥锁
var (
clients = make(map[*websocket.Conn]bool)
clientsMux sync.Mutex
)
func SetUpIMGroup(router *gin.Engine) {
imGroup := router.Group("/im")
imGroup.POST("/send_message", SendMessage)
imGroup.GET("/ai_chat_ws", SRMessage)
}
func SendMessage(c *gin.Context) {
var req SMessage
id, _ := c.Get("id")
userID := int(id.(float64))
if err := c.ShouldBind(&req); err == nil {
var err2 error
var mid uint
err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
// ws长连接
func SRMessage(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
// 升级HTTP连接为WebSocket连接
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
clients[ws] = true
if err1 != nil {
log.Println("升级为WebSocket时发生错误:", err1)
return
}
subscribeAndHandleIMMessages(ws, userID)
}
func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
ctx := context.Background()
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg"
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
defer pubsub.Close()
defer ws.Close()
ch := pubsub.Channel()
//定时器每3秒发送一次心跳
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
var sessionID uint
//接收客户端消息,发送消息
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := ws.ReadMessage()
log.Println("rec message:", string(message))
if err != nil {
log.Println("Read error:", err)
}
var data proto.WSMessage
err1 := json.Unmarshal(message, &data)
if err1 != nil {
log.Println("Error parsing JSON:", err1)
}
if sessionID == 0 {
//创建会话
//会话名字为前10个字符如果
var sessionName string
if len(data.Msg) > 10 {
sessionName = data.Msg[:10]
} else {
sessionName = data.Msg
}
err, sessionID = service.CreateSession(userId, sessionName)
log.Println("create session id:", sessionID)
if err != nil {
log.Println("Create session error:", err)
return
}
}
//需要后续处理
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data)
if err != nil {
return
}
}
}()
go func() {
var data proto.WSMessage
//从redis订阅消息
for m := range ch {
msg := m.Payload // 获取消息,消息格式为json
if msg != "" {
data.Msg = msg
data.Type = "msg"
data.SessionID = int(sessionID)
data.ToID = 0
msg_, _ := json.Marshal(data)
err2 := ws.WriteMessage(websocket.TextMessage, msg_)
if err2 != nil {
// 发生错误,删除连接
clientsMux.Lock()
delete(clients, ws)
clientsMux.Unlock()
break
}
}
}
}()
//发送心跳包
for {
select {
case <-done:
//客户端断开连接
return
case t := <-ticker.C:
//发送心跳
err := ws.WriteMessage(websocket.PingMessage, nil)
if err != nil {
log.Println("Ping error:", err, t)
return
}
}
}
}