github回调信息完成

This commit is contained in:
junleea 2025-04-27 10:00:19 +08:00
parent 2ac3e0175e
commit 7c5ee9c48e
8 changed files with 186 additions and 26 deletions

View File

@ -96,6 +96,12 @@ func Init() error {
return err
}
err = db.AutoMigrate(&ThirdPartyUserInfo{})
if err != nil {
fmt.Println("third party user info table:", err)
return err
}
err = db.AutoMigrate(&AIStreamMsg{})
if err != nil {
fmt.Println("ai stream message table:", err)

View File

@ -27,6 +27,17 @@ type User struct {
UpdateTime string `gorm:"column:update_time"`
}
// 存储第三方统一信息
type ThirdPartyUserInfo struct {
gorm.Model
UserID int `json:"user_id"` // 用户ID,本系统的用户id
ThirdPartyID int `json:"third_party_id"` // 第三方用户ID
ThirdPartyPlatform string `json:"third_party_platform"` // 第三方平台名称,qq,github
ThirdPartyUserName string `json:"third_party_user_name"` // 第三方用户名
ThirdPartyUserAvatar string `json:"third_party_user_avatar"` // 第三方用户头像
ThirdPartyUserUrl string `json:"third_party_user_url"` // 第三方用户主页,可选
}
func CreateUser(name, password, email, gender string, age int) uint {
user := User{Name: name, Email: email, Password: password, Gender: gender, Age: age}
res := DB.Create(&user)
@ -197,3 +208,42 @@ func FindUserNum() int64 {
DB.Model(&User{}).Count(&count)
return count
}
// 根据用户id获取第三方平台信息
func FindThirdPartyUserInfoByUserID(userID int) []ThirdPartyUserInfo {
var thirdPartyUserInfos []ThirdPartyUserInfo
DB.Where("user_id = ?", userID).Find(&thirdPartyUserInfos)
return thirdPartyUserInfos
}
// 根据平台用户id获取信息
func FindThirdPartyUserInfoByThirdPartyID(thirdPartyID int) []ThirdPartyUserInfo {
var thirdPartyUserInfo []ThirdPartyUserInfo
DB.Where("third_party_id = ?", thirdPartyID).First(&thirdPartyUserInfo)
return thirdPartyUserInfo
}
// 根据第三方平台名称和用户id获取信息
func FindThirdPartyUserInfoByPlatformAndUserID(thirdPartyPlatform string, userID int) []ThirdPartyUserInfo {
var thirdPartyUserInfo []ThirdPartyUserInfo
DB.Where("third_party_platform = ? and user_id = ?", thirdPartyPlatform, userID).First(&thirdPartyUserInfo)
return thirdPartyUserInfo
}
func CreateThirdPartyUserInfo(userID, thirdPartyID int, thirdPartyPlatform, thirdPartyUserName, thirdPartyUserAvatar, thirdPartyUserUrl string) uint {
thirdPartyUserInfo := ThirdPartyUserInfo{UserID: userID, ThirdPartyID: thirdPartyID, ThirdPartyPlatform: thirdPartyPlatform, ThirdPartyUserName: thirdPartyUserName, ThirdPartyUserAvatar: thirdPartyUserAvatar, ThirdPartyUserUrl: thirdPartyUserUrl}
res := DB.Create(&thirdPartyUserInfo)
if res.Error != nil {
return 0
}
return thirdPartyUserInfo.ID
}
// 删除
func DeleteThirdPartyUserInfoByID(id int) int {
res := DB.Delete(&ThirdPartyUserInfo{}, id)
if res.Error != nil {
return 0
}
return id
}

View File

@ -655,6 +655,7 @@ func handleGithubCallback(c *gin.Context) {
//json解析
var state proto.ThirdPartyLoginState
err = json.Unmarshal([]byte(decodedStr), &state)
log.Println("handle github callback state:", decodedStr, "\tcode:", code)
if err != nil {
log.Println("json unmarshal error:", err)
}

View File

@ -97,6 +97,9 @@ const (
//第三方登录
ThirdPartyLoginUUIDInvalid = 161 //第三方登录uuid失效
ThirdPartyAddUserHasBinded = 162 //第三方登录用户已绑定
//第三方用户未绑定
ThirdPartyUserNotBinded = 163 //第三方登录用户未绑定
)

View File

@ -92,3 +92,11 @@ type GitHubOAuthRequest struct {
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
}
// github返回用户信息
type GitHubUserInfo struct {
LoginUserName string `json:"login"` // 用户名
UserID int `json:"id"` // 用户ID
AvatarUrl string `json:"avatar_url"` //头像
Url string `json:"url"` // 用户主页
}

View File

@ -9,6 +9,7 @@ import (
"github.com/golang-jwt/jwt"
"log"
"regexp"
"strconv"
"time"
)
@ -205,7 +206,78 @@ func SetDashboardInfoToRedis() {
func DoGithubCallBack(state *proto.ThirdPartyLoginState, code string) {
//获取Access Token
resp, err := worker.ExchangeCodeForAccessToken(proto.Config.GITHUB_CLIENT_ID, proto.Config.GITHUB_CLIENT_SECRET, code, "")
if err != nil {
log.Println("get github access token error:", err)
return
}
if resp.AccessToken == "" {
log.Println("get github access token is empty")
return
}
//获取用户信息
userInfo, err := worker.GetGitHubUserInfo(resp.AccessToken)
if err != nil {
log.Println("get github user info error:", err)
return
}
var thirdPartyLoginStatus proto.ThirdPartyLoginStatus
thirdPartyLoginStatus.Type = state.Platform
if state.Type == "login" {
//根据第三方平台查找用户
thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID)
if thirdPartyUserInfoList == nil || len(thirdPartyUserInfoList) == 0 {
thirdPartyLoginStatus.Status = 2 //未绑定用户
} else {
thirdPartyUserInfo := thirdPartyUserInfoList[0]
//获取用户信息
user := GetUserByIDWithCache(thirdPartyUserInfo.UserID)
if user.ID == 0 {
thirdPartyLoginStatus.Status = 2
log.Println("get user by id error")
} else {
//成功
thirdPartyLoginStatus.Status = 0
thirdPartyLoginStatus.UserInfo.UserID = int(user.ID)
thirdPartyLoginStatus.UserInfo.Username = user.Name
thirdPartyLoginStatus.UserInfo.Email = user.Email
thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name)
}
}
} else if state.Type == "add" {
//根据第三方平台查找用户
thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID)
if thirdPartyUserInfoList != nil && len(thirdPartyUserInfoList) > 0 {
thirdPartyLoginStatus.Status = 3 //已绑定用户
} else {
userIDStr := worker.GetRedis("user_add_platform_" + state.UUID)
if userIDStr == "" {
log.Println("user id is empty")
thirdPartyLoginStatus.Status = 2 //未绑定用户
} else {
//字符串转int
userID, _ := strconv.Atoi(userIDStr)
//根据用户ID获取用户信息
user := GetUserByIDWithCache(userID)
if user.ID == 0 {
thirdPartyLoginStatus.Status = 2 //未绑定用户
log.Println("get user by id error")
} else {
//成功
thirdPartyLoginStatus.Status = 0
thirdPartyLoginStatus.UserInfo.UserID = int(user.ID)
thirdPartyLoginStatus.UserInfo.Username = user.Name
thirdPartyLoginStatus.UserInfo.Email = user.Email
thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name)
}
}
}
} else {
log.Println("DoGithubCallBack state type error:", state.Type)
}
//更新redis中的第三方登录状态
thirdPartyLoginStatusStr, _ := json.Marshal(thirdPartyLoginStatus)
worker.SetRedisWithExpire(state.UUID, string(thirdPartyLoginStatusStr), time.Minute*10)
}
// 解析jwt内容
@ -233,3 +305,18 @@ func DecodeJWTToken(tokenStr string) (int, error) {
}
return int(userID), nil
}
// 生成token
func GenerateJWTToken(userID int, userName string) (string, error) {
//创建token
claims := jwt.MapClaims{
"id": userID,
"username": userName,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(proto.SigningKey)
if err != nil {
return "", err
}
return tokenString, nil
}

View File

@ -4,7 +4,6 @@ import (
"StuAcaWorksAI/proto"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
)
@ -53,15 +52,17 @@ func ExchangeCodeForAccessToken(clientID, clientSecret, code, redirectURI string
}
// 获取用户信息
func GetGitHubUserInfo(accessToken string) {
func GetGitHubUserInfo(accessToken string) (proto.GitHubUserInfo, error) {
url := "https://api.github.com/user"
headers := map[string]string{
"Authorization": "Bearer " + accessToken,
}
err, data := DoGetRequest(url, headers)
var resp proto.GitHubUserInfo
if err != nil {
return
return resp, err
}
fmt.Println("User Info:", string(data))
err = json.Unmarshal(data, &resp)
return resp, err
}

View File

@ -2,7 +2,8 @@ package worker
import (
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"net/url"
"strings"
@ -68,8 +69,8 @@ type QQRefreshTokenRequest struct {
}
// 2. Get Access Token
func GetToken(w http.ResponseWriter, r *http.Request) {
code := r.FormValue("code")
func GetQQToken(code string) string {
var accessToken string
params := url.Values{}
params.Add("grant_type", "authorization_code")
params.Add("client_id", AppId)
@ -80,11 +81,12 @@ func GetToken(w http.ResponseWriter, r *http.Request) {
response, err := http.Get(loginURL)
if err != nil {
w.Write([]byte(err.Error()))
log.Println("GetQQToken error:", err.Error())
return accessToken
}
defer response.Body.Close()
bs, _ := ioutil.ReadAll(response.Body)
bs, _ := io.ReadAll(response.Body)
body := string(bs)
resultMap := convertToMap(body)
@ -93,41 +95,43 @@ func GetToken(w http.ResponseWriter, r *http.Request) {
info.AccessToken = resultMap["access_token"]
info.RefreshToken = resultMap["refresh_token"]
info.ExpiresIn = resultMap["expires_in"]
GetOpenId(info, w)
return info.AccessToken
}
// 3. Get OpenId
func GetOpenId(info *PrivateInfo, w http.ResponseWriter) {
resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", info.AccessToken))
// 3. Get QQ OpenId
func GetOpenId(accessToken string) {
resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", accessToken))
if err != nil {
w.Write([]byte(err.Error()))
return
}
defer resp.Body.Close()
bs, _ := ioutil.ReadAll(resp.Body)
bs, _ := io.ReadAll(resp.Body)
body := string(bs)
info.OpenId = body[45:77]
openId := body[45:77]
GetUserInfo(info, w)
info, err := GetUserInfo(accessToken, openId)
if err != nil {
return
}
log.Println(info)
}
// 4. Get User info
func GetUserInfo(info *PrivateInfo, w http.ResponseWriter) {
func GetUserInfo(accessToken string, openID string) (string, error) {
params := url.Values{}
params.Add("access_token", info.AccessToken)
params.Add("openid", info.OpenId)
params.Add("access_token", accessToken)
params.Add("openid", openID)
params.Add("oauth_consumer_key", AppId)
uri := fmt.Sprintf("https://graph.qq.com/user/get_user_info?%s", params.Encode())
resp, err := http.Get(uri)
if err != nil {
w.Write([]byte(err.Error()))
return "", err
}
defer resp.Body.Close()
bs, _ := ioutil.ReadAll(resp.Body)
w.Write(bs)
bs, _ := io.ReadAll(resp.Body)
return string(bs), nil
}
func convertToMap(str string) map[string]string {