未验证 提交 8ef8fd26 编写于 作者: M Ming Deng 提交者: GitHub

Merge pull request #4036 from astaxie/develop

V1.12.2
...@@ -36,8 +36,7 @@ install: ...@@ -36,8 +36,7 @@ install:
- go get github.com/beego/goyaml2 - go get github.com/beego/goyaml2
- go get gopkg.in/yaml.v2 - go get gopkg.in/yaml.v2
- go get github.com/belogik/goes - go get github.com/belogik/goes
- go get github.com/siddontang/ledisdb/config - go get github.com/ledisdb/ledisdb
- go get github.com/siddontang/ledisdb/ledis
- go get github.com/ssdb/gossdb/ssdb - go get github.com/ssdb/gossdb/ssdb
- go get github.com/cloudflare/golz4 - go get github.com/cloudflare/golz4
- go get github.com/gogo/protobuf/proto - go get github.com/gogo/protobuf/proto
...@@ -49,7 +48,7 @@ install: ...@@ -49,7 +48,7 @@ install:
- go get -u honnef.co/go/tools/cmd/staticcheck - go get -u honnef.co/go/tools/cmd/staticcheck
- go get -u github.com/mdempsky/unconvert - go get -u github.com/mdempsky/unconvert
- go get -u github.com/gordonklaus/ineffassign - go get -u github.com/gordonklaus/ineffassign
- go get -u github.com/golang/lint/golint - go get -u golang.org/x/lint/golint
- go get -u github.com/go-redis/redis - go get -u github.com/go-redis/redis
before_script: before_script:
- psql --version - psql --version
......
...@@ -33,6 +33,8 @@ Congratulations! You've just built your first **beego** app. ...@@ -33,6 +33,8 @@ Congratulations! You've just built your first **beego** app.
###### Please see [Documentation](http://beego.me/docs) for more. ###### Please see [Documentation](http://beego.me/docs) for more.
###### [beego-example](https://github.com/beego-dev/beego-example)
## Features ## Features
* RESTful support * RESTful support
......
...@@ -24,6 +24,8 @@ import ( ...@@ -24,6 +24,8 @@ import (
"text/template" "text/template"
"time" "time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
...@@ -55,12 +57,14 @@ func init() { ...@@ -55,12 +57,14 @@ func init() {
beeAdminApp = &adminApp{ beeAdminApp = &adminApp{
routers: make(map[string]http.HandlerFunc), routers: make(map[string]http.HandlerFunc),
} }
// keep in mind that all data should be html escaped to avoid XSS attack
beeAdminApp.Route("/", adminIndex) beeAdminApp.Route("/", adminIndex)
beeAdminApp.Route("/qps", qpsIndex) beeAdminApp.Route("/qps", qpsIndex)
beeAdminApp.Route("/prof", profIndex) beeAdminApp.Route("/prof", profIndex)
beeAdminApp.Route("/healthcheck", healthcheck) beeAdminApp.Route("/healthcheck", healthcheck)
beeAdminApp.Route("/task", taskStatus) beeAdminApp.Route("/task", taskStatus)
beeAdminApp.Route("/listconf", listConf) beeAdminApp.Route("/listconf", listConf)
beeAdminApp.Route("/metrics", promhttp.Handler().ServeHTTP)
FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true }
} }
...@@ -105,8 +109,8 @@ func listConf(rw http.ResponseWriter, r *http.Request) { ...@@ -105,8 +109,8 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
case "conf": case "conf":
m := make(M) m := make(M)
list("BConfig", BConfig, m) list("BConfig", BConfig, m)
m["AppConfigPath"] = appConfigPath m["AppConfigPath"] = template.HTMLEscapeString(appConfigPath)
m["AppConfigProvider"] = appConfigProvider m["AppConfigProvider"] = template.HTMLEscapeString(appConfigProvider)
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(configTpl)) tmpl = template.Must(tmpl.Parse(configTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
...@@ -151,8 +155,9 @@ func listConf(rw http.ResponseWriter, r *http.Request) { ...@@ -151,8 +155,9 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
resultList := new([][]string) resultList := new([][]string)
for _, f := range bf { for _, f := range bf {
var result = []string{ var result = []string{
f.pattern, // void xss
utils.GetFuncName(f.filterFunc), template.HTMLEscapeString(f.pattern),
template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)),
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} }
...@@ -207,8 +212,8 @@ func PrintTree() M { ...@@ -207,8 +212,8 @@ func PrintTree() M {
printTree(resultList, t) printTree(resultList, t)
methods = append(methods, method) methods = append(methods, template.HTMLEscapeString(method))
methodsData[method] = resultList methodsData[template.HTMLEscapeString(method)] = resultList
} }
content["Data"] = methodsData content["Data"] = methodsData
...@@ -227,21 +232,21 @@ func printTree(resultList *[][]string, t *Tree) { ...@@ -227,21 +232,21 @@ func printTree(resultList *[][]string, t *Tree) {
if v, ok := l.runObject.(*ControllerInfo); ok { if v, ok := l.runObject.(*ControllerInfo); ok {
if v.routerType == routerTypeBeego { if v.routerType == routerTypeBeego {
var result = []string{ var result = []string{
v.pattern, template.HTMLEscapeString(v.pattern),
fmt.Sprintf("%s", v.methods), template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)),
v.controllerType.String(), template.HTMLEscapeString(v.controllerType.String()),
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} else if v.routerType == routerTypeRESTFul { } else if v.routerType == routerTypeRESTFul {
var result = []string{ var result = []string{
v.pattern, template.HTMLEscapeString(v.pattern),
fmt.Sprintf("%s", v.methods), template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)),
"", "",
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} else if v.routerType == routerTypeHandler { } else if v.routerType == routerTypeHandler {
var result = []string{ var result = []string{
v.pattern, template.HTMLEscapeString(v.pattern),
"", "",
"", "",
} }
...@@ -266,7 +271,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { ...@@ -266,7 +271,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) {
result bytes.Buffer result bytes.Buffer
) )
toolbox.ProcessInput(command, &result) toolbox.ProcessInput(command, &result)
data["Content"] = result.String() data["Content"] = template.HTMLEscapeString(result.String())
if format == "json" && command == "gc summary" { if format == "json" && command == "gc summary" {
dataJSON, err := json.Marshal(data) dataJSON, err := json.Marshal(data)
...@@ -280,7 +285,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { ...@@ -280,7 +285,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) {
return return
} }
data["Title"] = command data["Title"] = template.HTMLEscapeString(command)
defaultTpl := defaultScriptsTpl defaultTpl := defaultScriptsTpl
if command == "gc summary" { if command == "gc summary" {
defaultTpl = gcAjaxTpl defaultTpl = gcAjaxTpl
...@@ -304,13 +309,13 @@ func healthcheck(rw http.ResponseWriter, _ *http.Request) { ...@@ -304,13 +309,13 @@ func healthcheck(rw http.ResponseWriter, _ *http.Request) {
if err := h.Check(); err != nil { if err := h.Check(); err != nil {
result = []string{ result = []string{
"error", "error",
name, template.HTMLEscapeString(name),
err.Error(), template.HTMLEscapeString(err.Error()),
} }
} else { } else {
result = []string{ result = []string{
"success", "success",
name, template.HTMLEscapeString(name),
"OK", "OK",
} }
} }
...@@ -334,11 +339,11 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { ...@@ -334,11 +339,11 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
if taskname != "" { if taskname != "" {
if t, ok := toolbox.AdminTaskList[taskname]; ok { if t, ok := toolbox.AdminTaskList[taskname]; ok {
if err := t.Run(); err != nil { if err := t.Run(); err != nil {
data["Message"] = []string{"error", fmt.Sprintf("%s", err)} data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))}
} }
data["Message"] = []string{"success", fmt.Sprintf("%s run success,Now the Status is <br>%s", taskname, t.GetStatus())} data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is <br>%s", taskname, t.GetStatus()))}
} else { } else {
data["Message"] = []string{"warning", fmt.Sprintf("there's no task which named: %s", taskname)} data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))}
} }
} }
...@@ -354,10 +359,10 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { ...@@ -354,10 +359,10 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
} }
for tname, tk := range toolbox.AdminTaskList { for tname, tk := range toolbox.AdminTaskList {
result := []string{ result := []string{
tname, template.HTMLEscapeString(tname),
tk.GetSpec(), template.HTMLEscapeString(tk.GetSpec()),
tk.GetStatus(), template.HTMLEscapeString(tk.GetStatus()),
tk.GetPrev().String(), template.HTMLEscapeString(tk.GetPrev().String()),
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} }
......
...@@ -52,6 +52,8 @@ func oldMap() M { ...@@ -52,6 +52,8 @@ func oldMap() M {
m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize
m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum
m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
......
...@@ -123,14 +123,13 @@ func (app *App) Run(mws ...MiddleWare) { ...@@ -123,14 +123,13 @@ func (app *App) Run(mws ...MiddleWare) {
httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
app.Server.Addr = httpsAddr app.Server.Addr = httpsAddr
} }
server := grace.NewServer(httpsAddr, app.Handlers) server := grace.NewServer(httpsAddr, app.Server.Handler)
server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout server.Server.WriteTimeout = app.Server.WriteTimeout
if BConfig.Listen.EnableMutualHTTPS { if BConfig.Listen.EnableMutualHTTPS {
if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true
} }
} else { } else {
if BConfig.Listen.AutoTLS { if BConfig.Listen.AutoTLS {
...@@ -145,14 +144,14 @@ func (app *App) Run(mws ...MiddleWare) { ...@@ -145,14 +144,14 @@ func (app *App) Run(mws ...MiddleWare) {
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true
} }
} }
endRunning <- true
}() }()
} }
if BConfig.Listen.EnableHTTP { if BConfig.Listen.EnableHTTP {
go func() { go func() {
server := grace.NewServer(addr, app.Handlers) server := grace.NewServer(addr, app.Server.Handler)
server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout server.Server.WriteTimeout = app.Server.WriteTimeout
if BConfig.Listen.ListenTCP4 { if BConfig.Listen.ListenTCP4 {
...@@ -161,8 +160,8 @@ func (app *App) Run(mws ...MiddleWare) { ...@@ -161,8 +160,8 @@ func (app *App) Run(mws ...MiddleWare) {
if err := server.ListenAndServe(); err != nil { if err := server.ListenAndServe(); err != nil {
logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true
} }
endRunning <- true
}() }()
} }
<-endRunning <-endRunning
......
...@@ -23,7 +23,7 @@ import ( ...@@ -23,7 +23,7 @@ import (
const ( const (
// VERSION represent beego web framework version. // VERSION represent beego web framework version.
VERSION = "1.12.1" VERSION = "1.12.2"
// DEV is for develop // DEV is for develop
DEV = "dev" DEV = "dev"
......
// Copyright 2020 astaxie
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beego
var (
BuildVersion string
BuildGitRevision string
BuildStatus string
BuildTag string
BuildTime string
GoVersion string
GitBranch string
)
...@@ -218,9 +218,12 @@ func (bc *MemoryCache) vacuum() { ...@@ -218,9 +218,12 @@ func (bc *MemoryCache) vacuum() {
} }
for { for {
<-time.After(bc.dur) <-time.After(bc.dur)
bc.RLock()
if bc.items == nil { if bc.items == nil {
bc.RUnlock()
return return
} }
bc.RUnlock()
if keys := bc.expiredKeys(); len(keys) != 0 { if keys := bc.expiredKeys(); len(keys) != 0 {
bc.clearItems(keys) bc.clearItems(keys)
} }
......
...@@ -55,6 +55,9 @@ type Cache struct { ...@@ -55,6 +55,9 @@ type Cache struct {
key string key string
password string password string
maxIdle int maxIdle int
//the timeout to a value less than the redis server's timeout.
timeout time.Duration
} }
// NewRedisCache create new redis cache with default collection name. // NewRedisCache create new redis cache with default collection name.
...@@ -137,12 +140,12 @@ func (rc *Cache) Decr(key string) error { ...@@ -137,12 +140,12 @@ func (rc *Cache) Decr(key string) error {
// ClearAll clean all cache in redis. delete this redis collection. // ClearAll clean all cache in redis. delete this redis collection.
func (rc *Cache) ClearAll() error { func (rc *Cache) ClearAll() error {
c := rc.p.Get() cachedKeys, err := rc.Scan(rc.key + ":*")
defer c.Close()
cachedKeys, err := redis.Strings(c.Do("KEYS", rc.key+":*"))
if err != nil { if err != nil {
return err return err
} }
c := rc.p.Get()
defer c.Close()
for _, str := range cachedKeys { for _, str := range cachedKeys {
if _, err = c.Do("DEL", str); err != nil { if _, err = c.Do("DEL", str); err != nil {
return err return err
...@@ -151,6 +154,35 @@ func (rc *Cache) ClearAll() error { ...@@ -151,6 +154,35 @@ func (rc *Cache) ClearAll() error {
return err return err
} }
// Scan scan all keys matching the pattern. a better choice than `keys`
func (rc *Cache) Scan(pattern string) (keys []string, err error) {
c := rc.p.Get()
defer c.Close()
var (
cursor uint64 = 0 // start
result []interface{}
list []string
)
for {
result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024))
if err != nil {
return
}
list, err = redis.Strings(result[1], nil)
if err != nil {
return
}
keys = append(keys, list...)
cursor, err = redis.Uint64(result[0], nil)
if err != nil {
return
}
if cursor == 0 { // over
return
}
}
}
// StartAndGC start redis cache adapter. // StartAndGC start redis cache adapter.
// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} // config is like {"key":"collection key","conn":"connection info","dbNum":"0"}
// the cache item in redis are stored forever, // the cache item in redis are stored forever,
...@@ -182,12 +214,21 @@ func (rc *Cache) StartAndGC(config string) error { ...@@ -182,12 +214,21 @@ func (rc *Cache) StartAndGC(config string) error {
if _, ok := cf["maxIdle"]; !ok { if _, ok := cf["maxIdle"]; !ok {
cf["maxIdle"] = "3" cf["maxIdle"] = "3"
} }
if _, ok := cf["timeout"]; !ok {
cf["timeout"] = "180s"
}
rc.key = cf["key"] rc.key = cf["key"]
rc.conninfo = cf["conn"] rc.conninfo = cf["conn"]
rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) rc.dbNum, _ = strconv.Atoi(cf["dbNum"])
rc.password = cf["password"] rc.password = cf["password"]
rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"])
if v, err := time.ParseDuration(cf["timeout"]); err == nil {
rc.timeout = v
} else {
rc.timeout = 180 * time.Second
}
rc.connectInit() rc.connectInit()
c := rc.p.Get() c := rc.p.Get()
...@@ -221,7 +262,7 @@ func (rc *Cache) connectInit() { ...@@ -221,7 +262,7 @@ func (rc *Cache) connectInit() {
// initialize a new pool // initialize a new pool
rc.p = &redis.Pool{ rc.p = &redis.Pool{
MaxIdle: rc.maxIdle, MaxIdle: rc.maxIdle,
IdleTimeout: 180 * time.Second, IdleTimeout: rc.timeout,
Dial: dialFunc, Dial: dialFunc,
} }
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package redis package redis
import ( import (
"fmt"
"testing" "testing"
"time" "time"
...@@ -104,3 +105,40 @@ func TestRedisCache(t *testing.T) { ...@@ -104,3 +105,40 @@ func TestRedisCache(t *testing.T) {
t.Error("clear all err") t.Error("clear all err")
} }
} }
func TestCache_Scan(t *testing.T) {
timeoutDuration := 10 * time.Second
// init
bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`)
if err != nil {
t.Error("init err")
}
// insert all
for i := 0; i < 10000; i++ {
if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil {
t.Error("set Error", err)
}
}
// scan all for the first time
keys, err := bm.(*Cache).Scan(DefaultKey + ":*")
if err != nil {
t.Error("scan Error", err)
}
if len(keys) != 10000 {
t.Error("scan all err")
}
// clear all
if err = bm.ClearAll(); err != nil {
t.Error("clear all err")
}
// scan all for the second time
keys, err = bm.(*Cache).Scan(DefaultKey + ":*")
if err != nil {
t.Error("scan Error", err)
}
if len(keys) != 0 {
t.Error("scan all err")
}
}
...@@ -81,6 +81,8 @@ type WebConfig struct { ...@@ -81,6 +81,8 @@ type WebConfig struct {
DirectoryIndex bool DirectoryIndex bool
StaticDir map[string]string StaticDir map[string]string
StaticExtensionsToGzip []string StaticExtensionsToGzip []string
StaticCacheFileSize int
StaticCacheFileNum int
TemplateLeft string TemplateLeft string
TemplateRight string TemplateRight string
ViewsPath string ViewsPath string
...@@ -129,6 +131,8 @@ var ( ...@@ -129,6 +131,8 @@ var (
appConfigPath string appConfigPath string
// appConfigProvider is the provider for the config, default is ini // appConfigProvider is the provider for the config, default is ini
appConfigProvider = "ini" appConfigProvider = "ini"
// WorkPath is the absolute path to project root directory
WorkPath string
) )
func init() { func init() {
...@@ -137,7 +141,7 @@ func init() { ...@@ -137,7 +141,7 @@ func init() {
if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil {
panic(err) panic(err)
} }
workPath, err := os.Getwd() WorkPath, err = os.Getwd()
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -145,7 +149,7 @@ func init() { ...@@ -145,7 +149,7 @@ func init() {
if os.Getenv("BEEGO_RUNMODE") != "" { if os.Getenv("BEEGO_RUNMODE") != "" {
filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf"
} }
appConfigPath = filepath.Join(workPath, "conf", filename) appConfigPath = filepath.Join(WorkPath, "conf", filename)
if !utils.FileExists(appConfigPath) { if !utils.FileExists(appConfigPath) {
appConfigPath = filepath.Join(AppPath, "conf", filename) appConfigPath = filepath.Join(AppPath, "conf", filename)
if !utils.FileExists(appConfigPath) { if !utils.FileExists(appConfigPath) {
...@@ -236,6 +240,8 @@ func newBConfig() *Config { ...@@ -236,6 +240,8 @@ func newBConfig() *Config {
DirectoryIndex: false, DirectoryIndex: false,
StaticDir: map[string]string{"/static": "static"}, StaticDir: map[string]string{"/static": "static"},
StaticExtensionsToGzip: []string{".css", ".js"}, StaticExtensionsToGzip: []string{".css", ".js"},
StaticCacheFileSize: 1024 * 100,
StaticCacheFileNum: 1000,
TemplateLeft: "{{", TemplateLeft: "{{",
TemplateRight: "}}", TemplateRight: "}}",
ViewsPath: "views", ViewsPath: "views",
...@@ -317,6 +323,14 @@ func assignConfig(ac config.Configer) error { ...@@ -317,6 +323,14 @@ func assignConfig(ac config.Configer) error {
} }
} }
if sfs, err := ac.Int("StaticCacheFileSize"); err == nil {
BConfig.WebConfig.StaticCacheFileSize = sfs
}
if sfn, err := ac.Int("StaticCacheFileNum"); err == nil {
BConfig.WebConfig.StaticCacheFileNum = sfn
}
if lo := ac.String("LogOutputs"); lo != "" { if lo := ac.String("LogOutputs"); lo != "" {
// if lo is not nil or empty // if lo is not nil or empty
// means user has set his own LogOutputs // means user has set his own LogOutputs
...@@ -408,9 +422,9 @@ func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, err ...@@ -408,9 +422,9 @@ func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, err
func (b *beegoAppConfig) Set(key, val string) error { func (b *beegoAppConfig) Set(key, val string) error {
if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil {
return err return b.innerConfig.Set(key, val)
} }
return b.innerConfig.Set(key, val) return nil
} }
func (b *beegoAppConfig) String(key string) string { func (b *beegoAppConfig) String(key string) string {
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
) )
...@@ -94,8 +95,10 @@ func (c *JSONConfigContainer) Int(key string) (int, error) { ...@@ -94,8 +95,10 @@ func (c *JSONConfigContainer) Int(key string) (int, error) {
if val != nil { if val != nil {
if v, ok := val.(float64); ok { if v, ok := val.(float64); ok {
return int(v), nil return int(v), nil
} else if v, ok := val.(string); ok {
return strconv.Atoi(v)
} }
return 0, errors.New("not int value") return 0, errors.New("not valid value")
} }
return 0, errors.New("not exist key:" + key) return 0, errors.New("not exist key:" + key)
} }
......
...@@ -115,6 +115,8 @@ func TestAssignConfig_03(t *testing.T) { ...@@ -115,6 +115,8 @@ func TestAssignConfig_03(t *testing.T) {
ac.Set("RunMode", "online") ac.Set("RunMode", "online")
ac.Set("StaticDir", "download:down download2:down2") ac.Set("StaticDir", "download:down download2:down2")
ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png")
ac.Set("StaticCacheFileSize", "87456")
ac.Set("StaticCacheFileNum", "1254")
assignConfig(ac) assignConfig(ac)
t.Logf("%#v", BConfig) t.Logf("%#v", BConfig)
...@@ -132,6 +134,12 @@ func TestAssignConfig_03(t *testing.T) { ...@@ -132,6 +134,12 @@ func TestAssignConfig_03(t *testing.T) {
if BConfig.WebConfig.StaticDir["/download2"] != "down2" { if BConfig.WebConfig.StaticDir["/download2"] != "down2" {
t.FailNow() t.FailNow()
} }
if BConfig.WebConfig.StaticCacheFileSize != 87456 {
t.FailNow()
}
if BConfig.WebConfig.StaticCacheFileNum != 1254 {
t.FailNow()
}
if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 {
t.FailNow() t.FailNow()
} }
......
...@@ -71,7 +71,9 @@ func (input *BeegoInput) Reset(ctx *Context) { ...@@ -71,7 +71,9 @@ func (input *BeegoInput) Reset(ctx *Context) {
input.CruSession = nil input.CruSession = nil
input.pnames = input.pnames[:0] input.pnames = input.pnames[:0]
input.pvalues = input.pvalues[:0] input.pvalues = input.pvalues[:0]
input.dataLock.Lock()
input.data = nil input.data = nil
input.dataLock.Unlock()
input.RequestBody = []byte{} input.RequestBody = []byte{}
} }
...@@ -87,7 +89,7 @@ func (input *BeegoInput) URI() string { ...@@ -87,7 +89,7 @@ func (input *BeegoInput) URI() string {
// URL returns request url path (without query string, fragment). // URL returns request url path (without query string, fragment).
func (input *BeegoInput) URL() string { func (input *BeegoInput) URL() string {
return input.Context.Request.URL.Path return input.Context.Request.URL.EscapedPath()
} }
// Site returns base site url as scheme://domain type. // Site returns base site url as scheme://domain type.
...@@ -282,6 +284,11 @@ func (input *BeegoInput) ParamsLen() int { ...@@ -282,6 +284,11 @@ func (input *BeegoInput) ParamsLen() int {
func (input *BeegoInput) Param(key string) string { func (input *BeegoInput) Param(key string) string {
for i, v := range input.pnames { for i, v := range input.pnames {
if v == key && i <= len(input.pvalues) { if v == key && i <= len(input.pvalues) {
// we cannot use url.PathEscape(input.pvalues[i])
// for example, if the value is /a/b
// after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb
// However, the value is used in ControllerRegister.ServeHTTP
// and split by "/", so function crash...
return input.pvalues[i] return input.pvalues[i]
} }
} }
......
...@@ -2,36 +2,35 @@ module github.com/astaxie/beego ...@@ -2,36 +2,35 @@ module github.com/astaxie/beego
require ( require (
github.com/Knetic/govaluate v3.0.0+incompatible // indirect github.com/Knetic/govaluate v3.0.0+incompatible // indirect
github.com/OwnLocal/goes v1.0.0
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737
github.com/casbin/casbin v1.7.0 github.com/casbin/casbin v1.7.0
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58
github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d
github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c // indirect github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 // indirect
github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect
github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 // indirect github.com/elastic/go-elasticsearch/v6 v6.8.5
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 // indirect
github.com/elazarl/go-bindata-assetfs v1.0.0 github.com/elazarl/go-bindata-assetfs v1.0.0
github.com/go-redis/redis v6.14.2+incompatible github.com/go-redis/redis v6.14.2+incompatible
github.com/go-sql-driver/mysql v1.4.1 github.com/go-sql-driver/mysql v1.5.0
github.com/gogo/protobuf v1.1.1 github.com/gogo/protobuf v1.1.1
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/gomodule/redigo v2.0.0+incompatible github.com/gomodule/redigo v2.0.0+incompatible
github.com/hashicorp/golang-lru v0.5.4
github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6
github.com/lib/pq v1.0.0 github.com/lib/pq v1.0.0
github.com/mattn/go-sqlite3 v1.10.0 github.com/mattn/go-sqlite3 v2.0.3+incompatible
github.com/pelletier/go-toml v1.2.0 // indirect github.com/pelletier/go-toml v1.2.0 // indirect
github.com/pkg/errors v0.8.0 // indirect github.com/prometheus/client_golang v1.7.0
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 // indirect github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644
github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373
github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d // indirect
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec
github.com/stretchr/testify v1.4.0
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect
github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
golang.org/x/tools v0.0.0-20200117065230-39095c1d176c golang.org/x/tools v0.0.0-20200117065230-39095c1d176c
gopkg.in/yaml.v2 v2.2.1 gopkg.in/yaml.v2 v2.2.8
) )
replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85
......
此差异已折叠。
...@@ -46,7 +46,10 @@ func (srv *Server) Serve() (err error) { ...@@ -46,7 +46,10 @@ func (srv *Server) Serve() (err error) {
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
// wait for Shutdown to return // wait for Shutdown to return
return <-srv.terminalChan if shutdownErr := <-srv.terminalChan; shutdownErr != nil {
return shutdownErr
}
return
} }
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
...@@ -180,7 +183,7 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) ...@@ -180,7 +183,7 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
log.Println(err) log.Println(err)
return err return err
} }
err = process.Kill() err = process.Signal(syscall.SIGTERM)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -407,6 +407,7 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) { ...@@ -407,6 +407,7 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) {
}() }()
b.Header("Content-Type", bodyWriter.FormDataContentType()) b.Header("Content-Type", bodyWriter.FormDataContentType())
b.req.Body = ioutil.NopCloser(pr) b.req.Body = ioutil.NopCloser(pr)
b.Header("Transfer-Encoding", "chunked")
return return
} }
......
...@@ -16,6 +16,7 @@ package logs ...@@ -16,6 +16,7 @@ package logs
import ( import (
"testing" "testing"
"time"
) )
// Try each log level in decreasing order of priority. // Try each log level in decreasing order of priority.
...@@ -49,3 +50,15 @@ func TestConsoleNoColor(t *testing.T) { ...@@ -49,3 +50,15 @@ func TestConsoleNoColor(t *testing.T) {
log.SetLogger("console", `{"color":false}`) log.SetLogger("console", `{"color":false}`)
testConsoleCalls(log) testConsoleCalls(log)
} }
// Test console async
func TestConsoleAsync(t *testing.T) {
log := NewLogger(100)
log.SetLogger("console")
log.Async()
//log.Close()
testConsoleCalls(log)
for len(log.msgChan) != 0 {
time.Sleep(1 * time.Millisecond)
}
}
package es package es
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net"
"net/url" "net/url"
"strings"
"time" "time"
"github.com/OwnLocal/goes" "github.com/elastic/go-elasticsearch/v6"
"github.com/elastic/go-elasticsearch/v6/esapi"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
) )
...@@ -20,8 +23,14 @@ func NewES() logs.Logger { ...@@ -20,8 +23,14 @@ func NewES() logs.Logger {
return cw return cw
} }
// esLogger will log msg into ES
// before you using this implementation,
// please import this package
// usually means that you can import this package in your main package
// for example, anonymous:
// import _ "github.com/astaxie/beego/logs/es"
type esLogger struct { type esLogger struct {
*goes.Client *elasticsearch.Client
DSN string `json:"dsn"` DSN string `json:"dsn"`
Level int `json:"level"` Level int `json:"level"`
} }
...@@ -38,10 +47,13 @@ func (el *esLogger) Init(jsonconfig string) error { ...@@ -38,10 +47,13 @@ func (el *esLogger) Init(jsonconfig string) error {
return err return err
} else if u.Path == "" { } else if u.Path == "" {
return errors.New("missing prefix") return errors.New("missing prefix")
} else if host, port, err := net.SplitHostPort(u.Host); err != nil {
return err
} else { } else {
conn := goes.NewClient(host, port) conn, err := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{el.DSN},
})
if err != nil {
return err
}
el.Client = conn el.Client = conn
} }
return nil return nil
...@@ -53,21 +65,26 @@ func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { ...@@ -53,21 +65,26 @@ func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error {
return nil return nil
} }
vals := make(map[string]interface{}) idx := LogDocument{
vals["@timestamp"] = when.Format(time.RFC3339) Timestamp: when.Format(time.RFC3339),
vals["@msg"] = msg Msg: msg,
d := goes.Document{ }
Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()),
Type: "logs", body, err := json.Marshal(idx)
Fields: vals, if err != nil {
return err
} }
_, err := el.Index(d, nil) req := esapi.IndexRequest{
Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()),
DocumentType: "logs",
Body: strings.NewReader(string(body)),
}
_, err = req.Do(context.Background(), el.Client)
return err return err
} }
// Destroy is a empty method // Destroy is a empty method
func (el *esLogger) Destroy() { func (el *esLogger) Destroy() {
} }
// Flush is a empty method // Flush is a empty method
...@@ -75,7 +92,11 @@ func (el *esLogger) Flush() { ...@@ -75,7 +92,11 @@ func (el *esLogger) Flush() {
} }
type LogDocument struct {
Timestamp string `json:"timestamp"`
Msg string `json:"msg"`
}
func init() { func init() {
logs.Register(logs.AdapterEs, NewES) logs.Register(logs.AdapterEs, NewES)
} }
...@@ -359,6 +359,10 @@ RESTART_LOGGER: ...@@ -359,6 +359,10 @@ RESTART_LOGGER:
func (w *fileLogWriter) deleteOldLog() { func (w *fileLogWriter) deleteOldLog() {
dir := filepath.Dir(w.Filename) dir := filepath.Dir(w.Filename)
absolutePath, err := filepath.EvalSymlinks(w.Filename)
if err == nil {
dir = filepath.Dir(absolutePath)
}
filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
......
...@@ -295,7 +295,11 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error ...@@ -295,7 +295,11 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
lm.level = logLevel lm.level = logLevel
lm.msg = msg lm.msg = msg
lm.when = when lm.when = when
bl.msgChan <- lm if bl.outputs != nil {
bl.msgChan <- lm
} else {
logMsgPool.Put(lm)
}
} else { } else {
bl.writeToLoggers(when, msg, logLevel) bl.writeToLoggers(when, msg, logLevel)
} }
......
// Copyright 2020 astaxie
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metric
import (
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/astaxie/beego"
"github.com/astaxie/beego/logs"
)
func PrometheusMiddleWare(next http.Handler) http.Handler {
summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "beego",
Subsystem: "http_request",
ConstLabels: map[string]string{
"server": beego.BConfig.ServerName,
"env": beego.BConfig.RunMode,
"appname": beego.BConfig.AppName,
},
Help: "The statics info for http request",
}, []string{"pattern", "method", "status", "duration"})
prometheus.MustRegister(summaryVec)
registerBuildInfo()
return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) {
start := time.Now()
next.ServeHTTP(writer, q)
end := time.Now()
go report(end.Sub(start), writer, q, summaryVec)
})
}
func registerBuildInfo() {
buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "beego",
Subsystem: "build_info",
Help: "The building information",
ConstLabels: map[string]string{
"appname": beego.BConfig.AppName,
"build_version": beego.BuildVersion,
"build_revision": beego.BuildGitRevision,
"build_status": beego.BuildStatus,
"build_tag": beego.BuildTag,
"build_time": strings.Replace(beego.BuildTime, "--", " ", 1),
"go_version": beego.GoVersion,
"git_branch": beego.GitBranch,
"start_time": time.Now().Format("2006-01-02 15:04:05"),
},
}, []string{})
prometheus.MustRegister(buildInfo)
buildInfo.WithLabelValues().Set(1)
}
func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) {
ctrl := beego.BeeApp.Handlers
ctx := ctrl.GetContext()
ctx.Reset(writer, q)
defer ctrl.GiveBackContext(ctx)
// We cannot read the status code from q.Response.StatusCode
// since the http server does not set q.Response. So q.Response is nil
// Thus, we use reflection to read the status from writer whose concrete type is http.response
responseVal := reflect.ValueOf(writer).Elem()
field := responseVal.FieldByName("status")
status := -1
if field.IsValid() && field.Kind() == reflect.Int {
status = int(field.Int())
}
ptn := "UNKNOWN"
if rt, found := ctrl.FindRouter(ctx); found {
ptn = rt.GetPattern()
} else {
logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String())
}
ms := dur / time.Millisecond
vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms))
}
// Copyright 2020 astaxie
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metric
import (
"net/http"
"net/url"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/astaxie/beego/context"
)
func TestPrometheusMiddleWare(t *testing.T) {
middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
writer := &context.Response{}
request := &http.Request{
URL: &url.URL{
Host: "localhost",
RawPath: "/a/b/c",
},
Method: "POST",
}
vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"})
report(time.Second, writer, request, vec)
middleware.ServeHTTP(writer, request)
}
...@@ -470,7 +470,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s ...@@ -470,7 +470,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
multi := len(values) / len(names) multi := len(values) / len(names)
if isMulti { if isMulti && multi > 1 {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
} }
...@@ -770,6 +770,16 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -770,6 +770,16 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
cols = append(cols, col+" = "+col+" * ?") cols = append(cols, col+" = "+col+" * ?")
case ColExcept: case ColExcept:
cols = append(cols, col+" = "+col+" / ?") cols = append(cols, col+" = "+col+" / ?")
case ColBitAnd:
cols = append(cols, col+" = "+col+" & ?")
case ColBitRShift:
cols = append(cols, col+" = "+col+" >> ?")
case ColBitLShift:
cols = append(cols, col+" = "+col+" << ?")
case ColBitXOR:
cols = append(cols, col+" = "+col+" ^ ?")
case ColBitOr:
cols = append(cols, col+" = "+col+" | ?")
} }
values[i] = c.value values[i] = c.value
} else { } else {
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
lru "github.com/hashicorp/golang-lru"
"reflect" "reflect"
"sync" "sync"
"time" "time"
...@@ -106,8 +107,8 @@ func (ac *_dbCache) getDefault() (al *alias) { ...@@ -106,8 +107,8 @@ func (ac *_dbCache) getDefault() (al *alias) {
type DB struct { type DB struct {
*sync.RWMutex *sync.RWMutex
DB *sql.DB DB *sql.DB
stmts map[string]*sql.Stmt stmtDecorators *lru.Cache
} }
func (d *DB) Begin() (*sql.Tx, error) { func (d *DB) Begin() (*sql.Tx, error) {
...@@ -118,22 +119,36 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) ...@@ -118,22 +119,36 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
return d.DB.BeginTx(ctx, opts) return d.DB.BeginTx(ctx, opts)
} }
func (d *DB) getStmt(query string) (*sql.Stmt, error) { //su must call release to release *sql.Stmt after using
func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
d.RLock() d.RLock()
if stmt, ok := d.stmts[query]; ok { c, ok := d.stmtDecorators.Get(query)
if ok {
c.(*stmtDecorator).acquire()
d.RUnlock() d.RUnlock()
return stmt, nil return c.(*stmtDecorator), nil
} }
d.RUnlock() d.RUnlock()
d.Lock()
c, ok = d.stmtDecorators.Get(query)
if ok {
c.(*stmtDecorator).acquire()
d.Unlock()
return c.(*stmtDecorator), nil
}
stmt, err := d.Prepare(query) stmt, err := d.Prepare(query)
if err != nil { if err != nil {
d.Unlock()
return nil, err return nil, err
} }
d.Lock() sd := newStmtDecorator(stmt)
d.stmts[query] = stmt sd.acquire()
d.stmtDecorators.Add(query, sd)
d.Unlock() d.Unlock()
return stmt, nil
return sd, nil
} }
func (d *DB) Prepare(query string) (*sql.Stmt, error) { func (d *DB) Prepare(query string) (*sql.Stmt, error) {
...@@ -145,52 +160,63 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error ...@@ -145,52 +160,63 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error
} }
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query) sd, err := d.getStmtDecorator(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sd.getStmt()
defer sd.release()
return stmt.Exec(args...) return stmt.Exec(args...)
} }
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query) sd, err := d.getStmtDecorator(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sd.getStmt()
defer sd.release()
return stmt.ExecContext(ctx, args...) return stmt.ExecContext(ctx, args...)
} }
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query) sd, err := d.getStmtDecorator(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sd.getStmt()
defer sd.release()
return stmt.Query(args...) return stmt.Query(args...)
} }
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query) sd, err := d.getStmtDecorator(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sd.getStmt()
defer sd.release()
return stmt.QueryContext(ctx, args...) return stmt.QueryContext(ctx, args...)
} }
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query) sd, err := d.getStmtDecorator(query)
if err != nil { if err != nil {
panic(err) panic(err)
} }
stmt := sd.getStmt()
defer sd.release()
return stmt.QueryRow(args...) return stmt.QueryRow(args...)
} }
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
sd, err := d.getStmtDecorator(query)
stmt, err := d.getStmt(query)
if err != nil { if err != nil {
panic(err) panic(err)
} }
stmt := sd.getStmt()
defer sd.release()
return stmt.QueryRowContext(ctx, args) return stmt.QueryRowContext(ctx, args)
} }
...@@ -268,9 +294,9 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { ...@@ -268,9 +294,9 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al.Name = aliasName al.Name = aliasName
al.DriverName = driverName al.DriverName = driverName
al.DB = &DB{ al.DB = &DB{
RWMutex: new(sync.RWMutex), RWMutex: new(sync.RWMutex),
DB: db, DB: db,
stmts: make(map[string]*sql.Stmt), stmtDecorators: newStmtDecoratorLruWithEvict(),
} }
if dr, ok := drivers[driverName]; ok { if dr, ok := drivers[driverName]; ok {
...@@ -374,6 +400,7 @@ func SetMaxIdleConns(aliasName string, maxIdleConns int) { ...@@ -374,6 +400,7 @@ func SetMaxIdleConns(aliasName string, maxIdleConns int) {
func SetMaxOpenConns(aliasName string, maxOpenConns int) { func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns al.MaxOpenConns = maxOpenConns
al.DB.DB.SetMaxOpenConns(maxOpenConns)
// for tip go 1.2 // for tip go 1.2
if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
...@@ -395,3 +422,44 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { ...@@ -395,3 +422,44 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} }
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
} }
type stmtDecorator struct {
wg sync.WaitGroup
lastUse int64
stmt *sql.Stmt
}
func (s *stmtDecorator) getStmt() *sql.Stmt {
return s.stmt
}
func (s *stmtDecorator) acquire() {
s.wg.Add(1)
s.lastUse = time.Now().Unix()
}
func (s *stmtDecorator) release() {
s.wg.Done()
}
//garbage recycle for stmt
func (s *stmtDecorator) destroy() {
go func() {
s.wg.Wait()
_ = s.stmt.Close()
}()
}
func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
return &stmtDecorator{
stmt: sqlStmt,
lastUse: time.Now().Unix(),
}
}
func newStmtDecoratorLruWithEvict() *lru.Cache {
cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) {
value.(*stmtDecorator).destroy()
})
return cache
}
...@@ -17,6 +17,8 @@ package orm ...@@ -17,6 +17,8 @@ package orm
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect"
"time"
) )
// sqlite operators. // sqlite operators.
...@@ -66,6 +68,14 @@ type dbBaseSqlite struct { ...@@ -66,6 +68,14 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite) var _ dbBaser = new(dbBaseSqlite)
// override base db read for update behavior as SQlite does not support syntax
func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
if isForUpdate {
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
}
return d.dbBase.Read(q, mi, ind, tz, cols, false)
}
// get sqlite operator. // get sqlite operator.
func (d *dbBaseSqlite) OperatorSQL(operator string) string { func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator] return sqliteOperators[operator]
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"runtime/debug"
"strings" "strings"
) )
...@@ -298,6 +299,7 @@ func bootStrap() { ...@@ -298,6 +299,7 @@ func bootStrap() {
end: end:
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
debug.PrintStack()
os.Exit(2) os.Exit(2)
} }
} }
......
...@@ -559,9 +559,9 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { ...@@ -559,9 +559,9 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
al.Name = aliasName al.Name = aliasName
al.DriverName = driverName al.DriverName = driverName
al.DB = &DB{ al.DB = &DB{
RWMutex: new(sync.RWMutex), RWMutex: new(sync.RWMutex),
DB: db, DB: db,
stmts: make(map[string]*sql.Stmt), stmtDecorators: newStmtDecoratorLruWithEvict(),
} }
detectTZ(al) detectTZ(al)
......
...@@ -32,6 +32,11 @@ const ( ...@@ -32,6 +32,11 @@ const (
ColMinus ColMinus
ColMultiply ColMultiply
ColExcept ColExcept
ColBitAnd
ColBitRShift
ColBitLShift
ColBitXOR
ColBitOr
) )
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: // ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
...@@ -40,7 +45,8 @@ const ( ...@@ -40,7 +45,8 @@ const (
// } // }
func ColValue(opt operator, value interface{}) interface{} { func ColValue(opt operator, value interface{}) interface{} {
switch opt { switch opt {
case ColAdd, ColMinus, ColMultiply, ColExcept: case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift,
ColBitLShift, ColBitXOR, ColBitOr:
default: default:
panic(fmt.Errorf("orm.ColValue wrong operator")) panic(fmt.Errorf("orm.ColValue wrong operator"))
} }
......
...@@ -500,7 +500,7 @@ func genRouterCode(pkgRealpath string) { ...@@ -500,7 +500,7 @@ func genRouterCode(pkgRealpath string) {
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{ beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `", Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `, ` + `Router: "` + c.Router + `"` + `,
AllowHTTPMethods: ` + allmethod + `, AllowHTTPMethods: ` + allmethod + `,
MethodParams: ` + methodParams + `, MethodParams: ` + methodParams + `,
Filters: ` + filters + `, Filters: ` + filters + `,
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"os"
"path" "path"
"path/filepath" "path/filepath"
"reflect" "reflect"
...@@ -121,6 +122,10 @@ type ControllerInfo struct { ...@@ -121,6 +122,10 @@ type ControllerInfo struct {
methodParams []*param.MethodParam methodParams []*param.MethodParam
} }
func (c *ControllerInfo) GetPattern() string {
return c.pattern
}
// ControllerRegister containers registered router rules, controller handlers and filters. // ControllerRegister containers registered router rules, controller handlers and filters.
type ControllerRegister struct { type ControllerRegister struct {
routers map[string]*Tree routers map[string]*Tree
...@@ -249,25 +254,39 @@ func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerIn ...@@ -249,25 +254,39 @@ func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerIn
func (p *ControllerRegister) Include(cList ...ControllerInterface) { func (p *ControllerRegister) Include(cList ...ControllerInterface) {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
skip := make(map[string]bool, 10) skip := make(map[string]bool, 10)
wgopath := utils.GetGOPATHs()
go111module := os.Getenv(`GO111MODULE`)
for _, c := range cList { for _, c := range cList {
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type() t := reflect.Indirect(reflectVal).Type()
wgopath := utils.GetGOPATHs() // for go modules
if len(wgopath) == 0 { if go111module == `on` {
panic("you are in dev mode. So please set gopath") pkgpath := filepath.Join(WorkPath, "..", t.PkgPath())
} if utils.FileExists(pkgpath) {
pkgpath := "" if pkgpath != "" {
for _, wg := range wgopath { if _, ok := skip[pkgpath]; !ok {
wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) skip[pkgpath] = true
if utils.FileExists(wg) { parserPkg(pkgpath, t.PkgPath())
pkgpath = wg }
break }
} }
} } else {
if pkgpath != "" { if len(wgopath) == 0 {
if _, ok := skip[pkgpath]; !ok { panic("you are in dev mode. So please set gopath")
skip[pkgpath] = true }
parserPkg(pkgpath, t.PkgPath()) pkgpath := ""
for _, wg := range wgopath {
wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath()))
if utils.FileExists(wg) {
pkgpath = wg
break
}
}
if pkgpath != "" {
if _, ok := skip[pkgpath]; !ok {
skip[pkgpath] = true
parserPkg(pkgpath, t.PkgPath())
}
} }
} }
} }
...@@ -288,6 +307,21 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { ...@@ -288,6 +307,21 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
} }
} }
// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context
// And don't forget to give back context to pool
// example:
// ctx := p.GetContext()
// ctx.Reset(w, q)
// defer p.GiveBackContext(ctx)
func (p *ControllerRegister) GetContext() *beecontext.Context {
return p.pool.Get().(*beecontext.Context)
}
// GiveBackContext put the ctx into pool so that it could be reuse
func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) {
p.pool.Put(ctx)
}
// Get add get method // Get add get method
// usage: // usage:
// Get("/", func(ctx *context.Context){ // Get("/", func(ctx *context.Context){
...@@ -667,10 +701,11 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -667,10 +701,11 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
routerInfo *ControllerInfo routerInfo *ControllerInfo
isRunnable bool isRunnable bool
) )
context := p.pool.Get().(*beecontext.Context) context := p.GetContext()
context.Reset(rw, r) context.Reset(rw, r)
defer p.pool.Put(context) defer p.GiveBackContext(context)
if BConfig.RecoverFunc != nil { if BConfig.RecoverFunc != nil {
defer BConfig.RecoverFunc(context) defer BConfig.RecoverFunc(context)
} }
...@@ -739,7 +774,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -739,7 +774,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
routerInfo, findRouter = p.FindRouter(context) routerInfo, findRouter = p.FindRouter(context)
} }
//if no matches to url, throw a not found exception // if no matches to url, throw a not found exception
if !findRouter { if !findRouter {
exception("404", context) exception("404", context)
goto Admin goto Admin
...@@ -750,19 +785,22 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -750,19 +785,22 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
} }
//execute middleware filters if routerInfo != nil {
// store router pattern into context
context.Input.SetData("RouterPattern", routerInfo.pattern)
}
// execute middleware filters
if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) {
goto Admin goto Admin
} }
//check policies // check policies
if p.execPolicy(context, urlPath) { if p.execPolicy(context, urlPath) {
goto Admin goto Admin
} }
if routerInfo != nil { if routerInfo != nil {
//store router pattern into context
context.Input.SetData("RouterPattern", routerInfo.pattern)
if routerInfo.routerType == routerTypeRESTFul { if routerInfo.routerType == routerTypeRESTFul {
if _, ok := routerInfo.methods[r.Method]; ok { if _, ok := routerInfo.methods[r.Method]; ok {
isRunnable = true isRunnable = true
...@@ -796,7 +834,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -796,7 +834,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// also defined runRouter & runMethod from filter // also defined runRouter & runMethod from filter
if !isRunnable { if !isRunnable {
//Invoke the request handler // Invoke the request handler
var execController ControllerInterface var execController ControllerInterface
if routerInfo != nil && routerInfo.initialize != nil { if routerInfo != nil && routerInfo.initialize != nil {
execController = routerInfo.initialize() execController = routerInfo.initialize()
...@@ -809,13 +847,13 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -809,13 +847,13 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
} }
//call the controller init function // call the controller init function
execController.Init(context, runRouter.Name(), runMethod, execController) execController.Init(context, runRouter.Name(), runMethod, execController)
//call prepare function // call prepare function
execController.Prepare() execController.Prepare()
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
if BConfig.WebConfig.EnableXSRF { if BConfig.WebConfig.EnableXSRF {
execController.XSRFToken() execController.XSRFToken()
if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
...@@ -827,7 +865,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -827,7 +865,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.URLMapping() execController.URLMapping()
if !context.ResponseWriter.Started { if !context.ResponseWriter.Started {
//exec main logic // exec main logic
switch runMethod { switch runMethod {
case http.MethodGet: case http.MethodGet:
execController.Get() execController.Get()
...@@ -852,14 +890,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -852,14 +890,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
in := param.ConvertParams(methodParams, method.Type(), context) in := param.ConvertParams(methodParams, method.Type(), context)
out := method.Call(in) out := method.Call(in)
//For backward compatibility we only handle response if we had incoming methodParams // For backward compatibility we only handle response if we had incoming methodParams
if methodParams != nil { if methodParams != nil {
p.handleParamResponse(context, execController, out) p.handleParamResponse(context, execController, out)
} }
} }
} }
//render template // render template
if !context.ResponseWriter.Started && context.Output.Status == 0 { if !context.ResponseWriter.Started && context.Output.Status == 0 {
if BConfig.WebConfig.AutoRender { if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil { if err := execController.Render(); err != nil {
...@@ -873,7 +911,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -873,7 +911,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.Finish() execController.Finish()
} }
//execute middleware filters // execute middleware filters
if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) {
goto Admin goto Admin
} }
...@@ -883,7 +921,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -883,7 +921,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
Admin: Admin:
//admin module record QPS // admin module record QPS
statusCode := context.ResponseWriter.Status statusCode := context.ResponseWriter.Status
if statusCode == 0 { if statusCode == 0 {
...@@ -931,7 +969,7 @@ Admin: ...@@ -931,7 +969,7 @@ Admin:
} }
func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
//looping in reverse order for the case when both error and value are returned and error sets the response status code // looping in reverse order for the case when both error and value are returned and error sets the response status code
for i := len(results) - 1; i >= 0; i-- { for i := len(results) - 1; i >= 0; i-- {
result := results[i] result := results[i]
if result.Kind() != reflect.Interface || !result.IsNil() { if result.Kind() != reflect.Interface || !result.IsNil() {
...@@ -973,11 +1011,11 @@ func toURL(params map[string]string) string { ...@@ -973,11 +1011,11 @@ func toURL(params map[string]string) string {
// LogAccess logging info HTTP Access // LogAccess logging info HTTP Access
func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) {
//Skip logging if AccessLogs config is false // Skip logging if AccessLogs config is false
if !BConfig.Log.AccessLogs { if !BConfig.Log.AccessLogs {
return return
} }
//Skip logging static requests unless EnableStaticLogs config is true // Skip logging static requests unless EnableStaticLogs config is true
if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) {
return return
} }
...@@ -1002,7 +1040,7 @@ func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { ...@@ -1002,7 +1040,7 @@ func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) {
HTTPReferrer: r.Header.Get("Referer"), HTTPReferrer: r.Header.Get("Referer"),
HTTPUserAgent: r.Header.Get("User-Agent"), HTTPUserAgent: r.Header.Get("User-Agent"),
RemoteUser: r.Header.Get("Remote-User"), RemoteUser: r.Header.Get("Remote-User"),
BodyBytesSent: 0, //@todo this one is missing! BodyBytesSent: 0, // @todo this one is missing!
} }
logs.AccessLog(record, BConfig.Log.AccessLogsFormat) logs.AccessLog(record, BConfig.Log.AccessLogsFormat)
} }
#!/bin/bash
# WARNING: DO NOT EDIT, THIS FILE IS PROBABLY A COPY
#
# The original version of this file is located in the https://github.com/istio/common-files repo.
# If you're looking at this file in a different repo and want to make a change, please go to the
# common-files repo, make the change there and check it in. Then come back to this repo and run
# "make update-common".
# Copyright Istio Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script builds and version stamps the output
# adatp to beego
VERBOSE=${VERBOSE:-"0"}
V=""
if [[ "${VERBOSE}" == "1" ]];then
V="-x"
set -x
fi
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
OUT=${1:?"output path"}
shift
set -e
BUILD_GOOS=${GOOS:-linux}
BUILD_GOARCH=${GOARCH:-amd64}
GOBINARY=${GOBINARY:-go}
GOPKG="$GOPATH/pkg"
BUILDINFO=${BUILDINFO:-""}
STATIC=${STATIC:-1}
LDFLAGS=${LDFLAGS:--extldflags -static}
GOBUILDFLAGS=${GOBUILDFLAGS:-""}
# Split GOBUILDFLAGS by spaces into an array called GOBUILDFLAGS_ARRAY.
IFS=' ' read -r -a GOBUILDFLAGS_ARRAY <<< "$GOBUILDFLAGS"
GCFLAGS=${GCFLAGS:-}
export CGO_ENABLED=0
if [[ "${STATIC}" != "1" ]];then
LDFLAGS=""
fi
# gather buildinfo if not already provided
# For a release build BUILDINFO should be produced
# at the beginning of the build and used throughout
if [[ -z ${BUILDINFO} ]];then
BUILDINFO=$(mktemp)
"${SCRIPTPATH}/report_build_info.sh" > "${BUILDINFO}"
fi
# BUILD LD_EXTRAFLAGS
LD_EXTRAFLAGS=""
while read -r line; do
LD_EXTRAFLAGS="${LD_EXTRAFLAGS} -X ${line}"
done < "${BUILDINFO}"
# verify go version before build
# NB. this was copied verbatim from Kubernetes hack
minimum_go_version=go1.13 # supported patterns: go1.x, go1.x.x (x should be a number)
IFS=" " read -ra go_version <<< "$(${GOBINARY} version)"
if [[ "${minimum_go_version}" != $(echo -e "${minimum_go_version}\n${go_version[2]}" | sort -s -t. -k 1,1 -k 2,2n -k 3,3n | head -n1) && "${go_version[2]}" != "devel" ]]; then
echo "Warning: Detected that you are using an older version of the Go compiler. Beego requires ${minimum_go_version} or greater."
fi
CURRENT_BRANCH=$(git branch | grep '*')
CURRENT_BRANCH=${CURRENT_BRANCH:2}
BUILD_TIME=$(date +%Y-%m-%d--%T)
LD_EXTRAFLAGS="${LD_EXTRAFLAGS} -X github.com/astaxie/beego.GoVersion=${go_version[2]:2}"
LD_EXTRAFLAGS="${LD_EXTRAFLAGS} -X github.com/astaxie/beego.GitBranch=${CURRENT_BRANCH}"
LD_EXTRAFLAGS="${LD_EXTRAFLAGS} -X github.com/astaxie/beego.BuildTime=$BUILD_TIME"
OPTIMIZATION_FLAGS="-trimpath"
if [ "${DEBUG}" == "1" ]; then
OPTIMIZATION_FLAGS=""
fi
echo "BUILD_GOARCH: $BUILD_GOARCH"
echo "GOPKG: $GOPKG"
echo "LD_EXTRAFLAGS: $LD_EXTRAFLAGS"
echo "GO_VERSION: ${go_version[2]}"
echo "BRANCH: $CURRENT_BRANCH"
echo "BUILD_TIME: $BUILD_TIME"
time GOOS=${BUILD_GOOS} GOARCH=${BUILD_GOARCH} ${GOBINARY} build \
${V} "${GOBUILDFLAGS_ARRAY[@]}" ${GCFLAGS:+-gcflags "${GCFLAGS}"} \
-o "${OUT}" \
${OPTIMIZATION_FLAGS} \
-pkgdir="${GOPKG}/${BUILD_GOOS}_${BUILD_GOARCH}" \
-ldflags "${LDFLAGS} ${LD_EXTRAFLAGS}" "${@}"
\ No newline at end of file
#!/bin/bash
# WARNING: DO NOT EDIT, THIS FILE IS PROBABLY A COPY
#
# The original version of this file is located in the https://github.com/istio/common-files repo.
# If you're looking at this file in a different repo and want to make a change, please go to the
# common-files repo, make the change there and check it in. Then come back to this repo and run
# "make update-common".
# Copyright Istio Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# adapt to beego
if BUILD_GIT_REVISION=$(git rev-parse HEAD 2> /dev/null); then
if [[ -n "$(git status --porcelain 2>/dev/null)" ]]; then
BUILD_GIT_REVISION=${BUILD_GIT_REVISION}"-dirty"
fi
else
BUILD_GIT_REVISION=unknown
fi
# Check for local changes
if git diff-index --quiet HEAD --; then
tree_status="Clean"
else
tree_status="Modified"
fi
# security wanted VERSION='unknown'
VERSION="${BUILD_GIT_REVISION}"
if [[ -n ${BEEGO_VERSION} ]]; then
VERSION="${BEEGO_VERSION}"
fi
GIT_DESCRIBE_TAG=$(git describe --tags)
echo "github.com/astaxie/beego.BuildVersion=${VERSION}"
echo "github.com/astaxie/beego.BuildGitRevision=${BUILD_GIT_REVISION}"
echo "github.com/astaxie/beego.BuildStatus=${tree_status}"
echo "github.com/astaxie/beego.BuildTag=${GIT_DESCRIBE_TAG}"
\ No newline at end of file
...@@ -7,9 +7,10 @@ import ( ...@@ -7,9 +7,10 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/ledisdb/ledisdb/config"
"github.com/ledisdb/ledisdb/ledis"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/siddontang/ledisdb/config"
"github.com/siddontang/ledisdb/ledis"
) )
var ( var (
......
...@@ -74,7 +74,9 @@ func (st *CookieSessionStore) SessionID() string { ...@@ -74,7 +74,9 @@ func (st *CookieSessionStore) SessionID() string {
// SessionRelease Write cookie session to http response cookie // SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
st.lock.Lock()
encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values)
st.lock.Unlock()
if err == nil { if err == nil {
cookie := &http.Cookie{Name: cookiepder.config.CookieName, cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(encodedCookie), Value: url.QueryEscape(encodedCookie),
......
...@@ -129,8 +129,9 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { ...@@ -129,8 +129,9 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
// if file is not exist, create it. // if file is not exist, create it.
// the file path is generated from sid string. // the file path is generated from sid string.
func (fp *FileProvider) SessionRead(sid string) (Store, error) { func (fp *FileProvider) SessionRead(sid string) (Store, error) {
if strings.ContainsAny(sid, "./") { invalidChars := "./"
return nil, nil if strings.ContainsAny(sid, invalidChars) {
return nil, errors.New("the sid shouldn't have following characters: " + invalidChars)
} }
if len(sid) < 2 { if len(sid) < 2 {
return nil, errors.New("length of the sid is less than 2") return nil, errors.New("length of the sid is less than 2")
...@@ -138,7 +139,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { ...@@ -138,7 +139,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755)
if err != nil { if err != nil {
SLogger.Println(err.Error()) SLogger.Println(err.Error())
} }
...@@ -231,7 +232,7 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { ...@@ -231,7 +232,7 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
return nil, fmt.Errorf("newsid %s exist", newSidFile) return nil, fmt.Errorf("newsid %s exist", newSidFile)
} }
err = os.MkdirAll(newPath, 0777) err = os.MkdirAll(newPath, 0755)
if err != nil { if err != nil {
SLogger.Println(err.Error()) SLogger.Println(err.Error())
} }
......
...@@ -28,6 +28,7 @@ import ( ...@@ -28,6 +28,7 @@ import (
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/hashicorp/golang-lru"
) )
var errNotStaticRequest = errors.New("request not a static file request") var errNotStaticRequest = errors.New("request not a static file request")
...@@ -67,6 +68,10 @@ func serverStaticRouter(ctx *context.Context) { ...@@ -67,6 +68,10 @@ func serverStaticRouter(ctx *context.Context) {
http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
} }
return return
} else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) {
//over size file serve with http module
http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
return
} }
var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath)
...@@ -93,10 +98,11 @@ func serverStaticRouter(ctx *context.Context) { ...@@ -93,10 +98,11 @@ func serverStaticRouter(ctx *context.Context) {
} }
type serveContentHolder struct { type serveContentHolder struct {
data []byte data []byte
modTime time.Time modTime time.Time
size int64 size int64
encoding string originSize int64 //original file size:to judge file changed
encoding string
} }
type serveContentReader struct { type serveContentReader struct {
...@@ -104,22 +110,36 @@ type serveContentReader struct { ...@@ -104,22 +110,36 @@ type serveContentReader struct {
} }
var ( var (
staticFileMap = make(map[string]*serveContentHolder) staticFileLruCache *lru.Cache
mapLock sync.RWMutex lruLock sync.RWMutex
) )
func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) { func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) {
if staticFileLruCache == nil {
//avoid lru cache error
if BConfig.WebConfig.StaticCacheFileNum >= 1 {
staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum)
} else {
staticFileLruCache, _ = lru.New(1)
}
}
mapKey := acceptEncoding + ":" + filePath mapKey := acceptEncoding + ":" + filePath
mapLock.RLock() lruLock.RLock()
mapFile := staticFileMap[mapKey] var mapFile *serveContentHolder
mapLock.RUnlock() if cacheItem, ok := staticFileLruCache.Get(mapKey); ok {
mapFile = cacheItem.(*serveContentHolder)
}
lruLock.RUnlock()
if isOk(mapFile, fi) { if isOk(mapFile, fi) {
reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
} }
mapLock.Lock() lruLock.Lock()
defer mapLock.Unlock() defer lruLock.Unlock()
if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) { if cacheItem, ok := staticFileLruCache.Get(mapKey); ok {
mapFile = cacheItem.(*serveContentHolder)
}
if !isOk(mapFile, fi) {
file, err := os.Open(filePath) file, err := os.Open(filePath)
if err != nil { if err != nil {
return false, "", nil, nil, err return false, "", nil, nil, err
...@@ -130,8 +150,10 @@ func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, str ...@@ -130,8 +150,10 @@ func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, str
if err != nil { if err != nil {
return false, "", nil, nil, err return false, "", nil, nil, err
} }
mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n} mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n}
staticFileMap[mapKey] = mapFile if isOk(mapFile, fi) {
staticFileLruCache.Add(mapKey, mapFile)
}
} }
reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
...@@ -141,8 +163,10 @@ func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, str ...@@ -141,8 +163,10 @@ func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, str
func isOk(s *serveContentHolder, fi os.FileInfo) bool { func isOk(s *serveContentHolder, fi os.FileInfo) bool {
if s == nil { if s == nil {
return false return false
} else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) {
return false
} }
return s.modTime == fi.ModTime() && s.size == fi.Size() return s.modTime == fi.ModTime() && s.originSize == fi.Size()
} }
// isStaticCompress detect static files // isStaticCompress detect static files
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"compress/zlib" "compress/zlib"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
...@@ -53,6 +54,31 @@ func TestOpenStaticFileDeflate_1(t *testing.T) { ...@@ -53,6 +54,31 @@ func TestOpenStaticFileDeflate_1(t *testing.T) {
testOpenFile("deflate", content, t) testOpenFile("deflate", content, t)
} }
func TestStaticCacheWork(t *testing.T) {
encodings := []string{"", "gzip", "deflate"}
fi, _ := os.Stat(licenseFile)
for _, encoding := range encodings {
_, _, first, _, err := openFile(licenseFile, fi, encoding)
if err != nil {
t.Error(err)
continue
}
_, _, second, _, err := openFile(licenseFile, fi, encoding)
if err != nil {
t.Error(err)
continue
}
address1 := fmt.Sprintf("%p", first)
address2 := fmt.Sprintf("%p", second)
if address1 != address2 {
t.Errorf("encoding '%v' can not hit cache", encoding)
}
}
}
func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) { func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) {
t.Log(sch.size, len(content)) t.Log(sch.size, len(content))
if sch.size != int64(len(content)) { if sch.size != int64(len(content)) {
...@@ -66,7 +92,7 @@ func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader ...@@ -66,7 +92,7 @@ func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader
t.Fail() t.Fail()
} }
} }
if len(staticFileMap) == 0 { if staticFileLruCache.Len() == 0 {
t.Log("men map is empty") t.Log("men map is empty")
t.Fail() t.Fail()
} }
......
...@@ -117,8 +117,8 @@ func (m *URLMap) GetMap() map[string]interface{} { ...@@ -117,8 +117,8 @@ func (m *URLMap) GetMap() map[string]interface{} {
// GetMapData return all mapdata // GetMapData return all mapdata
func (m *URLMap) GetMapData() []map[string]interface{} { func (m *URLMap) GetMapData() []map[string]interface{} {
m.lock.Lock() m.lock.RLock()
defer m.lock.Unlock() defer m.lock.RUnlock()
var resultLists []map[string]interface{} var resultLists []map[string]interface{}
......
...@@ -33,7 +33,7 @@ type bounds struct { ...@@ -33,7 +33,7 @@ type bounds struct {
// The bounds for each field. // The bounds for each field.
var ( var (
AdminTaskList map[string]Tasker AdminTaskList map[string]Tasker
taskLock sync.Mutex taskLock sync.RWMutex
stop chan bool stop chan bool
changed chan bool changed chan bool
isstart bool isstart bool
...@@ -408,7 +408,10 @@ func run() { ...@@ -408,7 +408,10 @@ func run() {
} }
for { for {
// we only use RLock here because NewMapSorter copy the reference, do not change any thing
taskLock.RLock()
sortList := NewMapSorter(AdminTaskList) sortList := NewMapSorter(AdminTaskList)
taskLock.RUnlock()
sortList.Sort() sortList.Sort()
var effective time.Time var effective time.Time
if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() {
...@@ -432,9 +435,11 @@ func run() { ...@@ -432,9 +435,11 @@ func run() {
continue continue
case <-changed: case <-changed:
now = time.Now().Local() now = time.Now().Local()
taskLock.Lock()
for _, t := range AdminTaskList { for _, t := range AdminTaskList {
t.SetNext(now) t.SetNext(now)
} }
taskLock.Unlock()
continue continue
case <-stop: case <-stop:
return return
......
...@@ -224,7 +224,7 @@ func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { ...@@ -224,7 +224,7 @@ func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) {
func numIn(name string) (num int, err error) { func numIn(name string) (num int, err error) {
fn, ok := funcs[name] fn, ok := funcs[name]
if !ok { if !ok {
err = fmt.Errorf("doesn't exsits %s valid function", name) err = fmt.Errorf("doesn't exists %s valid function", name)
return return
} }
// sub *Validation obj and key // sub *Validation obj and key
...@@ -236,7 +236,7 @@ func trim(name, key string, s []string) (ts []interface{}, err error) { ...@@ -236,7 +236,7 @@ func trim(name, key string, s []string) (ts []interface{}, err error) {
ts = make([]interface{}, len(s), len(s)+1) ts = make([]interface{}, len(s), len(s)+1)
fn, ok := funcs[name] fn, ok := funcs[name]
if !ok { if !ok {
err = fmt.Errorf("doesn't exsits %s valid function", name) err = fmt.Errorf("doesn't exists %s valid function", name)
return return
} }
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package validation package validation
import ( import (
"log"
"reflect" "reflect"
"testing" "testing"
) )
...@@ -23,7 +24,7 @@ type user struct { ...@@ -23,7 +24,7 @@ type user struct {
ID int ID int
Tag string `valid:"Maxx(aa)"` Tag string `valid:"Maxx(aa)"`
Name string `valid:"Required;"` Name string `valid:"Required;"`
Age int `valid:"Required;Range(1, 140)"` Age int `valid:"Required; Range(1, 140)"`
match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"` match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"`
} }
...@@ -42,7 +43,7 @@ func TestGetValidFuncs(t *testing.T) { ...@@ -42,7 +43,7 @@ func TestGetValidFuncs(t *testing.T) {
} }
f, _ = tf.FieldByName("Tag") f, _ = tf.FieldByName("Tag")
if _, err = getValidFuncs(f); err.Error() != "doesn't exsits Maxx valid function" { if _, err = getValidFuncs(f); err.Error() != "doesn't exists Maxx valid function" {
t.Fatal(err) t.Fatal(err)
} }
...@@ -80,6 +81,33 @@ func TestGetValidFuncs(t *testing.T) { ...@@ -80,6 +81,33 @@ func TestGetValidFuncs(t *testing.T) {
} }
} }
type User struct {
Name string `valid:"Required;MaxSize(5)" `
Sex string `valid:"Required;" label:"sex_label"`
Age int `valid:"Required;Range(1, 140);" label:"age_label"`
}
func TestValidation(t *testing.T) {
u := User{"man1238888456", "", 1140}
valid := Validation{}
b, err := valid.Valid(&u)
if err != nil {
// handle error
}
if !b {
// validation does not pass
// blabla...
for _, err := range valid.Errors {
log.Println(err.Key, err.Message)
}
if len(valid.Errors) != 3 {
t.Error("must be has 3 error")
}
} else {
t.Error("must be has 3 error")
}
}
func TestCall(t *testing.T) { func TestCall(t *testing.T) {
u := user{Name: "test", Age: 180} u := user{Name: "test", Age: 180}
tf := reflect.TypeOf(u) tf := reflect.TypeOf(u)
......
...@@ -273,10 +273,13 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result { ...@@ -273,10 +273,13 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result {
Field = parts[0] Field = parts[0]
Name = parts[1] Name = parts[1]
Label = parts[2] Label = parts[2]
if len(Label) == 0 {
Label = Field
}
} }
err := &Error{ err := &Error{
Message: Label + chk.DefaultMessage(), Message: Label + " " + chk.DefaultMessage(),
Key: key, Key: key,
Name: Name, Name: Name,
Field: Field, Field: Field,
...@@ -293,19 +296,25 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result { ...@@ -293,19 +296,25 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result {
} }
} }
// key must like aa.bb.cc or aa.bb.
// AddError adds independent error message for the provided key // AddError adds independent error message for the provided key
func (v *Validation) AddError(key, message string) { func (v *Validation) AddError(key, message string) {
Name := key Name := key
Field := "" Field := ""
Label := ""
parts := strings.Split(key, ".") parts := strings.Split(key, ".")
if len(parts) == 3 { if len(parts) == 3 {
Field = parts[0] Field = parts[0]
Name = parts[1] Name = parts[1]
Label = parts[2]
if len(Label) == 0 {
Label = Field
}
} }
err := &Error{ err := &Error{
Message: message, Message: Label + " " + message,
Key: key, Key: key,
Name: Name, Name: Name,
Field: Field, Field: Field,
...@@ -381,7 +390,6 @@ func (v *Validation) Valid(obj interface{}) (b bool, err error) { ...@@ -381,7 +390,6 @@ func (v *Validation) Valid(obj interface{}) (b bool, err error) {
} }
} }
chk := Required{""}.IsSatisfied(currentField) chk := Required{""}.IsSatisfied(currentField)
if !hasRequired && v.RequiredFirst && !chk { if !hasRequired && v.RequiredFirst && !chk {
if _, ok := CanSkipFuncs[vf.Name]; ok { if _, ok := CanSkipFuncs[vf.Name]; ok {
......
...@@ -253,44 +253,68 @@ func TestBase64(t *testing.T) { ...@@ -253,44 +253,68 @@ func TestBase64(t *testing.T) {
func TestMobile(t *testing.T) { func TestMobile(t *testing.T) {
valid := Validation{} valid := Validation{}
if valid.Mobile("19800008888", "mobile").Ok { validMobiles := []string{
t.Error("\"19800008888\" is a valid mobile phone number should be false") "19800008888",
} "18800008888",
if !valid.Mobile("18800008888", "mobile").Ok { "18000008888",
t.Error("\"18800008888\" is a valid mobile phone number should be true") "8618300008888",
} "+8614700008888",
if !valid.Mobile("18000008888", "mobile").Ok { "17300008888",
t.Error("\"18000008888\" is a valid mobile phone number should be true") "+8617100008888",
} "8617500008888",
if !valid.Mobile("8618300008888", "mobile").Ok { "8617400008888",
t.Error("\"8618300008888\" is a valid mobile phone number should be true") "16200008888",
} "16500008888",
if !valid.Mobile("+8614700008888", "mobile").Ok { "16600008888",
t.Error("\"+8614700008888\" is a valid mobile phone number should be true") "16700008888",
} "13300008888",
if !valid.Mobile("17300008888", "mobile").Ok { "14900008888",
t.Error("\"17300008888\" is a valid mobile phone number should be true") "15300008888",
} "17300008888",
if !valid.Mobile("+8617100008888", "mobile").Ok { "17700008888",
t.Error("\"+8617100008888\" is a valid mobile phone number should be true") "18000008888",
} "18900008888",
if !valid.Mobile("8617500008888", "mobile").Ok { "19100008888",
t.Error("\"8617500008888\" is a valid mobile phone number should be true") "19900008888",
} "19300008888",
if valid.Mobile("8617400008888", "mobile").Ok { "13000008888",
t.Error("\"8617400008888\" is a valid mobile phone number should be false") "13100008888",
} "13200008888",
if !valid.Mobile("16200008888", "mobile").Ok { "14500008888",
t.Error("\"16200008888\" is a valid mobile phone number should be true") "15500008888",
} "15600008888",
if !valid.Mobile("16500008888", "mobile").Ok { "16600008888",
t.Error("\"16500008888\" is a valid mobile phone number should be true") "17100008888",
} "17500008888",
if !valid.Mobile("16600008888", "mobile").Ok { "17600008888",
t.Error("\"16600008888\" is a valid mobile phone number should be true") "18500008888",
} "18600008888",
if !valid.Mobile("16700008888", "mobile").Ok { "13400008888",
t.Error("\"16700008888\" is a valid mobile phone number should be true") "13500008888",
"13600008888",
"13700008888",
"13800008888",
"13900008888",
"14700008888",
"15000008888",
"15100008888",
"15200008888",
"15800008888",
"15900008888",
"17200008888",
"17800008888",
"18200008888",
"18300008888",
"18400008888",
"18700008888",
"18800008888",
"19800008888",
}
for _, m := range validMobiles {
if !valid.Mobile(m, "mobile").Ok {
t.Error(m + " is a valid mobile phone number should be true")
}
} }
} }
...@@ -381,8 +405,8 @@ func TestValid(t *testing.T) { ...@@ -381,8 +405,8 @@ func TestValid(t *testing.T) {
if len(valid.Errors) != 1 { if len(valid.Errors) != 1 {
t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors))
} }
if valid.Errors[0].Key != "Age.Range" { if valid.Errors[0].Key != "Age.Range." {
t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key)
} }
} }
......
...@@ -16,9 +16,11 @@ package validation ...@@ -16,9 +16,11 @@ package validation
import ( import (
"fmt" "fmt"
"github.com/astaxie/beego/logs"
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
"sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
) )
...@@ -57,6 +59,8 @@ var MessageTmpls = map[string]string{ ...@@ -57,6 +59,8 @@ var MessageTmpls = map[string]string{
"ZipCode": "Must be valid zipcode", "ZipCode": "Must be valid zipcode",
} }
var once sync.Once
// SetDefaultMessage set default messages // SetDefaultMessage set default messages
// if not set, the default messages are // if not set, the default messages are
// "Required": "Can not be empty", // "Required": "Can not be empty",
...@@ -84,9 +88,12 @@ func SetDefaultMessage(msg map[string]string) { ...@@ -84,9 +88,12 @@ func SetDefaultMessage(msg map[string]string) {
return return
} }
for name := range msg { once.Do(func() {
MessageTmpls[name] = msg[name] for name := range msg {
} MessageTmpls[name] = msg[name]
}
})
logs.Warn(`you must SetDefaultMessage at once`)
} }
// Validator interface // Validator interface
...@@ -632,7 +639,7 @@ func (b Base64) GetLimitValue() interface{} { ...@@ -632,7 +639,7 @@ func (b Base64) GetLimitValue() interface{} {
} }
// just for chinese mobile phone number // just for chinese mobile phone number
var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][01356789]|[4][579]|[6][2567]))\d{8}$`) var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?1([356789][0-9]|4[579]|6[67]|7[0135678]|9[189])[0-9]{8}$`)
// Mobile check struct // Mobile check struct
type Mobile struct { type Mobile struct {
......
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.
此差异已折叠。
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert
import (
"context"
"errors"
"io/ioutil"
"os"
"path/filepath"
)
// ErrCacheMiss is returned when a certificate is not found in cache.
var ErrCacheMiss = errors.New("acme/autocert: certificate cache miss")
// Cache is used by Manager to store and retrieve previously obtained certificates
// and other account data as opaque blobs.
//
// Cache implementations should not rely on the key naming pattern. Keys can
// include any printable ASCII characters, except the following: \/:*?"<>|
type Cache interface {
// Get returns a certificate data for the specified key.
// If there's no such key, Get returns ErrCacheMiss.
Get(ctx context.Context, key string) ([]byte, error)
// Put stores the data in the cache under the specified key.
// Underlying implementations may use any data storage format,
// as long as the reverse operation, Get, results in the original data.
Put(ctx context.Context, key string, data []byte) error
// Delete removes a certificate data from the cache under the specified key.
// If there's no such key in the cache, Delete returns nil.
Delete(ctx context.Context, key string) error
}
// DirCache implements Cache using a directory on the local filesystem.
// If the directory does not exist, it will be created with 0700 permissions.
type DirCache string
// Get reads a certificate data from the specified file name.
func (d DirCache) Get(ctx context.Context, name string) ([]byte, error) {
name = filepath.Join(string(d), name)
var (
data []byte
err error
done = make(chan struct{})
)
go func() {
data, err = ioutil.ReadFile(name)
close(done)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
}
if os.IsNotExist(err) {
return nil, ErrCacheMiss
}
return data, err
}
// Put writes the certificate data to the specified file name.
// The file will be created with 0600 permissions.
func (d DirCache) Put(ctx context.Context, name string, data []byte) error {
if err := os.MkdirAll(string(d), 0700); err != nil {
return err
}
done := make(chan struct{})
var err error
go func() {
defer close(done)
var tmp string
if tmp, err = d.writeTempFile(name, data); err != nil {
return
}
select {
case <-ctx.Done():
// Don't overwrite the file if the context was canceled.
default:
newName := filepath.Join(string(d), name)
err = os.Rename(tmp, newName)
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
}
return err
}
// Delete removes the specified file name.
func (d DirCache) Delete(ctx context.Context, name string) error {
name = filepath.Join(string(d), name)
var (
err error
done = make(chan struct{})
)
go func() {
err = os.Remove(name)
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
}
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// writeTempFile writes b to a temporary file, closes the file and returns its path.
func (d DirCache) writeTempFile(prefix string, b []byte) (string, error) {
// TempFile uses 0600 permissions
f, err := ioutil.TempFile(string(d), prefix)
if err != nil {
return "", err
}
if _, err := f.Write(b); err != nil {
f.Close()
return "", err
}
return f.Name(), f.Close()
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert
import (
"crypto/tls"
"log"
"net"
"os"
"path/filepath"
"runtime"
"time"
)
// NewListener returns a net.Listener that listens on the standard TLS
// port (443) on all interfaces and returns *tls.Conn connections with
// LetsEncrypt certificates for the provided domain or domains.
//
// It enables one-line HTTPS servers:
//
// log.Fatal(http.Serve(autocert.NewListener("example.com"), handler))
//
// NewListener is a convenience function for a common configuration.
// More complex or custom configurations can use the autocert.Manager
// type instead.
//
// Use of this function implies acceptance of the LetsEncrypt Terms of
// Service. If domains is not empty, the provided domains are passed
// to HostWhitelist. If domains is empty, the listener will do
// LetsEncrypt challenges for any requested domain, which is not
// recommended.
//
// Certificates are cached in a "golang-autocert" directory under an
// operating system-specific cache or temp directory. This may not
// be suitable for servers spanning multiple machines.
//
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
func NewListener(domains ...string) net.Listener {
m := &Manager{
Prompt: AcceptTOS,
}
if len(domains) > 0 {
m.HostPolicy = HostWhitelist(domains...)
}
dir := cacheDir()
if err := os.MkdirAll(dir, 0700); err != nil {
log.Printf("warning: autocert.NewListener not using a cache: %v", err)
} else {
m.Cache = DirCache(dir)
}
return m.Listener()
}
// Listener listens on the standard TLS port (443) on all interfaces
// and returns a net.Listener returning *tls.Conn connections.
//
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
//
// Unlike NewListener, it is the caller's responsibility to initialize
// the Manager m's Prompt, Cache, HostPolicy, and other desired options.
func (m *Manager) Listener() net.Listener {
ln := &listener{
m: m,
conf: m.TLSConfig(),
}
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")
return ln
}
type listener struct {
m *Manager
conf *tls.Config
tcpListener net.Listener
tcpListenErr error
}
func (ln *listener) Accept() (net.Conn, error) {
if ln.tcpListenErr != nil {
return nil, ln.tcpListenErr
}
conn, err := ln.tcpListener.Accept()
if err != nil {
return nil, err
}
tcpConn := conn.(*net.TCPConn)
// Because Listener is a convenience function, help out with
// this too. This is not possible for the caller to set once
// we return a *tcp.Conn wrapping an inaccessible net.Conn.
// If callers don't want this, they can do things the manual
// way and tweak as needed. But this is what net/http does
// itself, so copy that. If net/http changes, we can change
// here too.
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(3 * time.Minute)
return tls.Server(tcpConn, ln.conf), nil
}
func (ln *listener) Addr() net.Addr {
if ln.tcpListener != nil {
return ln.tcpListener.Addr()
}
// net.Listen failed. Return something non-nil in case callers
// call Addr before Accept:
return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443}
}
func (ln *listener) Close() error {
if ln.tcpListenErr != nil {
return ln.tcpListenErr
}
return ln.tcpListener.Close()
}
func homeDir() string {
if runtime.GOOS == "windows" {
return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
}
if h := os.Getenv("HOME"); h != "" {
return h
}
return "/"
}
func cacheDir() string {
const base = "golang-autocert"
switch runtime.GOOS {
case "darwin":
return filepath.Join(homeDir(), "Library", "Caches", base)
case "windows":
for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} {
if v := os.Getenv(ev); v != "" {
return filepath.Join(v, base)
}
}
// Worst case:
return filepath.Join(homeDir(), base)
}
if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" {
return filepath.Join(xdg, base)
}
return filepath.Join(homeDir(), ".cache", base)
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert
import (
"context"
"crypto"
"sync"
"time"
)
// renewJitter is the maximum deviation from Manager.RenewBefore.
const renewJitter = time.Hour
// domainRenewal tracks the state used by the periodic timers
// renewing a single domain's cert.
type domainRenewal struct {
m *Manager
ck certKey
key crypto.Signer
timerMu sync.Mutex
timer *time.Timer
}
// start starts a cert renewal timer at the time
// defined by the certificate expiration time exp.
//
// If the timer is already started, calling start is a noop.
func (dr *domainRenewal) start(exp time.Time) {
dr.timerMu.Lock()
defer dr.timerMu.Unlock()
if dr.timer != nil {
return
}
dr.timer = time.AfterFunc(dr.next(exp), dr.renew)
}
// stop stops the cert renewal timer.
// If the timer is already stopped, calling stop is a noop.
func (dr *domainRenewal) stop() {
dr.timerMu.Lock()
defer dr.timerMu.Unlock()
if dr.timer == nil {
return
}
dr.timer.Stop()
dr.timer = nil
}
// renew is called periodically by a timer.
// The first renew call is kicked off by dr.start.
func (dr *domainRenewal) renew() {
dr.timerMu.Lock()
defer dr.timerMu.Unlock()
if dr.timer == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// TODO: rotate dr.key at some point?
next, err := dr.do(ctx)
if err != nil {
next = renewJitter / 2
next += time.Duration(pseudoRand.int63n(int64(next)))
}
dr.timer = time.AfterFunc(next, dr.renew)
testDidRenewLoop(next, err)
}
// updateState locks and replaces the relevant Manager.state item with the given
// state. It additionally updates dr.key with the given state's key.
func (dr *domainRenewal) updateState(state *certState) {
dr.m.stateMu.Lock()
defer dr.m.stateMu.Unlock()
dr.key = state.key
dr.m.state[dr.ck] = state
}
// do is similar to Manager.createCert but it doesn't lock a Manager.state item.
// Instead, it requests a new certificate independently and, upon success,
// replaces dr.m.state item with a new one and updates cache for the given domain.
//
// It may lock and update the Manager.state if the expiration date of the currently
// cached cert is far enough in the future.
//
// The returned value is a time interval after which the renewal should occur again.
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment
// but we try nonetheless
if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil {
next := dr.next(tlscert.Leaf.NotAfter)
if next > dr.m.renewBefore()+renewJitter {
signer, ok := tlscert.PrivateKey.(crypto.Signer)
if ok {
state := &certState{
key: signer,
cert: tlscert.Certificate,
leaf: tlscert.Leaf,
}
dr.updateState(state)
return next, nil
}
}
}
der, leaf, err := dr.m.authorizedCert(ctx, dr.key, dr.ck)
if err != nil {
return 0, err
}
state := &certState{
key: dr.key,
cert: der,
leaf: leaf,
}
tlscert, err := state.tlscert()
if err != nil {
return 0, err
}
if err := dr.m.cachePut(ctx, dr.ck, tlscert); err != nil {
return 0, err
}
dr.updateState(state)
return dr.next(leaf.NotAfter), nil
}
func (dr *domainRenewal) next(expiry time.Time) time.Duration {
d := expiry.Sub(timeNow()) - dr.m.renewBefore()
// add a bit of randomness to renew deadline
n := pseudoRand.int63n(int64(renewJitter))
d -= time.Duration(n)
if d < 0 {
return 0
}
return d
}
var testDidRenewLoop = func(next time.Duration, err error) {}
此差异已折叠。
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package acme
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
_ "crypto/sha512" // need for EC keys
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
)
// jwsEncodeJSON signs claimset using provided key and a nonce.
// The result is serialized in JSON format.
// See https://tools.ietf.org/html/rfc7515#section-7.
func jwsEncodeJSON(claimset interface{}, key crypto.Signer, nonce string) ([]byte, error) {
jwk, err := jwkEncode(key.Public())
if err != nil {
return nil, err
}
alg, sha := jwsHasher(key)
if alg == "" || !sha.Available() {
return nil, ErrUnsupportedKey
}
phead := fmt.Sprintf(`{"alg":%q,"jwk":%s,"nonce":%q}`, alg, jwk, nonce)
phead = base64.RawURLEncoding.EncodeToString([]byte(phead))
cs, err := json.Marshal(claimset)
if err != nil {
return nil, err
}
payload := base64.RawURLEncoding.EncodeToString(cs)
hash := sha.New()
hash.Write([]byte(phead + "." + payload))
sig, err := jwsSign(key, sha, hash.Sum(nil))
if err != nil {
return nil, err
}
enc := struct {
Protected string `json:"protected"`
Payload string `json:"payload"`
Sig string `json:"signature"`
}{
Protected: phead,
Payload: payload,
Sig: base64.RawURLEncoding.EncodeToString(sig),
}
return json.Marshal(&enc)
}
// jwkEncode encodes public part of an RSA or ECDSA key into a JWK.
// The result is also suitable for creating a JWK thumbprint.
// https://tools.ietf.org/html/rfc7517
func jwkEncode(pub crypto.PublicKey) (string, error) {
switch pub := pub.(type) {
case *rsa.PublicKey:
// https://tools.ietf.org/html/rfc7518#section-6.3.1
n := pub.N
e := big.NewInt(int64(pub.E))
// Field order is important.
// See https://tools.ietf.org/html/rfc7638#section-3.3 for details.
return fmt.Sprintf(`{"e":"%s","kty":"RSA","n":"%s"}`,
base64.RawURLEncoding.EncodeToString(e.Bytes()),
base64.RawURLEncoding.EncodeToString(n.Bytes()),
), nil
case *ecdsa.PublicKey:
// https://tools.ietf.org/html/rfc7518#section-6.2.1
p := pub.Curve.Params()
n := p.BitSize / 8
if p.BitSize%8 != 0 {
n++
}
x := pub.X.Bytes()
if n > len(x) {
x = append(make([]byte, n-len(x)), x...)
}
y := pub.Y.Bytes()
if n > len(y) {
y = append(make([]byte, n-len(y)), y...)
}
// Field order is important.
// See https://tools.ietf.org/html/rfc7638#section-3.3 for details.
return fmt.Sprintf(`{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`,
p.Name,
base64.RawURLEncoding.EncodeToString(x),
base64.RawURLEncoding.EncodeToString(y),
), nil
}
return "", ErrUnsupportedKey
}
// jwsSign signs the digest using the given key.
// It returns ErrUnsupportedKey if the key type is unknown.
// The hash is used only for RSA keys.
func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) {
switch key := key.(type) {
case *rsa.PrivateKey:
return key.Sign(rand.Reader, digest, hash)
case *ecdsa.PrivateKey:
r, s, err := ecdsa.Sign(rand.Reader, key, digest)
if err != nil {
return nil, err
}
rb, sb := r.Bytes(), s.Bytes()
size := key.Params().BitSize / 8
if size%8 > 0 {
size++
}
sig := make([]byte, size*2)
copy(sig[size-len(rb):], rb)
copy(sig[size*2-len(sb):], sb)
return sig, nil
}
return nil, ErrUnsupportedKey
}
// jwsHasher indicates suitable JWS algorithm name and a hash function
// to use for signing a digest with the provided key.
// It returns ("", 0) if the key is not supported.
func jwsHasher(key crypto.Signer) (string, crypto.Hash) {
switch key := key.(type) {
case *rsa.PrivateKey:
return "RS256", crypto.SHA256
case *ecdsa.PrivateKey:
switch key.Params().Name {
case "P-256":
return "ES256", crypto.SHA256
case "P-384":
return "ES384", crypto.SHA384
case "P-521":
return "ES512", crypto.SHA512
}
}
return "", 0
}
// JWKThumbprint creates a JWK thumbprint out of pub
// as specified in https://tools.ietf.org/html/rfc7638.
func JWKThumbprint(pub crypto.PublicKey) (string, error) {
jwk, err := jwkEncode(pub)
if err != nil {
return "", err
}
b := sha256.Sum256([]byte(jwk))
return base64.RawURLEncoding.EncodeToString(b[:]), nil
}
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册