container.go 15.0 KB
Newer Older
J
jeff 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
package restful

// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.

import (
	"bytes"
	"errors"
	"fmt"
	"net/http"
	"os"
	"runtime"
	"strings"
	"sync"

	"github.com/emicklei/go-restful/log"
)

// Container holds a collection of WebServices and a http.ServeMux to dispatch http requests.
// The requests are further dispatched to routes of WebServices using a RouteSelector
type Container struct {
	webServicesLock        sync.RWMutex
	webServices            []*WebService
	ServeMux               *http.ServeMux
	isRegisteredOnRoot     bool
	containerFilters       []FilterFunction
	doNotRecover           bool // default is true
	recoverHandleFunc      RecoverHandleFunction
	serviceErrorHandleFunc ServiceErrorHandleFunction
	router                 RouteSelector // default is a CurlyRouter (RouterJSR311 is a slower alternative)
	contentEncodingEnabled bool          // default is false
}

// NewContainer creates a new Container using a new ServeMux and default router (CurlyRouter)
func NewContainer() *Container {
	return &Container{
		webServices:            []*WebService{},
		ServeMux:               http.NewServeMux(),
		isRegisteredOnRoot:     false,
		containerFilters:       []FilterFunction{},
		doNotRecover:           true,
		recoverHandleFunc:      logStackOnRecover,
		serviceErrorHandleFunc: writeServiceError,
		router:                 CurlyRouter{},
		contentEncodingEnabled: false}
}

// RecoverHandleFunction declares functions that can be used to handle a panic situation.
// The first argument is what recover() returns. The second must be used to communicate an error response.
type RecoverHandleFunction func(interface{}, http.ResponseWriter)

// RecoverHandler changes the default function (logStackOnRecover) to be called
// when a panic is detected. DoNotRecover must be have its default value (=false).
func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
	c.recoverHandleFunc = handler
}

// ServiceErrorHandleFunction declares functions that can be used to handle a service error situation.
// The first argument is the service error, the second is the request that resulted in the error and
// the third must be used to communicate an error response.
type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)

// ServiceErrorHandler changes the default function (writeServiceError) to be called
// when a ServiceError is detected.
func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
	c.serviceErrorHandleFunc = handler
}

// DoNotRecover controls whether panics will be caught to return HTTP 500.
// If set to true, Route functions are responsible for handling any error situation.
// Default value is true.
func (c *Container) DoNotRecover(doNot bool) {
	c.doNotRecover = doNot
}

// Router changes the default Router (currently CurlyRouter)
func (c *Container) Router(aRouter RouteSelector) {
	c.router = aRouter
}

// EnableContentEncoding (default=false) allows for GZIP or DEFLATE encoding of responses.
func (c *Container) EnableContentEncoding(enabled bool) {
	c.contentEncodingEnabled = enabled
}

// Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
func (c *Container) Add(service *WebService) *Container {
	c.webServicesLock.Lock()
	defer c.webServicesLock.Unlock()

	// if rootPath was not set then lazy initialize it
	if len(service.rootPath) == 0 {
		service.Path("/")
	}

	// cannot have duplicate root paths
	for _, each := range c.webServices {
		if each.RootPath() == service.RootPath() {
H
hongming 已提交
100
			log.Printf("WebService with duplicate root path detected:['%v']", each)
J
jeff 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
			os.Exit(1)
		}
	}

	// If not registered on root then add specific mapping
	if !c.isRegisteredOnRoot {
		c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
	}
	c.webServices = append(c.webServices, service)
	return c
}

// addHandler may set a new HandleFunc for the serveMux
// this function must run inside the critical region protected by the webServicesLock.
// returns true if the function was registered on root ("/")
func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
	pattern := fixedPrefixPath(service.RootPath())
	// check if root path registration is needed
	if "/" == pattern || "" == pattern {
		serveMux.HandleFunc("/", c.dispatch)
		return true
	}
	// detect if registration already exists
	alreadyMapped := false
	for _, each := range c.webServices {
		if each.RootPath() == service.RootPath() {
			alreadyMapped = true
			break
		}
	}
	if !alreadyMapped {
		serveMux.HandleFunc(pattern, c.dispatch)
		if !strings.HasSuffix(pattern, "/") {
			serveMux.HandleFunc(pattern+"/", c.dispatch)
		}
	}
	return false
}

