package handler import ( "StuAcaWorksAI/proto" "StuAcaWorksAI/service" "StuAcaWorksAI/worker" "context" "encoding/json" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "log" "net/http" "strconv" "sync" "time" ) type SMessage struct { Type int `json:"type" form:"type"` Msg string `json:"msg" form:"msg"` ToID int `json:"to_id" form:"to_id"` SessionID int `json:"session_id" form:"session_id"` } var ( upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { // 允许所有来源的连接 return true }, } ) // 创建一个用于存储WebSocket连接的map和互斥锁 var ( clients = make(map[*websocket.Conn]bool) clientsMux sync.Mutex ) func SetUpIMGroup(router *gin.Engine) { imGroup := router.Group("/im") imGroup.POST("/send_message", SendMessage) imGroup.GET("/ai_chat_ws", SRMessage) } func SendMessage(c *gin.Context) { var req SMessage id, _ := c.Get("id") userID := int(id.(float64)) if err := c.ShouldBind(&req); err == nil { var err2 error var mid uint err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg) if err2 == nil { c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid}) } else { c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"}) } } else { c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) } } // ws长连接 func SRMessage(c *gin.Context) { id, _ := c.Get("id") userID := int(id.(float64)) // 升级HTTP连接为WebSocket连接 ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil) clients[ws] = true if err1 != nil { log.Println("升级为WebSocket时发生错误:", err1) return } subscribeAndHandleIMMessages(ws, userID) } func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { ctx := context.Background() chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg" pubsub := worker.RedisClient.Subscribe(ctx, chanel) defer pubsub.Close() defer ws.Close() ch := pubsub.Channel() //定时器,每3秒发送一次心跳 ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() var sessionID uint //接收客户端消息,发送消息 done := make(chan struct{}) go func() { defer close(done) for { _, message, err := ws.ReadMessage() log.Println("rec message:", string(message)) if err != nil { log.Println("Read error:", err) } var data proto.WSMessage err1 := json.Unmarshal(message, &data) if err1 != nil { log.Println("Error parsing JSON:", err1) } if sessionID == 0 { //创建会话 //会话名字为前10个字符,如果 var sessionName string if len(data.Msg) > 10 { sessionName = data.Msg[:10] } else { sessionName = data.Msg } err, sessionID = service.CreateSession(userId, sessionName) log.Println("create session id:", sessionID) if err != nil { log.Println("Create session error:", err) return } } //需要后续处理 err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data) if err != nil { return } } }() go func() { var data proto.WSMessage //从redis订阅消息 for m := range ch { msg := m.Payload // 获取消息,消息格式为json if msg != "" { data.Msg = msg data.Type = "msg" data.SessionID = int(sessionID) data.ToID = 0 msg_, _ := json.Marshal(data) err2 := ws.WriteMessage(websocket.TextMessage, msg_) if err2 != nil { // 发生错误,删除连接 clientsMux.Lock() delete(clients, ws) clientsMux.Unlock() break } } } }() //发送心跳包 for { select { case <-done: //客户端断开连接 return case t := <-ticker.C: //发送心跳 err := ws.WriteMessage(websocket.PingMessage, nil) if err != nil { log.Println("Ping error:", err, t) return } } } }