添加模型管理功能,会话及对话交流部分完成

This commit is contained in:
junleea 2025-03-21 14:44:40 +08:00
parent 9fb52655a9
commit 98239a05d6
10 changed files with 435 additions and 0 deletions

View File

@ -53,6 +53,24 @@ func Init() error {
fmt.Println("file auth table:", err) fmt.Println("file auth table:", err)
return err return err
} }
err = db.AutoMigrate(&Model{})
if err != nil {
fmt.Println("model table:", err)
return err
}
err = db.AutoMigrate(&Session{})
if err != nil {
fmt.Println("session table:", err)
return err
}
err = db.AutoMigrate(&Message{})
if err != nil {
fmt.Println("message table:", err)
return err
}
DB = db DB = db
return err return err
} }

99
dao/im.go Normal file
View File

@ -0,0 +1,99 @@
package dao
import (
"StuAcaWorksAI/proto"
"gorm.io/gorm"
)
type Session struct {
gorm.Model
UserID int `gorm:"column:user_id"` //只能由用户创建
Name string `gorm:"column:name"`
}
type Message struct {
gorm.Model
Type int `gorm:"column:type"` //1用户之间的消息2为与模型消息
SessionID int `gorm:"column:session_id"`
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型则为模型id根据type判断
Msg string `gorm:"column:msg"` //消息内容
Status int `gorm:"column:status"` //0为未读1为已读
}
func CreateSession(userID int, name string) (error, uint) {
session := Session{UserID: userID, Name: name}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&session)
} else {
res = DB.Create(&session)
}
return res.Error, session.ID
}
func FindSessionByID(id int) Session {
var session Session
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("id = ?", id).First(&session)
} else {
DB.Where("id = ?", id).First(&session)
}
return session
}
// 获取用户最新的创建时间会话列表 50个
func FindSessionByUserID(userID int) []Session {
var sessions []Session
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("user_id = ?", userID).Order("created_at DESC").Limit(50).Find(&sessions)
} else {
DB.Where("user_id = ?", userID).Order("created_at DESC").Limit(50).Find(&sessions)
}
return sessions
}
// 更新会话的名字
func UpdateSessionByID(id int, userId int, name string) error {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
} else {
res = DB.Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
}
return res.Error
}
// 删除会话
func DeleteSessionByID(id int, userId int) error {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Delete(&Session{}, id)
} else {
res = DB.Delete(&Session{}, id)
}
return res.Error
}
// 创建消息
func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string) (error, uint) {
message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&message)
} else {
res = DB.Create(&message)
}
return res.Error, message.ID
}
// 根据会话id获取消息
func FindMessageBySessionID(sessionID int) []Message {
var messages []Message
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages)
} else {
DB.Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages)
}
return messages
}

71
dao/model.go Normal file
View File

@ -0,0 +1,71 @@
package dao
import (
"StuAcaWorksAI/proto"
"gorm.io/gorm"
)
type Model struct {
gorm.Model
Type string `gorm:"column:type"` //模型类型
UserID uint `gorm:"column:user_id"` //用户id
Url string `gorm:"column:url"` //模型地址
Parameter string `gorm:"column:parameter"` //模型参数,存储json
}
// 创建模型
func CreateModel(userID uint, modelType, url, parameter string) (error, uint) {
model := Model{UserID: userID, Type: modelType, Url: url, Parameter: parameter}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&model)
} else {
res = DB.Create(&model)
}
return res.Error, model.ID
}
// 根据id查找模型
func FindModelByID(id, userID int) []Model {
var model []Model
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("id = ? and user_id = ?", id, userID).Find(&model)
} else {
DB.Where("id = ? and user_id = ?", id, userID).Find(&model)
}
return model
}
// 根据用户id查找模型
func FindModelByUserID(userID int) []Model {
var models []Model
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("user_id = ?", userID).Find(&models)
} else {
DB.Where("user_id = ?", userID).Find(&models)
}
return models
}
// 根据id删除模型
func DeleteModelByID(id int) error {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Delete(&Model{}, id)
} else {
res = DB.Delete(&Model{}, id)
}
return res.Error
}
// 根据id更新模型
func UpdateModelByID(id int, userID uint, modelType, url, parameter string) error {
model := Model{UserID: userID, Type: modelType, Url: url, Parameter: parameter}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Model(&Model{}).Where("id = ?", id).Updates(&model)
} else {
res = DB.Model(&Model{}).Where("id = ?", id).Updates(&model)
}
return res.Error
}

1
go.mod
View File

