github回调信息完成
This commit is contained in:
parent
2ac3e0175e
commit
7c5ee9c48e
|
|
@ -96,6 +96,12 @@ func Init() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = db.AutoMigrate(&ThirdPartyUserInfo{})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("third party user info table:", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = db.AutoMigrate(&AIStreamMsg{})
|
err = db.AutoMigrate(&AIStreamMsg{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("ai stream message table:", err)
|
fmt.Println("ai stream message table:", err)
|
||||||
|
|
|
||||||
50
dao/user.go
50
dao/user.go
|
|
@ -27,6 +27,17 @@ type User struct {
|
||||||
UpdateTime string `gorm:"column:update_time"`
|
UpdateTime string `gorm:"column:update_time"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 存储第三方统一信息
|
||||||
|
type ThirdPartyUserInfo struct {
|
||||||
|
gorm.Model
|
||||||
|
UserID int `json:"user_id"` // 用户ID,本系统的用户id
|
||||||
|
ThirdPartyID int `json:"third_party_id"` // 第三方用户ID
|
||||||
|
ThirdPartyPlatform string `json:"third_party_platform"` // 第三方平台名称,qq,github
|
||||||
|
ThirdPartyUserName string `json:"third_party_user_name"` // 第三方用户名
|
||||||
|
ThirdPartyUserAvatar string `json:"third_party_user_avatar"` // 第三方用户头像
|
||||||
|
ThirdPartyUserUrl string `json:"third_party_user_url"` // 第三方用户主页,可选
|
||||||
|
}
|
||||||
|
|
||||||
func CreateUser(name, password, email, gender string, age int) uint {
|
func CreateUser(name, password, email, gender string, age int) uint {
|
||||||
user := User{Name: name, Email: email, Password: password, Gender: gender, Age: age}
|
user := User{Name: name, Email: email, Password: password, Gender: gender, Age: age}
|
||||||
res := DB.Create(&user)
|
res := DB.Create(&user)
|
||||||
|
|
@ -197,3 +208,42 @@ func FindUserNum() int64 {
|
||||||
DB.Model(&User{}).Count(&count)
|
DB.Model(&User{}).Count(&count)
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据用户id获取第三方平台信息
|
||||||
|
func FindThirdPartyUserInfoByUserID(userID int) []ThirdPartyUserInfo {
|
||||||
|
var thirdPartyUserInfos []ThirdPartyUserInfo
|
||||||
|
DB.Where("user_id = ?", userID).Find(&thirdPartyUserInfos)
|
||||||
|
return thirdPartyUserInfos
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据平台用户id获取信息
|
||||||
|
func FindThirdPartyUserInfoByThirdPartyID(thirdPartyID int) []ThirdPartyUserInfo {
|
||||||
|
var thirdPartyUserInfo []ThirdPartyUserInfo
|
||||||
|
DB.Where("third_party_id = ?", thirdPartyID).First(&thirdPartyUserInfo)
|
||||||
|
return thirdPartyUserInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据第三方平台名称和用户id获取信息
|
||||||
|
func FindThirdPartyUserInfoByPlatformAndUserID(thirdPartyPlatform string, userID int) []ThirdPartyUserInfo {
|
||||||
|
var thirdPartyUserInfo []ThirdPartyUserInfo
|
||||||
|
DB.Where("third_party_platform = ? and user_id = ?", thirdPartyPlatform, userID).First(&thirdPartyUserInfo)
|
||||||
|
return thirdPartyUserInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateThirdPartyUserInfo(userID, thirdPartyID int, thirdPartyPlatform, thirdPartyUserName, thirdPartyUserAvatar, thirdPartyUserUrl string) uint {
|
||||||
|
thirdPartyUserInfo := ThirdPartyUserInfo{UserID: userID, ThirdPartyID: thirdPartyID, ThirdPartyPlatform: thirdPartyPlatform, ThirdPartyUserName: thirdPartyUserName, ThirdPartyUserAvatar: thirdPartyUserAvatar, ThirdPartyUserUrl: thirdPartyUserUrl}
|
||||||
|
res := DB.Create(&thirdPartyUserInfo)
|
||||||
|
if res.Error != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return thirdPartyUserInfo.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除
|
||||||
|
func DeleteThirdPartyUserInfoByID(id int) int {
|
||||||
|
res := DB.Delete(&ThirdPartyUserInfo{}, id)
|
||||||
|
if res.Error != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -655,6 +655,7 @@ func handleGithubCallback(c *gin.Context) {
|
||||||
//json解析
|
//json解析
|
||||||
var state proto.ThirdPartyLoginState
|
var state proto.ThirdPartyLoginState
|
||||||
err = json.Unmarshal([]byte(decodedStr), &state)
|
err = json.Unmarshal([]byte(decodedStr), &state)
|
||||||
|
log.Println("handle github callback state:", decodedStr, "\tcode:", code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("json unmarshal error:", err)
|
log.Println("json unmarshal error:", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,9 @@ const (
|
||||||
|
|
||||||
//第三方登录
|
//第三方登录
|
||||||
ThirdPartyLoginUUIDInvalid = 161 //第三方登录uuid失效
|
ThirdPartyLoginUUIDInvalid = 161 //第三方登录uuid失效
|
||||||
|
ThirdPartyAddUserHasBinded = 162 //第三方登录用户已绑定
|
||||||
|
//第三方用户未绑定
|
||||||
|
ThirdPartyUserNotBinded = 163 //第三方登录用户未绑定
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,3 +92,11 @@ type GitHubOAuthRequest struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
RedirectURI string `json:"redirect_uri"`
|
RedirectURI string `json:"redirect_uri"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// github返回用户信息
|
||||||
|
type GitHubUserInfo struct {
|
||||||
|
LoginUserName string `json:"login"` // 用户名
|
||||||
|
UserID int `json:"id"` // 用户ID
|
||||||
|
AvatarUrl string `json:"avatar_url"` //头像
|
||||||
|
Url string `json:"url"` // 用户主页
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"log"
|
"log"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -205,7 +206,78 @@ func SetDashboardInfoToRedis() {
|
||||||
|
|
||||||
func DoGithubCallBack(state *proto.ThirdPartyLoginState, code string) {
|
func DoGithubCallBack(state *proto.ThirdPartyLoginState, code string) {
|
||||||
//获取Access Token
|
//获取Access Token
|
||||||
|
resp, err := worker.ExchangeCodeForAccessToken(proto.Config.GITHUB_CLIENT_ID, proto.Config.GITHUB_CLIENT_SECRET, code, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Println("get github access token error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp.AccessToken == "" {
|
||||||
|
log.Println("get github access token is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//获取用户信息
|
||||||
|
userInfo, err := worker.GetGitHubUserInfo(resp.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("get github user info error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var thirdPartyLoginStatus proto.ThirdPartyLoginStatus
|
||||||
|
thirdPartyLoginStatus.Type = state.Platform
|
||||||
|
if state.Type == "login" {
|
||||||
|
//根据第三方平台查找用户
|
||||||
|
thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID)
|
||||||
|
if thirdPartyUserInfoList == nil || len(thirdPartyUserInfoList) == 0 {
|
||||||
|
thirdPartyLoginStatus.Status = 2 //未绑定用户
|
||||||
|
} else {
|
||||||
|
thirdPartyUserInfo := thirdPartyUserInfoList[0]
|
||||||
|
//获取用户信息
|
||||||
|
user := GetUserByIDWithCache(thirdPartyUserInfo.UserID)
|
||||||
|
if user.ID == 0 {
|
||||||
|
thirdPartyLoginStatus.Status = 2
|
||||||
|
log.Println("get user by id error")
|
||||||
|
} else {
|
||||||
|
//成功
|
||||||
|
thirdPartyLoginStatus.Status = 0
|
||||||
|
thirdPartyLoginStatus.UserInfo.UserID = int(user.ID)
|
||||||
|
thirdPartyLoginStatus.UserInfo.Username = user.Name
|
||||||
|
thirdPartyLoginStatus.UserInfo.Email = user.Email
|
||||||
|
thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if state.Type == "add" {
|
||||||
|
//根据第三方平台查找用户
|
||||||
|
thirdPartyUserInfoList := dao.FindThirdPartyUserInfoByThirdPartyID(userInfo.UserID)
|
||||||
|
if thirdPartyUserInfoList != nil && len(thirdPartyUserInfoList) > 0 {
|
||||||
|
thirdPartyLoginStatus.Status = 3 //已绑定用户
|
||||||
|
} else {
|
||||||
|
userIDStr := worker.GetRedis("user_add_platform_" + state.UUID)
|
||||||
|
if userIDStr == "" {
|
||||||
|
log.Println("user id is empty")
|
||||||
|
thirdPartyLoginStatus.Status = 2 //未绑定用户
|
||||||
|
} else {
|
||||||
|
//字符串转int
|
||||||
|
userID, _ := strconv.Atoi(userIDStr)
|
||||||
|
//根据用户ID获取用户信息
|
||||||
|
user := GetUserByIDWithCache(userID)
|
||||||
|
if user.ID == 0 {
|
||||||
|
thirdPartyLoginStatus.Status = 2 //未绑定用户
|
||||||
|
log.Println("get user by id error")
|
||||||
|
} else {
|
||||||
|
//成功
|
||||||
|
thirdPartyLoginStatus.Status = 0
|
||||||
|
thirdPartyLoginStatus.UserInfo.UserID = int(user.ID)
|
||||||
|
thirdPartyLoginStatus.UserInfo.Username = user.Name
|
||||||
|
thirdPartyLoginStatus.UserInfo.Email = user.Email
|
||||||
|
thirdPartyLoginStatus.UserInfo.Token, _ = GenerateJWTToken(int(user.ID), user.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Println("DoGithubCallBack state type error:", state.Type)
|
||||||
|
}
|
||||||
|
//更新redis中的第三方登录状态
|
||||||
|
thirdPartyLoginStatusStr, _ := json.Marshal(thirdPartyLoginStatus)
|
||||||
|
worker.SetRedisWithExpire(state.UUID, string(thirdPartyLoginStatusStr), time.Minute*10)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析jwt内容
|
// 解析jwt内容
|
||||||
|
|
@ -233,3 +305,18 @@ func DecodeJWTToken(tokenStr string) (int, error) {
|
||||||
}
|
}
|
||||||
return int(userID), nil
|
return int(userID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 生成token
|
||||||
|
func GenerateJWTToken(userID int, userName string) (string, error) {
|
||||||
|
//创建token
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"id": userID,
|
||||||
|
"username": userName,
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenString, err := token.SignedString(proto.SigningKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return tokenString, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"StuAcaWorksAI/proto"
|
"StuAcaWorksAI/proto"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
@ -53,15 +52,17 @@ func ExchangeCodeForAccessToken(clientID, clientSecret, code, redirectURI string
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取用户信息
|
// 获取用户信息
|
||||||
func GetGitHubUserInfo(accessToken string) {
|
func GetGitHubUserInfo(accessToken string) (proto.GitHubUserInfo, error) {
|
||||||
|
|
||||||
url := "https://api.github.com/user"
|
url := "https://api.github.com/user"
|
||||||
headers := map[string]string{
|
headers := map[string]string{
|
||||||
"Authorization": "Bearer " + accessToken,
|
"Authorization": "Bearer " + accessToken,
|
||||||
}
|
}
|
||||||
err, data := DoGetRequest(url, headers)
|
err, data := DoGetRequest(url, headers)
|
||||||
|
var resp proto.GitHubUserInfo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return resp, err
|
||||||
}
|
}
|
||||||
fmt.Println("User Info:", string(data))
|
err = json.Unmarshal(data, &resp)
|
||||||
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
46
worker/qq.go
46
worker/qq.go
|
|
@ -2,7 +2,8 @@ package worker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -68,8 +69,8 @@ type QQRefreshTokenRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Get Access Token
|
// 2. Get Access Token
|
||||||
func GetToken(w http.ResponseWriter, r *http.Request) {
|
func GetQQToken(code string) string {
|
||||||
code := r.FormValue("code")
|
var accessToken string
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Add("grant_type", "authorization_code")
|
params.Add("grant_type", "authorization_code")
|
||||||
params.Add("client_id", AppId)
|
params.Add("client_id", AppId)
|
||||||
|
|
@ -80,11 +81,12 @@ func GetToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
response, err := http.Get(loginURL)
|
response, err := http.Get(loginURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Write([]byte(err.Error()))
|
log.Println("GetQQToken error:", err.Error())
|
||||||
|
return accessToken
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
bs, _ := ioutil.ReadAll(response.Body)
|
bs, _ := io.ReadAll(response.Body)
|
||||||
body := string(bs)
|
body := string(bs)
|
||||||
|
|
||||||
resultMap := convertToMap(body)
|
resultMap := convertToMap(body)
|
||||||
|
|
@ -93,41 +95,43 @@ func GetToken(w http.ResponseWriter, r *http.Request) {
|
||||||
info.AccessToken = resultMap["access_token"]
|
info.AccessToken = resultMap["access_token"]
|
||||||
info.RefreshToken = resultMap["refresh_token"]
|
info.RefreshToken = resultMap["refresh_token"]
|
||||||
info.ExpiresIn = resultMap["expires_in"]
|
info.ExpiresIn = resultMap["expires_in"]
|
||||||
|
return info.AccessToken
|
||||||
GetOpenId(info, w)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Get OpenId
|
// 3. Get QQ OpenId
|
||||||
func GetOpenId(info *PrivateInfo, w http.ResponseWriter) {
|
func GetOpenId(accessToken string) {
|
||||||
resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", info.AccessToken))
|
resp, err := http.Get(fmt.Sprintf("%s?access_token=%s", "https://graph.qq.com/oauth2.0/me", accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Write([]byte(err.Error()))
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
bs, _ := ioutil.ReadAll(resp.Body)
|
bs, _ := io.ReadAll(resp.Body)
|
||||||
body := string(bs)
|
body := string(bs)
|
||||||
info.OpenId = body[45:77]
|
openId := body[45:77]
|
||||||
|
|
||||||
GetUserInfo(info, w)
|
info, err := GetUserInfo(accessToken, openId)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println(info)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Get User info
|
// 4. Get User info
|
||||||
func GetUserInfo(info *PrivateInfo, w http.ResponseWriter) {
|
func GetUserInfo(accessToken string, openID string) (string, error) {
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Add("access_token", info.AccessToken)
|
params.Add("access_token", accessToken)
|
||||||
params.Add("openid", info.OpenId)
|
params.Add("openid", openID)
|
||||||
params.Add("oauth_consumer_key", AppId)
|
params.Add("oauth_consumer_key", AppId)
|
||||||
|
|
||||||
uri := fmt.Sprintf("https://graph.qq.com/user/get_user_info?%s", params.Encode())
|
uri := fmt.Sprintf("https://graph.qq.com/user/get_user_info?%s", params.Encode())
|
||||||
resp, err := http.Get(uri)
|
resp, err := http.Get(uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Write([]byte(err.Error()))
|
return "", err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
bs, _ := io.ReadAll(resp.Body)
|
||||||
bs, _ := ioutil.ReadAll(resp.Body)
|
return string(bs), nil
|
||||||
w.Write(bs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToMap(str string) map[string]string {
|
func convertToMap(str string) map[string]string {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue