diff --git a/README.md b/README.md index 373c0fdffdd2fde9b298e0656b8842a0051233f3..ac35709aeac1d0d9412c7d75c497a8e348554208 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ GoKu API Gateway CE,支持OpenAPI与微服务管理,支持私有云部署, 6. **IP黑白名单**:支持全局IP白名单、也可自定义某个接口的IP白名单。 -7. **数据整形**:支持参数的转换与绑定,支持formdata、raw数据、json。 +7. **数据整形**:支持参数的转换与绑定,支持formdata、raw数据。 8. **配置文件**:支持配置文件修改网关配置。 @@ -32,15 +32,13 @@ GoKu API Gateway CE,支持OpenAPI与微服务管理,支持私有云部署, 1. **UI界面**:支持通过UI界面修改网关配置。 -2. **API支持**:支持通过API对网关进行操作。 +2. **兼容eoLinker-AMS**:可与国内最大的接口管理平台打通。 -3. **兼容eoLinker-AMS**:可与国内最大的接口管理平台打通。 +3. **支持Restful**:支持rest路由。 -4. **支持Restful**:支持rest路由。 +4. **告警设置**:当系统达到预设告警条件时,邮件通知运维人员。 -6. **告警设置**:当系统达到预设告警条件时,邮件通知运维人员。 - -7. **超时设置**:配置访问超时时间,网关控制超时后立即返回,防止系统雪崩。 +5. **超时设置**:配置访问超时时间,网关控制超时后立即返回,防止系统雪崩。 **……** @@ -73,10 +71,20 @@ GoKu API Gateway CE,支持OpenAPI与微服务管理,支持私有云部署, ## 更新日志 +#### V2.0.3(2018/5/14) +新增: + +1. 支持form-data格式下文件传输; +2. 接口配置文件弃用proxy_body_type、proxy_body_desc字段,启用is_raw字段用于支持raw数据转发。 + +优化: + +1. 基于HttpRouter优化路由转发性能。 + #### V2.0.2(2018/5/7) 修复: -1. 修复请求路径带参数时,匹配路径失败; +1. 修复请求路径带query参数时,路径匹配失败的问题; 2. 修复proxy_method配置必须大写的问题,现支持不区分大小写。 #### V2.0.1(2018/5/4) diff --git a/release/goku_ce_2.0.1.zip b/release/goku_ce_2.0.1.zip deleted file mode 100644 index bf94fcb958d32b104415e8d697fa27b666a7009d..0000000000000000000000000000000000000000 Binary files a/release/goku_ce_2.0.1.zip and /dev/null differ diff --git a/release/goku_ce_2.0.2.zip b/release/goku_ce_2.0.2.zip deleted file mode 100644 index c72575954104a6d9a0189e8b054ee5b7d9bfaaf5..0000000000000000000000000000000000000000 Binary files a/release/goku_ce_2.0.2.zip and /dev/null differ diff --git a/release/goku_ce_2.0.3.zip b/release/goku_ce_2.0.3.zip new file mode 100644 index 0000000000000000000000000000000000000000..c38d78608d2e33acecbf6f74b1b6dd9c774dcca6 Binary files /dev/null and b/release/goku_ce_2.0.3.zip differ diff --git a/source_code/conf/conf.go b/source_code/conf/conf.go index 00c782f7ec4bf6046db84667bfc1650e9a0848a2..6d04cd01a3ad4071f7284705d55227ba5bb94460 100644 --- a/source_code/conf/conf.go +++ b/source_code/conf/conf.go @@ -33,6 +33,9 @@ type GatewayInfo struct { StrategyList Strategy ApiList Api BackendList Backend + UpdateTime string `json:"update_time" yaml:"update_time"` + CreateTime string `json:"create_time" yaml:"create_time"` + GroupList ApiGroup } type Strategy struct { @@ -50,6 +53,8 @@ type StrategyInfo struct { IPWhiteList []string `json:"ip_white_list" yaml:"ip_white_list"` IPBlackList []string `json:"ip_black_list" yaml:"ip_black_list"` RateLimitList []RateLimitInfo `json:"rate_limit_list" yaml:"rate_limit_list"` + UpdateTime string `json:"update_time" yaml:"update_time"` + CreateTime string `json:"create_time" yaml:"create_time"` } type RateLimitInfo struct { @@ -67,7 +72,7 @@ type ApiGroupInfo struct { } type ApiGroup struct { - Group ApiGroupInfo `json:"group" yaml:"group"` + Group []ApiGroupInfo `json:"group" yaml:"group"` } type Api struct { @@ -82,8 +87,7 @@ type ApiInfo struct { BackendID int `json:"backend_id" yaml:"backend_id"` ProxyURL string `json:"proxy_url" yaml:"proxy_url"` ProxyMethod string `json:"proxy_method" yaml:"proxy_method"` - ProxyBodyType string `json:"proxy_body_type" yaml:"proxy_body_type"` - ProxyBody string `json:"proxy_body" yaml:"proxy_body"` + IsRaw bool `json:"is_raw" yaml:"is_raw"` ProxyParams []Param `json:"proxy_params" yaml:"proxy_params"` ConstantParams []ConstantParam `json:"constant_params" yaml:"constant_params"` } diff --git a/source_code/conf/parse_config.go b/source_code/conf/parse_config.go index 1387751f6598b622dfbcae56e6c7cb1524b6ce87..cba029be11edfab41b5d7d06531f73cd91d61766 100644 --- a/source_code/conf/parse_config.go +++ b/source_code/conf/parse_config.go @@ -13,11 +13,11 @@ func ParseConfInfo() GlobalConfig { var g GlobalConfig err := yaml.Unmarshal([]byte(Configure),&g) if err != nil { - panic("Error global config!") + panic("Global Config Error!") } path,err := GetDir(g.GatewayConfPath) if err != nil { - panic("Error gateway config path!") + panic("Gateway Config Path Error!") } fmt.Println(path) gatewayList := getGatewayList(path) @@ -33,11 +33,11 @@ func getGatewayList(path []string) []GatewayInfo { c,err := ioutil.ReadFile(p + PthSep + "gateway.conf") if err != nil { - panic("Error gateway config path! Error path: " + p) + panic("Gateway Config Path Error! Error path: " + p) } err = yaml.Unmarshal(c,&gateway) if err != nil { - panic("Error gateway config! Error path: " + p) + panic("Gateway Config Error! Error path: " + p) } if gateway.GatewayStatus != "on" { continue @@ -54,11 +54,11 @@ func getStrategyList(path string) Strategy { var strategy Strategy c,err := ioutil.ReadFile(path) if err != nil { - panic("Error strategy config path! Error path: " + path) + panic("Strategy Config Path Error! Error path: " + path) } err = yaml.Unmarshal(c,&strategy) if err != nil { - panic("Error strategy config! Error path: " + path) + panic("Strategy Config Error! Error path: " + path) } return strategy } @@ -67,11 +67,11 @@ func getApiList(path string) Api { var api Api c,err := ioutil.ReadFile(path) if err != nil { - panic("Error api config path! Error path: " + path) + panic("Api Config Path Error! Error path: " + path) } err = yaml.Unmarshal(c,&api) if err != nil { - panic("Error api config! Error path: " + path) + panic("Api Config Error! Error path: " + path) } return api } @@ -80,7 +80,7 @@ func getBackendList(path string) Backend { var backend Backend c,err := ioutil.ReadFile(path) if err != nil { - panic("Error backend config path! Error path: " + path) + panic("Backend Config Path Error! Error path: " + path) } err = yaml.Unmarshal(c,&backend) if err != nil { diff --git a/source_code/config/gateway/test/api.conf b/source_code/config/gateway/test/api.conf index 7a91f6a620b9ae1d86b044ce1635ecf4d625c92b..7fea9a577bb29086e2c79b94e11ed8aa04de0315 100644 --- a/source_code/config/gateway/test/api.conf +++ b/source_code/config/gateway/test/api.conf @@ -1,6 +1,6 @@ apis: # api列表 - api_name: 全国油价 # 接口名称 - group_id: 0 # 接口所属分组,0代表无分组 + group_id: 0 # 接口所属分组 backend_id: 1 # 后端服务 request_url: /common/oil/getOilPriceToday # 网关请求路径 request_method: # 请求方法:get/post/put/delete/options/patch/head @@ -8,8 +8,7 @@ apis: # api列表 - post proxy_url: /common/oil/getOilPriceToday # 后端请求路径 proxy_method: post # 非数组,后端请求方法:get/post/put/delete/options/patch/head - proxy_body_type: formdata # body类型:formdata/raw/json - proxy_body: "" # raw内容 + is_raw: false # 请求body是否为raw数据:true/false proxy_params: # 请求参数映射列表 - key: province # 网关接口参数名 key_position: body # 网关参数位置:query/body/header diff --git a/source_code/config/gateway/test/backend.conf b/source_code/config/gateway/test/backend.conf index ad6d2ccd12c607dce745940a529d5e7cea50102b..519714009b42faaca224d18ca45450a94f493ffb 100644 --- a/source_code/config/gateway/test/backend.conf +++ b/source_code/config/gateway/test/backend.conf @@ -1,4 +1,4 @@ backend: # 后端服务列表 -- backend_id: 1 # 后端服务id,0代表无API分组 +- backend_id: 1 # 后端服务id backend_name: 测试后端 # 后端服务名称 backend_path: https://api.apishop.net # 后端服务地址 \ No newline at end of file diff --git a/source_code/goku-ce.go b/source_code/goku-ce.go index 96e9af573ff74e57862894ef37200a23f5651607..84894c634c312ad7485014a4fd62c289fbe8e604 100644 --- a/source_code/goku-ce.go +++ b/source_code/goku-ce.go @@ -8,10 +8,9 @@ import ( func main() { server := goku.New() - server.Use(middleware.Mapping) + server.RegisterRouter(server.ServiceConfig,middleware.Mapping) server.Listen() server.Run() } - diff --git a/source_code/goku/context.go b/source_code/goku/context.go new file mode 100644 index 0000000000000000000000000000000000000000..0b564c9df911a2233c1cb37e75b2e27d3be6d90e --- /dev/null +++ b/source_code/goku/context.go @@ -0,0 +1,41 @@ +package goku + +import ( + "goku-ce/conf" +) +type Context struct { + GatewayInfo Gateway + StrategyInfo Strategy + ApiInfo Api + Rate map[string]Rate +} + +type Gateway struct { + GatewayAlias string `json:"gateway_alias" yaml:"gateway_alias"` + GatewayStatus string `json:"gateway_status" yaml:"gateway_status"` + IPLimitType string `json:"ip_limit_type" yaml:"ip_limit_type"` + IPWhiteList []string `json:"ip_white_list" yaml:"ip_white_list"` + IPBlackList []string `json:"ip_black_list" yaml:"ip_black_list"` +} + +type Strategy struct { + StrategyID string `json:"strategy_id" yaml:"strategy_id"` + Auth string `json:"auth" yaml:"auth"` + BasicUserName string `json:"basic_user_name" yaml:"basic_user_name"` + BasicUserPassword string `json:"basic_user_password" yaml:"basic_user_password"` + ApiKey string `json:"api_key" yaml:"api_key"` + IPLimitType string `json:"ip_limit_type" yaml:"ip_limit_type"` + IPWhiteList []string `json:"ip_white_list" yaml:"ip_white_list"` + IPBlackList []string `json:"ip_black_list" yaml:"ip_black_list"` + RateLimitList []conf.RateLimitInfo `json:"rate_limit_list" yaml:"rate_limit_list"` +} + +type Api struct { + RequestURL string `json:"request_url" yaml:"request_url"` + BackendPath string `json:"backend_path" yaml:"backend_path"` + ProxyURL string `json:"proxy_url" yaml:"proxy_url"` + ProxyMethod string `json:"proxy_method" yaml:"proxy_method"` + IsRaw bool `json:"is_raw" yaml:"is_raw"` + ProxyParams []conf.Param `json:"proxy_params" yaml:"proxy_params"` + ConstantParams []conf.ConstantParam `json:"constant_params" yaml:"constant_params"` +} diff --git a/source_code/goku/goku.go b/source_code/goku/goku.go index da9b77113717839f81a8c1092dd3a1ce12e20c62..bc3c3bd68528bed1bfda43d37540cb485c792b98 100644 --- a/source_code/goku/goku.go +++ b/source_code/goku/goku.go @@ -8,55 +8,44 @@ import ( "net/http" "log" "os" - "sync/atomic" - "time" ) type GokuServer interface{ Run() error - Use(handler ...Handler) Address() string Listener() net.Listener Listen() error } -type Handler interface{} - -type Injector interface { - Get(reflect.Type) reflect.Value - Map(interface{}) Injector -} - type classicGoku struct { *Goku } + type Goku struct{ - handlers []Handler + *Router index int ServiceConfig conf.GlobalConfig logger *log.Logger listener *net.Listener address string values map[reflect.Type]reflect.Value - parent Injector cClose chan bool isStopping bool activeCount int32 - Rate map[string]Rate } -func (i *Goku) Map(val interface{}) Injector { - i.values[reflect.TypeOf(val)] = reflect.ValueOf(val) - return i -} // 启动一个Goku实例 -func New() GokuServer{ - g := &Goku{values: make(map[reflect.Type]reflect.Value),logger:log.New(os.Stdout, "[Goku]", 0),Rate:make(map[string]Rate)} +func New() *Goku{ + g := &Goku{ + Router:NewRouter(), + values: make(map[reflect.Type]reflect.Value), + logger:log.New(os.Stdout, "[Goku]", 0), + } + g.ServiceConfig = conf.ParseConfInfo() - g.Map(g) - return &classicGoku{g} + return g } func (g *Goku) Run() error{ @@ -106,100 +95,3 @@ func (g *Goku) Listener() net.Listener{ return *g.listener } -func (g *Goku) run() { - for g.index < len(g.handlers) { - handle := g.handlers[g.index] - _, err := g.Invoke(handle) - if err != nil { - panic(err) - } - g.index += 1 - - } - g.index = 0 - return -} - - - -func (g *Goku) Use(handler ...Handler) { - for _,h := range handler{ - ValidateHandler(h) - g.handlers = append(g.handlers,h) - } -} - -func (i *Goku) Get(t reflect.Type) reflect.Value { - val := i.values[t] - - if val.IsValid() { - return val - } - - if t.Kind() == reflect.Interface { - for k, v := range i.values { - if k.Implements(t) { - val = v - break - } - } - } - - // Still no type found, try to look it up on the parent - if !val.IsValid() && i.parent != nil { - val = i.parent.Get(t) - } - - return val - -} - -// 调用函数 -func (g *Goku) Invoke(handler Handler) ([]reflect.Value,error) { - ValidateHandler(handler) - t := reflect.TypeOf(handler) - var in = make([]reflect.Value, t.NumIn()) - for i := 0; i < t.NumIn(); i++ { - argType := t.In(i) - val := g.Get(argType) - if !val.IsValid() { - return nil, fmt.Errorf("Value not found for type %v", argType) - } - - in[i] = val - } - return reflect.ValueOf(handler).Call(in),nil -} - -func (c *Goku) IsStopped() bool { - return !(c.index < len(c.handlers)) -} - - -func (g *Goku) ServeHTTP(res http.ResponseWriter, req *http.Request) { - if(req.RequestURI == "/favicon.ico"){ - return - } - g.Map(res) - g.Map(req) - activeCount := atomic.AddInt32(&g.activeCount, -1) - if g.isStopping && activeCount == 0 { - time.Sleep(1) - g.cClose <- true - } - g.run() -} - -func InterfaceOf(value interface{}) reflect.Type { - t := reflect.TypeOf(value) - - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - - if t.Kind() != reflect.Interface { - panic("Called inject.InterfaceOf with a value that is not a pointer to an interface. (*MyInterface)(nil)") - } - - return t -} diff --git a/source_code/goku/path.go b/source_code/goku/path.go new file mode 100644 index 0000000000000000000000000000000000000000..a070c0b5c0580c6da18ada86125f8ca8d328447d --- /dev/null +++ b/source_code/goku/path.go @@ -0,0 +1,87 @@ +package goku + +func CleanPath(p string) string { + if p == "" { + return "/" + } + + n := len(p) + var buf []byte + + r := 1 + w := 1 + + if p[0] != '/' { + r = 0 + buf = make([]byte, n+1) + buf[0] = '/' + } + + trailing := n > 1 && p[n-1] == '/' + + + for r < n { + switch { + case p[r] == '/': + r++ + + case p[r] == '.' && r+1 == n: + trailing = true + r++ + + case p[r] == '.' && p[r+1] == '/': + r++ + + case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'): + r += 2 + + if w > 1 { + w-- + + if buf == nil { + for w > 1 && p[w] != '/' { + w-- + } + } else { + for w > 1 && buf[w] != '/' { + w-- + } + } + } + + default: + if w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + for r < n && p[r] != '/' { + bufApp(&buf, p, w, p[r]) + w++ + r++ + } + } + } + + if trailing && w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + if buf == nil { + return p[:w] + } + return string(buf[:w]) +} + +func bufApp(buf *[]byte, s string, w int, c byte) { + if *buf == nil { + if s[w] == c { + return + } + + *buf = make([]byte, len(s)) + copy(*buf, s[:w]) + } + (*buf)[w] = c +} diff --git a/source_code/goku/router.go b/source_code/goku/router.go new file mode 100644 index 0000000000000000000000000000000000000000..d19afb5770dff8f3636c55f232331b3b2931af69 --- /dev/null +++ b/source_code/goku/router.go @@ -0,0 +1,305 @@ +package goku + +import ( + "net/http" + "goku-ce/conf" + "strings" +) + +// Handle是一个可以被注册到路由中去处理http请求,类似于HandlerFunc,但是有第三个参数值 +type Handle func(http.ResponseWriter, *http.Request, Params,*Context) +// Param is a single URL parameter, consisting of a key and a value. +type Param struct { + Key string + Value string +} + +// Params是一个参数切片,作为路由的返回结果,这个切片是有序的 +// 第一个URL参数会作为第一个切片值,因此通过索引来读值是安全的 +type Params []Param + +// ByName returns the value of the first Param which key matches the given name. +// If no matching Param is found, an empty string is returned. +func (ps Params) ByName(name string) string { + for i := range ps { + if ps[i].Key == name { + return ps[i].Value + } + } + return "" +} + +// Router是一个可以被用来调度请求去不同处理函数的Handler +type Router struct { + trees map[string]*node + + context *Context + + handle Handle + + RedirectTrailingSlash bool + + RedirectFixedPath bool + + HandleMethodNotAllowed bool + + HandleOPTIONS bool + + NotFound http.Handler + + MethodNotAllowed http.Handler + + PanicHandler func(http.ResponseWriter, *http.Request, interface{}) +} + +func NewRouter() *Router { + return &Router{ + RedirectTrailingSlash: true, + RedirectFixedPath: true, + HandleMethodNotAllowed: true, + HandleOPTIONS: true, + } +} + + +func (r *Router) Use(handle Handle) { + r.handle = handle +} + + +func (r *Router) Handle(method, path string, handle Handle,context Context) { + if path[0] != '/' { + panic("path must begin with '/' in path '" + path + "'") + } + + if r.trees == nil { + r.trees = make(map[string]*node) + } + + root := r.trees[method] + if root == nil { + root = new(node) + r.trees[method] = root + } + root.addRoute(path, handle,context) +} + +// // HandlerFunc 是一个适配器允许使用http.HandleFunc函数作为一个请求处理器 +// func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) { +// r.Handler(method, path, handler) +// } + +func (r *Router) recv(w http.ResponseWriter, req *http.Request) { + if rcv := recover(); rcv != nil { + r.PanicHandler(w, req, rcv) + } +} +// 查找允许手动查找方法 + 路径组合。 +// 这对于构建围绕此路由器的框架非常有用。 +// 如果找到路径, 它将返回句柄函数和路径参数值 +// 否则, 第三个返回值指示是否应执行与附加/不带尾随斜线的同一路径的重定向 +func (r *Router) Lookup(method, path string) (Handle, Params, *Context, bool) { + if root := r.trees[method]; root != nil { + return root.getValue(path) + } + return nil, nil,&Context{}, false +} + +func (r *Router) allowed(path, reqMethod string) (allow string) { + if path == "*" { // server-wide + for method := range r.trees { + if method == "OPTIONS" { + continue + } + + // 将请求方法添加到允许的方法列表中 + if len(allow) == 0 { + allow = method + } else { + allow += ", " + method + } + + } + } else { // 特定路径 + for method := range r.trees { + // 跳过请求的方法-我们已经尝试过这一项 + if method == reqMethod || method == "OPTIONS" { + continue + } + + handle, _, _,_ := r.trees[method].getValue(path) + if handle != nil { + // 将请求方法添加到允许的方法列表中 + if len(allow) == 0 { + allow = method + } else { + allow += ", " + method + } + } + } + } + if len(allow) > 0 { + allow += ", OPTIONS" + } + return +} + +// ServeHTTP使用路由实现http.Handler接口 +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if r.PanicHandler != nil { + + defer r.recv(w, req) + } + // now := time.Now() + path := req.URL.Path + pathArray := strings.Split(path,"/") + + if len(pathArray) == 2 { + w.WriteHeader(500) + if pathArray[1] == "" { + w.Write([]byte("Missing Gateway Alias")) + } else { + w.Write([]byte("Missing StrategyID")) + } + return + } else if len(pathArray) == 3 { + w.WriteHeader(500) + if pathArray[2] == "" { + w.Write([]byte("Missing StrategyID")) + } else { + w.Write([]byte("Invalid URI")) + } + return + } + + if root := r.trees[req.Method]; root != nil { + handle, ps, context,tsr := root.getValue(path); + if handle != nil { + handle(w, req, ps,context) + return + } else if req.Method != "CONNECT" && path != "/" { + code := 301 + if req.Method != "GET" { + code = 307 + } + + if tsr && r.RedirectTrailingSlash { + if len(path) > 1 && path[len(path)-1] == '/' { + req.URL.Path = path[:len(path)-1] + } else { + req.URL.Path = path + "/" + } + http.Redirect(w, req, req.URL.String(), code) + return + } + // 尝试修复请求路径 + if r.RedirectFixedPath { + fixedPath, found := root.findCaseInsensitivePath( + CleanPath(path), + r.RedirectTrailingSlash, + ) + if found { + req.URL.Path = string(fixedPath) + http.Redirect(w, req, req.URL.String(), code) + return + } + } + } + } + + if req.Method == "OPTIONS" && r.HandleOPTIONS { + // Handle OPTIONS requests + if allow := r.allowed(path, req.Method); len(allow) > 0 { + w.Header().Set("Allow", allow) + return + } + } else { + // Handle 405 + if r.HandleMethodNotAllowed { + if allow := r.allowed(path, req.Method); len(allow) > 0 { + w.Header().Set("Allow", allow) + if r.MethodNotAllowed != nil { + r.MethodNotAllowed.ServeHTTP(w, req) + } else { + http.Error(w, + http.StatusText(http.StatusMethodNotAllowed), + http.StatusMethodNotAllowed, + ) + } + return + } + } + } + + // Handle 404 + if r.NotFound != nil { + r.NotFound.ServeHTTP(w, req) + } else { + w.WriteHeader(404) + w.Write([]byte("Invalid URI")) + } +} + +// 注册路由 +func (r *Router) RegisterRouter(c conf.GlobalConfig,handle Handle) { + r.handle = handle + for _, g := range c.GatewayList { + if g.GatewayStatus != "on" { + continue + } + gateway := Gateway{ + GatewayAlias: g.GatewayAlias, + GatewayStatus: g.GatewayStatus, + IPLimitType: g.IPLimitType, + IPWhiteList: g.IPWhiteList, + IPBlackList: g.IPBlackList, + } + for _, s := range g.StrategyList.Strategy { + strategy := Strategy{ + StrategyID: s.StrategyID, + Auth: s.Auth, + ApiKey: s.ApiKey, + BasicUserName: s.BasicUserName, + BasicUserPassword: s.BasicUserPassword, + IPLimitType: s.IPLimitType, + IPWhiteList: s.IPWhiteList, + IPBlackList: s.IPBlackList, + RateLimitList:s.RateLimitList, + } + for _, api := range g.ApiList.Apis { + path := "/" + g.GatewayAlias + "/" + s.StrategyID + api.RequestURL + backendPath := "" + flag := false + // 获取后端请求路径 + for _,b := range g.BackendList.Backend { + if b.BackendID == api.BackendID{ + backendPath = b.BackendPath + flag =true + break + } + } + if !flag && api.BackendID != -1{ + continue + } + apiInfo := Api{ + RequestURL: api.RequestURL, + BackendPath: backendPath, + ProxyURL: api.ProxyURL, + IsRaw:api.IsRaw, + ProxyMethod:api.ProxyMethod, + ProxyParams:api.ProxyParams, + ConstantParams:api.ConstantParams, + } + context := Context{ + GatewayInfo:gateway, + StrategyInfo:strategy, + ApiInfo:apiInfo, + Rate:make(map[string]Rate), + } + for _,method := range api.RequestMethod { + r.Handle(strings.ToUpper(method),path,r.handle,context) + } + } + } + } +} diff --git a/source_code/goku/tree.go b/source_code/goku/tree.go new file mode 100644 index 0000000000000000000000000000000000000000..e228bf1079d47d89a712f68c27a2e17d69a80002 --- /dev/null +++ b/source_code/goku/tree.go @@ -0,0 +1,619 @@ +package goku + +import ( + "strings" + "unicode" + "unicode/utf8" + "fmt" +) + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + + +// 计算路径中参数数量 +func countParams(path string) uint8 { + var n uint + for i := 0; i < len(path); i++ { + if path[i] != ':' && path[i] != '*' { + continue + } + n++ + } + if n >= 255 { + return 255 + } + return uint8(n) +} + +type nodeType uint8 + +const ( + static nodeType = iota // default + root + param + catchAll +) + +type node struct { + path string + wildChild bool + nType nodeType + maxParams uint8 + indices string + children []*node + handle Handle + priority uint32 + context *Context +} + + +// 在必要时给定子项和排序的优先级 +func (n *node) incrementChildPrio(pos int) int { + n.children[pos].priority++ + prio := n.children[pos].priority + + newPos := pos + for newPos > 0 && n.children[newPos-1].priority < prio { + n.children[newPos-1], n.children[newPos] = n.children[newPos], n.children[newPos-1] + + newPos-- + } + + if newPos != pos { + n.indices = n.indices[:newPos] + // unchanged prefix, might be empty + n.indices[pos:pos+1] + // the index char we move + n.indices[newPos:pos] + n.indices[pos+1:] // rest without char at 'pos' + } + + return newPos +} + +// addRoute将具有给定句柄的节点添加到路径中 +func (n *node) addRoute(path string, handle Handle,context Context) { + fullPath := path + n.priority++ + numParams := countParams(path) + + if len(n.path) > 0 || len(n.children) > 0 { + walk: + for { + // 更新当前节点的 maxParams + if numParams > n.maxParams { + n.maxParams = numParams + } + + i := 0 + max := min(len(path), len(n.path)) + for i < max && path[i] == n.path[i] { + i++ + } + + if i < len(n.path) { + child := node{ + path: n.path[i:], + wildChild: n.wildChild, + nType: static, + indices: n.indices, + children: n.children, + handle: n.handle, + context: n.context, + priority: n.priority - 1, + } + + for i := range child.children { + if child.children[i].maxParams > child.maxParams { + child.maxParams = child.children[i].maxParams + } + } + + n.children = []*node{&child} + n.indices = string([]byte{n.path[i]}) + n.path = path[:i] + n.handle = nil + n.context = &Context{} + n.wildChild = false + } + + if i < len(path) { + path = path[i:] + + if n.wildChild { + n = n.children[0] + n.priority++ + + if numParams > n.maxParams { + n.maxParams = numParams + } + numParams-- + + if len(path) >= len(n.path) && n.path == path[:len(n.path)] && + (len(n.path) >= len(path) || path[len(n.path)] == '/') { + continue walk + } else { + var pathSeg string + if n.nType == catchAll { + pathSeg = path + } else { + pathSeg = strings.SplitN(path, "/", 2)[0] + } + prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path + panic("'" + pathSeg + + "' in new path '" + fullPath + + "' conflicts with existing wildcard '" + n.path + + "' in existing prefix '" + prefix + + "'") + } + } + + c := path[0] + + if n.nType == param && c == '/' && len(n.children) == 1 { + n = n.children[0] + n.priority++ + continue walk + } + + for i := 0; i < len(n.indices); i++ { + if c == n.indices[i] { + i = n.incrementChildPrio(i) + n = n.children[i] + continue walk + } + } + + if c != ':' && c != '*' { + n.indices += string([]byte{c}) + child := &node{ + maxParams: numParams, + } + n.children = append(n.children, child) + n.incrementChildPrio(len(n.indices) - 1) + n = child + } + n.insertChild(numParams, path, fullPath, handle,context) + return + + } else if i == len(path) { + if n.handle != nil { + panic("a handle is already registered for path '" + fullPath + "'") + } + n.handle = handle + n.context = &context + } + return + } + } else { + n.insertChild(numParams, path, fullPath, handle,context) + n.nType = root + } +} + +func (n *node) insertChild(numParams uint8, path, fullPath string, handle Handle,context Context) { + var offset int // already handled bytes of the path + + for i, max := 0, len(path); numParams > 0; i++ { + c := path[i] + if c != ':' && c != '*' { + continue + } + + end := i + 1 + for end < max && path[end] != '/' { + switch path[end] { + case ':', '*': + panic("only one wildcard per path segment is allowed, has: '" + + path[i:] + "' in path '" + fullPath + "'") + default: + end++ + } + } + + if len(n.children) > 0 { + panic("wildcard route '" + path[i:end] + + "' conflicts with existing children in path '" + fullPath + "'") + } + + if end-i < 2 { + panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") + } + + if c == ':' { // param + if i > 0 { + n.path = path[offset:i] + offset = i + } + + child := &node{ + nType: param, + maxParams: numParams, + } + n.children = []*node{child} + n.wildChild = true + n = child + n.priority++ + numParams-- + + if end < max { + n.path = path[offset:end] + offset = end + + child := &node{ + maxParams: numParams, + priority: 1, + } + n.children = []*node{child} + n = child + } + + } else { // catchAll + if end != max || numParams > 1 { + panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") + } + + if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { + panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'") + } + + i-- + if path[i] != '/' { + panic("no / before catch-all in path '" + fullPath + "'") + } + + n.path = path[offset:i] + + child := &node{ + wildChild: true, + nType: catchAll, + maxParams: 1, + } + n.children = []*node{child} + n.indices = string(path[i]) + n = child + n.priority++ + + child = &node{ + path: path[i:], + nType: catchAll, + maxParams: 1, + handle: handle, + context: &context, + priority: 1, + } + n.children = []*node{child} + + return + } + } + + n.path = path[offset:] + n.handle = handle + n.context = &context +} + +func (n *node) getValue(path string) (handle Handle, p Params, context *Context, tsr bool) { +walk: // outer loop for walking the tree + for { + if len(path) > len(n.path) { + if path[:len(n.path)] == n.path { + path = path[len(n.path):] + if !n.wildChild { + c := path[0] + for i := 0; i < len(n.indices); i++ { + if c == n.indices[i] { + n = n.children[i] + continue walk + } + } + tsr = (path == "/" && n.handle != nil) + fmt.Println(123) + return + } + + // handle wildcard child + n = n.children[0] + switch n.nType { + case param: + // find param end (either '/' or path end) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // save param value + if p == nil { + // lazy allocation + p = make(Params, 0, n.maxParams) + } + i := len(p) + p = p[:i+1] // expand slice within preallocated capacity + p[i].Key = n.path[1:] + p[i].Value = path[:end] + + // we need to go deeper! + if end < len(path) { + if len(n.children) > 0 { + path = path[end:] + n = n.children[0] + continue walk + } + + // ... but we can't + tsr = (len(path) == end+1) + fmt.Println(456) + return + } + context = n.context + if handle = n.handle; handle != nil { + return + } else if len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists for TSR recommendation + n = n.children[0] + tsr = (n.path == "/" && n.handle != nil) + } + + return + + case catchAll: + // save param value + if p == nil { + // lazy allocation + p = make(Params, 0, n.maxParams) + } + i := len(p) + p = p[:i+1] // expand slice within preallocated capacity + p[i].Key = n.path[2:] + p[i].Value = path + + handle = n.handle + context = n.context + return + + default: + panic("invalid node type") + } + } + } else if path == n.path { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + context = n.context + if handle = n.handle; handle != nil { + return + } + + if path == "/" && n.wildChild && n.nType != root { + tsr = true + return + } + + // No handle found. Check if a handle for this path + a + // trailing slash exists for trailing slash recommendation + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { + n = n.children[i] + tsr = (len(n.path) == 1 && n.handle != nil) || + (n.nType == catchAll && n.children[0].handle != nil) + fmt.Println(541) + return + } + } + return + } + // Nothing found. We can recommend to redirect to the same URL with an + // extra trailing slash if a leaf exists for that path + tsr = (path == "/") || + (len(n.path) == len(path)+1 && n.path[len(path)] == '/' && + path == n.path[:len(n.path)-1] && n.handle != nil) + return + } +} + +// Makes a case-insensitive lookup of the given path and tries to find a handler. +// It can optionally also fix trailing slashes. +// It returns the case-corrected path and a bool indicating whether the lookup +// was successful. +func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) { + return n.findCaseInsensitivePathRec( + path, + strings.ToLower(path), + make([]byte, 0, len(path)+1), // preallocate enough memory for new path + [4]byte{}, // empty rune buffer + fixTrailingSlash, + ) +} + +// shift bytes in array by n bytes left +func shiftNRuneBytes(rb [4]byte, n int) [4]byte { + switch n { + case 0: + return rb + case 1: + return [4]byte{rb[1], rb[2], rb[3], 0} + case 2: + return [4]byte{rb[2], rb[3]} + case 3: + return [4]byte{rb[3]} + default: + return [4]byte{} + } +} + +// recursive case-insensitive lookup function used by n.findCaseInsensitivePath +func (n *node) findCaseInsensitivePathRec(path, loPath string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) ([]byte, bool) { + loNPath := strings.ToLower(n.path) + +walk: // outer loop for walking the tree + for len(loPath) >= len(loNPath) && (len(loNPath) == 0 || loPath[1:len(loNPath)] == loNPath[1:]) { + // add common path to result + ciPath = append(ciPath, n.path...) + + if path = path[len(n.path):]; len(path) > 0 { + loOld := loPath + loPath = loPath[len(loNPath):] + + // If this node does not have a wildcard (param or catchAll) child, + // we can just look up the next child node and continue to walk down + // the tree + if !n.wildChild { + // skip rune bytes already processed + rb = shiftNRuneBytes(rb, len(loNPath)) + + if rb[0] != 0 { + // old rune not finished + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == rb[0] { + // continue with child node + n = n.children[i] + loNPath = strings.ToLower(n.path) + continue walk + } + } + } else { + // process a new rune + var rv rune + + // find rune start + // runes are up to 4 byte long, + // -4 would definitely be another rune + var off int + for max := min(len(loNPath), 3); off < max; off++ { + if i := len(loNPath) - off; utf8.RuneStart(loOld[i]) { + // read rune from cached lowercase path + rv, _ = utf8.DecodeRuneInString(loOld[i:]) + break + } + } + + // calculate lowercase bytes of current rune + utf8.EncodeRune(rb[:], rv) + // skipp already processed bytes + rb = shiftNRuneBytes(rb, off) + + for i := 0; i < len(n.indices); i++ { + // lowercase matches + if n.indices[i] == rb[0] { + // must use a recursive approach since both the + // uppercase byte and the lowercase byte might exist + // as an index + if out, found := n.children[i].findCaseInsensitivePathRec( + path, loPath, ciPath, rb, fixTrailingSlash, + ); found { + return out, true + } + break + } + } + + // same for uppercase rune, if it differs + if up := unicode.ToUpper(rv); up != rv { + utf8.EncodeRune(rb[:], up) + rb = shiftNRuneBytes(rb, off) + + for i := 0; i < len(n.indices); i++ { + // uppercase matches + if n.indices[i] == rb[0] { + // continue with child node + n = n.children[i] + loNPath = strings.ToLower(n.path) + continue walk + } + } + } + } + + // Nothing found. We can recommend to redirect to the same URL + // without a trailing slash if a leaf exists for that path + return ciPath, (fixTrailingSlash && path == "/" && n.handle != nil) + } + + n = n.children[0] + switch n.nType { + case param: + // find param end (either '/' or path end) + k := 0 + for k < len(path) && path[k] != '/' { + k++ + } + + // add param value to case insensitive path + ciPath = append(ciPath, path[:k]...) + + // we need to go deeper! + if k < len(path) { + if len(n.children) > 0 { + // continue with child node + n = n.children[0] + loNPath = strings.ToLower(n.path) + loPath = loPath[k:] + path = path[k:] + continue + } + + // ... but we can't + if fixTrailingSlash && len(path) == k+1 { + return ciPath, true + } + return ciPath, false + } + + if n.handle != nil { + return ciPath, true + } else if fixTrailingSlash && len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists + n = n.children[0] + if n.path == "/" && n.handle != nil { + return append(ciPath, '/'), true + } + } + return ciPath, false + + case catchAll: + return append(ciPath, path...), true + + default: + panic("invalid node type") + } + } else { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if n.handle != nil { + return ciPath, true + } + + // No handle found. + // Try to fix the path by adding a trailing slash + if fixTrailingSlash { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { + n = n.children[i] + if (len(n.path) == 1 && n.handle != nil) || + (n.nType == catchAll && n.children[0].handle != nil) { + return append(ciPath, '/'), true + } + return ciPath, false + } + } + } + return ciPath, false + } + } + + // Nothing found. + // Try to fix the path by adding / removing a trailing slash + if fixTrailingSlash { + if path == "/" { + return ciPath, true + } + if len(loPath)+1 == len(loNPath) && loNPath[len(loPath)] == '/' && + loPath[1:] == loNPath[1:len(loPath)] && n.handle != nil { + return append(ciPath, n.path...), true + } + } + return ciPath, false +} diff --git a/source_code/goku/utils.go b/source_code/goku/utils.go deleted file mode 100644 index 250c8dd5e6e1762563c46f264deda1edfaf0f29a..0000000000000000000000000000000000000000 --- a/source_code/goku/utils.go +++ /dev/null @@ -1,12 +0,0 @@ -package goku - -import ( - "reflect" -) - -// 判定handler是否是函数类型 -func ValidateHandler(handler Handler) { - if reflect.TypeOf(handler).Kind() != reflect.Func { - panic("goku handler must be a callable func") - } -} diff --git a/source_code/middleware/auth.go b/source_code/middleware/auth.go index 09cea97184d09db7a5e1490b46d8cf7002dcf00d..027ed47ab54a0de8c492b654aa7eaf8c830eb902 100644 --- a/source_code/middleware/auth.go +++ b/source_code/middleware/auth.go @@ -1,24 +1,25 @@ package middleware import ( - "goku-ce/conf" + "goku-ce/goku" "net/http" "strings" "encoding/base64" ) -func Auth(c conf.StrategyInfo,res http.ResponseWriter, req *http.Request) (bool,string) { - if c.Auth == "basic" { +func Auth(context *goku.Context,res http.ResponseWriter, req *http.Request) (bool,string) { + c := context.StrategyInfo + if strings.ToLower(c.Auth) == "basic" { authStr := []byte(c.BasicUserName + ":" + c.BasicUserPassword) authorization := "Basic " + base64.StdEncoding.EncodeToString(authStr) auth := strings.Join(req.Header["Authorization"],", ") if authorization != auth { - return false, "Error username or userpassword" + return false, "Username or UserPassword Error" } - } else if c.Auth == "apikey" { + } else if strings.ToLower(c.Auth) == "apikey" { apiKey := strings.Join(req.Header["Apikey"],", ") if c.ApiKey != apiKey { - return false,"Error apiKey" + return false,"Invalid ApiKey" } } return true,"" diff --git a/source_code/middleware/backend.go b/source_code/middleware/backend.go deleted file mode 100644 index e95cf8fbb9f3939486f354bd207fba23402d0703..0000000000000000000000000000000000000000 --- a/source_code/middleware/backend.go +++ /dev/null @@ -1,18 +0,0 @@ -package middleware - -import ( - "goku-ce/conf" -) - -func GetBackendInfo(backendID int,b conf.Backend) (bool,conf.BackendInfo) { - flag := false - var backendInfo conf.BackendInfo - for _,i := range b.Backend { - if i.BackendID == backendID { - flag = true - backendInfo = i - break - } - } - return flag,backendInfo -} \ No newline at end of file diff --git a/source_code/middleware/ip_limit.go b/source_code/middleware/ip_limit.go index c2500bd6f361cca11d22d06210fb2a7bfc12eff7..c2c352c5128bd3ff1ccfd7fc8a107eee5ea16791 100644 --- a/source_code/middleware/ip_limit.go +++ b/source_code/middleware/ip_limit.go @@ -1,35 +1,33 @@ package middleware import ( - "goku-ce/conf" + "goku-ce/goku" "net/http" "strings" ) -func IPLimit(g conf.GatewayInfo,d conf.StrategyInfo,res http.ResponseWriter, req *http.Request) (bool,string) { +func IPLimit(g *goku.Context,res http.ResponseWriter, req *http.Request) (bool,string) { remoteAddr := req.RemoteAddr remoteIP := InterceptIP(remoteAddr, ":") if !globalIPLimit(g,remoteIP){ - res.WriteHeader(404) - return false,"[global] Illegal ip" - } else if globalIPLimit(g,remoteIP) && !strategyIPLimit(d,remoteIP) { - res.WriteHeader(404) - return false,"[strategy] Illegal ip" + return false,"[Global] Illegal IP" + } else if globalIPLimit(g,remoteIP) && !strategyIPLimit(g,remoteIP) { + return false,"[Strategy] Illegal IP" } return true,"" } -func globalIPLimit(g conf.GatewayInfo,remoteIP string) bool{ - if g.IPLimitType == "black"{ - for _,ip := range g.IPBlackList{ +func globalIPLimit(g *goku.Context,remoteIP string) bool{ + if g.GatewayInfo.IPLimitType == "black"{ + for _,ip := range g.GatewayInfo.IPBlackList{ if ip == remoteIP { return false } } return true - } else if g.IPLimitType == "white" { - for _,ip := range g.IPWhiteList{ + } else if g.GatewayInfo.IPLimitType == "white" { + for _,ip := range g.GatewayInfo.IPWhiteList{ if ip == remoteIP { return true } @@ -39,16 +37,16 @@ func globalIPLimit(g conf.GatewayInfo,remoteIP string) bool{ return true } -func strategyIPLimit(d conf.StrategyInfo,remoteIP string) bool { - if d.IPLimitType == "black" { - for _,ip := range d.IPBlackList{ +func strategyIPLimit(g *goku.Context,remoteIP string) bool { + if g.StrategyInfo.IPLimitType == "black" { + for _,ip := range g.StrategyInfo.IPBlackList{ if ip == remoteIP { return false } } return true - } else if d.IPLimitType == "white" { - for _,ip := range d.IPWhiteList{ + } else if g.StrategyInfo.IPLimitType == "white" { + for _,ip := range g.StrategyInfo.IPWhiteList{ if ip == remoteIP { return true } diff --git a/source_code/middleware/rate.go b/source_code/middleware/rate.go index f75beb50ff69fa2c672e87a3d6dfd1d248f94de6..0f18b469c9048935165ce343958bad77822f5a5e 100644 --- a/source_code/middleware/rate.go +++ b/source_code/middleware/rate.go @@ -1,12 +1,14 @@ package middleware import ( + "fmt" "time" "goku-ce/conf" "goku-ce/goku" ) -func getStrategyRate(c conf.StrategyInfo) (bool,[]conf.RateLimitInfo) { +func getStrategyRate(context *goku.Context) (bool,[]conf.RateLimitInfo) { + c := context.StrategyInfo now := time.Now() flag := false rateLimitList := make([]conf.RateLimitInfo,0) @@ -63,16 +65,18 @@ func timeInPeriod(c conf.RateLimitInfo,now int) bool { return false } -func RateLimit(g *goku.Goku,c conf.StrategyInfo) (bool,string) { - value, ok := g.Rate[c.StrategyID] +func RateLimit(context *goku.Context) (bool,string) { + c := context.StrategyInfo + g := context.Rate + value, ok := g[c.StrategyID] if !ok { var w goku.Rate - g.Rate[c.StrategyID] = w + g[c.StrategyID] = w } if !value.IsInit{ - flag,r := getStrategyRate(c) + flag,r := getStrategyRate(context) if flag == false { - return false,"Don't allow visit!" + return false,"Forbidden Request" } for _,i := range r { if i.Period == "sec" { @@ -87,9 +91,9 @@ func RateLimit(g *goku.Goku,c conf.StrategyInfo) (bool,string) { } value.IsInit = true } else if value.SecLimit.IsNeedReset(){ - flag,r := getStrategyRate(c) + flag,r := getStrategyRate(context) if flag == false { - return false,"Don't allow visit!" + return false,"Forbidden Request" } for _,i := range r { if i.Period == "sec" { @@ -103,58 +107,58 @@ func RateLimit(g *goku.Goku,c conf.StrategyInfo) (bool,string) { } } } - + fmt.Println(time.Now().Format("2006-01-02 15:04:05")) if value.Limit == "day" { if !value.DayLimit.DayLimit() { value.Limit = "day" - g.Rate[c.StrategyID] = value - return false,"Day visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Day Exceeded" } value.Limit = "" } else if value.Limit == "hour" { if !value.HourLimit.HourLimit() { value.Limit = "hour" - g.Rate[c.StrategyID] = value - return false,"Hour visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Hour Exceeded" }else if !value.DayLimit.DayLimit() { value.Limit = "day" - g.Rate[c.StrategyID] = value - return false,"Day visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Day Exceeded" } value.Limit = "" } else if value.Limit == "minute" { if !value.MinuteLimit.MinLimit() { value.Limit = "minute" - g.Rate[c.StrategyID] = value - return false,"Minute visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Minute Exceeded" }else if !value.HourLimit.HourLimit() { value.Limit = "hour" - g.Rate[c.StrategyID] = value - return false,"Hour visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Hour Exceeded" }else if !value.DayLimit.DayLimit() { value.Limit = "day" - g.Rate[c.StrategyID] = value - return false,"Day visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Day Exceeded" } value.Limit = "" } else { if !value.SecLimit.SecLimit() { - g.Rate[c.StrategyID] = value - return false,"Second visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Second Exceeded" }else if !value.MinuteLimit.MinLimit() { value.Limit = "minute" - g.Rate[c.StrategyID] = value - return false,"Minute visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Minute Exceeded" }else if !value.HourLimit.HourLimit() { value.Limit = "hour" - g.Rate[c.StrategyID] = value - return false,"Hour visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Hour Exceeded" }else if !value.DayLimit.DayLimit() { value.Limit = "day" - g.Rate[c.StrategyID] = value - return false,"Day visit limit exceeded" + g[c.StrategyID] = value + return false,"API Rate Limit of Day Exceeded" } } - g.Rate[c.StrategyID] = value + g[c.StrategyID] = value return true,"" } \ No newline at end of file diff --git a/source_code/middleware/request_mapping.go b/source_code/middleware/request_mapping.go index 2ff38ba598f6f52180bdf30895136bff5a6a85f5..5ea40790d9d879c4666da4f8ef634d4df00449ff 100644 --- a/source_code/middleware/request_mapping.go +++ b/source_code/middleware/request_mapping.go @@ -1,135 +1,90 @@ package middleware import ( - "fmt" "net/http" "strings" "goku-ce/goku" - "goku-ce/conf" "goku-ce/request" + "io/ioutil" ) -func Mapping(g *goku.Goku,res http.ResponseWriter, req *http.Request) (bool,string){ - url := InterceptURL(req.RequestURI,"?") - requestURI := strings.Split(url,"/") - if len(requestURI) == 2 { - if requestURI[1] == "" { - res.WriteHeader(404) - res.Write([]byte("Lack gatewayAlias")) - return false,"Lack gatewayAlias" - } else { - res.WriteHeader(404) - res.Write([]byte("Lack StrategyID")) - return false,"Lack StrategyID" - } +func Mapping(res http.ResponseWriter, req *http.Request,param goku.Params,context *goku.Context) { + // 验证IP是否合法 + f,s := IPLimit(context,res,req) + if !f { + res.WriteHeader(403) + res.Write([]byte(s)) + return } - fmt.Println(url) - gatewayAlias := requestURI[1] - StrategyID := requestURI[2] - urlLen := len(gatewayAlias) + len(StrategyID) + 2 - flag := false - for _,m := range g.ServiceConfig.GatewayList{ - if m.GatewayAlias == gatewayAlias{ - for _,i := range m.StrategyList.Strategy{ - if i.StrategyID == StrategyID{ - flag = true - f,r := IPLimit(m,i,res,req) - if !f { - res.Write([]byte(r)) - return false,r - } - - f,r = Auth(i,res,req) - if !f { - res.Write([]byte(r)) - return false,r - } - - f,r = RateLimit(g,i) - if !f { - res.Write([]byte(r)) - return false,r - } - break - } - } - } - if flag { - for _,i := range m.ApiList.Apis{ - if i.RequestURL == url[urlLen:]{ - // 验证请求 - if !validateRequest(i,req){ - res.WriteHeader(404) - res.Write([]byte("Error Request Method!")) - return false,"Error Request Method!" - } - - // 验证后端信息是否存在 - f,r := GetBackendInfo(i.BackendID,m.BackendList) - if !f { - res.WriteHeader(404) - res.Write([]byte("Backend config is not exist!")) - return false,"Backend config is not exist!" - } - - - _,response,httpResponseHeader := CreateRequest(i,r,req,res) - for key, values := range httpResponseHeader { - for _, value := range values { - res.Header().Add(key,value) - } - } - res.Write(response) - return true,string(response) - } - } - } - } - res.Write([]byte("URI Not Found")) - return false,"URI Not Found" -} - -// 验证协议及请求参数 -func validateRequest(api conf.ApiInfo, req *http.Request) bool{ - flag := false - for _,method := range api.RequestMethod{ - if !(strings.ToUpper(method) == req.Method){ - flag = true - break + f,s = Auth(context,res,req) + if !f { + res.WriteHeader(403) + res.Write([]byte(s)) + return + } + f,s = RateLimit(context) + if !f { + res.WriteHeader(403) + res.Write([]byte(s)) + return + } + statusCode,body,headers := CreateRequest(context,req,res) + for key,values := range headers { + for _,value := range values { + res.Header().Add(key,value) } } - return flag + res.WriteHeader(statusCode) + res.Write(body) + return } // 将请求参数写入请求中 -func CreateRequest(api conf.ApiInfo,i conf.BackendInfo,httpRequest *http.Request,httpResponse http.ResponseWriter) (int,[]byte,map[string][]string) { +func CreateRequest(g *goku.Context,httpRequest *http.Request,httpResponse http.ResponseWriter) (int,[]byte,map[string][]string) { + api := g.ApiInfo var backendHeaders map[string][]string = make(map[string][]string) var backendQueryParams map[string][]string = make(map[string][]string) var backendFormParams map[string][]string = make(map[string][]string) err := httpRequest.ParseForm() if err != nil { - return 500,[]byte("Fail to Parse Args"),make(map[string][]string) + return 500,[]byte("Parsing Arguments Fail"),make(map[string][]string) } backendMethod := strings.ToUpper(api.ProxyMethod) - backenDomain := i.BackendPath + api.ProxyURL + backenDomain := api.BackendPath + api.ProxyURL requ,err := request.Method(backendMethod,backenDomain) for _, reqParam := range api.ProxyParams { - var param []string + var param []string + isFile := false switch reqParam.KeyPosition { - case "header": - param = httpRequest.Header[reqParam.Key] + case "header": + key := parseHeader(reqParam.Key) + param = httpRequest.Header[key] case "body": if httpRequest.Method == "POST" || httpRequest.Method == "PUT" || httpRequest.Method == "PATCH" { param = httpRequest.PostForm[reqParam.Key] + if param == nil { + f,fh,err := httpRequest.FormFile(reqParam.Key) + if err != nil { + panic(err) + } + defer f.Close() + body,err := ioutil.ReadAll(f) + if err != nil { + panic(err) + } + requ.AddFile(reqParam.ProxyKey,fh.Filename,body) + isFile = true + } } else { continue } case "query": param = httpRequest.Form[reqParam.Key] } + if param == nil { - if reqParam.NotEmpty { + if reqParam.NotEmpty && !isFile { return 400, []byte("Missing required parameters"),make(map[string][]string) } else { continue @@ -137,8 +92,9 @@ func CreateRequest(api conf.ApiInfo,i conf.BackendInfo,httpRequest *http.Request } switch reqParam.ProxyKeyPosition { - case "header": - backendHeaders[reqParam.ProxyKey] = param + case "header": + key := parseHeader(reqParam.ProxyKey) + backendHeaders[key] = param case "body": if backendMethod == "POST" || backendMethod == "PUT" || backendMethod == "PATCH" { backendFormParams[reqParam.ProxyKey] = param @@ -150,16 +106,17 @@ func CreateRequest(api conf.ApiInfo,i conf.BackendInfo,httpRequest *http.Request for _, constParam := range api.ConstantParams { switch constParam.Position { - case "header": - backendHeaders[constParam.Key] = []string{constParam.Key} + case "header": + backendHeaders[constParam.Key] = []string{constParam.Value} case "body": - if backendMethod == "POST" || backendMethod == "PUT" || backendMethod == "PATCH" { backendFormParams[constParam.Key] = []string{constParam.Value} } else { backendQueryParams[constParam.Key] = []string{constParam.Value} - } - } + } + case "query": + backendQueryParams[constParam.Key] = []string{constParam.Value} + } } if err != nil{ @@ -170,28 +127,20 @@ func CreateRequest(api conf.ApiInfo,i conf.BackendInfo,httpRequest *http.Request requ.SetHeader(key, values...) } for key, values := range backendQueryParams { - fmt.Println(key) - fmt.Println(values) requ.SetQueryParam(key, values...) } for key, values := range backendFormParams { - fmt.Println(key) - fmt.Println(values) requ.SetFormParam(key, values...) } - if api.ProxyBodyType == "raw" { - requ.SetRawBody([]byte(api.ProxyBody)) - } else if api.ProxyBodyType == "json" { - requ.SetJSON(api.ProxyBody) - } + if api.IsRaw { + body,_ := ioutil.ReadAll(httpRequest.Body) + requ.SetRawBody([]byte(body)) + } - cookies := make(map[string]string) for _, cookie := range httpRequest.Cookies() { cookies[cookie.Name] = cookie.Value } - // requ.SetHeader("Cookie",cookies) - res, err := requ.Send() if err != nil { return 500,[]byte(""),make(map[string][]string) @@ -208,13 +157,16 @@ func CreateRequest(api conf.ApiInfo,i conf.BackendInfo,httpRequest *http.Request return res.StatusCode(), res.Body(),httpResponseHeader } -func InterceptURL(str, substr string) string { - result := strings.Index(str, substr) - var rs string - if result != -1{ - rs = str[:result] - }else { - rs = str +// 修饰请求头 +func parseHeader(header string) string { + headerArray := strings.Split(header,"-") + result := "" + for i,h := range headerArray { + h = strings.Replace(h,"_","",-1) + result += strings.ToUpper(h[0:1]) + strings.ToLower(h[1:]) + if i + 1 < len(headerArray) { + result += "-" + } } - return rs + return result } \ No newline at end of file diff --git a/source_code/request/request.go b/source_code/request/request.go index 24f821b41720c09d59df081ab015c7a73e65f6e1..c694ff6257992a41cdacbe1da9da904fef78f320 100644 --- a/source_code/request/request.go +++ b/source_code/request/request.go @@ -78,7 +78,7 @@ func (this *request) SetQueryParam(key string, values ...string) Request { func Method(method string, urlPath string) (Request, error) { if method != "GET" && method != "POST" && method != "PUT" && method != "DELETE" && method != "HEAD" && method != "OPTIONS" && method != "PATCH" { - return nil, errors.New("method not supported") + return nil, errors.New("Unsupported Request Method") } return newRequest(method, urlPath) } @@ -228,16 +228,14 @@ func (this *request) Send() (res Response, err error) { fmt.Println(err) return } + req.Header.Set("Accept-Encoding", "gzip") req.Header = parseHeaders(this.headers) httpResponse, err := this.client.Do(req) if err != nil { - httpResponse.Body.Close() - fmt.Println(err) return } res, err = newResponse(httpResponse) if err != nil { - fmt.Println(err) return } return diff --git a/source_code/request/response.go b/source_code/request/response.go index 3a94c729f89716d9df4859de1afd044dbb0fb5cf..3ec6f6cb3cfbd705a5024f8e6d9e62f6b8fa2177 100644 --- a/source_code/request/response.go +++ b/source_code/request/response.go @@ -3,6 +3,8 @@ package request import ( "io/ioutil" "net/http" + "io" + "compress/gzip" ) type Response interface { @@ -24,7 +26,15 @@ type response struct { func newResponse(httpResponse *http.Response) (Response, error) { defer httpResponse.Body.Close() var headers map[string][]string = httpResponse.Header - body, err := ioutil.ReadAll(httpResponse.Body) + var reader io.ReadCloser + switch httpResponse.Header.Get("Content-Encoding") { + case "gzip": + reader, _ = gzip.NewReader(httpResponse.Body) + defer reader.Close() + default: + reader = httpResponse.Body + } + body, err := ioutil.ReadAll(reader) content_length := int64(len(body)) if err != nil { return nil, err