From 7c5ee9c48ea381c64276892b2bf21b3bfaeb98e4 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Sun, 27 Apr 2025 10:00:19 +0800 Subject: [PATCH] =?UTF-8?q?github=E5=9B=9E=E8=B0=83=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/db.go | 6 +++ dao/user.go | 50 ++++++++++++++++++++++++ handler/tool.go | 1 + proto/status.go | 3 ++ proto/tool.go | 8 ++++ service/toolService.go | 89 +++++++++++++++++++++++++++++++++++++++++- worker/github.go | 9 +++-- worker/qq.go | 46 ++++++++++++---------- 8 files changed, 186 insertions(+), 26 deletions(-) diff --git a/dao/db.go b/dao/db.go index d7d6b42..0c54f5c 100644 --- a/dao/db.go +++ b/dao/db.go @@ -96,6 +96,12 @@ func Init() error { return err } + err = db.AutoMigrate(&ThirdPartyUserInfo{}) + if err != nil { + fmt.Println("third party user info table:", err) + return err + } + err = db.AutoMigrate(&AIStreamMsg{}) if err != nil { fmt.Println("ai stream message table:", err) diff --git a/dao/user.go b/dao/user.go index 57fc820..8576b18 100644 --- a/dao/user.go +++ b/dao/user.go @@ -27,6 +27,17 @@ type User struct { UpdateTime string `gorm:"column:update_time"` } +// 存储第三方统一信息 +type ThirdPartyUserInfo struct { + gorm.Model + UserID int `json:"user_id"` // 用户ID,本系统的用户id + ThirdPartyID int `json:"third_party_id"` // 第三方用户ID + ThirdPartyPlatform string `json:"third_party_platform"` // 第三方平台名称,qq,github + ThirdPartyUserName string `json:"third_party_user_name"` // 第三方用户名 + ThirdPartyUserAvatar string `json:"third_party_user_avatar"` // 第三方用户头像 + ThirdPartyUserUrl string `json:"third_party_user_url"` // 第三方用户主页,可选 +} + func CreateUser(name, password, email, gender string, age int) uint { user := User{Name: name, Email: email, Password: password, Gender: gender, Age: age} res := DB.Create(&user) @@ -197,3 +208,42 @@ func FindUserNum() int64 { DB.Model(&User{}).Count(&count) return count } + +// 根据用户id获取第三方平台信息 +func FindThirdPartyUserInfoByUserID(userID int) []ThirdPartyUserInfo { + var thirdPartyUserInfos []ThirdPartyUserInfo + DB.Where("user_id = ?", userID).Find(&thirdPartyUserInfos) + return thirdPartyUserInfos +} + +// 根据平台用户id获取信息 +func FindThirdPartyUserInfoByThirdPartyID(thirdPartyID int) []ThirdPartyUserInfo { + var thirdPartyUserInfo []ThirdPartyUserInfo + DB.Where("third_party_id = ?", thirdPartyID).First(&thirdPartyUserInfo) + return thirdPartyUserInfo +} + +// 根据第三方平台名称和用户id获取信息 +func FindThirdPartyUserInfoByPlatformAndUserID(thirdPartyPlatform string, userID int) []ThirdPartyUserInfo { + var thirdPartyUserInfo []ThirdPartyUserInfo + DB.Where("third_party_platform = ? and user_id = ?", thirdPartyPlatform, userID).First(&thirdPartyUserInfo) + return thirdPartyUserInfo +} + +func CreateThirdPartyUserInfo(userID, thirdPartyID int, thirdPartyPlatform, thirdPartyUserName, thirdPartyUserAvatar, thirdPartyUserUrl string) uint { + thirdPartyUserInfo := ThirdPartyUserInfo{UserID: userID, ThirdPartyID: thirdPartyID, ThirdPartyPlatform: thirdPartyPlatform, ThirdPartyUserName: thirdPartyUserName, ThirdPartyUserAvatar: thirdPartyUserAvatar, ThirdPartyUserUrl: thirdPartyUserUrl} + res := DB.Create(&thirdPartyUserInfo) + if res.Error != nil { + return 0 + } + return thirdPartyUserInfo.ID +} + +// 删除 +func DeleteThirdPartyUserInfoByID(id int) int { + res := DB.Delete(&ThirdPartyUserInfo{}, id) + if res.Error != nil { + return 0 + } + return id +} diff --git a/handler/tool.go b/handler/tool.go index 1e31aed..7bc3552 100644 --- a/handler/tool.go +++ b/handler/tool.go @@ -655,6 +655,7 @@ func handleGithubCallback(c *gin.Context) { //json解析 var state proto.ThirdPartyLoginState err = json.Unmarshal([]byte(decodedStr), &state) + log.Println("handle github callback state:", decodedStr, "\tcode:", code) if err != nil { log.Println("json unmarshal error:", err) } diff --git a/proto/status.go b/proto/status.go index 1bbb890..991aef2 100644 --- a/proto/status.go +++ b/proto/status.go @@ -97,6 +97,9 @@ const ( //第三方登录 ThirdPartyLoginUUIDInvalid = 161 //第三方登录uuid失效 + ThirdPartyAddUserHasBinded = 162 //第三方登录用户已绑定 + //第三方用户未绑定 + ThirdPartyUserNotBinded = 163 //第三方登录用户未绑定 ) diff --git a/proto/tool.go b/proto/tool.go index 15049d8..f80fdf4 100644 --- a/proto/tool.go +++ b/proto/tool.go @@ -92,3 +92,11 @@ type GitHubOAuthRequest struct { Code string `json:"code"` RedirectURI string `json:"redirect_uri"` } + +// github返回用户信息 +type GitHubUserInfo struct { + LoginUserName string `json:"login"` // 用户名 + UserID int `json:"id"` // 用户ID + AvatarUrl string `json:"avatar_url"` //头像 + Url string `json:"url"` // 用户主页 +} diff --git a/service/toolService.go b/service/toolService.go index 14c1fea..f4c0136 100644 --- a/service/toolService.go +++ b/service/toolService.go @@ -9,6 +9,7 @@ import ( "github.com/golang-jwt/jwt" "log" "regexp" + "strconv" "time" ) @@ -205,7 +206,78 @@ func SetDashboardInfoToRedis() { func DoGithubCallBack(state *proto.ThirdPartyLoginState, code string) { //获取Access Token - + resp, err := worker.ExchangeCodeForAccessToken(proto.Config.GITHUB_CLIENT_ID, proto.Config.GITHUB_CLIENT_SECRET, code, "") + if err != nil { + log.Println("get github access token error:", err) + return + } + if resp.AccessToken == "" { + log.Println("get github access token is empty") + return + } + //获取用户信息 + userInfo, err := worker.GetGitHubUserInfo(resp.AccessToken) + if err != nil { + log.Println("get github user info error:", err) + return + } + var thirdPartyLoginStatus proto.ThirdPartyLoginStatus + thirdPartyLoginStatus.Type = state.Platform + if state.Type == "login" { + //根据第三方平台查找用户 + thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID) + if thirdPartyUserInfoList == nil || len(thirdPartyUserInfoList) == 0 { + thirdPartyLoginStatus.Status = 2 //未绑定用户 + } else { + thirdPartyUserInfo := thirdPartyUserInfoList[0] + //获取用户信息 + user := GetUserByIDWithCache(thirdPartyUserInfo.UserID) + if user.ID == 0 { + thirdPartyLoginStatus.Status = 2 + log.Println("get user by id error") + } else { + //成功 + thirdPartyLoginStatus.Status = 0 + thirdPartyLoginStatus.UserInfo.UserID = int(user.ID) + thirdPartyLoginStatus.UserInfo.Username = user.Name + thirdPartyLoginStatus.UserInfo.Email = user.Email + thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name) + } + } + } else if state.Type == "add" { + //根据第三方平台查找用户 + thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID) + if thirdPartyUserInfoList != nil && len(thirdPartyUserInfoList) > 0 { + thirdPartyLoginStatus.Status = 3 //已绑定用户 + } else { + userIDStr := worker.GetRedis("user_add_platform_" + state.UUID) + if userIDStr == "" { + log.Println("user id is empty") + thirdPartyLoginStatus.Status = 2 //未绑定用户 + } else { + //字符串转int + userID, _ := strconv.Atoi(userIDStr) + //根据用户ID获取用户信息 + user := GetUserByIDWithCache(userID) + if user.ID == 0 { + thirdPartyLoginStatus.Status = 2 //未绑定用户 + log.Println("get user by id error") + } else { + //成功 + thirdPartyLoginStatus.Status = 0 + thirdPartyLoginStatus.UserInfo.UserID = int(user.ID) + thirdPartyLoginStatus.UserInfo.Username = user.Name + thirdPartyLoginStatus.UserInfo.Email = user.Email + thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name) + } + } + } + } else { + log.Println("DoGithubCallBack state type error:", state.Type) + } + //更新redis中的第三方登录状态 + thirdPartyLoginStatusStr, _ := json.Marshal(thirdPartyLoginStatus) + worker.SetRedisWithExpire(state.UUID, string(thirdPartyLoginStatusStr), time.Minute*10) } // 解析jwt内容 @@ -233,3 +305,18 @@ func DecodeJWTToken(tokenStr string) (int, error) { } return int(userID), nil } + +// 生成token +func GenerateJWTToken(userID int, userName string) (string, error) { + //创建token + claims := jwt.MapClaims{ + "id": userID, + "username": userName, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(proto.SigningKey) + if err != nil { + return "", err + } + return tokenString, nil +} diff --git a/worker/github.go b/worker/github.go index adbb7be..c3d5092 100644 --- a/worker/github.go +++ b/worker/github.go @@ -4,7 +4,6 @@ import ( "StuAcaWorksAI/proto" "bytes" "encoding/json" - "fmt" "io" "net/http" ) @@ -53,15 +52,17 @@ func ExchangeCodeForAccessToken(clientID, clientSecret, code, redirectURI string } // 获取用户信息 -func GetGitHubUserInfo(accessToken string) { +func GetGitHubUserInfo(accessToken string) (proto.GitHubUserInfo, error) { url := "https://api.github.com/user" headers := map[string]string{ "Authorization": "Bearer " + accessToken, } err, data := DoGetRequest(url, headers) + var resp proto.GitHubUserInfo if err != nil { - return + return resp, err } - fmt.Println("User Info:", string(data)) + err = json.Unmarshal(data, &resp) + return resp, err } diff --git a/worker/qq.go b/worker/qq.go index 647ea37..b3f6ebf 100644 --- a/worker/qq.go +++ b/worker/qq.go @@ -2,7 +2,8 @@ package worker import ( "fmt" - "io/ioutil" + "io" + "log" "net/http" "net/url" "strings" @@ -68,8 +69,8 @@ type QQRefreshTokenRequest struct { } // 2. Get Access Token -func GetToken(w http.ResponseWriter, r *http.Request) { - code := r.FormValue("code") +func GetQQToken(code string) string { + var accessToken string params := url.Values{} params.Add("grant_type", "authorization_code") params.Add("client_id", AppId) @@ -80,11 +81,12 @@ func GetToken(w http.ResponseWriter, r *http.Request) { response, err := http.Get(loginURL) if err != nil { - w.Write([]byte(err.Error())) + log.Println("GetQQToken error:", err.Error()) + return accessToken } defer response.Body.Close() - bs, _ := ioutil.ReadAll(response.Body) + bs, _ := io.ReadAll(response.Body) body := string(bs) resultMap := convertToMap(body) @@ -93,41 +95,43 @@ func GetToken(w http.ResponseWriter, r *http.Request) { info.AccessToken = resultMap["access_token"] info.RefreshToken = resultMap["refresh_token"] info.ExpiresIn = resultMap["expires_in"] - - GetOpenId(info, w) + return info.AccessToken } -// 3. Get OpenId -func GetOpenId(info *PrivateInfo, w http.ResponseWriter) { - resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", info.AccessToken)) +// 3. Get QQ OpenId +func GetOpenId(accessToken string) { + resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", accessToken)) if err != nil { - w.Write([]byte(err.Error())) + return } defer resp.Body.Close() - bs, _ := ioutil.ReadAll(resp.Body) + bs, _ := io.ReadAll(resp.Body) body := string(bs) - info.OpenId = body[45:77] + openId := body[45:77] - GetUserInfo(info, w) + info, err := GetUserInfo(accessToken, openId) + if err != nil { + return + } + log.Println(info) } // 4. Get User info -func GetUserInfo(info *PrivateInfo, w http.ResponseWriter) { +func GetUserInfo(accessToken string, openID string) (string, error) { params := url.Values{} - params.Add("access_token", info.AccessToken) - params.Add("openid", info.OpenId) + params.Add("access_token", accessToken) + params.Add("openid", openID) params.Add("oauth_consumer_key", AppId) uri := fmt.Sprintf("https://graph.qq.com/user/get_user_info?%s", params.Encode()) resp, err := http.Get(uri) if err != nil { - w.Write([]byte(err.Error())) + return "", err } defer resp.Body.Close() - - bs, _ := ioutil.ReadAll(resp.Body) - w.Write(bs) + bs, _ := io.ReadAll(resp.Body) + return string(bs), nil } func convertToMap(str string) map[string]string {