package handler import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/service" "StuAcaWorksAI/service/generalModel" "StuAcaWorksAI/worker" "context" "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/gorilla/websocket" "log" "net/http" "strconv" "sync" "time" ) type SMessage struct { Type int `json:"type" form:"type"` Msg string `json:"msg" form:"msg"` ToID int `json:"to_id" form:"to_id"` SessionID int `json:"session_id" form:"session_id"` } var ( upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { // 允许所有来源的连接 return true }, } ) // 创建一个用于存储WebSocket连接的map和互斥锁 var ( clients = make(map[*websocket.Conn]bool) clientsMux sync.Mutex ) func SetUpIMGroup(router *gin.Engine) { imGroup := router.Group("/im") imGroup.POST("/send_message", SendMessage) imGroup.GET("/ai_chat_ws", SRMessage) imGroup.POST("/get_message", GetMessage) imGroup.GET("/kbase_text_ws", KBaseTextWS) //添加国外ai处理接口,服务之间调用 imGroup.POST("/send_message_foreign_ai", SendMessageForeignAI) imGroup.POST("/receive_message_foreign_ai", ReceiveMessageForeignAI) } func GetMessage(c *gin.Context) { var req SMessage id, _ := c.Get("id") userID := int(id.(float64)) if err := c.ShouldBind(&req); err == nil { err2, messages := service.FindMessageBySessionID(req.SessionID, userID) if err2 == nil { c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": messages}) } else { c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"}) } } else { c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) } } func SendMessage(c *gin.Context) { var req SMessage id, _ := c.Get("id") userID := int(id.(float64)) if err := c.ShouldBind(&req); err == nil { var err2 error var mid uint err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg) if err2 == nil { c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid}) } else { c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"}) } } else { c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) } } // ws长连接 func SRMessage(c *gin.Context) { id, _ := c.Get("id") userID := int(id.(float64)) // 升级HTTP连接为WebSocket连接 ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil) clients[ws] = true if err1 != nil { log.Println("升级为WebSocket时发生错误:", err1) return } subscribeAndHandleIMMessages(ws, userID) } func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { ctx := context.Background() //随机字符串 strRandom := uuid.NewString() chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg_" + strRandom pubsub := worker.RedisClient.Subscribe(ctx, chanel) defer pubsub.Close() defer ws.Close() ch := pubsub.Channel() //定时器,每3秒发送一次心跳 //ticker := time.NewTicker(3 * time.Second) //defer ticker.Stop() var sessionID uint sessionID = 0 //接收客户端消息,发送消息 done := make(chan struct{}) go func() { defer close(done) for { _, message, err := ws.ReadMessage() log.Println("rec message:", string(message)) if err != nil { log.Println("Read error:", err) break } var data proto.WSMessageReq err1 := json.Unmarshal(message, &data) if err1 != nil { log.Println("Error parsing JSON:", err1) } err = doReceiveGenChatMessage(userId, &sessionID, &data, chanel) if err != nil { log.Println("do receive gen chat message error:", err) //返回错误消息 errMsg := proto.WSMessage{Code: proto.OperationFailed, ErrorMsg: "服务器处理错误", SessionID: int(sessionID)} errMsgStr, _ := json.Marshal(errMsg) worker.Publish(chanel, string(errMsgStr), time.Second*60) return } } }() go func() { //defer close(done) //从redis订阅消息 for m := range ch { msg := m.Payload // 获取消息,消息格式为json if msg != "" { err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg)) if err2 != nil { // 发生错误,删除连接 clientsMux.Lock() delete(clients, ws) clientsMux.Unlock() break } } } //查看done是否关闭 }() //发送心跳包 for { select { case <-done: //结束会话将context保存 ctx2, ok := service.SessionContextMap[int(sessionID)] if !ok { log.Println("session context not found") return } err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2) if err != nil { log.Println("update session context error:", err) } //保存的会话 delete(service.SessionContextMap, int(sessionID)) //客户端断开连接 return //case t := <-ticker.C: // //发送心跳 // err := ws.WriteMessage(websocket.PingMessage, nil) // if err != nil { // log.Println("Ping error:", err, t) // return // } //} } } } func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error { var model dao.Model //查看请求功能类型 //models, funcs := service.FindFuncModelByFunction(data.Function, userId) models, funcs, mferr := service.FindFuncModelListByFunctionV2(data.Function) if mferr != nil { return mferr } //log.Println("find function model by function:", models, funcs) //log.Println("funcs:", len(funcs), "\tmodels:", len(models)) if len(funcs) == 0 { return errors.New("function not exist") } if len(models) == 0 { //log.Println("doReceiveGenChatMessage models:", models) return errors.New("model not exist") } var err error sessionName := "" if data.IsImage || data.IsFile { var fileMsg proto.UserFileMessage err = json.Unmarshal([]byte(data.Msg), &fileMsg) if err != nil { log.Println("Error parsing JSON:", err) } sessionName = fileMsg.Text } else { sessionName = data.Msg } //前端传入的sessionID为0,则新会话 *sessionID = uint(data.SessionID) if *sessionID == 0 { //创建会话 if sessionName == "" { sessionName = "新会话" } sessionType := proto.SessionTypeUserWithModelGeneration //默认通用聊天 if data.Function == "gen-prompt" { sessionType = proto.SessionTypeUserPrompt //提示词会话 } err, *sessionID = service.CreateSession(userId, sessionType, sessionName) if err != nil { log.Println("Create session error:", err) } log.Println("create session id:", *sessionID) //会话名称过长,使用ai进行总结 if len(sessionName) > 30 && data.Function == "gen-chat" { //只有在通用聊天时才进行ai总结 //通过ai总结会话名称 go func() { sessionErr := service.UpdateSessionNameByAIAndID(int(*sessionID), userId, sessionName) if sessionErr != nil { log.Println("update ai session name error:", sessionErr) return } //更新会话名称 }() } } //获取模型 for _, v := range models { if v.ID == data.ModelID { model = v break } } if model.ID == 0 { //说明前端传入的模型id不正确,取第一个模型 model = models[0] } //解析模型参数 var modelParam proto.ModelParam err = json.Unmarshal([]byte(model.Parameter), &modelParam) if err != nil { return err } modelParam.Url = model.Url modelParam.System = funcs[0].Info //系统功能 //设置模型temperature及其他参数 checkAndSetModelTemperatureParam(&modelParam, data, &model) imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type} var userMsgID uint status := proto.MsgHasRead if data.IsImage || data.IsFile { status = proto.UserToModelImageMsgType //图片类型 } //将消息存入数据库 err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID)) log.Println("create user message id:", userMsgID) if err != nil { return err } //机器人id为2,消息类型为1,状态为1 imContext.QuestionMessageID = int(userMsgID) //若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成 if data.IsKBase { //根据会话ID获取对应知识库消息 kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId)) if err2 != nil { log.Println("find knowledge base by session id error:", err2) service.WSReturnErrorMessage(imContext.Channel, "find knowledge base error:"+err2.Error(), proto.WSKBaseSessionError) return err2 } data.KBaseID = kbaseMsg.ID err2 = service.SelectKBaseContentByQuery(&imContext, data) if err2 != nil { log.Println("select kbase content error:", err2) service.WSReturnErrorMessage(imContext.Channel, "select kbase error:"+err2.Error(), proto.WSKBaseServerError) return err2 } if data.KBaseType == "query" { //若为查询直接返回 return nil } } //需要后续处理 //err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam) err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext) return err } func KBaseTextWS(c *gin.Context) { id, _ := c.Get("id") userID := int(id.(float64)) //query参数 kBaseSercerID := c.Query("kbase_server_id") if kBaseSercerID == "" { c.JSON(http.StatusOK, gin.H{"code": proto.ParameterError, "message": "kbase_server_id不能为空"}) return } // 升级HTTP连接为WebSocket连接 ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil) if err1 != nil { log.Println("升级为WebSocket时发生错误:", err1) return } subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID) } func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseServerID string) { ctx := context.Background() chanel := fmt.Sprintf("kbase_server_id_%s", kBaseServerID) // 频道名称 //设置KBaseServerID可用状态,加入可用集合 setKey := "kbase_server_id_set_is_available" worker.SetRedisSetAdd(setKey, kBaseServerID) pubsub := worker.RedisClient.Subscribe(ctx, chanel) defer func() { //设置KBaseServerID不可用状态,删除可用集合 worker.SetRedisSetRemove(setKey, kBaseServerID) pubsub.Close() ws.Close() }() ch := pubsub.Channel() //接收客户端消息,发送消息 done := make(chan struct{}) go func() { defer close(done) for { _, message, err := ws.ReadMessage() var data proto.KnowledgeBaseServerResponse //log.Println("rec message:", string(message)) if err != nil { log.Println("Read error:", err) break } err1 := json.Unmarshal(message, &data) if err1 != nil { log.Println("Error parsing JSON:", err1) continue } service.CreateAIStreamMsg(userId, data.IMContext.ModelID, data.IMContext.SessionID, data, proto.KnowledgeBaseServerResponseType) //将消息存入数据库 //这里主要将接收到的消息转到对应协程 //目标chanel targetChanel := fmt.Sprintf("kbase_%s", data.KnowledgeBase.UUID) //将消息转入 err = worker.RedisClient.Publish(ctx, targetChanel, message).Err() if err != nil { log.Println("KBase Text WS Publish error:", err) } } }() go func() { //defer close(done) //从redis订阅消息 for m := range ch { msg := m.Payload // 获取消息,消息格式为json if msg != "" { err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg)) if err2 != nil { break } } } }() for { select { case <-done: return } } } func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto.WSMessageReq, model *dao.Model) { isErrParam := false if data.Temperature < 0 || data.Temperature > 2 { isErrParam = true } else { modelParam.Temperature = data.Temperature } if model.Type == proto.ModelTypeDouBao { if data.TopP < 0 || data.TopP > 1 { isErrParam = true } else { modelParam.TopP = data.TopP } } else if model.Type == proto.ModelTypeSpark { if data.TopK < 0 || data.TopK > 1 { isErrParam = true } else { modelParam.TopK = data.TopK } } else { isErrParam = true } if isErrParam { modelParam.SetDefaultParams() } } func SendMessageForeignAI(c *gin.Context) { id, _ := c.Get("user_id") userID := id.(int) user := service.GetUserByIDWithCache(userID) var resp proto.GenerateResp var req proto.SendMessageForeignAIRRequest if user.Role == "admin" { //解析请求参数 if err := c.ShouldBind(&req); err == nil { reqBytes, _ := json.Marshal(req) log.Println("SendMessageForeignAI req:", string(reqBytes)) //校验参数 if req.ModelParam.Model == "" || req.ModelParam.Url == "" || req.ModelParam.APIKey == "" { resp.Code = proto.ParameterError resp.Message = "模型必要参数不能为空" } else { //处理 generalModel.OpenAIForeignAI(req.ModelParam, &req.IMParamContext, req.PlatForm, req.ChatCompletionMessageParamUnion) //返回结果 resp.Code = proto.SuccessCode resp.Message = "success" } } else { resp.Code = proto.ParameterError resp.Message = "参数错误" } } else { resp.Code = proto.PermissionDenied resp.Message = "没有权限" } c.JSON(http.StatusOK, resp) } func ReceiveMessageForeignAI(c *gin.Context) { id, _ := c.Get("user_id") userID := id.(int) user := service.GetUserByIDWithCache(userID) var resp proto.GenerateResp var req proto.ReceiveMessageForeignAIRRequest if user.Role == "admin" { //解析请求参数 if err := c.ShouldBind(&req); err == nil { //校验参数 if req.IMParamContext.SessionID == 0 || req.ChatCompletionChunk.Choices == nil { resp.Code = proto.ParameterError resp.Message = "模型必要参数不能为空" } else { generalModel.OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(&req) resp.Code = proto.SuccessCode resp.Message = "success" } } else { resp.Code = proto.ParameterError resp.Message = "参数错误" } } else { resp.Code = proto.PermissionDenied resp.Message = "没有权限" } c.JSON(http.StatusOK, resp) }