github回调信息完成
This commit is contained in:
parent
2ac3e0175e
commit
7c5ee9c48e
|
|
@ -96,6 +96,12 @@ func Init() error {
|
|||
return err
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&ThirdPartyUserInfo{})
|
||||
if err != nil {
|
||||
fmt.Println("third party user info table:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&AIStreamMsg{})
|
||||
if err != nil {
|
||||
fmt.Println("ai stream message table:", err)
|
||||
|
|
|
|||
50
dao/user.go
50
dao/user.go
|
|
@ -27,6 +27,17 @@ type User struct {
|
|||
UpdateTime string `gorm:"column:update_time"`
|
||||
}
|
||||
|
||||
// 存储第三方统一信息
|
||||
type ThirdPartyUserInfo struct {
|
||||
gorm.Model
|
||||
UserID int `json:"user_id"` // 用户ID,本系统的用户id
|
||||
ThirdPartyID int `json:"third_party_id"` // 第三方用户ID
|
||||
ThirdPartyPlatform string `json:"third_party_platform"` // 第三方平台名称,qq,github
|
||||
ThirdPartyUserName string `json:"third_party_user_name"` // 第三方用户名
|
||||
ThirdPartyUserAvatar string `json:"third_party_user_avatar"` // 第三方用户头像
|
||||
ThirdPartyUserUrl string `json:"third_party_user_url"` // 第三方用户主页,可选
|
||||
}
|
||||
|
||||
func CreateUser(name, password, email, gender string, age int) uint {
|
||||
user := User{Name: name, Email: email, Password: password, Gender: gender, Age: age}
|
||||
res := DB.Create(&user)
|
||||
|
|
@ -197,3 +208,42 @@ func FindUserNum() int64 {
|
|||
DB.Model(&User{}).Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
// 根据用户id获取第三方平台信息
|
||||
func FindThirdPartyUserInfoByUserID(userID int) []ThirdPartyUserInfo {
|
||||
var thirdPartyUserInfos []ThirdPartyUserInfo
|
||||
DB.Where("user_id = ?", userID).Find(&thirdPartyUserInfos)
|
||||
return thirdPartyUserInfos
|
||||
}
|
||||
|
||||
// 根据平台用户id获取信息
|
||||
func FindThirdPartyUserInfoByThirdPartyID(thirdPartyID int) []ThirdPartyUserInfo {
|
||||
var thirdPartyUserInfo []ThirdPartyUserInfo
|
||||
DB.Where("third_party_id = ?", thirdPartyID).First(&thirdPartyUserInfo)
|
||||
return thirdPartyUserInfo
|
||||
}
|
||||
|
||||
// 根据第三方平台名称和用户id获取信息
|
||||
func FindThirdPartyUserInfoByPlatformAndUserID(thirdPartyPlatform string, userID int) []ThirdPartyUserInfo {
|
||||
var thirdPartyUserInfo []ThirdPartyUserInfo
|
||||
DB.Where("third_party_platform = ? and user_id = ?", thirdPartyPlatform, userID).First(&thirdPartyUserInfo)
|
||||
return thirdPartyUserInfo
|
||||
}
|
||||
|
||||
func CreateThirdPartyUserInfo(userID, thirdPartyID int, thirdPartyPlatform, thirdPartyUserName, thirdPartyUserAvatar, thirdPartyUserUrl string) uint {
|
||||
thirdPartyUserInfo := ThirdPartyUserInfo{UserID: userID, ThirdPartyID: thirdPartyID, ThirdPartyPlatform: thirdPartyPlatform, ThirdPartyUserName: thirdPartyUserName, ThirdPartyUserAvatar: thirdPartyUserAvatar, ThirdPartyUserUrl: thirdPartyUserUrl}
|
||||
res := DB.Create(&thirdPartyUserInfo)
|
||||
if res.Error != nil {
|
||||
return 0
|
||||
}
|
||||
return thirdPartyUserInfo.ID
|
||||
}
|
||||
|
||||
// 删除
|
||||
func DeleteThirdPartyUserInfoByID(id int) int {
|
||||
res := DB.Delete(&ThirdPartyUserInfo{}, id)
|
||||
if res.Error != nil {
|
||||
return 0
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
|
|
|||
|
|
@ -655,6 +655,7 @@ func handleGithubCallback(c *gin.Context) {
|
|||
//json解析
|
||||
var state proto.ThirdPartyLoginState
|
||||
err = json.Unmarshal([]byte(decodedStr), &state)
|
||||
log.Println("handle github callback state:", decodedStr, "\tcode:", code)
|
||||
if err != nil {
|
||||
log.Println("json unmarshal error:", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,6 +97,9 @@ const (
|
|||
|
||||
//第三方登录
|
||||
ThirdPartyLoginUUIDInvalid = 161 //第三方登录uuid失效
|
||||
ThirdPartyAddUserHasBinded = 162 //第三方登录用户已绑定
|
||||
//第三方用户未绑定
|
||||
ThirdPartyUserNotBinded = 163 //第三方登录用户未绑定
|
||||
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -92,3 +92,11 @@ type GitHubOAuthRequest struct {
|
|||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
// github返回用户信息
|
||||
type GitHubUserInfo struct {
|
||||
LoginUserName string `json:"login"` // 用户名
|
||||
UserID int `json:"id"` // 用户ID
|
||||
AvatarUrl string `json:"avatar_url"` //头像
|
||||
Url string `json:"url"` // 用户主页
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/golang-jwt/jwt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -205,7 +206,78 @@ func SetDashboardInfoToRedis() {
|
|||
|
||||
func DoGithubCallBack(state *proto.ThirdPartyLoginState, code string) {
|
||||
//获取Access Token
|
||||
|
||||
resp, err := worker.ExchangeCodeForAccessToken(proto.Config.GITHUB_CLIENT_ID, proto.Config.GITHUB_CLIENT_SECRET, code, "")
|
||||
if err != nil {
|
||||
log.Println("get github access token error:", err)
|
||||
return
|
||||
}
|
||||
if resp.AccessToken == "" {
|
||||
log.Println("get github access token is empty")
|
||||
return
|
||||
}
|
||||
//获取用户信息
|
||||
userInfo, err := worker.GetGitHubUserInfo(resp.AccessToken)
|
||||
if err != nil {
|
||||
log.Println("get github user info error:", err)
|
||||
return
|
||||
}
|
||||
var thirdPartyLoginStatus proto.ThirdPartyLoginStatus
|
||||
thirdPartyLoginStatus.Type = state.Platform
|
||||
if state.Type == "login" {
|
||||
//根据第三方平台查找用户
|
||||
thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID)
|
||||
if thirdPartyUserInfoList == nil || len(thirdPartyUserInfoList) == 0 {
|
||||
thirdPartyLoginStatus.Status = 2 //未绑定用户
|
||||
} else {
|
||||
thirdPartyUserInfo := thirdPartyUserInfoList[0]
|
||||
//获取用户信息
|
||||
user := GetUserByIDWithCache(thirdPartyUserInfo.UserID)
|
||||
if user.ID == 0 {
|
||||
thirdPartyLoginStatus.Status = 2
|
||||
log.Println("get user by id error")
|
||||
} else {
|
||||
//成功
|
||||
thirdPartyLoginStatus.Status = 0
|
||||
thirdPartyLoginStatus.UserInfo.UserID = int(user.ID)
|
||||
thirdPartyLoginStatus.UserInfo.Username = user.Name
|
||||
thirdPartyLoginStatus.UserInfo.Email = user.Email
|
||||
thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name)
|
||||
}
|
||||
}
|
||||
} else if state.Type == "add" {
|
||||
//根据第三方平台查找用户
|
||||
thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID)
|
||||
if thirdPartyUserInfoList != nil && len(thirdPartyUserInfoList) > 0 {
|
||||
thirdPartyLoginStatus.Status = 3 //已绑定用户
|
||||
} else {
|
||||
userIDStr := worker.GetRedis("user_add_platform_" + state.UUID)
|
||||
if userIDStr == "" {
|
||||
log.Println("user id is empty")
|
||||
thirdPartyLoginStatus.Status = 2 //未绑定用户
|
||||
} else {
|
||||
//字符串转int
|
||||
userID, _ := strconv.Atoi(userIDStr)
|
||||
//根据用户ID获取用户信息
|
||||
user := GetUserByIDWithCache(userID)
|
||||
if user.ID == 0 {
|
||||
thirdPartyLoginStatus.Status = 2 //未绑定用户
|
||||
log.Println("get user by id error")
|
||||
} else {
|
||||
//成功
|
||||
thirdPartyLoginStatus.Status = 0
|
||||
thirdPartyLoginStatus.UserInfo.UserID = int(user.ID)
|
||||
thirdPartyLoginStatus.UserInfo.Username = user.Name
|
||||
thirdPartyLoginStatus.UserInfo.Email = user.Email
|
||||
thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Println("DoGithubCallBack state type error:", state.Type)
|
||||
}
|
||||
//更新redis中的第三方登录状态
|
||||
thirdPartyLoginStatusStr, _ := json.Marshal(thirdPartyLoginStatus)
|
||||
worker.SetRedisWithExpire(state.UUID, string(thirdPartyLoginStatusStr), time.Minute*10)
|
||||
}
|
||||
|
||||
// 解析jwt内容
|
||||
|
|
@ -233,3 +305,18 @@ func DecodeJWTToken(tokenStr string) (int, error) {
|
|||
}
|
||||
return int(userID), nil
|
||||
}
|
||||
|
||||
// 生成token
|
||||
func GenerateJWTToken(userID int, userName string) (string, error) {
|
||||
//创建token
|
||||
claims := jwt.MapClaims{
|
||||
"id": userID,
|
||||
"username": userName,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(proto.SigningKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tokenString, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"StuAcaWorksAI/proto"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
|
@ -53,15 +52,17 @@ func ExchangeCodeForAccessToken(clientID, clientSecret, code, redirectURI string
|
|||
}
|
||||
|
||||
// 获取用户信息
|
||||
func GetGitHubUserInfo(accessToken string) {
|
||||
func GetGitHubUserInfo(accessToken string) (proto.GitHubUserInfo, error) {
|
||||
|
||||
url := "https://api.github.com/user"
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
}
|
||||
err, data := DoGetRequest(url, headers)
|
||||
var resp proto.GitHubUserInfo
|
||||
if err != nil {
|
||||
return
|
||||
return resp, err
|
||||
}
|
||||
fmt.Println("User Info:", string(data))
|
||||
err = json.Unmarshal(data, &resp)
|
||||
return resp, err
|
||||
}
|
||||
|
|
|
|||
46
worker/qq.go
46
worker/qq.go
|
|
@ -2,7 +2,8 @@ package worker
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
|
@ -68,8 +69,8 @@ type QQRefreshTokenRequest struct {
|
|||
}
|
||||
|
||||
// 2. Get Access Token
|
||||
func GetToken(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.FormValue("code")
|
||||
func GetQQToken(code string) string {
|
||||
var accessToken string
|
||||
params := url.Values{}
|
||||
params.Add("grant_type", "authorization_code")
|
||||
params.Add("client_id", AppId)
|
||||
|
|
@ -80,11 +81,12 @@ func GetToken(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
response, err := http.Get(loginURL)
|
||||
if err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
log.Println("GetQQToken error:", err.Error())
|
||||
return accessToken
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
bs, _ := ioutil.ReadAll(response.Body)
|
||||
bs, _ := io.ReadAll(response.Body)
|
||||
body := string(bs)
|
||||
|
||||
resultMap := convertToMap(body)
|
||||
|
|
@ -93,41 +95,43 @@ func GetToken(w http.ResponseWriter, r *http.Request) {
|
|||
info.AccessToken = resultMap["access_token"]
|
||||
info.RefreshToken = resultMap["refresh_token"]
|
||||
info.ExpiresIn = resultMap["expires_in"]
|
||||
|
||||
GetOpenId(info, w)
|
||||
return info.AccessToken
|
||||
}
|
||||
|
||||
// 3. Get OpenId
|
||||
func GetOpenId(info *PrivateInfo, w http.ResponseWriter) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", info.AccessToken))
|
||||
// 3. Get QQ OpenId
|
||||
func GetOpenId(accessToken string) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", accessToken))
|
||||
if err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bs, _ := ioutil.ReadAll(resp.Body)
|
||||
bs, _ := io.ReadAll(resp.Body)
|
||||
body := string(bs)
|
||||
info.OpenId = body[45:77]
|
||||
openId := body[45:77]
|
||||
|
||||
GetUserInfo(info, w)
|
||||
info, err := GetUserInfo(accessToken, openId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Println(info)
|
||||
}
|
||||
|
||||
// 4. Get User info
|
||||
func GetUserInfo(info *PrivateInfo, w http.ResponseWriter) {
|
||||
func GetUserInfo(accessToken string, openID string) (string, error) {
|
||||
params := url.Values{}
|
||||
params.Add("access_token", info.AccessToken)
|
||||
params.Add("openid", info.OpenId)
|
||||
params.Add("access_token", accessToken)
|
||||
params.Add("openid", openID)
|
||||
params.Add("oauth_consumer_key", AppId)
|
||||
|
||||
uri := fmt.Sprintf("https://graph.qq.com/user/get_user_info?%s", params.Encode())
|
||||
resp, err := http.Get(uri)
|
||||
if err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bs, _ := ioutil.ReadAll(resp.Body)
|
||||
w.Write(bs)
|
||||
bs, _ := io.ReadAll(resp.Body)
|
||||
return string(bs), nil
|
||||
}
|
||||
|
||||
func convertToMap(str string) map[string]string {
|
||||
|
|
|
|||
Loading…
Reference in New Issue