diff --git a/etc/rdb.yml b/etc/rdb.yml index e2b31aba57349d218d671198882d72b343e957fb..5c1de106cc16a7a79c54e24dad29a7f5b7dee63c 100644 --- a/etc/rdb.yml +++ b/etc/rdb.yml @@ -92,3 +92,5 @@ wechat: corp_id: "xxxxxxxxxxxxx" agent_id: 1000000 secret: "xxxxxxxxxxxxxxxxx" + +captcha: false diff --git a/src/models/login_code.go b/src/models/login_code.go index 6e911d5e3d57edfa35f277b59ffcd05d4dfda2f3..0ad2413ac570c8e9c50c5cc0ea0b8f276458bbe6 100644 --- a/src/models/login_code.go +++ b/src/models/login_code.go @@ -1,5 +1,7 @@ package models +import "errors" + type LoginCode struct { Username string `json:"username"` Code string `json:"code"` @@ -7,6 +9,10 @@ type LoginCode struct { CreatedAt int64 `json:"created_at"` } +var ( + errLoginCode = errors.New("invalid login code") +) + func LoginCodeGet(where string, args ...interface{}) (*LoginCode, error) { var obj LoginCode has, err := DB["rdb"].Where(where, args...).Get(&obj) @@ -15,7 +21,7 @@ func LoginCodeGet(where string, args ...interface{}) (*LoginCode, error) { } if !has { - return nil, nil + return nil, errLoginCode } return &obj, nil diff --git a/src/models/user.go b/src/models/user.go index 5d2e568eabd2f9869c7e4fc0cc11f2f533e5b0d2..3d9cdb8938ed11155b90ab932d9171fc65b0a24d 100644 --- a/src/models/user.go +++ b/src/models/user.go @@ -18,6 +18,15 @@ import ( "github.com/didi/nightingale/src/modules/rdb/config" ) +const ( + LOGIN_T_SMS = "sms-code" + LOGIN_T_EMAIL = "email-code" + LOGIN_T_RST = "rst-code" + LOGIN_T_PWD = "password" + LOGIN_T_LDAP = "ldap" + LOGIN_EXPIRES_IN = 300 +) + type User struct { Id int64 `json:"id"` UUID string `json:"-" xorm:"'uuid'"` @@ -82,18 +91,16 @@ func InitRooter() { log.Println("user root init done") } -func LdapLogin(user, pass, clientIP string) error { +func LdapLogin(user, pass string) (*User, error) { sr, err := ldapReq(user, pass) if err != nil { - return err + return nil, err } - go LoginLogNew(user, clientIP, "in") - var u User has, err := DB["rdb"].Where("username=?", user).Get(&u) if err != nil { - return err + return nil, err } u.CopyLdapAttr(sr) @@ -101,9 +108,9 @@ func LdapLogin(user, pass, clientIP string) error { if has { if config.Config.LDAP.CoverAttributes { _, err := DB["rdb"].Where("id=?", u.Id).Update(u) - return err + return nil, err } else { - return nil + return &u, err } } @@ -111,34 +118,76 @@ func LdapLogin(user, pass, clientIP string) error { u.Password = "******" u.UUID = GenUUIDForUser(user) _, err = DB["rdb"].Insert(u) - return err + return &u, nil } -func PassLogin(user, pass, clientIP string) error { +func PassLogin(user, pass string) (*User, error) { var u User - has, err := DB["rdb"].Where("username=?", user).Cols("password").Get(&u) + has, err := DB["rdb"].Where("username=?", user).Get(&u) if err != nil { - return err + return nil, err } if !has { logger.Infof("password auth fail, no such user: %s", user) - return fmt.Errorf("login fail, check your username and password") + return nil, fmt.Errorf("login fail, check your username and password") } loginPass, err := CryptoPass(pass) if err != nil { - return err + return nil, err } if loginPass != u.Password { logger.Infof("password auth fail, password error, user: %s", user) - return fmt.Errorf("login fail, check your username and password") + return nil, fmt.Errorf("login fail, check your username and password") } - go LoginLogNew(user, clientIP, "in") + return &u, nil +} - return nil +func SmsCodeLogin(phone, code string) (*User, error) { + user, _ := UserGet("phone=?", phone) + if user == nil { + return nil, fmt.Errorf("phone %s dose not exist", phone) + } + + lc, err := LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_SMS) + if err != nil { + logger.Infof("sms-code auth fail, user: %s", user.Username) + return nil, fmt.Errorf("login fail, check your sms-code") + } + + if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN { + logger.Infof("sms-code auth expired, user: %s", user.Username) + return nil, fmt.Errorf("login fail, the code has expired") + } + + lc.Del() + + return user, nil +} + +func EmailCodeLogin(email, code string) (*User, error) { + user, _ := UserGet("email=?", email) + if user == nil { + return nil, fmt.Errorf("email %s dose not exist", email) + } + + lc, err := LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_EMAIL) + if err != nil { + logger.Infof("email-code auth fail, user: %s", user.Username) + return nil, fmt.Errorf("login fail, check your email-code") + } + + if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN { + logger.Infof("email-code auth expired, user: %s", user.Username) + return nil, fmt.Errorf("login fail, the code has expired") + } + + lc.Del() + + return user, nil } func UserGet(where string, args ...interface{}) (*User, error) { diff --git a/src/modules/rdb/http/router.go b/src/modules/rdb/http/router.go index df6b5b472af15f8dc06208581e975b399bd72efe..dc18670d8e0f89505841aede7f10cb070fa70a9a 100644 --- a/src/modules/rdb/http/router.go +++ b/src/modules/rdb/http/router.go @@ -22,6 +22,8 @@ func Config(r *gin.Engine) { notLogin.GET("/auth/v2/callback", authCallbackV2) notLogin.GET("/auth/v2/logout", logoutV2) + notLogin.POST("/auth/send-login-code-by-sms", v1SendLoginCodeBySms) + notLogin.POST("/auth/send-login-code-by-email", v1SendLoginCodeByEmail) notLogin.POST("/auth/send-rst-code-by-sms", sendRstCodeBySms) notLogin.POST("/auth/rst-password", rstPassword) notLogin.GET("/auth/captcha", captchaGet) diff --git a/src/modules/rdb/http/router_auth.go b/src/modules/rdb/http/router_auth.go index b7298bc2953bd684d8b9854ecabb666ff7af9e42..f0a3d0fb46fd8840c1df446a1fe7421b0013f0a5 100644 --- a/src/modules/rdb/http/router_auth.go +++ b/src/modules/rdb/http/router_auth.go @@ -27,6 +27,7 @@ var ( loginCodeSmsTpl *template.Template loginCodeEmailTpl *template.Template errUnsupportCaptcha = errors.New("unsupported captcha") + errInvalidAnswer = errors.New("Invalid captcha answer") // https://captcha.mojotv.cn captchaDirver = base64Captcha.DriverString{ @@ -39,55 +40,60 @@ var ( } ) +func getConfigFile(name, ext string) (string, error) { + if p := path.Join(path.Join(file.SelfDir(), "etc", name+".local."+ext)); file.IsExist(p) { + return p, nil + } + if p := path.Join(path.Join(file.SelfDir(), "etc", name+"."+ext)); file.IsExist(p) { + return p, nil + } else { + return "", fmt.Errorf("file %s not found", p) + } + +} + func init() { - var err error - filename := path.Join(file.SelfDir(), "etc", "login-code-sms.tpl") - loginCodeSmsTpl, err = template.ParseFiles(filename) + filename, err := getConfigFile("login-code-sms", "tpl") if err != nil { - log.Fatalf("open %s err: %s", filename, err) + log.Fatal(err) } - filename = path.Join(file.SelfDir(), "etc", "login-code-email.tpl") - loginCodeEmailTpl, err = template.ParseFiles(filename) + loginCodeSmsTpl, err = template.ParseFiles(filename) if err != nil { log.Fatalf("open %s err: %s", filename, err) } -} - -type loginForm struct { - Username string `json:"username" binding:"required"` - Password string `json:"password" binding:"required"` - IsLDAP int `json:"is_ldap"` - RemoteAddr string `json:"remote_addr"` -} -func (f *loginForm) validate() { - if str.Dangerous(f.Username) { - bomb("%s invalid", f.Username) + filename, err = getConfigFile("login-code-email", "tpl") + if err != nil { + log.Fatal(err) } - - if len(f.Username) > 64 { - bomb("%s too long", f.Username) + loginCodeEmailTpl, err = template.ParseFiles(filename) + if err != nil { + log.Fatalf("open %s err: %s", filename, err) } } func login(c *gin.Context) { - var f loginForm + var f loginInput bind(c, &f) f.validate() - if f.IsLDAP == 1 { - dangerous(models.LdapLogin(f.Username, f.Password, c.ClientIP())) - } else { - dangerous(models.PassLogin(f.Username, f.Password, c.ClientIP())) + if config.Config.Captcha { + c, err := models.CaptchaGet("captcha_id=?", f.CaptchaId) + dangerous(err) + if strings.ToLower(c.Answer) != strings.ToLower(f.Answer) { + dangerous(errInvalidAnswer) + } } - user, err := models.UserGet("username=?", f.Username) + user, err := authLogin(f) dangerous(err) writeCookieUser(c, user.UUID) renderMessage(c, "") + + go models.LoginLogNew(user.Username, c.ClientIP(), "in") } func logout(c *gin.Context) { @@ -105,14 +111,14 @@ func logout(c *gin.Context) { writeCookieUser(c, "") - go models.LoginLogNew(username, c.ClientIP(), "out") - if config.Config.SSO.Enable { redirect := queryStr(c, "redirect", "/") c.Redirect(302, ssoc.LogoutLocation(redirect)) } else { c.String(200, "logout successfully") } + + go models.LoginLogNew(username, c.ClientIP(), "out") } type authRedirect struct { @@ -181,15 +187,16 @@ func logoutV2(c *gin.Context) { writeCookieUser(c, "") ret.Msg = "logout successfully" - go models.LoginLogNew(username, c.ClientIP(), "out") - if config.Config.SSO.Enable { if redirect == "" { redirect = "/" } ret.Redirect = ssoc.LogoutLocation(redirect) } + renderData(c, ret, nil) + + go models.LoginLogNew(username, c.ClientIP(), "out") } type loginInput struct { @@ -198,51 +205,53 @@ type loginInput struct { Phone string `json:"phone"` Email string `json:"email"` Code string `json:"code"` - Type string `json:"type"` - RemoteAddr string `json:"remote_addr"` + CaptchaId string `json:"captcha_id"` + Answer string `json:"answer" description:"captcha answer"` + Type string `json:"type" description:"sms-code|email-code|password|ldap"` + RemoteAddr string `json:"remote_addr" description:"use for server account(v1)"` + IsLDAP int `json:"is_ldap" description:"deprecated"` } -const ( - LOGIN_T_SMS = "sms-code" - LOGIN_T_EMAIL = "email-code" - LOGIN_T_RST = "rst-code" - LOGIN_T_PWD = "password" - LOGIN_T_LDAP = "ldap" - LOGIN_EXPIRES_IN = 300 -) +func (f *loginInput) validate() { + if f.IsLDAP == 1 { + f.Type = models.LOGIN_T_LDAP + } + if f.Type == "" { + f.Type = models.LOGIN_T_PWD + } + if f.Type == models.LOGIN_T_PWD { + if str.Dangerous(f.Username) { + bomb("%s invalid", f.Username) + } + if len(f.Username) > 64 { + bomb("%s too long", f.Username) + } + } +} // v1Login called by sso.rdb module func v1Login(c *gin.Context) { var f loginInput bind(c, &f) - user, err := func() (*models.User, error) { - switch strings.ToLower(f.Type) { - case LOGIN_T_LDAP: - err := models.LdapLogin(f.Username, f.Password, c.ClientIP()) - if err != nil { - return nil, err - } - return models.UserGet("username=?", f.Username) - case LOGIN_T_PWD: - err := models.PassLogin(f.Username, f.Password, c.ClientIP()) - if err != nil { - return nil, err - } - return models.UserGet("username=?", f.Username) - case LOGIN_T_SMS: - return smsCodeVerify(f.Phone, f.Code) - case LOGIN_T_EMAIL: - return emailCodeVerify(f.Email, f.Code) - default: - return nil, fmt.Errorf("invalid login type %s", f.Type) - } - }() - - // TODO: implement remote address access control - go models.LoginLogNew(f.Username, f.RemoteAddr, "in") + user, err := authLogin(f) + renderData(c, *user, err) +} - renderData(c, user, err) +// authLogin called by /v1/rdb/login, /api/rdb/auth/login +func authLogin(in loginInput) (user *models.User, err error) { + switch strings.ToLower(in.Type) { + case models.LOGIN_T_LDAP: + return models.LdapLogin(in.Username, in.Password) + case models.LOGIN_T_PWD: + return models.PassLogin(in.Username, in.Password) + case models.LOGIN_T_SMS: + return models.SmsCodeLogin(in.Phone, in.Code) + case models.LOGIN_T_EMAIL: + return models.EmailCodeLogin(in.Email, in.Code) + default: + return nil, fmt.Errorf("invalid login type %s", in.Type) + } } type v1SendLoginCodeBySmsInput struct { @@ -269,7 +278,7 @@ func v1SendLoginCodeBySms(c *gin.Context) { loginCode := &models.LoginCode{ Username: user.Username, Code: code, - LoginType: LOGIN_T_SMS, + LoginType: models.LOGIN_T_SMS, CreatedAt: time.Now().Unix(), } @@ -298,26 +307,6 @@ func v1SendLoginCodeBySms(c *gin.Context) { renderData(c, msg, err) } -func smsCodeVerify(phone, code string) (*models.User, error) { - user, _ := models.UserGet("phone=?", phone) - if user == nil { - return nil, fmt.Errorf("phone %s dose not exist", phone) - } - - lc, err := models.LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_SMS) - if err != nil { - return nil, fmt.Errorf("invalid code", phone) - } - - if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN { - return nil, fmt.Errorf("the code has expired", phone) - } - - lc.Del() - - return user, nil -} - type v1SendLoginCodeByEmailInput struct { Email string `json:"email"` } @@ -342,7 +331,7 @@ func v1SendLoginCodeByEmail(c *gin.Context) { loginCode := &models.LoginCode{ Username: user.Username, Code: code, - LoginType: LOGIN_T_EMAIL, + LoginType: models.LOGIN_T_EMAIL, CreatedAt: time.Now().Unix(), } @@ -369,26 +358,6 @@ func v1SendLoginCodeByEmail(c *gin.Context) { renderData(c, msg, err) } -func emailCodeVerify(email, code string) (*models.User, error) { - user, _ := models.UserGet("email=?", email) - if user == nil { - return nil, fmt.Errorf("email %s dose not exist", email) - } - - lc, err := models.LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_EMAIL) - if err != nil { - return nil, fmt.Errorf("invalid code", email) - } - - if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN { - return nil, fmt.Errorf("the code has expired", email) - } - - lc.Del() - - return user, nil -} - type sendRstCodeBySmsInput struct { Phone string `json:"phone"` } @@ -413,7 +382,7 @@ func sendRstCodeBySms(c *gin.Context) { loginCode := &models.LoginCode{ Username: user.Username, Code: code, - LoginType: LOGIN_T_RST, + LoginType: models.LOGIN_T_RST, CreatedAt: time.Now().Unix(), } @@ -459,12 +428,12 @@ func rstPassword(c *gin.Context) { } lc, err := models.LoginCodeGet("username=? and code=? and login_type=?", - user.Username, in.Code, LOGIN_T_RST) + user.Username, in.Code, models.LOGIN_T_RST) if err != nil { return fmt.Errorf("invalid code", in.Phone) } - if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN { + if time.Now().Unix()-lc.CreatedAt > models.LOGIN_EXPIRES_IN { return fmt.Errorf("the code has expired", in.Phone) }