diff --git a/handler/tool.go b/handler/tool.go index 298f68d..c81d3aa 100644 --- a/handler/tool.go +++ b/handler/tool.go @@ -855,6 +855,14 @@ func GetThirdPartyAuthUrl(c *gin.Context) { params.Add("state", stateID) params.Add("scope", "user") respUrl = fmt.Sprintf("%s?%s", baseUrl, params.Encode()) + case "microsoft": + params := url.Values{} + params.Add("client_id", worker.MicroSoftClientID) + params.Add("redirect_uri", "https://pm.ljsea.top/tool/third_party_callback") + params.Add("response_type", "code") //返回code + params.Add("state", stateID) + params.Add("scope", "User.Read Mail.Read") + respUrl = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize?%s", worker.MicroSoftTenantID, params.Encode()) default: log.Println("platform not support:", platform) } diff --git a/proto/conf.go b/proto/conf.go index a8b6173..3a2f9be 100644 --- a/proto/conf.go +++ b/proto/conf.go @@ -64,31 +64,32 @@ type User struct { } type ConfigStruct struct { - DB int `json:"db"` // 0: mysql, 1: pg - MYSQL_DSN string `json:"mysql_dsn"` - PG_DSN string `json:"pg_dsn"` - REDIS_ADDR string `json:"redis_addr"` - TOKEN_USE_REDIS bool `json:"token_use_redis"` - REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码 - REDIS_PASSWORD string `json:"redis_password"` - REDIS_DB int `json:"redis_db"` - TOKEN_SECRET string `json:"token_secret"` - CID_BASE_DIR string `json:"cid_base_dir"` - FILE_BASE_DIR string `json:"file_base_dir"` - MONITOR bool `json:"monitor"` // 状态监控及邮件通知 - SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志 - SERVER_PORT string `json:"server_port"` // 服务端口 - LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存,0表示永久保存 - SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型,master: 主服务器,slave: 从服务器,从服务器会定时同步数据 - MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名 - USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒 - SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器 - SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能 - KBASE_SERVER []KBaseServer `json:"kbase_server"` // 知识库服务器列表 - GITHUB_CLIENT_ID string `json:"github_client_id"` // github client id - GITHUB_CLIENT_SECRET string `json:"github_client_secret"` // github client secret - GITEE_CLIENT_ID string `json:"gitee_client_id"` // gitee client id - GITEE_CLIENT_SECRET string `json:"gitee_client_secret"` // gitee client secret + DB int `json:"db"` // 0: mysql, 1: pg + MYSQL_DSN string `json:"mysql_dsn"` + PG_DSN string `json:"pg_dsn"` + REDIS_ADDR string `json:"redis_addr"` + TOKEN_USE_REDIS bool `json:"token_use_redis"` + REDIS_User_PW bool `json:"redis_user_pw"` // 是否使用密码 + REDIS_PASSWORD string `json:"redis_password"` + REDIS_DB int `json:"redis_db"` + TOKEN_SECRET string `json:"token_secret"` + CID_BASE_DIR string `json:"cid_base_dir"` + FILE_BASE_DIR string `json:"file_base_dir"` + MONITOR bool `json:"monitor"` // 状态监控及邮件通知 + SERVER_SQL_LOG bool `json:"server_sql_log"` // 服务器sql日志 + SERVER_PORT string `json:"server_port"` // 服务端口 + LOG_SAVE_DAYS int `json:"log_save_days"` // 日志保存天数,-1表示不保存,0表示永久保存 + SERVER_USER_TYPE string `json:"user_type"` // 服务器用户类型,master: 主服务器,slave: 从服务器,从服务器会定时同步数据 + MASTER_SERVER_DOMAIN string `json:"master_server_domain"` // 主服务器域名 + USER_SYNC_TIME int `json:"user_sync_time"` // 用户数据同步时间,单位秒 + SERVER_NAME string `json:"server_name"` // 服务器名称,用于区分不同服务器 + SPARK_PPT_USAGE bool `json:"spark_ppt_usage"` // 是否使用spark ppt功能 + KBASE_SERVER []KBaseServer `json:"kbase_server"` // 知识库服务器列表 + GITHUB_CLIENT_ID string `json:"github_client_id"` // github client id + GITHUB_CLIENT_SECRET string `json:"github_client_secret"` // github client secret + GITEE_CLIENT_ID string `json:"gitee_client_id"` // gitee client id + GITEE_CLIENT_SECRET string `json:"gitee_client_secret"` // gitee client secret + MICROSOFT_CLIENT_SECRET string `json:"microsoft_client_secret"` // microsoft client secret } type KBaseServer struct { diff --git a/proto/tool.go b/proto/tool.go index 3c1df5e..aefa66f 100644 --- a/proto/tool.go +++ b/proto/tool.go @@ -342,3 +342,27 @@ type GiteaUserInfo struct { Picture string `json:"picture"` Groups interface{} `json:"groups"` } + +/**************************microsoft***********************/ +type MicrosoftOAuthResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` +} + +type MicrosoftOAuthRequest struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Code string `json:"code"` + RedirectURI string `json:"redirect_uri"` + GrantType string `json:"grant_type"` // authorization_code +} +type MicrosoftUserInfo struct { + Sub string `json:"sub"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Email string `json:"email"` + Picture string `json:"picture"` + Groups interface{} `json:"groups"` +} diff --git a/service/toolService.go b/service/toolService.go index 6778e12..34aadd9 100644 --- a/service/toolService.go +++ b/service/toolService.go @@ -491,6 +491,8 @@ func DoThirdPartyCallBack(state *proto.ThirdPartyLoginState, code string) { DoStackoverflowCallBack(state, code) case "my_gitea", "gitea": DoGiteaCallBack(state, code) + case "microsoft": + DoMicroSoftCallBack(state, code) default: log.Println("DoThirdPartyCallBack platform error:", state.Platform) } @@ -704,3 +706,36 @@ func DoGiteaCallBack(state *proto.ThirdPartyLoginState, code string) { log.Printf("do handle %s callback success, third party login status: %v\n", state.Platform, thirdPartyLoginStatus) worker.SetRedisWithExpire(state.UUID, string(thirdPartyLoginStatusStr), time.Minute*10) } + +func DoMicroSoftCallBack(state *proto.ThirdPartyLoginState, code string) { + var thirdPartyLoginStatus proto.ThirdPartyLoginStatus + thirdPartyLoginStatus.Type = state.Platform + //根据code获取Access Token + tokenResp, _ := worker.GetMicroSoftAccessTokenByCode(code, "https://pm.ljsea.top/tool/third_party_callback", worker.MicroSoftClientID, proto.Config.MICROSOFT_CLIENT_SECRET) + //if err != nil { + // log.Printf("get %s access token error:%v\n", state.Platform, err) + // thirdPartyLoginStatus.Status = proto.ParameterError + // return + //} + if tokenResp.AccessToken == "" { + log.Printf("get %s access token is empty,token resp:%v\n", state.Platform, tokenResp) + thirdPartyLoginStatus.Status = proto.ParameterError + } else { + log.Printf("get %s access token:%v\n", state.Platform, tokenResp.AccessToken) + //获取用户信息 + userInfoResp, err2 := worker.GetMicroSoftUserInfo(tokenResp.AccessToken) + if err2 != nil { + log.Printf("get %s user info error:%v\n", state.Platform, err2) + thirdPartyLoginStatus.Status = proto.ParameterError + } else { + log.Printf("get %s user info:%v\n", state.Platform, userInfoResp) + thirdPartyUserInfo := proto.ThirdPartyUserInfo{UserID: userInfoResp.Sub, Name: userInfoResp.Name, Avatar: userInfoResp.Picture, Email: userInfoResp.Email} + HandleThirdPartyLoginStatusV2(state, &thirdPartyLoginStatus, &thirdPartyUserInfo) + thirdPartyLoginStatus.Status = proto.SuccessCode + } + } + //更新redis中的第三方登录状态 + thirdPartyLoginStatusStr, _ := json.Marshal(thirdPartyLoginStatus) + log.Printf("do handle %s callback success, third party login status: %v\n", state.Platform, thirdPartyLoginStatus) + worker.SetRedisWithExpire(state.UUID, string(thirdPartyLoginStatusStr), time.Minute*10) +} diff --git a/worker/thirdParty.go b/worker/thirdParty.go index c394453..8971bc5 100644 --- a/worker/thirdParty.go +++ b/worker/thirdParty.go @@ -569,3 +569,53 @@ func GetGiteaUserInfo(baseDomain, accessToken string) (proto.GiteaUserInfo, erro return resp, nil } + +// microsoft登录 +const ( + MicroSoftClientID = "53ce40d3-260e-4256-a500-201b30203e80" + MicroSoftTenantID = "df0fa05b-820a-48c3-8ebd-f159845bf0b2" +) + +func GetMicroSoftAccessTokenByCode(code string, redirectURI string, clientID string, clientSecret string) (proto.MicrosoftOAuthResponse, error) { + var resp proto.MicrosoftOAuthResponse + + url := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", MicroSoftTenantID) + req := proto.MicrosoftOAuthRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + Code: code, + RedirectURI: redirectURI, + GrantType: "authorization_code", + } + reqData, err := json.Marshal(req) + + log.Println("microsoft request url:", url, "reqData:", string(reqData)) + err2, respBytes := DoPostRequestJSON(url, reqData, nil) + if err2 != nil { + log.Println("microsoft response err:", err2) + return resp, err2 + } + err = json.Unmarshal(respBytes, &resp) + if err != nil { + log.Println("microsoft response decode err:", err, " response content:", string(respBytes)) + return resp, err + } + return resp, nil +} + +func GetMicroSoftUserInfo(accessToken string) (proto.MicrosoftUserInfo, error) { + url := "https://graph.microsoft.com/v1.0/me" + headers := map[string]string{ + "Authorization": "Bearer " + accessToken, + } + var resp proto.MicrosoftUserInfo + err2, respBytes := DoGetRequest(url, headers) + if err2 != nil { + return resp, err2 + } + err := json.Unmarshal(respBytes, &resp) + if err != nil { + return resp, err + } + return resp, nil +}