@ -7,6 +7,7 @@ require (
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/robfig/cron/v3 v3.0.1 github.com/robfig/cron/v3 v3.0.1
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.11 gorm.io/driver/postgres v1.5.11

2
go.sum
View File

@ -43,6 +43,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=

40
handler/im.go Normal file
View File

@ -0,0 +1,40 @@
package handler
import (
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service"
"github.com/gin-gonic/gin"
"net/http"
)
type SMessage struct {
Type int `json:"type" form:"type"`
Msg string `json:"msg" form:"msg"`
ToID int `json:"to_id" form:"to_id"`
SessionID int `json:"session_id" form:"session_id"`
}
func SetUpIMGroup(router *gin.Engine) {
imGroup := router.Group("/im")
imGroup.POST("/send_message", SendMessage)
}
func SendMessage(c *gin.Context) {
var req SMessage
id, _ := c.Get("id")
userID := int(id.(float64))
if err := c.ShouldBind(&req); err == nil {
var err2 error
var mid uint
err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
//ws长连接

98
handler/model.go Normal file
View File

@ -0,0 +1,98 @@
package handler
import (
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service"
"github.com/gin-gonic/gin"
"net/http"
)
func SetUpModelGroup(router *gin.Engine) {
modelGroup := router.Group("/model")
modelGroup.POST("/create", CreateModel)
modelGroup.POST("/find", FindModel)
modelGroup.DELETE("/delete", DeleteModel)
modelGroup.PUT("/update", UpdateModel)
}
type Model struct {
ID int `json:"id" form:"id"`
UserID int `json:"user_id" form:"user_id"`
Type string `json:"type" form:"type"`
Url string `json:"url" form:"url"`
Parameter string `json:"parameter" form:"parameter"`
}
func CreateModel(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
var req Model
if err := c.ShouldBind(&req); err == nil {
// 创建模型信息
err2, mid := service.CreateModel(uint(userID), req.Type, req.Url, req.Parameter)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.ModelCreateFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
func FindModel(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
var req Model
if err := c.ShouldBind(&req); err == nil {
if req.Type == "ID" {
// 根据id查找模型
model := service.FindModelByID(req.ID, userID)
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": model})
} else if req.Type == "UserID" {
// 根据用户id查找模型
models := service.FindModelByUserID(userID)
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": models})
} else {
c.JSON(http.StatusOK, gin.H{"error": "type error", "code": proto.ModelSearchFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
// 修改
func UpdateModel(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
var req Model
if err := c.ShouldBind(&req); err == nil {
// 更新模型信息
err2 := service.UpdateModelByID(req.ID, uint(userID), req.Type, req.Url, req.Parameter)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.ModelUpdateFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
// 删除
func DeleteModel(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
var req Model
if err := c.ShouldBind(&req); err == nil {
// 删除模型
err2 := service.DeleteModelByID(req.ID, userID)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.ModelDeleteFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}

View File

@ -73,4 +73,9 @@ const (
ShellUpdateFailed = 101 // 更新shell失败 ShellUpdateFailed = 101 // 更新shell失败
ShellDeleteFailed = 102 // 删除shell失败 ShellDeleteFailed = 102 // 删除shell失败
ShellSearchFailed = 103 // 获取shell失败 ShellSearchFailed = 103 // 获取shell失败
ModelCreateFailed = 110 // 创建模型失败
ModelUpdateFailed = 111 // 更新模型失败
ModelDeleteFailed = 112 // 删除模型失败
ModelSearchFailed = 113 // 获取模型失败
) )

52
service/imService.go Normal file
View File

@ -0,0 +1,52 @@
package service
import (
"StuAcaWorksAI/dao"
"errors"
)
// 创建会话服务
func CreateSession(userID int, name string) (error, uint) {
err, id := dao.CreateSession(userID, name)
return err, id
}
// 查找用户的会话列表
func FindSessionByUserID(userID int) []dao.Session {
sessions := dao.FindSessionByUserID(userID)
return sessions
}
// 更新会话名字
func UpdateSessionByID(id int, userId int, name string) error {
err := dao.UpdateSessionByID(id, userId, name)
return err
}
// 创建消息服务
func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int) (error, uint) {
//sessionID, fromID, toID, msgType, status int, msg string
err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg)
return err, id
}
// 通过会话id查询消息
func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) {
//查看会话是否属于该用户
session := dao.FindSessionByID(sessionID)
if session.UserID != userID {
return errors.New("session not belong to user"), nil
}
msgs := dao.FindMessageBySessionID(sessionID)
return nil, msgs
}
func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) {
//查看会话是否属于该用户
session := dao.FindSessionByID(sessionID)
if session.UserID != fromID {
return errors.New("session not belong to user"), 0
}
err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0)
return err, mid
}

49
service/modelService.go Normal file
View File

@ -0,0 +1,49 @@
package service
import (
"StuAcaWorksAI/dao"
"errors"
)
// 创建模型信息
func CreateModel(userID uint, modelType, url, parameter string) (error, uint) {
//查看用户是否有权限创建模型
user := GetUserByIDWithCache(int(userID))
if user.ID == 0 {
return errors.New("user not exist"), 0
}
if user.Role != "admin" {
return errors.New("user not admin,no permission"), 0
}
return dao.CreateModel(userID, modelType, url, parameter)
}
// 根据id查找模型
func FindModelByID(id, userID int) []dao.Model {
return dao.FindModelByID(id, userID)
}
// 根据用户id查找模型
func FindModelByUserID(userID int) []dao.Model {
return dao.FindModelByUserID(userID)
}
// 根据id删除模型
func DeleteModelByID(id, userID int) error {
//查看用户是否有权限删除模型
user := GetUserByIDWithCache(userID)
model := dao.FindModelByID(id, userID)
if user.ID == 0 {
return errors.New("user not exist")
}
if len(model) == 0 && user.Role != "admin" {
return errors.New("model not exist or no permission")
}
return dao.DeleteModelByID(id)
}
// 更新模型信息
func UpdateModelByID(id int, userID uint, modelType, url, parameter string) error {
return dao.UpdateModelByID(id, userID, modelType, url, parameter)
}