提交 d27b5d0a 编写于 作者: S Sebastien Deleuze

Improve CORS handling

This commit improves CORS support by:
 - Using CORS processing only for CORS-enabled endpoints
 - Skipping CORS processing for same-origin requests
 - Adding Vary headers for non-CORS requests

It introduces an AbstractHandlerMapping#hasCorsConfigurationSource
method in order to be able to check CORS endpoints efficiently.

Closes gh-22273
Closes gh-22496
上级 87147101
/*
* Copyright 2002-2015 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -20,6 +20,10 @@ import javax.servlet.http.HttpServletRequest;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
/**
* Utility class for CORS request handling based on the
......@@ -31,17 +35,43 @@ import org.springframework.http.HttpMethod;
public abstract class CorsUtils {
/**
* Returns {@code true} if the request is a valid CORS one.
* Returns {@code true} if the request is a valid CORS one by checking {@code Origin}
* header presence and ensuring that origins are different.
*/
public static boolean isCorsRequest(HttpServletRequest request) {
return (request.getHeader(HttpHeaders.ORIGIN) != null);
String origin = request.getHeader(HttpHeaders.ORIGIN);
if (origin == null) {
return false;
}
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
String scheme = request.getScheme();
String host = request.getServerName();
int port = request.getServerPort();
return !(ObjectUtils.nullSafeEquals(scheme, originUrl.getScheme()) &&
ObjectUtils.nullSafeEquals(host, originUrl.getHost()) &&
getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort()));
}
private static int getPort(@Nullable String scheme, int port) {
if (port == -1) {
if ("http".equals(scheme) || "ws".equals(scheme)) {
port = 80;
}
else if ("https".equals(scheme) || "wss".equals(scheme)) {
port = 443;
}
}
return port;
}
/**
* Returns {@code true} if the request is a valid CORS pre-flight one.
* To be used in combination with {@link #isCorsRequest(HttpServletRequest)} since
* regular CORS checks are not invoked here for performance reasons.
*/
public static boolean isPreFlightRequest(HttpServletRequest request) {
return (isCorsRequest(request) && HttpMethod.OPTIONS.matches(request.getMethod()) &&
return (HttpMethod.OPTIONS.matches(request.getMethod()) &&
request.getHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD) != null);
}
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -19,7 +19,6 @@ package org.springframework.web.cors;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
......@@ -36,7 +35,6 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.web.util.WebUtils;
/**
* The default implementation of {@link CorsProcessor}, as defined by the
......@@ -45,8 +43,7 @@ import org.springframework.web.util.WebUtils;
* <p>Note that when input {@link CorsConfiguration} is {@code null}, this
* implementation does not reject simple or actual requests outright but simply
* avoid adding CORS headers to the response. CORS processing is also skipped
* if the response already contains CORS headers, or if the request is detected
* as a same-origin one.
* if the response already contains CORS headers.
*
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
......@@ -62,26 +59,23 @@ public class DefaultCorsProcessor implements CorsProcessor {
public boolean processRequest(@Nullable CorsConfiguration config, HttpServletRequest request,
HttpServletResponse response) throws IOException {
response.addHeader(HttpHeaders.VARY, HttpHeaders.ORIGIN);
response.addHeader(HttpHeaders.VARY, HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
response.addHeader(HttpHeaders.VARY, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS);
if (!CorsUtils.isCorsRequest(request)) {
return true;
}
ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response);
if (responseHasCors(serverResponse)) {
if (response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null) {
logger.trace("Skip: response already contains \"Access-Control-Allow-Origin\"");
return true;
}
ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request);
if (WebUtils.isSameOrigin(serverRequest)) {
logger.trace("Skip: request is from same origin");
return true;
}
boolean preFlightRequest = CorsUtils.isPreFlightRequest(request);
if (config == null) {
if (preFlightRequest) {
rejectRequest(serverResponse);
rejectRequest(new ServletServerHttpResponse(response));
return false;
}
else {
......@@ -89,17 +83,7 @@ public class DefaultCorsProcessor implements CorsProcessor {
}
}
return handleInternal(serverRequest, serverResponse, config, preFlightRequest);
}
private boolean responseHasCors(ServerHttpResponse response) {
try {
return (response.getHeaders().getAccessControlAllowOrigin() != null);
}
catch (NullPointerException npe) {
// SPR-11919 and https://issues.jboss.org/browse/WFLY-3474
return false;
}
return handleInternal(new ServletServerHttpRequest(request), new ServletServerHttpResponse(response), config, preFlightRequest);
}
/**
......@@ -110,6 +94,7 @@ public class DefaultCorsProcessor implements CorsProcessor {
protected void rejectRequest(ServerHttpResponse response) throws IOException {
response.setStatusCode(HttpStatus.FORBIDDEN);
response.getBody().write("Invalid CORS request".getBytes(StandardCharsets.UTF_8));
response.flush();
}
/**
......@@ -122,9 +107,6 @@ public class DefaultCorsProcessor implements CorsProcessor {
String allowOrigin = checkOrigin(config, requestOrigin);
HttpHeaders responseHeaders = response.getHeaders();
responseHeaders.addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
if (allowOrigin == null) {
logger.debug("Reject: '" + requestOrigin + "' origin is not allowed");
rejectRequest(response);
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -36,18 +36,21 @@ import org.springframework.web.util.UriComponentsBuilder;
public abstract class CorsUtils {
/**
* Returns {@code true} if the request is a valid CORS one.
* Returns {@code true} if the request is a valid CORS one by checking {@code Origin}
* header presence and ensuring that origins are different via {@link #isSameOrigin}.
*/
@SuppressWarnings("deprecation")
public static boolean isCorsRequest(ServerHttpRequest request) {
return (request.getHeaders().get(HttpHeaders.ORIGIN) != null);
return request.getHeaders().containsKey(HttpHeaders.ORIGIN) && !isSameOrigin(request);
}
/**
* Returns {@code true} if the request is a valid CORS pre-flight one.
* To be used in combination with {@link #isCorsRequest(ServerHttpRequest)} since
* regular CORS checks are not invoked here for performance reasons.
*/
public static boolean isPreFlightRequest(ServerHttpRequest request) {
return (request.getMethod() == HttpMethod.OPTIONS && isCorsRequest(request) &&
request.getHeaders().get(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD) != null);
return (request.getMethod() == HttpMethod.OPTIONS && request.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD));
}
/**
......@@ -61,7 +64,9 @@ public abstract class CorsUtils {
*
* @return {@code true} if the request is a same-origin one, {@code false} in case
* of a cross-origin request
* @deprecated as of 5.2, same-origin checks are performed directly by {@link #isCorsRequest}
*/
@Deprecated
public static boolean isSameOrigin(ServerHttpRequest request) {
String origin = request.getHeaders().getOrigin();
if (origin == null) {
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -75,14 +75,10 @@ public class CorsWebFilter implements WebFilter {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
if (CorsUtils.isCorsRequest(request)) {
CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(exchange);
if (corsConfiguration != null) {
boolean isValid = this.processor.process(corsConfiguration, exchange);
if (!isValid || CorsUtils.isPreFlightRequest(request)) {
return Mono.empty();
}
}
CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(exchange);
boolean isValid = this.processor.process(corsConfiguration, exchange);
if (!isValid || CorsUtils.isPreFlightRequest(request)) {
return Mono.empty();
}
return chain.filter(exchange);
}
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -40,8 +40,7 @@ import org.springframework.web.server.ServerWebExchange;
* <p>Note that when input {@link CorsConfiguration} is {@code null}, this
* implementation does not reject simple or actual requests outright but simply
* avoid adding CORS headers to the response. CORS processing is also skipped
* if the response already contains CORS headers, or if the request is detected
* as a same-origin one.
* if the response already contains CORS headers.
*
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
......@@ -51,27 +50,26 @@ public class DefaultCorsProcessor implements CorsProcessor {
private static final Log logger = LogFactory.getLog(DefaultCorsProcessor.class);
private static final List<String> VARY_HEADERS = Arrays.asList(
HttpHeaders.ORIGIN, HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS);
@Override
public boolean process(@Nullable CorsConfiguration config, ServerWebExchange exchange) {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
response.getHeaders().addAll(HttpHeaders.VARY, VARY_HEADERS);
if (!CorsUtils.isCorsRequest(request)) {
return true;
}
if (responseHasCors(response)) {
if (response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null) {
logger.trace("Skip: response already contains \"Access-Control-Allow-Origin\"");
return true;
}
if (CorsUtils.isSameOrigin(request)) {
logger.trace("Skip: request is from same origin");
return true;
}
boolean preFlightRequest = CorsUtils.isPreFlightRequest(request);
if (config == null) {
if (preFlightRequest) {
......@@ -86,10 +84,6 @@ public class DefaultCorsProcessor implements CorsProcessor {
return handleInternal(exchange, config, preFlightRequest);
}
private boolean responseHasCors(ServerHttpResponse response) {
return response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null;
}
/**
* Invoked when one of the CORS checks failed.
*/
......@@ -107,9 +101,6 @@ public class DefaultCorsProcessor implements CorsProcessor {
ServerHttpResponse response = exchange.getResponse();
HttpHeaders responseHeaders = response.getHeaders();
response.getHeaders().addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
String requestOrigin = request.getHeaders().getOrigin();
String allowOrigin = checkOrigin(config, requestOrigin);
if (allowOrigin == null) {
......
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -83,16 +83,11 @@ public class CorsFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
if (CorsUtils.isCorsRequest(request)) {
CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(request);
if (corsConfiguration != null) {
boolean isValid = this.processor.processRequest(corsConfiguration, request, response);
if (!isValid || CorsUtils.isPreFlightRequest(request)) {
return;
}
}
CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(request);
boolean isValid = this.processor.processRequest(corsConfiguration, request, response);
if (!isValid || CorsUtils.isPreFlightRequest(request)) {
return;
}
filterChain.doFilter(request, response);
}
......
/*
* Copyright 2002-2015 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -62,11 +62,6 @@ public class CorsUtilsTests {
request.setMethod(HttpMethod.OPTIONS.name());
request.addHeader(HttpHeaders.ORIGIN, "https://domain.com");
assertFalse(CorsUtils.isPreFlightRequest(request));
request = new MockHttpServletRequest();
request.setMethod(HttpMethod.OPTIONS.name());
request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
assertFalse(CorsUtils.isPreFlightRequest(request));
}
}
......@@ -51,13 +51,35 @@ public class DefaultCorsProcessorTests {
public void setup() {
this.request = new MockHttpServletRequest();
this.request.setRequestURI("/test.html");
this.request.setRemoteHost("domain1.com");
this.request.setServerName("domain1.com");
this.conf = new CorsConfiguration();
this.response = new MockHttpServletResponse();
this.response.setStatus(HttpServletResponse.SC_OK);
this.processor = new DefaultCorsProcessor();
}
@Test
public void requestWithoutOriginHeader() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.processor.processRequest(this.conf, this.request, this.response);
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
}
@Test
public void sameOriginRequest() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain1.com");
this.processor.processRequest(this.conf, this.request, this.response);
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
}
@Test
public void actualRequestWithOriginHeader() throws Exception {
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -39,7 +39,7 @@ public class CorsUtilsTests {
@Test
public void isCorsRequest() {
ServerHttpRequest request = get("/").header(HttpHeaders.ORIGIN, "https://domain.com").build();
ServerHttpRequest request = get("http://domain.com/").header(HttpHeaders.ORIGIN, "https://domain.com").build();
assertTrue(CorsUtils.isCorsRequest(request));
}
......@@ -65,9 +65,6 @@ public class CorsUtilsTests {
request = options("/").header(HttpHeaders.ORIGIN, "https://domain.com").build();
assertFalse(CorsUtils.isPreFlightRequest(request));
request = options("/").header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET").build();
assertFalse(CorsUtils.isPreFlightRequest(request));
}
@Test // SPR-16262
......
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -63,6 +63,46 @@ public class CorsWebFilterTests {
filter = new CorsWebFilter(r -> config);
}
@Test
public void nonCorsRequest() {
WebFilterChain filterChain = (filterExchange) -> {
try {
HttpHeaders headers = filterExchange.getResponse().getHeaders();
assertNull(headers.getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertNull(headers.getFirst(ACCESS_CONTROL_EXPOSE_HEADERS));
} catch (AssertionError ex) {
return Mono.error(ex);
}
return Mono.empty();
};
MockServerWebExchange exchange = MockServerWebExchange.from(
MockServerHttpRequest
.get("https://domain1.com/test.html")
.header(HOST, "domain1.com"));
this.filter.filter(exchange, filterChain).block();
}
@Test
public void sameOriginRequest() {
WebFilterChain filterChain = (filterExchange) -> {
try {
HttpHeaders headers = filterExchange.getResponse().getHeaders();
assertNull(headers.getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertNull(headers.getFirst(ACCESS_CONTROL_EXPOSE_HEADERS));
} catch (AssertionError ex) {
return Mono.error(ex);
}
return Mono.empty();
};
MockServerWebExchange exchange = MockServerWebExchange.from(
MockServerHttpRequest
.get("https://domain1.com/test.html")
.header(ORIGIN, "https://domain1.com"));
this.filter.filter(exchange, filterChain).block();
}
@Test
public void validActualRequest() {
WebFilterChain filterChain = (filterExchange) -> {
......@@ -82,7 +122,7 @@ public class CorsWebFilterTests {
.header(HOST, "domain1.com")
.header(ORIGIN, "https://domain2.com")
.header("header2", "foo"));
this.filter.filter(exchange, filterChain);
this.filter.filter(exchange, filterChain).block();
}
@Test
......@@ -96,8 +136,7 @@ public class CorsWebFilterTests {
WebFilterChain filterChain = (filterExchange) -> Mono.error(
new AssertionError("Invalid requests must not be forwarded to the filter chain"));
filter.filter(exchange, filterChain);
filter.filter(exchange, filterChain).block();
assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
}
......@@ -115,7 +154,7 @@ public class CorsWebFilterTests {
WebFilterChain filterChain = (filterExchange) -> Mono.error(
new AssertionError("Preflight requests must not be forwarded to the filter chain"));
filter.filter(exchange, filterChain);
filter.filter(exchange, filterChain).block();
HttpHeaders headers = exchange.getResponse().getHeaders();
assertEquals("https://domain2.com", headers.getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
......@@ -138,7 +177,7 @@ public class CorsWebFilterTests {
WebFilterChain filterChain = (filterExchange) -> Mono.error(
new AssertionError("Preflight requests must not be forwarded to the filter chain"));
filter.filter(exchange, filterChain);
filter.filter(exchange, filterChain).block();
assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
}
......
......@@ -60,6 +60,37 @@ public class DefaultCorsProcessorTests {
}
@Test
public void requestWithoutOriginHeader() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest
.method(HttpMethod.GET, "http://domain1.com/test.html")
.build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
this.processor.process(this.conf, exchange);
ServerHttpResponse response = exchange.getResponse();
assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode());
}
@Test
public void sameOriginRequest() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest
.method(HttpMethod.GET, "http://domain1.com/test.html")
.header(HttpHeaders.ORIGIN, "http://domain1.com")
.build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
this.processor.process(this.conf, exchange);
ServerHttpResponse response = exchange.getResponse();
assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode());
}
@Test
public void actualRequestWithOriginHeader() throws Exception {
ServerWebExchange exchange = actualRequest();
......
/*
* Copyright 2002-2015 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -52,6 +52,36 @@ public class CorsFilterTests {
filter = new CorsFilter(r -> config);
}
@Test
public void nonCorsRequest() throws ServletException, IOException {
MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/test.html");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
};
filter.doFilter(request, response, filterChain);
}
@Test
public void sameOriginRequest() throws ServletException, IOException {
MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "https://domain1.com/test.html");
request.addHeader(HttpHeaders.ORIGIN, "https://domain1.com");
request.setScheme("https");
request.setServerName("domain1.com");
request.setServerPort(443);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
};
filter.doFilter(request, response, filterChain);
}
@Test
public void validActualRequest() throws ServletException, IOException {
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -23,6 +23,7 @@ import reactor.core.publisher.Mono;
import org.springframework.beans.factory.BeanNameAware;
import org.springframework.context.support.ApplicationObjectSupport;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.cors.CorsConfiguration;
......@@ -53,6 +54,7 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport
private final PathPatternParser patternParser;
@Nullable
private CorsConfigurationSource corsConfigurationSource;
private CorsProcessor corsProcessor = new DefaultCorsProcessor();
......@@ -65,7 +67,6 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport
public AbstractHandlerMapping() {
this.patternParser = new PathPatternParser();
this.corsConfigurationSource = new UrlBasedCorsConfigurationSource(this.patternParser);
}
......@@ -113,8 +114,14 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport
*/
public void setCorsConfigurations(Map<String, CorsConfiguration> corsConfigurations) {
Assert.notNull(corsConfigurations, "corsConfigurations must not be null");
this.corsConfigurationSource = new UrlBasedCorsConfigurationSource(this.patternParser);
((UrlBasedCorsConfigurationSource) this.corsConfigurationSource).setCorsConfigurations(corsConfigurations);
if (!corsConfigurations.isEmpty()) {
UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(this.patternParser);
source.setCorsConfigurations(corsConfigurations);
this.corsConfigurationSource = source;
}
else {
this.corsConfigurationSource = null;
}
}
/**
......@@ -175,12 +182,12 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport
if (logger.isDebugEnabled()) {
logger.debug(exchange.getLogPrefix() + "Mapped to " + handler);
}
if (CorsUtils.isCorsRequest(exchange.getRequest())) {
CorsConfiguration configA = this.corsConfigurationSource.getCorsConfiguration(exchange);
CorsConfiguration configB = getCorsConfiguration(handler, exchange);
CorsConfiguration config = (configA != null ? configA.combine(configB) : configB);
if (!getCorsProcessor().process(config, exchange) ||
CorsUtils.isPreFlightRequest(exchange.getRequest())) {
if (hasCorsConfigurationSource(handler)) {
ServerHttpRequest request = exchange.getRequest();
CorsConfiguration config = (this.corsConfigurationSource != null ? this.corsConfigurationSource.getCorsConfiguration(exchange) : null);
CorsConfiguration handlerConfig = getCorsConfiguration(handler, exchange);
config = (config != null ? config.combine(handlerConfig) : handlerConfig);
if (!this.corsProcessor.process(config, exchange) || CorsUtils.isPreFlightRequest(request)) {
return REQUEST_HANDLED_HANDLER;
}
}
......@@ -200,6 +207,14 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport
*/
protected abstract Mono<?> getHandlerInternal(ServerWebExchange exchange);
/**
* Return {@code true} if there is a {@link CorsConfigurationSource} for this handler.
* @since 5.2
*/
protected boolean hasCorsConfigurationSource(Object handler) {
return handler instanceof CorsConfigurationSource || this.corsConfigurationSource != null;
}
/**
* Retrieve the CORS configuration for the given handler.
* @param handler the handler to check (never {@code null})
......
......@@ -370,6 +370,13 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
return null;
}
@Override
protected boolean hasCorsConfigurationSource(Object handler) {
return super.hasCorsConfigurationSource(handler) ||
(handler instanceof HandlerMethod && this.mappingRegistry.getCorsConfiguration((HandlerMethod) handler) != null) ||
handler.equals(PREFLIGHT_AMBIGUOUS_MATCH);
}
@Override
protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) {
CorsConfiguration corsConfig = super.getCorsConfiguration(handler, exchange);
......@@ -451,6 +458,7 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
/**
* Return CORS configuration. Thread-safe for concurrent use.
*/
@Nullable
public CorsConfiguration getCorsConfiguration(HandlerMethod handlerMethod) {
HandlerMethod original = handlerMethod.getResolvedFromHandlerMethod();
return this.corsLookup.get(original != null ? original : handlerMethod);
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -31,7 +31,6 @@ import org.springframework.web.server.ServerWebExchange;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
/**
......@@ -74,8 +73,7 @@ public class CorsUrlHandlerMappingTests {
Object actual = this.handlerMapping.getHandler(exchange).block();
assertNotNull(actual);
assertNotSame(this.welcomeController, actual);
assertNull(exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertSame(this.welcomeController, actual);
}
@Test
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -81,7 +81,8 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
private final List<HandlerInterceptor> adaptedInterceptors = new ArrayList<>();
private CorsConfigurationSource corsConfigurationSource = new UrlBasedCorsConfigurationSource();
@Nullable
private CorsConfigurationSource corsConfigurationSource;
private CorsProcessor corsProcessor = new DefaultCorsProcessor();
......@@ -206,11 +207,16 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
*/
public void setCorsConfigurations(Map<String, CorsConfiguration> corsConfigurations) {
Assert.notNull(corsConfigurations, "corsConfigurations must not be null");
UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
source.setCorsConfigurations(corsConfigurations);
source.setPathMatcher(this.pathMatcher);
source.setUrlPathHelper(this.urlPathHelper);
this.corsConfigurationSource = source;
if (!corsConfigurations.isEmpty()) {
UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
source.setCorsConfigurations(corsConfigurations);
source.setPathMatcher(this.pathMatcher);
source.setUrlPathHelper(this.urlPathHelper);
this.corsConfigurationSource = source;
}
else {
this.corsConfigurationSource = null;
}
}
/**
......@@ -420,10 +426,10 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
logger.debug("Mapped to " + executionChain.getHandler());
}
if (CorsUtils.isCorsRequest(request)) {
CorsConfiguration globalConfig = this.corsConfigurationSource.getCorsConfiguration(request);
if (hasCorsConfigurationSource(handler)) {
CorsConfiguration config = (this.corsConfigurationSource != null ? this.corsConfigurationSource.getCorsConfiguration(request) : null);
CorsConfiguration handlerConfig = getCorsConfiguration(handler, request);
CorsConfiguration config = (globalConfig != null ? globalConfig.combine(handlerConfig) : handlerConfig);
config = (config != null ? config.combine(handlerConfig) : handlerConfig);
executionChain = getCorsHandlerExecutionChain(request, executionChain, config);
}
......@@ -488,6 +494,14 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
return chain;
}
/**
* Return {@code true} if there is a {@link CorsConfigurationSource} for this handler.
* @since 5.2
*/
protected boolean hasCorsConfigurationSource(Object handler) {
return handler instanceof CorsConfigurationSource || this.corsConfigurationSource != null;
}
/**
* Retrieve the CORS configuration for the given handler.
* @param handler the handler to check (never {@code null}).
......
......@@ -448,6 +448,13 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
return null;
}
@Override
protected boolean hasCorsConfigurationSource(Object handler) {
return super.hasCorsConfigurationSource(handler) ||
(handler instanceof HandlerMethod && this.mappingRegistry.getCorsConfiguration((HandlerMethod) handler) != null) ||
handler.equals(PREFLIGHT_AMBIGUOUS_MATCH);
}
@Override
protected CorsConfiguration getCorsConfiguration(Object handler, HttpServletRequest request) {
CorsConfiguration corsConfig = super.getCorsConfiguration(handler, request);
......@@ -555,6 +562,7 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
/**
* Return CORS configuration. Thread-safe for concurrent use.
*/
@Nullable
public CorsConfiguration getCorsConfiguration(HandlerMethod handlerMethod) {
HandlerMethod original = handlerMethod.getResolvedFromHandlerMethod();
return this.corsLookup.get(original != null ? original : handlerMethod);
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -138,7 +138,7 @@ public class WebMvcConfigurationSupportExtensionTests {
HandlerExecutionChain chain = rmHandlerMapping.getHandler(new MockHttpServletRequest("GET", "/"));
assertNotNull(chain);
assertNotNull(chain.getInterceptors());
assertEquals(3, chain.getInterceptors().length);
assertEquals(4, chain.getInterceptors().length);
assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[0].getClass());
assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[1].getClass());
assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[2].getClass());
......@@ -177,7 +177,7 @@ public class WebMvcConfigurationSupportExtensionTests {
chain = handlerMapping.getHandler(new MockHttpServletRequest("GET", "/resources/foo.gif"));
assertNotNull(chain);
assertNotNull(chain.getHandler());
assertEquals(Arrays.toString(chain.getInterceptors()), 4, chain.getInterceptors().length);
assertEquals(Arrays.toString(chain.getInterceptors()), 5, chain.getInterceptors().length);
// PathExposingHandlerInterceptor at chain.getInterceptors()[0]
assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[1].getClass());
assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[2].getClass());
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -81,8 +81,7 @@ public class CorsAbstractHandlerMappingTests {
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
assertNotNull(chain);
assertNotNull(chain.getHandler());
assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler"));
assertTrue(chain.getHandler() instanceof SimpleHandler);
}
@Test
......
......@@ -33,6 +33,7 @@ import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.InvalidMediaTypeException;
......@@ -48,7 +49,6 @@ import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsService;
......@@ -495,7 +495,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
@Override
@Nullable
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
if (!this.suppressCors && CorsUtils.isCorsRequest(request)) {
if (!this.suppressCors && (request.getHeader(HttpHeaders.ORIGIN) != null)) {
CorsConfiguration config = new CorsConfiguration();
config.setAllowedOrigins(new ArrayList<>(this.allowedOrigins));
config.addAllowedMethod("*");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册