diff --git a/dao/db.go b/dao/db.go index 5bf9c7f..b25d46a 100644 --- a/dao/db.go +++ b/dao/db.go @@ -9,28 +9,33 @@ import ( var DB *gorm.DB -func Init() { +func Init() error { dsn := proto.MYSQL_DSN db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if err != nil { panic("failed to connect database") + return err } err = db.AutoMigrate(&User{}) if err != nil { fmt.Println("user table:", err) + return err } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 err = db.AutoMigrate(&Video{}) if err != nil { fmt.Println("video table:", err) + return err } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 err = db.AutoMigrate(&Device{}) if err != nil { fmt.Println("device table:", err) + return err } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 err = db.AutoMigrate(&Logger{}) if err != nil { fmt.Println("logger table:", err) + return err } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 err = db.AutoMigrate(&CID{}) if err != nil { @@ -40,7 +45,30 @@ func Init() { if err != nil { fmt.Println("cidrunlog table:", err) } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 + + err = db.AutoMigrate(&Message{}) + if err != nil { + fmt.Println("message table:", err) + } + + err = db.AutoMigrate(&Group{}) + if err != nil { + fmt.Println("usergroup table:", err) + } + + err = db.AutoMigrate(&GroupUser{}) + if err != nil { + fmt.Println("groupuser table:", err) + return err + } + + err = db.AutoMigrate(&Friend{}) + if err != nil { + fmt.Println("friend table:", err) + } + DB = db + return err } func Close() { diff --git a/dao/im.go b/dao/im.go new file mode 100644 index 0000000..1e83c43 --- /dev/null +++ b/dao/im.go @@ -0,0 +1,178 @@ +package dao + +import ( + "gorm.io/gorm" + "videoplayer/proto" +) + +type Message struct { + gorm.Model + FromUserID int `gorm:"column:from_user_id"` + ToUserID int `gorm:"column:to_user_id"` + GroupID int `gorm:"column:group_id"` + Msg string `gorm:"column:msg"` + Status int `gorm:"column:status"` //单聊时才有0,已读,2未读 + Type int `gorm:"column:type"` //1为单聊,2为群聊,3为系统消息,4为好友请求,5为加入群聊请求 +} + +type Group struct { + gorm.Model + GroupName string `gorm:"column:group_name"` + GroupInfo string `gorm:"column:group_info"` + GroupType int `gorm:"column:group_type"` + GroupIcon string `gorm:"column:group_icon"` +} + +type GroupUser struct { + gorm.Model + GroupID int `gorm:"column:group_id"` + UserID int `gorm:"column:user_id"` + IsMaster bool `gorm:"column:is_master"` +} + +type Friend struct { + gorm.Model + UserID int `gorm:"column:user_id"` + FriendID int `gorm:"column:friend_id"` +} + +// 创建单聊消息 +func CreateSimpleMessage(from_user_id, to_user_id int, message string) (error, uint) { + msg := Message{FromUserID: from_user_id, ToUserID: to_user_id, Msg: message, Type: proto.MSG_TYPE_SIMPLE, Status: proto.MSG_STATUS_UNREAD} + res := DB.Debug().Create(&msg) + return res.Error, msg.ID +} + +func CreateGeneralMessage(from_user_id, to_user_id, msg_type, status, group_id int, message string) (error error, id uint) { + msg := Message{FromUserID: from_user_id, ToUserID: to_user_id, Msg: message, Type: msg_type, Status: status, GroupID: group_id} + res := DB.Debug().Create(&msg) + return res.Error, msg.ID +} + +// 每20个消息一组,请求index*20 +func GetMsgUserByIndex(from_user_id, to_user_id, msg_type, index, status int) ([]Message, error) { + var msgs []Message + var res *gorm.DB + if msg_type == 4 { + res = DB.Debug().Where("from_user_id = ? and to_user_id = ? and type = ? and status = ? ", from_user_id, to_user_id, msg_type, status).Find(&msgs).Order("created_at DESC").Limit(20 * index) + } else if msg_type == 1 { + //单聊,只有两个人的消息 + res = DB.Debug().Where("((from_user_id = ? and to_user_id = ?) or (from_user_id = ? and to_user_id = ?)) and type = ? ", from_user_id, to_user_id, to_user_id, from_user_id, msg_type).Find(&msgs).Order("created_at DESC").Limit(20 * index) + } else { + res = DB.Debug().Where("from_user_id = ? and to_user_id = ? and type = ? ", from_user_id, to_user_id, msg_type).Find(&msgs).Order("created_at DESC").Limit(20 * index) + } + return msgs, res.Error +} + +// 获取邀请消息 +func GetFriendGroupReq(user_id int) ([]Message, error) { + var msgs []Message + res := DB.Debug().Where("to_user_id = ? and type = ?", user_id, proto.MSG_TYPE_FRIEND).Find(&msgs) + return msgs, res.Error +} + +// 修改信息 +func UpdateMessage() { + +} + +// 添加好友 +func AddFriend(user_id, friend_id int) error { + friend := Friend{UserID: user_id, FriendID: friend_id} + friend2 := Friend{UserID: friend_id, FriendID: user_id} + + //开启事务 + tx := DB.Begin() + if err := tx.Create(&friend); err.Error != nil { + tx.Rollback() + return err.Error + } + if err := tx.Create(&friend2); err.Error != nil { + tx.Rollback() + return err.Error + } + //提交事务 + if err := tx.Commit().Error; err != nil { + return err + } + return error(nil) +} + +// 删除好友,删除后双方都不是好友 +func DeleteFriend(user_id, friend_id int) error { + //开启事务 + tx := DB.Begin() + if err := tx.Delete(&Friend{}, "user_id = ? and friend_id = ?", user_id, friend_id); err.Error != nil { + tx.Rollback() + return err.Error + } + if err := tx.Delete(&Friend{}, "user_id = ? and friend_id = ?", friend_id, user_id); err.Error != nil { + tx.Rollback() + return err.Error + } + //提交事务 + if err := tx.Commit().Error; err != nil { + return err + } + return error(nil) +} + +// 通过id查找消息 +func FindMessageByID(id uint) []Message { + var msgs []Message + DB.Debug().Where("id = ?", id).Find(&msgs) + return msgs +} + +// 更新消息状态 +func UpdateMessageStatus(id uint, status int) error { + res := DB.Debug().Model(&Message{}).Where("id = ?", id).Update("status", status) + return res.Error +} + +// 创建群聊,需要事务 +func CreateGroup(groupName, groupInfo, groupType, groupIcon string, user_id int) (error, uint) { + group := Group{GroupName: groupName, GroupInfo: groupInfo, GroupType: proto.MSG_TYPE_GROUP, GroupIcon: groupIcon} + //开启事务 + tx := DB.Begin() + if err := tx.Create(&group); err.Error != nil { + tx.Rollback() + return err.Error, 0 + } + groupUser := GroupUser{GroupID: int(group.ID), UserID: user_id, IsMaster: true} + if err := tx.Create(&groupUser); err.Error != nil { + tx.Rollback() + return err.Error, 0 + } + //提交事务 + if err := tx.Commit().Error; err != nil { + return err, 0 + } + return nil, group.ID +} + +// 查找用户是否在群聊 +func FindGroupUser(user_id, group_id int) []GroupUser { + var groupUsers []GroupUser + DB.Debug().Where("user_id = ? and group_id = ?", user_id, group_id).Find(&groupUsers) + return groupUsers +} + +// 加入群聊 +func JoinGroup(group_id, user_id int) error { + groupUser := GroupUser{GroupID: group_id, UserID: user_id} + res := DB.Debug().Create(&groupUser) + return res.Error +} + +// 退出群聊 +func QuitGroup(group_id, user_id int) error { + res := DB.Debug().Delete(&GroupUser{}, "group_id = ? and user_id = ?", group_id, user_id) + return res.Error +} + +func FindFriend(from_user_id, to_user_id int) []Friend { + var friends []Friend + DB.Debug().Where("user_id = ? and friend_id = ?", from_user_id, to_user_id).Find(&friends) + return friends +} diff --git a/handler/im.go b/handler/im.go index 4607094..f4e2a94 100644 --- a/handler/im.go +++ b/handler/im.go @@ -13,9 +13,31 @@ import ( "sync" "time" "videoplayer/proto" + "videoplayer/service" "videoplayer/worker" ) +type SMessage struct { + To_user_id int `json:"to_user_id" form:"to_user_id"` + Type int `json:"type" form:"type"` + Msg string `json:"msg" form:"msg"` +} + +type Message struct { + ID int `json:"id"` + To_user_id int `json:"to_user_id"` + From_user_id int `json:"from_user_id"` + Index int `json:"index"` + Type int `json:"type"` +} + +type CGroup struct { + Group_name string `json:"group_name" form:"group_name"` + Group_info string `json:"group_info" form:"group_info"` + Group_type string `json:"group_type" form:"group_type"` + Group_icon string `json:"group_icon" form:"group_icon"` +} + var ( upgrader = websocket.Upgrader{ ReadBufferSize: 1024, @@ -37,7 +59,11 @@ func SetUpIMGroup(router *gin.Engine) { imGroup := router.Group("/im") imGroup.POST("/get_imKey", GetImKey) imGroup.GET("/ws", SRMessage) - + imGroup.POST("/send_message", SendMessage) + imGroup.POST("/get_message", GetMessage) + //接受邀请,确认好友关系 + imGroup.POST("/accept_invite", AcceptInvite) + imGroup.POST("/create_group", CreateGroup) } func generateRandomHexString(length int) (string, error) { bytes := make([]byte, length/2) // 16字节的字符串需要32个十六进制字符,即16个字节 @@ -46,6 +72,72 @@ func generateRandomHexString(length int) (string, error) { } return hex.EncodeToString(bytes), nil } +func GetMessage(c *gin.Context) { + var req Message + user_id, _ := c.Get("id") + id := int(user_id.(float64)) + if err := c.ShouldBind(&req); err == nil { + fmt.Println(req) + msgs, err2 := service.GetMsgUserByIndexService(id, req.To_user_id, req.Index, req.Type, req.From_user_id) + if err2 == nil { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "data": msgs, "message": "success"}) + } else { + c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.OperationFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} + +func CreateGroup(c *gin.Context) { + var req CGroup + id, _ := c.Get("id") + user_id := int(id.(float64)) + if err := c.ShouldBind(&req); err == nil { + err2, id := service.CreateGroup(req.Group_name, req.Group_info, req.Group_type, req.Group_icon, user_id) + if err2 == nil { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "data": id, "message": "success"}) + } else { + c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.OperationFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} + +func AcceptInvite(c *gin.Context) { + var req Message + user_id, _ := c.Get("id") + cid := int(user_id.(float64)) + if err := c.ShouldBind(&req); err == nil { + err2 := service.AddFriendService(req.ID, cid, req.To_user_id) + if err2 == nil { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"}) + } else { + c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.OperationFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} + +func SendMessage(c *gin.Context) { + var req SMessage + user_id, _ := c.Get("id") + id := int(user_id.(float64)) + if err := c.ShouldBind(&req); err == nil { + var err2 error + var mid uint + err2, mid = service.CreateGeneralMessageService(id, req.To_user_id, req.Type, req.Msg) + if err2 == nil { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid}) + } else { + c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"}) + } + } else { + c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"}) + } +} func GetImKey(c *gin.Context) { id, _ := c.Get("id") diff --git a/main.go b/main.go index f0afcfc..6034cf4 100644 --- a/main.go +++ b/main.go @@ -18,8 +18,14 @@ func main() { r := gin.Default() gin.SetMode(gin.ReleaseMode) - dao.Init() - worker.InitRedis() + err := dao.Init() + if err != nil { + panic("failed to connect database:" + err.Error()) + } + err = worker.InitRedis() + if err != nil { + panic("failed to connect redis:" + err.Error()) + } r.Use(handler.CrosHandler()) r.Use(JWTAuthMiddleware()) // 使用 JWT 认证中间件 handler.SetUpVideoGroup(r) // Video diff --git a/proto/conf.go b/proto/conf.go index 5130b5f..f3c0f8f 100644 --- a/proto/conf.go +++ b/proto/conf.go @@ -11,7 +11,7 @@ const ( MYSQL_DSN = MYSQL_USER + ":" + MYSQL_PASSWORD + "@tcp(" + MYSQL_HOST + ":" + MYSQL_PORT + ")/" + MYSQL_DB + "?charset=utf8mb4&parseTime=True&loc=Local" REDIS_ADDR = "127.0.0.1:6379" - REDIS_PASSWORD = "lj502138" + REDIS_PASSWORD = "" REIDS_DB = 2 TOKEN_SECRET = "mfjurnc_32ndj9dfhj" @@ -20,6 +20,19 @@ const ( CID_BASE_DIR = "/home/lijun/cid/" ) +const ( + // 以下是消息类型 + MSG_TYPE_SIMPLE = 1 // 单聊 + MSG_TYPE_GROUP = 2 // 群聊 + MSG_TYPE_SYSTEM = 3 // 系统消息 + MSG_TYPE_FRIEND = 4 // 好友请求 + MSG_TYPE_GROUP_ADD = 5 // 加入群聊请求 + + // 以下是消息状态 + MSG_STATUS_READ = 1 // 已读 + MSG_STATUS_UNREAD = 0 // 未读 +) + type User struct { gorm.Model Name string `gorm:"column:name"` diff --git a/proto/status.go b/proto/status.go index 42268d1..a0f71cb 100644 --- a/proto/status.go +++ b/proto/status.go @@ -46,4 +46,7 @@ const ( NoRedisPermissions = 51 NoRunPermissions = 52 NoDevicePermissions = 53 + + //消息错误码 + MsgSendFailed = 61 // 消息发送失败 ) diff --git a/service/imService.go b/service/imService.go new file mode 100644 index 0000000..7c038d3 --- /dev/null +++ b/service/imService.go @@ -0,0 +1,102 @@ +package service + +import ( + "errors" + "videoplayer/dao" +) + +func CreateGeneralMessageService(from_id, to_id, msg_type int, content string) (error, uint) { + // 业务逻辑 + var err error + var id uint + switch msg_type { + case 1: + //判断是否是好友 + friend := dao.FindFriend(from_id, to_id) + if len(friend) == 0 { + return errors.New("not a friend"), 0 + } + err, id = dao.CreateSimpleMessage(from_id, to_id, content) + case 2: + err, id = dao.CreateGeneralMessage(from_id, to_id, msg_type, 0, 0, content) + case 3: + //user := dao.FindUserByID(to_id) + // 系统消息,需要管理员权限 + err, id = dao.CreateGeneralMessage(from_id, to_id, msg_type, 0, 0, content) + case 4: + res, _ := dao.GetMsgUserByIndex(from_id, to_id, 4, 1, 0) + if len(res) > 0 { + // 已经有会话记录 + return errors.New("already have a conversation"), 0 + } + err, id = dao.CreateGeneralMessage(from_id, to_id, msg_type, 0, 0, content) + case 5: + res, _ := dao.GetMsgUserByIndex(from_id, to_id, 5, 1, 0) + if len(res) > 0 { + // 已经有会话记录 + return errors.New("already have a conversation"), 0 + } + //邀请加入群聊请求 + err, id = dao.CreateGeneralMessage(from_id, to_id, msg_type, 0, from_id, content) + default: + // 未知消息类型 + err = errors.New("unknown message type") + } + return err, id +} + +func GetMsgUserByIndexService(from_id, to_id, index, msq_type, from_user_id int) ([]dao.Message, error) { + // 业务逻辑 + if index <= 0 || index > 100 { + return nil, errors.New("index out of range") + } + if msq_type == 4 { + from_id = from_user_id + } + msgs, err := dao.GetMsgUserByIndex(from_id, to_id, msq_type, index, 0) + return msgs, err +} + +// AddFriendService 通过消息id添加好友,和加入群聊 +// id 消息id +// from_user_id 发送消息的用户id +// to_user_id 接收消息的用户id,及接受用户id +func AddFriendService(id, from_user_id, to_user_id int) error { + // 业务逻辑 + res := dao.FindMessageByID(uint(id)) + if res[0].FromUserID == to_user_id && res[0].ToUserID == from_user_id { + friend := dao.FindFriend(from_user_id, to_user_id) + if len(friend) > 0 { + return errors.New("already a friend") + } + dao.UpdateMessageStatus(res[0].ID, 1) + return dao.AddFriend(from_user_id, to_user_id) + } else if res[0].ToUserID == from_user_id && res[0].GroupID == to_user_id { + //加入群聊 + //查看是否已经加入 + groupUser := dao.FindGroupUser(from_user_id, to_user_id) + if len(groupUser) > 0 { + return errors.New("already in the group") + } + err := dao.JoinGroup(groupUser[0].GroupID, to_user_id) + if err != nil { + return err + } + err = dao.UpdateMessageStatus(res[0].ID, 1) + return err + } else { + return errors.New("no such message,cannot add friend") + } +} + +// CreateGroup 创建群聊 +// groupName 群聊名称 +// groupInfo 群聊信息 +// groupType 群聊类型 +// groupIcon 群聊图标 +// user_id 创建群的用户id +func CreateGroup(groupName, groupInfo, groupType, groupIcon string, user_id int) (error, uint) { + // 业务逻辑 + err, id := dao.CreateGroup(groupName, groupInfo, groupType, groupIcon, user_id) + return err, id +} diff --git a/worker/redis.go b/worker/redis.go index a05fc21..bc19b5f 100644 --- a/worker/redis.go +++ b/worker/redis.go @@ -13,7 +13,7 @@ import ( ) var redisClient *redis.Client // Redis 客户端, 用于连接 Redis 服务器 -func InitRedis() { +func InitRedis() error { ctx := context.Background() // 连接redis redisClient = redis.NewClient(&redis.Options{ @@ -27,6 +27,7 @@ func InitRedis() { if err != nil { fmt.Println("Error connecting to Redis: %v", err) } + return err } func CloseRedis() {