github回调信息完成

This commit is contained in:
junleea 2025-04-27 10:00:19 +08:00
parent 2ac3e0175e
commit 7c5ee9c48e
8 changed files with 186 additions and 26 deletions

View File

@ -96,6 +96,12 @@ func Init() error {
return err return err
} }
err = db.AutoMigrate(&ThirdPartyUserInfo{})
if err != nil {
fmt.Println("third party user info table:", err)
return err
}
err = db.AutoMigrate(&AIStreamMsg{}) err = db.AutoMigrate(&AIStreamMsg{})
if err != nil { if err != nil {
fmt.Println("ai stream message table:", err) fmt.Println("ai stream message table:", err)

View File

@ -27,6 +27,17 @@ type User struct {
UpdateTime string `gorm:"column:update_time"` UpdateTime string `gorm:"column:update_time"`
} }
// 存储第三方统一信息
type ThirdPartyUserInfo struct {
gorm.Model
UserID int `json:"user_id"` // 用户ID,本系统的用户id
ThirdPartyID int `json:"third_party_id"` // 第三方用户ID
ThirdPartyPlatform string `json:"third_party_platform"` // 第三方平台名称,qq,github
ThirdPartyUserName string `json:"third_party_user_name"` // 第三方用户名
ThirdPartyUserAvatar string `json:"third_party_user_avatar"` // 第三方用户头像
ThirdPartyUserUrl string `json:"third_party_user_url"` // 第三方用户主页,可选
}
func CreateUser(name, password, email, gender string, age int) uint { func CreateUser(name, password, email, gender string, age int) uint {
user := User{Name: name, Email: email, Password: password, Gender: gender, Age: age} user := User{Name: name, Email: email, Password: password, Gender: gender, Age: age}
res := DB.Create(&user) res := DB.Create(&user)
@ -197,3 +208,42 @@ func FindUserNum() int64 {
DB.Model(&User{}).Count(&count) DB.Model(&User{}).Count(&count)
return count return count
} }
// 根据用户id获取第三方平台信息
func FindThirdPartyUserInfoByUserID(userID int) []ThirdPartyUserInfo {
var thirdPartyUserInfos []ThirdPartyUserInfo
DB.Where("user_id = ?", userID).Find(&thirdPartyUserInfos)
return thirdPartyUserInfos
}
// 根据平台用户id获取信息
func FindThirdPartyUserInfoByThirdPartyID(thirdPartyID int) []ThirdPartyUserInfo {
var thirdPartyUserInfo []ThirdPartyUserInfo
DB.Where("third_party_id = ?", thirdPartyID).First(&thirdPartyUserInfo)
return thirdPartyUserInfo
}
// 根据第三方平台名称和用户id获取信息
func FindThirdPartyUserInfoByPlatformAndUserID(thirdPartyPlatform string, userID int) []ThirdPartyUserInfo {
var thirdPartyUserInfo []ThirdPartyUserInfo
DB.Where("third_party_platform = ? and user_id = ?", thirdPartyPlatform, userID).First(&thirdPartyUserInfo)
return thirdPartyUserInfo
}
func CreateThirdPartyUserInfo(userID, thirdPartyID int, thirdPartyPlatform, thirdPartyUserName, thirdPartyUserAvatar, thirdPartyUserUrl string) uint {
thirdPartyUserInfo := ThirdPartyUserInfo{UserID: userID, ThirdPartyID: thirdPartyID, ThirdPartyPlatform: thirdPartyPlatform, ThirdPartyUserName: thirdPartyUserName, ThirdPartyUserAvatar: thirdPartyUserAvatar, ThirdPartyUserUrl: thirdPartyUserUrl}
res := DB.Create(&thirdPartyUserInfo)
if res.Error != nil {
return 0
}
return thirdPartyUserInfo.ID
}
// 删除
func DeleteThirdPartyUserInfoByID(id int) int {
res := DB.Delete(&ThirdPartyUserInfo{}, id)
if res.Error != nil {
return 0
}
return id
}

View File

@ -655,6 +655,7 @@ func handleGithubCallback(c *gin.Context) {
//json解析 //json解析
var state proto.ThirdPartyLoginState var state proto.ThirdPartyLoginState
err = json.Unmarshal([]byte(decodedStr), &state) err = json.Unmarshal([]byte(decodedStr), &state)
log.Println("handle github callback state:", decodedStr, "\tcode:", code)
if err != nil { if err != nil {
log.Println("json unmarshal error:", err) log.Println("json unmarshal error:", err)
} }

View File

@ -97,6 +97,9 @@ const (
//第三方登录 //第三方登录
ThirdPartyLoginUUIDInvalid = 161 //第三方登录uuid失效 ThirdPartyLoginUUIDInvalid = 161 //第三方登录uuid失效
ThirdPartyAddUserHasBinded = 162 //第三方登录用户已绑定
//第三方用户未绑定
ThirdPartyUserNotBinded = 163 //第三方登录用户未绑定
) )

View File

@ -92,3 +92,11 @@ type GitHubOAuthRequest struct {
Code string `json:"code"` Code string `json:"code"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
} }
// github返回用户信息
type GitHubUserInfo struct {
LoginUserName string `json:"login"` // 用户名
UserID int `json:"id"` // 用户ID
AvatarUrl string `json:"avatar_url"` //头像
Url string `json:"url"` // 用户主页
}

View File

@ -9,6 +9,7 @@ import (
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"log" "log"
"regexp" "regexp"
"strconv"
"time" "time"
) )
@ -205,7 +206,78 @@ func SetDashboardInfoToRedis() {
func DoGithubCallBack(state *proto.ThirdPartyLoginState, code string) { func DoGithubCallBack(state *proto.ThirdPartyLoginState, code string) {
//获取Access Token //获取Access Token
resp, err := worker.ExchangeCodeForAccessToken(proto.Config.GITHUB_CLIENT_ID, proto.Config.GITHUB_CLIENT_SECRET, code, "")
if err != nil {
log.Println("get github access token error:", err)
return
}
if resp.AccessToken == "" {
log.Println("get github access token is empty")
return
}
//获取用户信息
userInfo, err := worker.GetGitHubUserInfo(resp.AccessToken)
if err != nil {
log.Println("get github user info error:", err)
return
}
var thirdPartyLoginStatus proto.ThirdPartyLoginStatus
thirdPartyLoginStatus.Type = state.Platform
if state.Type == "login" {
//根据第三方平台查找用户
thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID)
if thirdPartyUserInfoList == nil || len(thirdPartyUserInfoList) == 0 {
thirdPartyLoginStatus.Status = 2 //未绑定用户
} else {
thirdPartyUserInfo := thirdPartyUserInfoList[0]
//获取用户信息
user := GetUserByIDWithCache(thirdPartyUserInfo.UserID)
if user.ID == 0 {
thirdPartyLoginStatus.Status = 2
log.Println("get user by id error")
} else {
//成功
thirdPartyLoginStatus.Status = 0
thirdPartyLoginStatus.UserInfo.UserID = int(user.ID)
thirdPartyLoginStatus.UserInfo.Username = user.Name
thirdPartyLoginStatus.UserInfo.Email = user.Email
thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name)
}
}
} else if state.Type == "add" {
//根据第三方平台查找用户
thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID)
if thirdPartyUserInfoList != nil && len(thirdPartyUserInfoList) > 0 {
thirdPartyLoginStatus.Status = 3 //已绑定用户
} else {
userIDStr := worker.GetRedis("user_add_platform_" + state.UUID)
if userIDStr == "" {
log.Println("user id is empty")
thirdPartyLoginStatus.Status = 2 //未绑定用户
} else {
//字符串转int
userID, _ := strconv.Atoi(userIDStr)
//根据用户ID获取用户信息
user := GetUserByIDWithCache(userID)
if user.ID == 0 {
thirdPartyLoginStatus.Status = 2 //未绑定用户
log.Println("get user by id error")
} else {
//成功
thirdPartyLoginStatus.Status = 0
thirdPartyLoginStatus.UserInfo.UserID = int(user.ID)
thirdPartyLoginStatus.UserInfo.Username = user.Name
thirdPartyLoginStatus.UserInfo.Email = user.Email
thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name)
}
}
}
} else {
log.Println("DoGithubCallBack state type error:", state.Type)
}
//更新redis中的第三方登录状态
thirdPartyLoginStatusStr, _ := json.Marshal(thirdPartyLoginStatus)
worker.SetRedisWithExpire(state.UUID, string(thirdPartyLoginStatusStr), time.Minute*10)
} }
// 解析jwt内容 // 解析jwt内容
@ -233,3 +305,18 @@ func DecodeJWTToken(tokenStr string) (int, error) {
} }
return int(userID), nil return int(userID), nil
} }
// 生成token
func GenerateJWTToken(userID int, userName string) (string, error) {
//创建token
claims := jwt.MapClaims{
"id": userID,
"username": userName,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(proto.SigningKey)
if err != nil {
return "", err
}
return tokenString, nil
}

View File

@ -4,7 +4,6 @@ import (
"StuAcaWorksAI/proto" "StuAcaWorksAI/proto"
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
) )
@ -53,15 +52,17 @@ func ExchangeCodeForAccessToken(clientID, clientSecret, code, redirectURI string
} }
// 获取用户信息 // 获取用户信息
func GetGitHubUserInfo(accessToken string) { func GetGitHubUserInfo(accessToken string) (proto.GitHubUserInfo, error) {
url := "https://api.github.com/user" url := "https://api.github.com/user"
headers := map[string]string{ headers := map[string]string{
"Authorization": "Bearer " + accessToken, "Authorization": "Bearer " + accessToken,
} }
err, data := DoGetRequest(url, headers) err, data := DoGetRequest(url, headers)
var resp proto.GitHubUserInfo
if err != nil { if err != nil {
return return resp, err
} }
fmt.Println("User Info:", string(data)) err = json.Unmarshal(data, &resp)
return resp, err
} }

View File

@ -2,7 +2,8 @@ package worker
import ( import (
"fmt" "fmt"
"io/ioutil" "io"
"log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -68,8 +69,8 @@ type QQRefreshTokenRequest struct {
} }
// 2. Get Access Token // 2. Get Access Token
func GetToken(w http.ResponseWriter, r *http.Request) { func GetQQToken(code string) string {
code := r.FormValue("code") var accessToken string
params := url.Values{} params := url.Values{}
params.Add("grant_type", "authorization_code") params.Add("grant_type", "authorization_code")
params.Add("client_id", AppId) params.Add("client_id", AppId)
@ -80,11 +81,12 @@ func GetToken(w http.ResponseWriter, r *http.Request) {
response, err := http.Get(loginURL) response, err := http.Get(loginURL)
if err != nil { if err != nil {
w.Write([]byte(err.Error())) log.Println("GetQQToken error:", err.Error())
return accessToken
} }
defer response.Body.Close() defer response.Body.Close()
bs, _ := ioutil.ReadAll(response.Body) bs, _ := io.ReadAll(response.Body)
body := string(bs) body := string(bs)
resultMap := convertToMap(body) resultMap := convertToMap(body)
@ -93,41 +95,43 @@ func GetToken(w http.ResponseWriter, r *http.Request) {
info.AccessToken = resultMap["access_token"] info.AccessToken = resultMap["access_token"]
info.RefreshToken = resultMap["refresh_token"] info.RefreshToken = resultMap["refresh_token"]
info.ExpiresIn = resultMap["expires_in"] info.ExpiresIn = resultMap["expires_in"]
return info.AccessToken
GetOpenId(info, w)
} }
// 3. Get OpenId // 3. Get QQ OpenId
func GetOpenId(info *PrivateInfo, w http.ResponseWriter) { func GetOpenId(accessToken string) {
resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", info.AccessToken)) resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", accessToken))
if err != nil { if err != nil {
w.Write([]byte(err.Error())) return
} }
defer resp.Body.Close() defer resp.Body.Close()
bs, _ := ioutil.ReadAll(resp.Body) bs, _ := io.ReadAll(resp.Body)
body := string(bs) body := string(bs)
info.OpenId = body[45:77] openId := body[45:77]
GetUserInfo(info, w) info, err := GetUserInfo(accessToken, openId)
if err != nil {
return
}
log.Println(info)
} }
// 4. Get User info // 4. Get User info
func GetUserInfo(info *PrivateInfo, w http.ResponseWriter) { func GetUserInfo(accessToken string, openID string) (string, error) {
params := url.Values{} params := url.Values{}
params.Add("access_token", info.AccessToken) params.Add("access_token", accessToken)
params.Add("openid", info.OpenId) params.Add("openid", openID)
params.Add("oauth_consumer_key", AppId) params.Add("oauth_consumer_key", AppId)
uri := fmt.Sprintf("https://graph.qq.com/user/get_user_info?%s", params.Encode()) uri := fmt.Sprintf("https://graph.qq.com/user/get_user_info?%s", params.Encode())
resp, err := http.Get(uri) resp, err := http.Get(uri)
if err != nil { if err != nil {
w.Write([]byte(err.Error())) return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
bs, _ := io.ReadAll(resp.Body)
bs, _ := ioutil.ReadAll(resp.Body) return string(bs), nil
w.Write(bs)
} }
func convertToMap(str string) map[string]string { func convertToMap(str string) map[string]string {