func (c *Container) Remove(ws *WebService) error {
	if c.ServeMux == http.DefaultServeMux {
H
hongming 已提交
142
		errMsg := fmt.Sprintf("cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
J
jeff 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
		log.Print(errMsg)
		return errors.New(errMsg)
	}
	c.webServicesLock.Lock()
	defer c.webServicesLock.Unlock()
	// build a new ServeMux and re-register all WebServices
	newServeMux := http.NewServeMux()
	newServices := []*WebService{}
	newIsRegisteredOnRoot := false
	for _, each := range c.webServices {
		if each.rootPath != ws.rootPath {
			// If not registered on root then add specific mapping
			if !newIsRegisteredOnRoot {
				newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
			}
			newServices = append(newServices, each)
		}
	}
	c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
	return nil
}

// logStackOnRecover is the default RecoverHandleFunction and is called
// when DoNotRecover is false and the recoverHandleFunc is not set for the container.
// Default implementation logs the stacktrace and writes the stacktrace on the response.
// This may be a security issue as it exposes sourcecode information.
func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
	var buffer bytes.Buffer
H
hongming 已提交
171
	buffer.WriteString(fmt.Sprintf("recover from panic situation: - %v\r\n", panicReason))
J
jeff 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
	for i := 2; ; i += 1 {
		_, file, line, ok := runtime.Caller(i)
		if !ok {
			break
		}
		buffer.WriteString(fmt.Sprintf("    %s:%d\r\n", file, line))
	}
	log.Print(buffer.String())
	httpWriter.WriteHeader(http.StatusInternalServerError)
	httpWriter.Write(buffer.Bytes())
}

// writeServiceError is the default ServiceErrorHandleFunction and is called
// when a ServiceError is returned during route selection. Default implementation
// calls resp.WriteErrorString(err.Code, err.Message)
func writeServiceError(err ServiceError, req *Request, resp *Response) {
H
hongming 已提交
188 189 190 191 192
	for header, values := range err.Header {
		for _, value := range values {
			resp.Header().Add(header, value)
		}
	}
J
jeff 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
	resp.WriteErrorString(err.Code, err.Message)
}

// Dispatch the incoming Http Request to a matching WebService.
func (c *Container) Dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
	if httpWriter == nil {
		panic("httpWriter cannot be nil")
	}
	if httpRequest == nil {
		panic("httpRequest cannot be nil")
	}
	c.dispatch(httpWriter, httpRequest)
}

// Dispatch the incoming Http Request to a matching WebService.
func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
H
hongming 已提交
209
	// so we can assign a compressing one later
J
jeff 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
	writer := httpWriter

	// CompressingResponseWriter should be closed after all operations are done
	defer func() {
		if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
			compressWriter.Close()
		}
	}()

	// Instal panic recovery unless told otherwise
	if !c.doNotRecover { // catch all for 500 response
		defer func() {
			if r := recover(); r != nil {
				c.recoverHandleFunc(r, writer)
				return
			}
		}()
	}

H
hongming 已提交
229 230 231 232 233 234 235 236 237 238 239
	// Find best match Route ; err is non nil if no match was found
	var webService *WebService
	var route *Route
	var err error
	func() {
		c.webServicesLock.RLock()
		defer c.webServicesLock.RUnlock()
		webService, route, err = c.router.SelectRoute(
			c.webServices,
			httpRequest)
	}()
J
jeff 已提交
240
	if err != nil {
H
hongming 已提交
241
		// a non-200 response (may be compressed) has already been written
J
jeff 已提交
242 243 244 245 246 247 248 249 250 251 252 253
		// run container filters anyway ; they should not touch the response...
		chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
			switch err.(type) {
			case ServiceError:
				ser := err.(ServiceError)
				c.serviceErrorHandleFunc(ser, req, resp)
			}
			// TODO
		}}
		chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
		return
	}
H
hongming 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276

	// Unless httpWriter is already an CompressingResponseWriter see if we need to install one
	if _, isCompressing := httpWriter.(*CompressingResponseWriter); !isCompressing {
		// Detect if compression is needed
		// assume without compression, test for override
		contentEncodingEnabled := c.contentEncodingEnabled
		if route != nil && route.contentEncodingEnabled != nil {
			contentEncodingEnabled = *route.contentEncodingEnabled
		}
		if contentEncodingEnabled {
			doCompress, encoding := wantsCompressedResponse(httpRequest)
			if doCompress {
				var err error
				writer, err = NewCompressingResponseWriter(httpWriter, encoding)
				if err != nil {
					log.Print("unable to install compressor: ", err)
					httpWriter.WriteHeader(http.StatusInternalServerError)
					return
				}
			}
		}
	}

J
jeff 已提交
277 278 279 280 281 282 283
	pathProcessor, routerProcessesPath := c.router.(PathProcessor)
	if !routerProcessesPath {
		pathProcessor = defaultPathProcessor{}
	}
	pathParams := pathProcessor.ExtractParameters(route, webService, httpRequest.URL.Path)
	wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest, pathParams)
	// pass through filters (if any)
H
hongming 已提交
284
	if size := len(c.containerFilters) + len(webService.filters) + len(route.Filters); size > 0 {
J
jeff 已提交
285
		// compose filter chain
H
hongming 已提交
286
		allFilters := make([]FilterFunction, 0, size)
J
jeff 已提交
287 288 289
		allFilters = append(allFilters, c.containerFilters...)
		allFilters = append(allFilters, webService.filters...)
		allFilters = append(allFilters, route.Filters...)
H
hongming 已提交
290
		chain := FilterChain{Filters: allFilters, Target: route.Function}
J
jeff 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
		chain.ProcessFilter(wrappedRequest, wrappedResponse)
	} else {
		// no filters, handle request by route
		route.Function(wrappedRequest, wrappedResponse)
	}
}

// fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
func fixedPrefixPath(pathspec string) string {
	varBegin := strings.Index(pathspec, "{")
	if -1 == varBegin {
		return pathspec
	}
	return pathspec[:varBegin]
}

// ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
H
hongming 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
func (c *Container) ServeHTTP(httpWriter http.ResponseWriter, httpRequest *http.Request) {
	// Skip, if content encoding is disabled
	if !c.contentEncodingEnabled {
		c.ServeMux.ServeHTTP(httpWriter, httpRequest)
		return
	}
	// content encoding is enabled

	// Skip, if httpWriter is already an CompressingResponseWriter
	if _, ok := httpWriter.(*CompressingResponseWriter); ok {
		c.ServeMux.ServeHTTP(httpWriter, httpRequest)
		return
	}

	writer := httpWriter
	// CompressingResponseWriter should be closed after all operations are done
	defer func() {
		if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
			compressWriter.Close()
		}
	}()

	doCompress, encoding := wantsCompressedResponse(httpRequest)
	if doCompress {
		var err error
		writer, err = NewCompressingResponseWriter(httpWriter, encoding)
		if err != nil {
			log.Print("unable to install compressor: ", err)
			httpWriter.WriteHeader(http.StatusInternalServerError)
			return
		}
	}

	c.ServeMux.ServeHTTP(writer, httpRequest)
J
jeff 已提交
342 343 344 345
}

// Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
func (c *Container) Handle(pattern string, handler http.Handler) {
H
hongming 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
	c.ServeMux.Handle(pattern, http.HandlerFunc(func(httpWriter http.ResponseWriter, httpRequest *http.Request) {
		// Skip, if httpWriter is already an CompressingResponseWriter
		if _, ok := httpWriter.(*CompressingResponseWriter); ok {
			handler.ServeHTTP(httpWriter, httpRequest)
			return
		}

		writer := httpWriter

		// CompressingResponseWriter should be closed after all operations are done
		defer func() {
			if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
				compressWriter.Close()
			}
		}()

		if c.contentEncodingEnabled {
			doCompress, encoding := wantsCompressedResponse(httpRequest)
			if doCompress {
				var err error
				writer, err = NewCompressingResponseWriter(httpWriter, encoding)
				if err != nil {
					log.Print("unable to install compressor: ", err)
					httpWriter.WriteHeader(http.StatusInternalServerError)
					return
				}
			}
		}

		handler.ServeHTTP(writer, httpRequest)
	}))
J
jeff 已提交
377 378 379 380 381 382 383 384 385 386 387 388 389
}

// HandleWithFilter registers the handler for the given pattern.
// Container's filter chain is applied for handler.
// If a handler already exists for pattern, HandleWithFilter panics.
func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
	f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
		if len(c.containerFilters) == 0 {
			handler.ServeHTTP(httpResponse, httpRequest)
			return
		}

		chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
H
hongming 已提交
390
			handler.ServeHTTP(resp, req.Request)
J
jeff 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
		}}
		chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
	}

	c.Handle(pattern, http.HandlerFunc(f))
}

// Filter appends a container FilterFunction. These are called before dispatching
// a http.Request to a WebService from the container
func (c *Container) Filter(filter FilterFunction) {
	c.containerFilters = append(c.containerFilters, filter)
}

// RegisteredWebServices returns the collections of added WebServices
func (c *Container) RegisteredWebServices() []*WebService {
	c.webServicesLock.RLock()
	defer c.webServicesLock.RUnlock()
	result := make([]*WebService, len(c.webServices))
	for ix := range c.webServices {
		result[ix] = c.webServices[ix]
	}
	return result
}

// computeAllowedMethods returns a list of HTTP methods that are valid for a Request
func (c *Container) computeAllowedMethods(req *Request) []string {
	// Go through all RegisteredWebServices() and all its Routes to collect the options
	methods := []string{}
	requestPath := req.Request.URL.Path
	for _, ws := range c.RegisteredWebServices() {
		matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
		if matches != nil {
			finalMatch := matches[len(matches)-1]
			for _, rt := range ws.Routes() {
				matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
				if matches != nil {
					lastMatch := matches[len(matches)-1]
					if lastMatch == "" || lastMatch == "/" { // do not include if value is neither empty nor ‘/’.
						methods = append(methods, rt.Method)
					}
				}
			}
		}
	}
	// methods = append(methods, "OPTIONS")  not sure about this
	return methods
}

// newBasicRequestResponse creates a pair of Request,Response from its http versions.
// It is basic because no parameter or (produces) content-type information is given.
func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
	resp := NewResponse(httpWriter)
	resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
	return NewRequest(httpRequest), resp
}