diff --git a/main.go b/main.go index 7e70fbc..5ae6252 100644 --- a/main.go +++ b/main.go @@ -107,41 +107,32 @@ func JWTAuthMiddleware() gin.HandlerFunc { } // 从请求头中获取 JWT 令牌 tokenString := c.Request.Header.Get("token") - //请求方式为get时,从url中获取token if tokenString == "" { tokenString = c.Query("token") } - //如果请求为login或register,则不需要验证token - for k, _ := range proto.Url_map { - if strings.Contains(c.Request.URL.Path, k) { - log.Println("need not check token:", c.Request.URL.Path) - c.Next() - return - } + //for k, _ := range proto.Url_map { + // if strings.Contains(c.Request.URL.Path, k) { + // log.Println("need not check token:", c.Request.URL.Path) + // c.Next() + // return + // } + //} + if proto.Url_map[c.Request.URL.Path] == true { //查看是否在不需要token的url中 + c.Next() + return } if tokenString == "" { - //c.AbortWithStatus(200) - c.JSON(200, gin.H{ - "message": "Unauthorized", - "error": "token is empty", - "code": proto.TokenIsNull, - }) + c.AbortWithStatusJSON(http.StatusOK, gin.H{"message": "unauthorized", "error": "token is empty", "code": proto.TokenIsNull}) return } if proto.Config.TOKEN_USE_REDIS { redisToken := worker.GetRedis(tokenString) if redisToken == "" { - c.AbortWithStatus(200) - c.JSON(200, gin.H{ - "message": "NOT_LOGIN", - "error": "server token is empty", - "code": proto.TokenIsNull, - }) + c.AbortWithStatusJSON(http.StatusOK, gin.H{"message": "NOT_LOGIN", "error": "server token is empty", "code": proto.TokenIsNull}) return } } - //查看token是否在超级token中 if worker.IsContainSet("super_permission_tokens", tokenString) { sId := c.Request.Header.Get("super_id") @@ -149,12 +140,7 @@ func JWTAuthMiddleware() gin.HandlerFunc { sId = c.Query("super_id") } if sId == "" { - c.AbortWithStatus(200) - c.JSON(200, gin.H{ - "message": "NOT_LOGIN", - "error": "super_id is empty", - "code": proto.TokenIsNull, - }) + c.AbortWithStatusJSON(http.StatusOK, gin.H{"message": "unauthorized", "error": "super_id is empty", "code": proto.TokenIsNull}) return } id, _ := strconv.Atoi(sId) @@ -167,19 +153,33 @@ func JWTAuthMiddleware() gin.HandlerFunc { } // 使用加密secret 解析 JWT 令牌 - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + //token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // return proto.SigningKey, nil + //}) + + claims := &proto.CustomClaims{} + token, err := jwt.ParseWithClaims(tokenString, claims, func(t *jwt.Token) (interface{}, error) { + // 验证签名算法 + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, jwt.ErrSignatureInvalid + } return proto.SigningKey, nil }) - - // 验证令牌 - if err != nil || !token.Valid { - c.AbortWithStatus(200) - c.JSON(200, gin.H{ - "message": "NOT_LOGIN", - "error": "Invalid token", - "code": proto.TokenExpired, - }) - return + // 错误处理 + if err != nil { + if ve, ok := err.(*jwt.ValidationError); ok { + switch { + case ve.Errors&jwt.ValidationErrorMalformed != 0: + c.AbortWithStatusJSON(http.StatusOK, gin.H{"error": "Malformed token", "code": proto.TokenInvalid}) + case ve.Errors&jwt.ValidationErrorExpired != 0: + c.AbortWithStatusJSON(http.StatusOK, gin.H{"error": "Token expired", "code": proto.TokenExpired}) + case ve.Errors&jwt.ValidationErrorNotValidYet != 0: + c.AbortWithStatusJSON(http.StatusOK, gin.H{"error": "Token not active yet", "code": proto.TokenInvalid}) + default: + c.AbortWithStatusJSON(http.StatusOK, gin.H{"error": "Invalid token", "code": proto.TokenInvalid}) + } + return + } } // 将用户信息添加到上下文中 @@ -190,18 +190,11 @@ func JWTAuthMiddleware() gin.HandlerFunc { c.Set("user_id", userID) if UserFuncIntercept(int(token.Claims.(jwt.MapClaims)["id"].(float64)), c.Request.URL.Path) { - c.AbortWithStatus(200) - c.JSON(http.StatusOK, gin.H{ - "message": "no function permission", - "error": "no permission", - "code": proto.NoPermission, - }) + c.AbortWithStatusJSON(http.StatusOK, gin.H{"message": "unauthorized", "error": "no function permission", "code": proto.NoPermission}) return } - // 继续处理请求 c.Next() - //log.Println("JWT token is valid, user ID:", token.Claims.(jwt.MapClaims)["id"], " path:", c.Request.URL.Path) } } @@ -365,7 +358,7 @@ func RunGeneralCron() { // 用户功能拦截,返回true表示拦截,false表示不拦截 func UserFuncIntercept(id int, url string) bool { //先查看是否有权限 - user := dao.FindUserByUserID(id) + user := service.GetUserByIDWithCache(id) //如果用户有权限,则不拦截 for k, v := range proto.Per_menu_map { if strings.Contains(url, k) { diff --git a/proto/conf.go b/proto/conf.go index d70aec3..58a43b8 100644 --- a/proto/conf.go +++ b/proto/conf.go @@ -3,6 +3,7 @@ package proto import ( "encoding/json" "fmt" + "github.com/golang-jwt/jwt/v5" "gorm.io/gorm" "log" "os" @@ -65,35 +66,35 @@ type User struct { } type ConfigStruct struct { - DB int `json:"db"` // 0: mysql, 1: pg - MYSQL_DSN string `json:"mysql_dsn"` - PG_DSN string `json:"pg_dsn"` - SlowQueryThreshold int `json:"slow_query_threshold"` // 慢查询阈值,单位ms - REDIS_ADDR string `json:"redis_addr"` - TOKEN_USE_REDIS bool `json:"token_use_redis"` - REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码 - REDIS_PASSWORD string `json:"redis_password"` - REDIS_DB int `json:"redis_db"` - TOKEN_SECRET string `json:"token_secret"` - CID_BASE_DIR string `json:"cid_base_dir"` - FILE_BASE_DIR string `json:"file_base_dir"` - LOG_OUTPUT bool `json:"log_output"` - AISessionNameModelID int `json:"ai_session_name_model_id"` // 用于ai总结会话名称的模型id - MONITOR bool `json:"monitor"` // 状态监控及邮件通知 - SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志 - SERVER_PORT string `json:"server_port"` // 服务端口 - LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存,0表示永久保存 - SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型,master: 主服务器,slave: 从服务器,从服务器会定时同步数据 - MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名 - USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒 - SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器 - SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能 - KBASE_SERVER []KBaseServer `json:"kbase_server"` // 知识库服务器列表 - GITHUB_CLIENT_ID string `json:"github_client_id"` // github client id - GITHUB_CLIENT_SECRET string `json:"github_client_secret"` // github client secret - GITEE_CLIENT_ID string `json:"gitee_client_id"` // gitee client id - GITEE_CLIENT_SECRET string `json:"gitee_client_secret"` // gitee client secret - MICROSOFT_CLIENT_SECRET string `json:"microsoft_client_secret"` // microsoft client secret + DB int `json:"db"` // 0: mysql, 1: pg + MYSQL_DSN string `json:"mysql_dsn"` + PG_DSN string `json:"pg_dsn"` + SlowQueryThreshold int `json:"slow_query_threshold"` // 慢查询阈值,单位ms + REDIS_ADDR string `json:"redis_addr"` + TOKEN_USE_REDIS bool `json:"token_use_redis"` + REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码 + REDIS_PASSWORD string `json:"redis_password"` + REDIS_DB int `json:"redis_db"` + TOKEN_SECRET string `json:"token_secret"` + CID_BASE_DIR string `json:"cid_base_dir"` + FILE_BASE_DIR string `json:"file_base_dir"` + LOG_OUTPUT bool `json:"log_output"` + AISessionNameModelID int `json:"ai_session_name_model_id"` // 用于ai总结会话名称的模型id + MONITOR bool `json:"monitor"` // 状态监控及邮件通知 + SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志 + SERVER_PORT string `json:"server_port"` // 服务端口 + LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存,0表示永久保存 + SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型,master: 主服务器,slave: 从服务器,从服务器会定时同步数据 + MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名 + USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒 + SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器 + SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能 + KBASE_SERVER []KBaseServer `json:"kbase_server"` // 知识库服务器列表 + GITHUB_CLIENT_ID string `json:"github_client_id"` // github client id + GITHUB_CLIENT_SECRET string `json:"github_client_secret"` // github client secret + GITEE_CLIENT_ID string `json:"gitee_client_id"` // gitee client id + GITEE_CLIENT_SECRET string `json:"gitee_client_secret"` // gitee client secret + MICROSOFT_CLIENT_SECRET string `json:"microsoft_client_secret"` // microsoft client secret } type KBaseServer struct { @@ -174,3 +175,15 @@ func DefaultConfig() { Config.SERVER_NAME = "default" Config.SPARK_PPT_USAGE = false } + +type CustomClaims struct { + ID string `json:"id"` + Username string `json:"username"` + UserID int `json:"user_id"` + jwt.RegisteredClaims +} + +func (c CustomClaims) Valid() error { + //TODO implement me + panic("implement me") +}