saw-go/handler/im.go

256 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"errors"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"log"
"net/http"
"strconv"
"sync"
)
type SMessage struct {
Type int `json:"type" form:"type"`
Msg string `json:"msg" form:"msg"`
ToID int `json:"to_id" form:"to_id"`
SessionID int `json:"session_id" form:"session_id"`
}
var (
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// 允许所有来源的连接
return true
},
}
)
// 创建一个用于存储WebSocket连接的map和互斥锁
var (
clients = make(map[*websocket.Conn]bool)
clientsMux sync.Mutex
)
func SetUpIMGroup(router *gin.Engine) {
imGroup := router.Group("/im")
imGroup.POST("/send_message", SendMessage)
imGroup.GET("/ai_chat_ws", SRMessage)
imGroup.POST("/get_message", GetMessage)
}
func GetMessage(c *gin.Context) {
var req SMessage
id, _ := c.Get("id")
userID := int(id.(float64))
if err := c.ShouldBind(&req); err == nil {
err2, messages := service.FindMessageBySessionID(req.SessionID, userID)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": messages})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
func SendMessage(c *gin.Context) {
var req SMessage
id, _ := c.Get("id")
userID := int(id.(float64))
if err := c.ShouldBind(&req); err == nil {
var err2 error
var mid uint
err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
// ws长连接
func SRMessage(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
// 升级HTTP连接为WebSocket连接
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
clients[ws] = true
if err1 != nil {
log.Println("升级为WebSocket时发生错误:", err1)
return
}
subscribeAndHandleIMMessages(ws, userID)
}
func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
ctx := context.Background()
//随机字符串
strRandom := uuid.NewString()
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg_" + strRandom
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
defer pubsub.Close()
defer ws.Close()
ch := pubsub.Channel()
//定时器每3秒发送一次心跳
//ticker := time.NewTicker(3 * time.Second)
//defer ticker.Stop()
var sessionID uint
sessionID = 0
//接收客户端消息,发送消息
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := ws.ReadMessage()
log.Println("rec message:", string(message))
if err != nil {
log.Println("Read error:", err)
break
}
var data proto.WSMessageReq
err1 := json.Unmarshal(message, &data)
if err1 != nil {
log.Println("Error parsing JSON:", err1)
}
err = doReceiveGenChatMessage(userId, &sessionID, &data, chanel)
if err != nil {
log.Println("do receive gen chat message error:", err)
return
}
}
}()
go func() {
//defer close(done)
//从redis订阅消息
for m := range ch {
msg := m.Payload // 获取消息,消息格式为json
if msg != "" {
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
if err2 != nil {
// 发生错误,删除连接
clientsMux.Lock()
delete(clients, ws)
clientsMux.Unlock()
break
}
}
}
//查看done是否关闭
}()
//发送心跳包
for {
select {
case <-done:
//结束会话将context保存
ctx2, ok := service.SessionContextMap[int(sessionID)]
if !ok {
log.Println("session context not found")
return
}
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
if err != nil {
log.Println("update session context error:", err)
}
//保存的会话
delete(service.SessionContextMap, int(sessionID))
//客户端断开连接
return
//case t := <-ticker.C:
// //发送心跳
// err := ws.WriteMessage(websocket.PingMessage, nil)
// if err != nil {
// log.Println("Ping error:", err, t)
// return
// }
//}
}
}
}
func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error {
var model dao.Model
//查看请求功能类型
//models, funcs := service.FindFuncModelByFunction(data.Function, userId)
models, funcs, mferr := service.FindFuncModelListByFunctionV2(data.Function)
if mferr != nil {
return mferr
}
//log.Println("find function model by function:", models, funcs)
//log.Println("funcs:", len(funcs), "\tmodels:", len(models))
if len(funcs) == 0 {
return errors.New("function not exist")
}
if len(models) == 0 {
//log.Println("doReceiveGenChatMessage models:", models)
return errors.New("model not exist")
}
var err error
//前端传入的sessionID为0则新会话
*sessionID = uint(data.SessionID)
if *sessionID == 0 {
//创建会话
err, *sessionID = service.CreateSession(userId, data.Msg)
log.Println("create session id:", *sessionID)
if err != nil {
log.Println("Create session error:", err)
}
}
//获取模型
for _, v := range models {
if v.ID == data.ModelID {
model = v
break
}
}
if model.ID == 0 {
//说明前端传入的模型id不正确取第一个模型
model = models[0]
}
//解析模型参数
var modelParam proto.ModelParam
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
if err != nil {
return err
}
modelParam.Url = model.Url
modelParam.System = funcs[0].Info //系统功能
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type}
var userMsgID uint
status := proto.MsgHasRead
if data.IsImage {
status = proto.UserToModelImageMsgType //图片类型
}
//将消息存入数据库
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
log.Println("create user message id:", userMsgID)
if err != nil {
return err
} //机器人id为2,消息类型为1,状态为1
//需要后续处理
//err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam)
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
return err
}