diff --git a/cmd/ks-iam/app/options/options.go b/cmd/ks-iam/app/options/options.go index d001257ad9bf6f5b665af5d87d3b5c3c73233f56..d986db33409eeac07fb566bcf7bffdba9333d9f8 100644 --- a/cmd/ks-iam/app/options/options.go +++ b/cmd/ks-iam/app/options/options.go @@ -27,6 +27,7 @@ import ( "kubesphere.io/kubesphere/pkg/simple/client/mysql" "kubesphere.io/kubesphere/pkg/simple/client/redis" "strings" + "time" ) type ServerRunOptions struct { @@ -37,7 +38,7 @@ type ServerRunOptions struct { MySQLOptions *mysql.MySQLOptions AdminEmail string AdminPassword string - TokenIdleTimeout string + TokenIdleTimeout time.Duration JWTSecret string AuthRateLimit string EnableMultiLogin bool @@ -61,7 +62,7 @@ func (s *ServerRunOptions) Flags() (fss cliflag.NamedFlagSets) { s.GenericServerRunOptions.AddFlags(fs) fs.StringVar(&s.AdminEmail, "admin-email", "admin@kubesphere.io", "default administrator's email") fs.StringVar(&s.AdminPassword, "admin-password", "passw0rd", "default administrator's password") - fs.StringVar(&s.TokenIdleTimeout, "token-idle-timeout", "30m", "tokens that are idle beyond that time will expire,0s means the token has no expiration time. valid time units are \"ns\",\"us\",\"ms\",\"s\",\"m\",\"h\"") + fs.DurationVar(&s.TokenIdleTimeout, "token-idle-timeout", 30*time.Minute, "tokens that are idle beyond that time will expire,0s means the token has no expiration time. valid time units are \"ns\",\"us\",\"ms\",\"s\",\"m\",\"h\"") fs.StringVar(&s.JWTSecret, "jwt-secret", "", "jwt secret") fs.StringVar(&s.AuthRateLimit, "auth-rate-limit", "5/30m", "specifies the maximum number of authentication attempts permitted and time interval,valid time units are \"s\",\"m\",\"h\"") fs.BoolVar(&s.EnableMultiLogin, "enable-multi-login", false, "allow one account to have multiple sessions") diff --git a/cmd/ks-iam/app/server.go b/cmd/ks-iam/app/server.go index 8a6b52794c5e66849c996aac4cab2db9f526f958..5027148607fd690f1126da4d62f25a27bd9d5a89 100644 --- a/cmd/ks-iam/app/server.go +++ b/cmd/ks-iam/app/server.go @@ -94,7 +94,7 @@ func Run(s *options.ServerRunOptions, stopChan <-chan struct{}) error { waitForResourceSync(stopChan) - err := iam.Init(s.AdminEmail, s.AdminPassword, s.TokenIdleTimeout, s.AuthRateLimit, s.EnableMultiLogin) + err := iam.Init(s.AdminEmail, s.AdminPassword, s.AuthRateLimit, s.TokenIdleTimeout, s.EnableMultiLogin) jwtutil.Setup(s.JWTSecret) diff --git a/pkg/apigateway/caddy-plugin/authenticate/authenticate.go b/pkg/apigateway/caddy-plugin/authenticate/authenticate.go index e30362461d097f72e062570869c05d1eed8dac63..4adb6d730c7d24bc06e64f6ba5ebcb8a68631dda 100644 --- a/pkg/apigateway/caddy-plugin/authenticate/authenticate.go +++ b/pkg/apigateway/caddy-plugin/authenticate/authenticate.go @@ -20,11 +20,11 @@ package authenticate import ( "errors" "fmt" - "github.com/go-redis/redis" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/klog" + "kubesphere.io/kubesphere/pkg/simple/client/redis" "log" "net/http" "strconv" @@ -43,9 +43,9 @@ type Auth struct { type Rule struct { Secret []byte Path string - RedisOptions *redis.Options + RedisOptions *redis.RedisOptions TokenIdleTimeout time.Duration - RedisClient *redis.Client + RedisClient *redis.RedisClient ExceptedPath []string } @@ -187,13 +187,13 @@ func (h Auth) Validate(uToken string) (*jwt.Token, error) { } if _, ok = payload["exp"]; ok { - // allow static token when contain expiration time + // allow static token has expiration time return token, nil } tokenKey := fmt.Sprintf("kubesphere:users:%s:token:%s", username, uToken) - exist, err := h.Rule.RedisClient.Exists(tokenKey).Result() + exist, err := h.Rule.RedisClient.Redis().Exists(tokenKey).Result() if err != nil { klog.Error(err) return nil, err @@ -201,7 +201,7 @@ func (h Auth) Validate(uToken string) (*jwt.Token, error) { if exist == 1 { // reset expiration time if token exist - h.Rule.RedisClient.Expire(tokenKey, h.Rule.TokenIdleTimeout) + h.Rule.RedisClient.Redis().Expire(tokenKey, h.Rule.TokenIdleTimeout) return token, nil } else { return nil, errors.New("illegal token") diff --git a/pkg/apigateway/caddy-plugin/authenticate/auto_load.go b/pkg/apigateway/caddy-plugin/authenticate/auto_load.go index 9cfcea063726e02b13627000919cbfa9bb1d3f61..3cb2f4c79586a65486c7184fe25a08b15c323b3c 100644 --- a/pkg/apigateway/caddy-plugin/authenticate/auto_load.go +++ b/pkg/apigateway/caddy-plugin/authenticate/auto_load.go @@ -19,7 +19,8 @@ package authenticate import ( "fmt" - "github.com/go-redis/redis" + "kubesphere.io/kubesphere/pkg/simple/client/redis" + "strings" "time" @@ -35,10 +36,9 @@ func Setup(c *caddy.Controller) error { return err } - rule.RedisClient = redis.NewClient(rule.RedisOptions) - c.OnStartup(func() error { - if err := rule.RedisClient.Ping().Err(); err != nil { + rule.RedisClient, err = redis.NewRedisClient(rule.RedisOptions, nil) + if err != nil { return err } fmt.Println("Authenticate middleware is initiated") @@ -46,7 +46,7 @@ func Setup(c *caddy.Controller) error { }) c.OnShutdown(func() error { - return rule.RedisClient.Close() + return rule.RedisClient.Redis().Close() }) httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { @@ -95,10 +95,12 @@ func parse(c *caddy.Controller) (Rule, error) { return rule, c.ArgErr() } - if redisOptions, err := redis.ParseURL(c.Val()); err != nil { + options := &redis.RedisOptions{RedisURL: c.Val()} + + if err := options.Validate(); len(err) > 0 { return rule, c.ArgErr() } else { - rule.RedisOptions = redisOptions + rule.RedisOptions = options } if c.NextArg() { diff --git a/pkg/apis/iam/v1alpha2/register.go b/pkg/apis/iam/v1alpha2/register.go index 6ce498af21c49de3800bd8eec8507c3da826df9d..3bbd13af80b55a9cb918ee57cf94c5fa2790380a 100644 --- a/pkg/apis/iam/v1alpha2/register.go +++ b/pkg/apis/iam/v1alpha2/register.go @@ -130,7 +130,13 @@ func addWebService(c *restful.Container) error { To(iam.Login). Doc("KubeSphere APIs support token-based authentication via the Authtoken request header. The POST Login API is used to retrieve the authentication token. After the authentication token is obtained, it must be inserted into the Authtoken header for all requests."). Reads(iam.LoginRequest{}). - Returns(http.StatusOK, ok, models.Token{}). + Returns(http.StatusOK, ok, models.AuthGrantResponse{}). + Metadata(restfulspec.KeyOpenAPITags, []string{constants.IdentityManagementTag})) + ws.Route(ws.POST("/token"). + To(iam.OAuth). + Doc("OAuth API,only support resource owner password credentials grant"). + Reads(iam.LoginRequest{}). + Returns(http.StatusOK, ok, models.AuthGrantResponse{}). Metadata(restfulspec.KeyOpenAPITags, []string{constants.IdentityManagementTag})) ws.Route(ws.GET("/users/{user}"). To(iam.DescribeUser). diff --git a/pkg/apiserver/iam/auth.go b/pkg/apiserver/iam/auth.go index f93855a17c1b1f34d0e25ece0c7f2c99cb248b8c..07892b8aa5eed0e278144270e3bf6385710bf176 100644 --- a/pkg/apiserver/iam/auth.go +++ b/pkg/apiserver/iam/auth.go @@ -18,15 +18,16 @@ package iam import ( + "fmt" "github.com/dgrijalva/jwt-go" "github.com/emicklei/go-restful" "k8s.io/klog" + "kubesphere.io/kubesphere/pkg/models" + "kubesphere.io/kubesphere/pkg/models/iam" + "kubesphere.io/kubesphere/pkg/server/errors" "kubesphere.io/kubesphere/pkg/utils/iputil" "kubesphere.io/kubesphere/pkg/utils/jwtutil" "net/http" - - "kubesphere.io/kubesphere/pkg/models/iam" - "kubesphere.io/kubesphere/pkg/server/errors" ) type Spec struct { @@ -50,8 +51,14 @@ type LoginRequest struct { Password string `json:"password" description:"password"` } +type OAuthRequest struct { + GrantType string `json:"grant_type"` + Username string `json:"username,omitempty" description:"username"` + Password string `json:"password,omitempty" description:"password"` + RefreshToken string `json:"refresh_token,omitempty"` +} + const ( - APIVersion = "authentication.k8s.io/v1beta1" KindTokenReview = "TokenReview" ) @@ -81,6 +88,39 @@ func Login(req *restful.Request, resp *restful.Response) { resp.WriteAsJson(token) } +func OAuth(req *restful.Request, resp *restful.Response) { + + authRequest := &OAuthRequest{} + + err := req.ReadEntity(authRequest) + + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, errors.Wrap(err)) + return + } + var result *models.AuthGrantResponse + switch authRequest.GrantType { + case "refresh_token": + result, err = iam.RefreshToken(authRequest.RefreshToken) + case "password": + ip := iputil.RemoteIp(req.Request) + result, err = iam.PasswordCredentialGrant(authRequest.Username, authRequest.Password, ip) + default: + resp.Header().Set("WWW-Authenticate", "grant_type is not supported") + resp.WriteHeaderAndEntity(http.StatusUnauthorized, errors.Wrap(fmt.Errorf("grant_type is not supported"))) + return + } + + if err != nil { + resp.Header().Set("WWW-Authenticate", err.Error()) + resp.WriteHeaderAndEntity(http.StatusUnauthorized, errors.Wrap(err)) + return + } + + resp.WriteEntity(result) + +} + // k8s token review func TokenReviewHandler(req *restful.Request, resp *restful.Response) { var tokenReview TokenReview @@ -103,7 +143,7 @@ func TokenReviewHandler(req *restful.Request, resp *restful.Response) { if err != nil { klog.Errorln("token review failed", uToken, err) - failed := TokenReview{APIVersion: APIVersion, + failed := TokenReview{APIVersion: tokenReview.APIVersion, Kind: KindTokenReview, Status: &Status{ Authenticated: false, @@ -138,7 +178,7 @@ func TokenReviewHandler(req *restful.Request, resp *restful.Response) { user.Groups = groups - success := TokenReview{APIVersion: APIVersion, + success := TokenReview{APIVersion: tokenReview.APIVersion, Kind: KindTokenReview, Status: &Status{ Authenticated: true, diff --git a/pkg/models/iam/im.go b/pkg/models/iam/im.go index d7444927d99bbaffbe53e05b0a7004045d0f7a5c..ec98d7f2d24ea6fc00647ac853ad4bd7861de4c2 100644 --- a/pkg/models/iam/im.go +++ b/pkg/models/iam/im.go @@ -70,13 +70,12 @@ const ( authRateLimitRegex = `(\d+)/(\d+[s|m|h])` defaultMaxAuthFailed = 5 defaultAuthTimeInterval = 30 * time.Minute - defaultTokenIdleTimeout = 30 * time.Minute ) -func Init(email, password, idleTimeout, authRateLimit string, multiLogin bool) error { +func Init(email, password, authRateLimit string, idleTimeout time.Duration, multiLogin bool) error { adminEmail = email adminPassword = password - tokenIdleTimeout = parseTokenIdleTimeout(idleTimeout) + tokenIdleTimeout = idleTimeout maxAuthFailed, authTimeInterval = parseAuthRateLimit(authRateLimit) enableMultiLogin = multiLogin @@ -97,15 +96,6 @@ func Init(email, password, idleTimeout, authRateLimit string, multiLogin bool) e return nil } -func parseTokenIdleTimeout(tokenExpirationTime string) time.Duration { - duration, err := time.ParseDuration(tokenExpirationTime) - if err != nil { - return defaultTokenIdleTimeout - } else { - return duration - } -} - func parseAuthRateLimit(authRateLimit string) (int, time.Duration) { regex := regexp.MustCompile(authRateLimitRegex) groups := regex.FindStringSubmatch(authRateLimit) @@ -255,8 +245,151 @@ func createGroupsBaseDN() error { return conn.Add(groupsCreateRequest) } +func RefreshToken(refreshToken string) (*models.AuthGrantResponse, error) { + validRefreshToken, err := jwtutil.ValidateToken(refreshToken) + if err != nil { + klog.Error(err) + return nil, err + } + + payload, ok := validRefreshToken.Claims.(jwt.MapClaims) + + if !ok { + err = errors.New("invalid payload") + klog.Error(err) + return nil, err + } + + claims := jwt.MapClaims{} + + // token with expiration time will not auto sliding + claims["username"] = payload["username"] + claims["email"] = payload["email"] + claims["iat"] = time.Now().Unix() + claims["exp"] = time.Now().Add(tokenIdleTimeout * 4).Unix() + + token := jwtutil.MustSigned(claims) + + claims = jwt.MapClaims{} + claims["username"] = payload["username"] + claims["email"] = payload["email"] + claims["iat"] = time.Now().Unix() + claims["type"] = "refresh_token" + claims["exp"] = time.Now().Add(tokenIdleTimeout * 5).Unix() + + refreshToken = jwtutil.MustSigned(claims) + + return &models.AuthGrantResponse{TokenType: "jwt", Token: token, RefreshToken: refreshToken, ExpiresIn: (tokenIdleTimeout * 4).Seconds()}, nil +} + +func PasswordCredentialGrant(username, password, ip string) (*models.AuthGrantResponse, error) { + redisClient, err := clientset.ClientSets().Redis() + if err != nil { + return nil, err + } + + records, err := redisClient.Keys(fmt.Sprintf("kubesphere:authfailed:%s:*", username)).Result() + + if err != nil { + klog.Error(err) + return nil, err + } + + if len(records) >= maxAuthFailed { + return nil, restful.NewError(http.StatusTooManyRequests, "auth rate limit exceeded") + } + + client, err := clientset.ClientSets().Ldap() + if err != nil { + return nil, err + } + conn, err := client.NewConn() + if err != nil { + return nil, err + } + defer conn.Close() + + userSearchRequest := ldap.NewSearchRequest( + client.UserSearchBase(), + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + fmt.Sprintf("(&(objectClass=inetOrgPerson)(|(uid=%s)(mail=%s)))", username, username), + []string{"uid", "mail"}, + nil, + ) + + result, err := conn.Search(userSearchRequest) + + if err != nil { + return nil, err + } + + if len(result.Entries) != 1 { + return nil, ldap.NewError(ldap.LDAPResultInvalidCredentials, errors.New("incorrect password")) + } + + uid := result.Entries[0].GetAttributeValue("uid") + email := result.Entries[0].GetAttributeValue("mail") + dn := result.Entries[0].DN + + // bind as the user to verify their password + err = conn.Bind(dn, password) + + if err != nil { + klog.Infoln("auth failed", username, err) + + if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { + loginFailedRecord := fmt.Sprintf("kubesphere:authfailed:%s:%d", uid, time.Now().UnixNano()) + redisClient.Set(loginFailedRecord, "", authTimeInterval) + } + + return nil, err + } + + claims := jwt.MapClaims{} + + // token with expiration time will not auto sliding + claims["username"] = uid + claims["email"] = email + claims["iat"] = time.Now().Unix() + claims["exp"] = time.Now().Add(tokenIdleTimeout * 4).Unix() + + token := jwtutil.MustSigned(claims) + + if !enableMultiLogin { + // multi login not allowed, remove the previous token + sessions, err := redisClient.Keys(fmt.Sprintf("kubesphere:users:%s:token:*", uid)).Result() + + if err != nil { + klog.Errorln(err) + return nil, err + } + + if len(sessions) > 0 { + klog.V(4).Infoln("revoke token", sessions) + err = redisClient.Del(sessions...).Err() + if err != nil { + klog.Errorln(err) + return nil, err + } + } + } + + claims = jwt.MapClaims{} + claims["username"] = uid + claims["email"] = email + claims["iat"] = time.Now().Unix() + claims["type"] = "refresh_token" + claims["exp"] = time.Now().Add(tokenIdleTimeout * 5).Unix() + + refreshToken := jwtutil.MustSigned(claims) + + loginLog(uid, ip) + + return &models.AuthGrantResponse{TokenType: "jwt", Token: token, RefreshToken: refreshToken, ExpiresIn: (tokenIdleTimeout * 4).Seconds()}, nil +} + // User login -func Login(username string, password string, ip string) (*models.Token, error) { +func Login(username, password, ip string) (*models.AuthGrantResponse, error) { redisClient, err := clientset.ClientSets().Redis() if err != nil { @@ -322,7 +455,7 @@ func Login(username string, password string, ip string) (*models.Token, error) { claims := jwt.MapClaims{} - // do not set expiration time + // token without expiration time will auto sliding claims["username"] = uid claims["email"] = email claims["iat"] = time.Now().Unix() @@ -356,7 +489,7 @@ func Login(username string, password string, ip string) (*models.Token, error) { loginLog(uid, ip) - return &models.Token{Token: token}, nil + return &models.AuthGrantResponse{Token: token}, nil } func loginLog(uid, ip string) { diff --git a/pkg/models/types.go b/pkg/models/types.go index 3916d6cb79b9d022a2dc7eaaeea7246feadd4a11..9144a20fb730b0116b7d8d324b69f44123885900 100644 --- a/pkg/models/types.go +++ b/pkg/models/types.go @@ -107,8 +107,11 @@ type PodInfo struct { Container string `json:"container" description:"container name"` } -type Token struct { - Token string `json:"access_token" description:"access token"` +type AuthGrantResponse struct { + TokenType string `json:"token_type,omitempty"` + Token string `json:"access_token" description:"access token"` + ExpiresIn float64 `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` } type ResourceQuota struct { diff --git a/pkg/server/config/config.go b/pkg/server/config/config.go index e1a06412577a1bc42ad02ef88634b6a4308aed05..dd22e5b675a47fbfd33fac558b16abdadb1169a6 100644 --- a/pkg/server/config/config.go +++ b/pkg/server/config/config.go @@ -229,7 +229,7 @@ func (c *Config) stripEmptyOptions() { c.MySQLOptions = nil } - if c.RedisOptions != nil && c.RedisOptions.Host == "" { + if c.RedisOptions != nil && c.RedisOptions.RedisURL == "" { c.RedisOptions = nil } diff --git a/pkg/server/config/config_test.go b/pkg/server/config/config_test.go index 611236b780e62dd8a0888090609ae35d25b6682b..0d62b72bb388d5a37b6c9371645e02e642c884c2 100644 --- a/pkg/server/config/config_test.go +++ b/pkg/server/config/config_test.go @@ -63,10 +63,7 @@ func newTestConfig() *Config { GroupSearchBase: "ou=Groups,dc=example,dc=org", }, RedisOptions: &redis.RedisOptions{ - Host: "10.10.111.110", - Port: 6379, - Password: "", - DB: 0, + RedisURL: "redis://:qwerty@localhost:6379/1", }, S3Options: &s2is3.S3Options{ Endpoint: "http://minio.openpitrix-system.svc", diff --git a/pkg/simple/client/factory.go b/pkg/simple/client/factory.go index 3d9f273a9748ecd31198602d56995be7271b2f4d..6fadb2c589f91ee207e861a7cf96d4934b52fc7e 100644 --- a/pkg/simple/client/factory.go +++ b/pkg/simple/client/factory.go @@ -181,7 +181,7 @@ func (cs *ClientSet) MySQL() (*mysql.Database, error) { func (cs *ClientSet) Redis() (*goredis.Client, error) { var err error - if cs.csoptions.redisOptions == nil || cs.csoptions.redisOptions.Host == "" { + if cs.csoptions.redisOptions == nil || cs.csoptions.redisOptions.RedisURL == "" { return nil, ClientSetNotEnabledError{} } diff --git a/pkg/simple/client/redis/options.go b/pkg/simple/client/redis/options.go index a89b5bac9e3b8a3008869ef1bdbbd75d7e018a64..a3e9d25debe4eefba1965bfefa62228284f0fac3 100644 --- a/pkg/simple/client/redis/options.go +++ b/pkg/simple/client/redis/options.go @@ -1,27 +1,20 @@ package redis import ( - "fmt" + "github.com/go-redis/redis" "github.com/spf13/pflag" - "kubesphere.io/kubesphere/pkg/utils/net" "kubesphere.io/kubesphere/pkg/utils/reflectutils" ) type RedisOptions struct { - Host string - Port int - Password string - DB int + RedisURL string } // NewRedisOptions returns options points to nowhere, // because redis is not required for some components func NewRedisOptions() *RedisOptions { return &RedisOptions{ - Host: "", - Port: 6379, - Password: "", - DB: 0, + RedisURL: "", } } @@ -29,14 +22,10 @@ func NewRedisOptions() *RedisOptions { func (r *RedisOptions) Validate() []error { errors := make([]error, 0) - if r.Host != "" { - if !net.IsValidPort(r.Port) { - errors = append(errors, fmt.Errorf("--redis-port is out of range")) - } - } + _, err := redis.ParseURL(r.RedisURL) - if r.DB < 0 { - errors = append(errors, fmt.Errorf("--redis-db is less than 0")) + if err != nil { + errors = append(errors, err) } return errors @@ -44,7 +33,7 @@ func (r *RedisOptions) Validate() []error { // ApplyTo apply to another options if it's a enabled option(non empty host) func (r *RedisOptions) ApplyTo(options *RedisOptions) { - if r.Host != "" { + if r.RedisURL != "" { reflectutils.Override(options, r) } } @@ -52,16 +41,6 @@ func (r *RedisOptions) ApplyTo(options *RedisOptions) { // AddFlags add option flags to command line flags, // if redis-host left empty, the following options will be ignored. func (r *RedisOptions) AddFlags(fs *pflag.FlagSet) { - fs.StringVar(&r.Host, "redis-host", r.Host, ""+ - "Redis service host address. If left blank, means redis is unnecessary, "+ - "redis will be disabled") - - fs.IntVar(&r.Port, "redis-port", r.Port, ""+ - "Redis service port number.") - - fs.StringVar(&r.Password, "redis-password", r.Password, ""+ - "Redis service password if necessary, default to empty") - - fs.IntVar(&r.DB, "redis-db", r.DB, ""+ - "Redis service database index, default to 0.") + fs.StringVar(&r.RedisURL, "redis-url", "", "Redis connection URL. If left blank, means redis is unnecessary, "+ + "redis will be disabled. e.g. redis://:password@host:port/db") } diff --git a/pkg/simple/client/redis/redis.go b/pkg/simple/client/redis/redis.go index aafade8d500d14c6986ae8e552ce6ad48de6a783..4a1fb83b33d944eab1d1062074feeeea05fb1800 100644 --- a/pkg/simple/client/redis/redis.go +++ b/pkg/simple/client/redis/redis.go @@ -18,7 +18,6 @@ package redis import ( - "fmt" "github.com/go-redis/redis" "k8s.io/klog" ) @@ -39,11 +38,14 @@ func NewRedisClientOrDie(options *RedisOptions, stopCh <-chan struct{}) *RedisCl func NewRedisClient(option *RedisOptions, stopCh <-chan struct{}) (*RedisClient, error) { var r RedisClient - r.client = redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%d", option.Host, option.Port), - Password: option.Password, - DB: option.DB, - }) + options, err := redis.ParseURL(option.RedisURL) + + if err != nil { + klog.Error(err) + return nil, err + } + + r.client = redis.NewClient(options) if err := r.client.Ping().Err(); err != nil { klog.Error("unable to reach redis host", err) @@ -51,12 +53,14 @@ func NewRedisClient(option *RedisOptions, stopCh <-chan struct{}) (*RedisClient, return nil, err } - go func() { - <-stopCh - if err := r.client.Close(); err != nil { - klog.Error(err) - } - }() + if stopCh != nil { + go func() { + <-stopCh + if err := r.client.Close(); err != nil { + klog.Error(err) + } + }() + } return &r, nil }