diff --git a/dao/db.go b/dao/db.go index bf432cc..13b1e58 100644 --- a/dao/db.go +++ b/dao/db.go @@ -9,7 +9,7 @@ import ( var DB *gorm.DB -func Init() { +func Init() error { dsn := proto.MYSQL_DSN db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) @@ -19,18 +19,22 @@ func Init() { err = db.AutoMigrate(&User{}) if err != nil { fmt.Println("user table:", err) + return err } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 err = db.AutoMigrate(&Video{}) if err != nil { fmt.Println("video table:", err) + return err } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 err = db.AutoMigrate(&Device{}) if err != nil { fmt.Println("device table:", err) + return err } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 err = db.AutoMigrate(&Logger{}) if err != nil { fmt.Println("logger table:", err) + return err } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 err = db.AutoMigrate(&CID{}) if err != nil { @@ -44,8 +48,26 @@ func Init() { err = db.AutoMigrate(&Message{}) if err != nil { fmt.Println("message table:", err) - } // 自动迁移,创建表,如果表已经存在,会自动更新表结构,不会删除表,只会创建不存在的表 + } + + err = db.AutoMigrate(&Group{}) + if err != nil { + fmt.Println("usergroup table:", err) + } + + err = db.AutoMigrate(&GroupUser{}) + if err != nil { + fmt.Println("groupuser table:", err) + return err + } + + err = db.AutoMigrate(&Friend{}) + if err != nil { + fmt.Println("friend table:", err) + } + DB = db + return err } func Close() { diff --git a/main.go b/main.go index 60a2d38..81b5adf 100644 --- a/main.go +++ b/main.go @@ -18,8 +18,14 @@ func main() { r := gin.Default() gin.SetMode(gin.ReleaseMode) - dao.Init() - worker.InitRedis() + err := dao.Init() + if err != nil { + panic("failed to connect database:" + err.Error()) + } + err = worker.InitRedis() + if err != nil { + panic("failed to connect redis:" + err.Error()) + } r.Use(handler.CrosHandler()) r.Use(JWTAuthMiddleware()) // 使用 JWT 认证中间件 handler.SetUpVideoGroup(r) // Video diff --git a/worker/redis.go b/worker/redis.go index a05fc21..bc19b5f 100644 --- a/worker/redis.go +++ b/worker/redis.go @@ -13,7 +13,7 @@ import ( ) var redisClient *redis.Client // Redis 客户端, 用于连接 Redis 服务器 -func InitRedis() { +func InitRedis() error { ctx := context.Background() // 连接redis redisClient = redis.NewClient(&redis.Options{ @@ -27,6 +27,7 @@ func InitRedis() { if err != nil { fmt.Println("Error connecting to Redis: %v", err) } + return err } func CloseRedis() {