未验证 提交 87281677 编写于 作者: Y Yening Qin 提交者: GitHub

feat: sso support skip tls verify (#1685)

* refactor: sso support skip tls verify

* fix: update oidc

* fix: cas init enable
上级 6e80a63b
......@@ -515,10 +515,23 @@ type SsoConfigOutput struct {
}
func (rt *Router) ssoConfigNameGet(c *gin.Context) {
var oidcDisplayName, casDisplayName, oauthDisplayName string
if rt.Sso.OIDC != nil {
oidcDisplayName = rt.Sso.OIDC.GetDisplayName()
}
if rt.Sso.CAS != nil {
casDisplayName = rt.Sso.CAS.GetDisplayName()
}
if rt.Sso.OAuth2 != nil {
oauthDisplayName = rt.Sso.OAuth2.GetDisplayName()
}
ginx.NewRender(c).Data(SsoConfigOutput{
OidcDisplayName: rt.Sso.OIDC.GetDisplayName(),
CasDisplayName: rt.Sso.CAS.GetDisplayName(),
OauthDisplayName: rt.Sso.OAuth2.GetDisplayName(),
OidcDisplayName: oidcDisplayName,
CasDisplayName: casDisplayName,
OauthDisplayName: oauthDisplayName,
}, nil)
}
......@@ -543,8 +556,7 @@ func (rt *Router) ssoConfigUpdate(c *gin.Context) {
var config oidcx.Config
err := toml.Unmarshal([]byte(f.Content), &config)
ginx.Dangerous(err)
err = rt.Sso.OIDC.Reload(config)
rt.Sso.OIDC, err = oidcx.New(config)
ginx.Dangerous(err)
case "CAS":
var config cas.Config
......
......@@ -3,6 +3,8 @@ package cas
import (
"bytes"
"context"
"crypto/tls"
"net/http"
"net/url"
"strings"
"sync"
......@@ -22,6 +24,7 @@ type Config struct {
RedirectURL string
DisplayName string
CoverAttributes bool
SkipTlsVerify bool
Attributes struct {
Nickname string
Phone string
......@@ -43,6 +46,7 @@ type SsoClient struct {
}
DefaultRoles []string
CoverAttributes bool
HTTPClient *http.Client
sync.RWMutex
}
......@@ -52,6 +56,7 @@ func New(cf Config) *SsoClient {
return &cli
}
cli.Enable = cf.Enable
cli.Config = cf
cli.SsoAddr = cf.SsoAddr
cli.CallbackAddr = cf.RedirectURL
......@@ -62,6 +67,14 @@ func New(cf Config) *SsoClient {
cli.DefaultRoles = cf.DefaultRoles
cli.CoverAttributes = cf.CoverAttributes
if cf.SkipTlsVerify {
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
cli.HTTPClient = &http.Client{Transport: transport}
}
return &cli
}
......@@ -83,6 +96,14 @@ func (s *SsoClient) Reload(cf Config) {
s.Attributes.Email = cf.Attributes.Email
s.DefaultRoles = cf.DefaultRoles
s.CoverAttributes = cf.CoverAttributes
if cf.SkipTlsVerify {
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
s.HTTPClient = &http.Client{Transport: transport}
}
}
func (s *SsoClient) GetDisplayName() string {
......@@ -180,6 +201,11 @@ func (s *SsoClient) ValidateServiceTicket(ctx context.Context, ticket, state str
CasURL: casUrl,
ServiceURL: serviceUrl,
}
if s.HTTPClient != nil {
resOptions.Client = s.HTTPClient
}
resCli := cas.NewRestClient(resOptions)
authRet, err := resCli.ValidateServiceTicket(cas.ServiceTicket(ticket))
if err != nil {
......
......@@ -3,6 +3,7 @@ package oauth2x
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net/http"
......@@ -37,6 +38,7 @@ type SsoClient struct {
UserinfoPrefix string
DefaultRoles []string
Ctx context.Context
sync.RWMutex
}
......@@ -51,6 +53,7 @@ type Config struct {
ClientId string
ClientSecret string
CoverAttributes bool
SkipTlsVerify bool
Attributes struct {
Username string
Nickname string
......@@ -95,6 +98,18 @@ func (s *SsoClient) Reload(cf Config) {
s.UserinfoPrefix = cf.UserinfoPrefix
s.DefaultRoles = cf.DefaultRoles
s.Ctx = context.Background()
if cf.SkipTlsVerify {
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
// Create an HTTP client that uses our custom transport
client := &http.Client{Transport: transport}
s.Ctx = context.WithValue(s.Ctx, oauth2.HTTPClient, client)
}
s.Config = oauth2.Config{
ClientID: cf.ClientId,
ClientSecret: cf.ClientSecret,
......@@ -176,13 +191,12 @@ func (s *SsoClient) exchangeUser(code string) (*CallbackOutput, error) {
s.RLock()
defer s.RUnlock()
ctx := context.Background()
oauth2Token, err := s.Config.Exchange(ctx, code)
oauth2Token, err := s.Config.Exchange(s.Ctx, code)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %s", err)
}
userInfo, err := getUserInfo(s.UserInfoAddr, oauth2Token.AccessToken, s.TranTokenMethod)
userInfo, err := s.getUserInfo(s.UserInfoAddr, oauth2Token.AccessToken, s.TranTokenMethod)
if err != nil {
logger.Errorf("failed to get user info: %s", err)
return nil, fmt.Errorf("failed to get user info: %s", err)
......@@ -197,7 +211,7 @@ func (s *SsoClient) exchangeUser(code string) (*CallbackOutput, error) {
}, nil
}
func getUserInfo(UserInfoAddr, accessToken string, TranTokenMethod string) ([]byte, error) {
func (s *SsoClient) getUserInfo(UserInfoAddr, accessToken string, TranTokenMethod string) ([]byte, error) {
var req *http.Request
if TranTokenMethod == "formdata" {
body := bytes.NewBuffer([]byte("access_token=" + accessToken))
......@@ -222,7 +236,14 @@ func getUserInfo(UserInfoAddr, accessToken string, TranTokenMethod string) ([]by
r.Header.Add("Authorization", "Bearer "+accessToken)
req = r
}
resp, err := http.DefaultClient.Do(req)
client := http.DefaultClient
c := s.Ctx.Value(oauth2.HTTPClient)
if c != nil {
client = c.(*http.Client)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
......
......@@ -2,7 +2,9 @@ package oidcx
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
"time"
......@@ -30,6 +32,7 @@ type SsoClient struct {
}
DefaultRoles []string
Ctx context.Context
sync.RWMutex
}
......@@ -41,6 +44,7 @@ type Config struct {
ClientId string
ClientSecret string
CoverAttributes bool
SkipTlsVerify bool
Attributes struct {
Nickname string
Username string
......@@ -81,7 +85,19 @@ func (s *SsoClient) Reload(cf Config) error {
s.Attributes.Email = cf.Attributes.Email
s.DisplayName = cf.DisplayName
s.DefaultRoles = cf.DefaultRoles
provider, err := oidc.NewProvider(context.Background(), cf.SsoAddr)
s.Ctx = context.Background()
if cf.SkipTlsVerify {
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
// Create an HTTP client that uses our custom transport
client := &http.Client{Transport: transport}
s.Ctx = context.WithValue(s.Ctx, oauth2.HTTPClient, client)
}
provider, err := oidc.NewProvider(s.Ctx, cf.SsoAddr)
if err != nil {
return err
}
......@@ -171,8 +187,7 @@ func (s *SsoClient) exchangeUser(code string) (*CallbackOutput, error) {
s.RLock()
defer s.RUnlock()
ctx := context.Background()
oauth2Token, err := s.Config.Exchange(ctx, code)
oauth2Token, err := s.Config.Exchange(s.Ctx, code)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %v", err)
}
......@@ -182,7 +197,7 @@ func (s *SsoClient) exchangeUser(code string) (*CallbackOutput, error) {
return nil, fmt.Errorf("no id_token field in oauth2 token. ")
}
idToken, err := s.Verifier.Verify(ctx, rawIDToken)
idToken, err := s.Verifier.Verify(s.Ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("failed to verify ID Token: %v", err)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册