From 98239a05d6ac06593c5cf00d8bb7d4e76284713c Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Fri, 21 Mar 2025 14:44:40 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=A8=A1=E5=9E=8B=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E4=BC=9A=E8=AF=9D=E5=8F=8A?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E4=BA=A4=E6=B5=81=E9=83=A8=E5=88=86=E5=AE=8C?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/db.go | 18 ++++++++ dao/im.go | 99 +++++++++++++++++++++++++++++++++++++++++ dao/model.go | 71 +++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + handler/im.go | 40 +++++++++++++++++ handler/model.go | 98 ++++++++++++++++++++++++++++++++++++++++ proto/status.go | 5 +++ service/imService.go | 52 ++++++++++++++++++++++ service/modelService.go | 49 ++++++++++++++++++++ 10 files changed, 435 insertions(+) create mode 100644 dao/im.go create mode 100644 dao/model.go create mode 100644 handler/im.go create mode 100644 handler/model.go create mode 100644 service/imService.go create mode 100644 service/modelService.go diff --git a/dao/db.go b/dao/db.go index 51c5b26..008a88e 100644 --- a/dao/db.go +++ b/dao/db.go @@ -53,6 +53,24 @@ func Init() error { fmt.Println("file auth table:", err) return err } + + err = db.AutoMigrate(&Model{}) + if err != nil { + fmt.Println("model table:", err) + return err + } + + err = db.AutoMigrate(&Session{}) + if err != nil { + fmt.Println("session table:", err) + return err + } + + err = db.AutoMigrate(&Message{}) + if err != nil { + fmt.Println("message table:", err) + return err + } DB = db return err } diff --git a/dao/im.go b/dao/im.go new file mode 100644 index 0000000..663e4f9 --- /dev/null +++ b/dao/im.go @@ -0,0 +1,99 @@ +package dao + +import ( + "StuAcaWorksAI/proto" + "gorm.io/gorm" +) + +type Session struct { + gorm.Model + UserID int `gorm:"column:user_id"` //只能由用户创建 + Name string `gorm:"column:name"` +} + +type Message struct { + gorm.Model + Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息 + SessionID int `gorm:"column:session_id"` + FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型 + ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断 + Msg string `gorm:"column:msg"` //消息内容 + Status int `gorm:"column:status"` //0为未读,1为已读 +} + +func CreateSession(userID int, name string) (error, uint) { + session := Session{UserID: userID, Name: name} + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Create(&session) + } else { + res = DB.Create(&session) + } + return res.Error, session.ID +} + +func FindSessionByID(id int) Session { + var session Session + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("id = ?", id).First(&session) + } else { + DB.Where("id = ?", id).First(&session) + } + return session +} + +// 获取用户最新的创建时间会话列表 50个 +func FindSessionByUserID(userID int) []Session { + var sessions []Session + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("user_id = ?", userID).Order("created_at DESC").Limit(50).Find(&sessions) + } else { + DB.Where("user_id = ?", userID).Order("created_at DESC").Limit(50).Find(&sessions) + } + return sessions +} + +// 更新会话的名字 +func UpdateSessionByID(id int, userId int, name string) error { + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name) + } else { + res = DB.Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name) + } + return res.Error +} + +// 删除会话 +func DeleteSessionByID(id int, userId int) error { + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Delete(&Session{}, id) + } else { + res = DB.Delete(&Session{}, id) + } + return res.Error +} + +// 创建消息 +func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string) (error, uint) { + message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg} + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Create(&message) + } else { + res = DB.Create(&message) + } + return res.Error, message.ID +} + +// 根据会话id获取消息 +func FindMessageBySessionID(sessionID int) []Message { + var messages []Message + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages) + } else { + DB.Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages) + } + return messages +} diff --git a/dao/model.go b/dao/model.go new file mode 100644 index 0000000..ef680e7 --- /dev/null +++ b/dao/model.go @@ -0,0 +1,71 @@ +package dao + +import ( + "StuAcaWorksAI/proto" + "gorm.io/gorm" +) + +type Model struct { + gorm.Model + Type string `gorm:"column:type"` //模型类型 + UserID uint `gorm:"column:user_id"` //用户id + Url string `gorm:"column:url"` //模型地址 + Parameter string `gorm:"column:parameter"` //模型参数,存储json +} + +// 创建模型 +func CreateModel(userID uint, modelType, url, parameter string) (error, uint) { + model := Model{UserID: userID, Type: modelType, Url: url, Parameter: parameter} + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Create(&model) + } else { + res = DB.Create(&model) + } + return res.Error, model.ID +} + +// 根据id查找模型 +func FindModelByID(id, userID int) []Model { + var model []Model + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("id = ? and user_id = ?", id, userID).Find(&model) + } else { + DB.Where("id = ? and user_id = ?", id, userID).Find(&model) + } + return model +} + +// 根据用户id查找模型 +func FindModelByUserID(userID int) []Model { + var models []Model + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("user_id = ?", userID).Find(&models) + } else { + DB.Where("user_id = ?", userID).Find(&models) + } + return models +} + +// 根据id删除模型 +func DeleteModelByID(id int) error { + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Delete(&Model{}, id) + } else { + res = DB.Delete(&Model{}, id) + } + return res.Error +} + +// 根据id更新模型 +func UpdateModelByID(id int, userID uint, modelType, url, parameter string) error { + model := Model{UserID: userID, Type: modelType, Url: url, Parameter: parameter} + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Model(&Model{}).Where("id = ?", id).Updates(&model) + } else { + res = DB.Model(&Model{}).Where("id = ?", id).Updates(&model) + } + return res.Error +} diff --git a/go.mod b/go.mod index 1348be1..f9c02ad 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-redis/redis/v8 v8.11.5 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 github.com/robfig/cron/v3 v3.0.1 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.11 diff --git a/go.sum b/go.sum index fbf8a54..7fc7e68 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= diff --git a/handler/im.go b/handler/im.go new file mode 100644 index 0000000..5b7b17f --- /dev/null +++ b/handler/im.go @@ -0,0 +1,40 @@ +package handler + +import ( + "StuAcaWorksAI/proto" + "StuAcaWorksAI/service" + "github.com/gin-gonic/gin" + "net/http" +) + +type SMessage struct { + Type int `json:"type" form:"type"` + Msg string `json:"msg" form:"msg"` + ToID int `json:"to_id" form:"to_id"` + SessionID int `json:"session_id" form:"session_id"` +} + +func SetUpIMGroup(router *gin.Engine) { + imGroup := router.Group("/im") + imGroup.POST("/send_message", SendMessage) +} + +func SendMessage(c *gin.Context) { + var req SMessage + id, _ := c.Get("id") + userID := int(id.(float64)) + if err := c.ShouldBind(&req); err == nil { + var err2 error + var mid uint + err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg) + if err2 == nil { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid}) + } else { + c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} + +//ws长连接 diff --git a/handler/model.go b/handler/model.go new file mode 100644 index 0000000..8b5156f --- /dev/null +++ b/handler/model.go @@ -0,0 +1,98 @@ +package handler + +import ( + "StuAcaWorksAI/proto" + "StuAcaWorksAI/service" + "github.com/gin-gonic/gin" + "net/http" +) + +func SetUpModelGroup(router *gin.Engine) { + modelGroup := router.Group("/model") + modelGroup.POST("/create", CreateModel) + modelGroup.POST("/find", FindModel) + modelGroup.DELETE("/delete", DeleteModel) + modelGroup.PUT("/update", UpdateModel) +} + +type Model struct { + ID int `json:"id" form:"id"` + UserID int `json:"user_id" form:"user_id"` + Type string `json:"type" form:"type"` + Url string `json:"url" form:"url"` + Parameter string `json:"parameter" form:"parameter"` +} + +func CreateModel(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + var req Model + if err := c.ShouldBind(&req); err == nil { + // 创建模型信息 + err2, mid := service.CreateModel(uint(userID), req.Type, req.Url, req.Parameter) + if err2 == nil { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid}) + } else { + c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.ModelCreateFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} + +func FindModel(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + var req Model + if err := c.ShouldBind(&req); err == nil { + if req.Type == "ID" { + // 根据id查找模型 + model := service.FindModelByID(req.ID, userID) + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": model}) + } else if req.Type == "UserID" { + // 根据用户id查找模型 + models := service.FindModelByUserID(userID) + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": models}) + } else { + c.JSON(http.StatusOK, gin.H{"error": "type error", "code": proto.ModelSearchFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} + +// 修改 +func UpdateModel(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + var req Model + if err := c.ShouldBind(&req); err == nil { + // 更新模型信息 + err2 := service.UpdateModelByID(req.ID, uint(userID), req.Type, req.Url, req.Parameter) + if err2 == nil { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"}) + } else { + c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.ModelUpdateFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} + +// 删除 +func DeleteModel(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + var req Model + if err := c.ShouldBind(&req); err == nil { + // 删除模型 + err2 := service.DeleteModelByID(req.ID, userID) + if err2 == nil { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"}) + } else { + c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.ModelDeleteFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} diff --git a/proto/status.go b/proto/status.go index a4e005a..8af508b 100644 --- a/proto/status.go +++ b/proto/status.go @@ -73,4 +73,9 @@ const ( ShellUpdateFailed = 101 // 更新shell失败 ShellDeleteFailed = 102 // 删除shell失败 ShellSearchFailed = 103 // 获取shell失败 + + ModelCreateFailed = 110 // 创建模型失败 + ModelUpdateFailed = 111 // 更新模型失败 + ModelDeleteFailed = 112 // 删除模型失败 + ModelSearchFailed = 113 // 获取模型失败 ) diff --git a/service/imService.go b/service/imService.go new file mode 100644 index 0000000..bebb99e --- /dev/null +++ b/service/imService.go @@ -0,0 +1,52 @@ +package service + +import ( + "StuAcaWorksAI/dao" + "errors" +) + +// 创建会话服务 +func CreateSession(userID int, name string) (error, uint) { + err, id := dao.CreateSession(userID, name) + return err, id +} + +// 查找用户的会话列表 +func FindSessionByUserID(userID int) []dao.Session { + sessions := dao.FindSessionByUserID(userID) + return sessions +} + +// 更新会话名字 +func UpdateSessionByID(id int, userId int, name string) error { + err := dao.UpdateSessionByID(id, userId, name) + return err +} + +// 创建消息服务 +func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int) (error, uint) { + //sessionID, fromID, toID, msgType, status int, msg string + err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg) + return err, id +} + +// 通过会话id查询消息 +func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) { + //查看会话是否属于该用户 + session := dao.FindSessionByID(sessionID) + if session.UserID != userID { + return errors.New("session not belong to user"), nil + } + msgs := dao.FindMessageBySessionID(sessionID) + return nil, msgs +} + +func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) { + //查看会话是否属于该用户 + session := dao.FindSessionByID(sessionID) + if session.UserID != fromID { + return errors.New("session not belong to user"), 0 + } + err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0) + return err, mid +} diff --git a/service/modelService.go b/service/modelService.go new file mode 100644 index 0000000..72af7bf --- /dev/null +++ b/service/modelService.go @@ -0,0 +1,49 @@ +package service + +import ( + "StuAcaWorksAI/dao" + "errors" +) + +// 创建模型信息 +func CreateModel(userID uint, modelType, url, parameter string) (error, uint) { + //查看用户是否有权限创建模型 + user := GetUserByIDWithCache(int(userID)) + if user.ID == 0 { + return errors.New("user not exist"), 0 + } + if user.Role != "admin" { + return errors.New("user not admin,no permission"), 0 + } + return dao.CreateModel(userID, modelType, url, parameter) +} + +// 根据id查找模型 +func FindModelByID(id, userID int) []dao.Model { + return dao.FindModelByID(id, userID) +} + +// 根据用户id查找模型 +func FindModelByUserID(userID int) []dao.Model { + return dao.FindModelByUserID(userID) +} + +// 根据id删除模型 +func DeleteModelByID(id, userID int) error { + //查看用户是否有权限删除模型 + user := GetUserByIDWithCache(userID) + model := dao.FindModelByID(id, userID) + if user.ID == 0 { + return errors.New("user not exist") + } + if len(model) == 0 && user.Role != "admin" { + return errors.New("model not exist or no permission") + } + return dao.DeleteModelByID(id) +} + +// 更新模型信息 +func UpdateModelByID(id int, userID uint, modelType, url, parameter string) error { + + return dao.UpdateModelByID(id, userID, modelType, url, parameter) +}