From 84138abfd1b436dd7046933e6cecd76fe37cd1ce Mon Sep 17 00:00:00 2001 From: Sebastien Deleuze Date: Mon, 13 Jul 2015 10:59:19 +0200 Subject: [PATCH] Avoid rejecting same-origin requests detected as CORS requests Browsers like Chrome or Safari include an Origin header for same-origin POST/PUT/DELETE requests, not only for cross-origin requests. Before this commit, these same-origin requests would have been detected as potential cross-origin requests, and rejected if the same-origin domain is not part of the configured allowedOrigins. This commit avoid to reject same-origin requests by reusing the logic introduced in Spring 4.1 for detecting reliably Websocket/SockJS same-origin requests with the WebUtils.isValidOrigin() method. This logic has been extracted in a new WebUtils.isSameOrigin() method. Issue: SPR-13206 --- .../web/cors/DefaultCorsProcessor.java | 14 ++++-- .../springframework/web/util/WebUtils.java | 21 ++++++-- .../web/util/WebUtilsTests.java | 50 +++++++++++-------- 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java index 3abd2facdb..ceec1dbc20 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java @@ -35,6 +35,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.CollectionUtils; +import org.springframework.web.util.WebUtils; /** * Default implementation of {@link CorsProcessor}, as defined by the @@ -42,7 +43,9 @@ import org.springframework.util.CollectionUtils; * *

Note that when input {@link CorsConfiguration} is {@code null}, this * implementation does not reject simple or actual requests outright but simply - * avoid adding CORS headers to the response. + * avoid adding CORS headers to the response. CORS processing is also skipped + * if the response already contains CORS headers, or if the request is detected + * as a same-origin one. * * @author Sebastien Deleuze * @author Rossen Stoyanhcev @@ -66,12 +69,16 @@ public class DefaultCorsProcessor implements CorsProcessor { ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response); ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + if (WebUtils.isSameOrigin(serverRequest)) { + logger.debug("Skip CORS processing, request is a same-origin one"); + return true; + } if (responseHasCors(serverResponse)) { + logger.debug("Skip CORS processing, response already contains \"Access-Control-Allow-Origin\" header"); return true; } boolean preFlightRequest = CorsUtils.isPreFlightRequest(request); - if (config == null) { if (preFlightRequest) { rejectRequest(serverResponse); @@ -93,9 +100,6 @@ public class DefaultCorsProcessor implements CorsProcessor { catch (NullPointerException npe) { // SPR-11919 and https://issues.jboss.org/browse/WFLY-3474 } - if (hasAllowOrigin) { - logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\""); - } return hasAllowOrigin; } diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java index f2b4863dce..9a263bbed1 100644 --- a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -789,15 +789,30 @@ public abstract class WebUtils { return true; } else if (CollectionUtils.isEmpty(allowedOrigins)) { - UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build(); - UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); - return (actualUrl.getHost().equals(originUrl.getHost()) && getPort(actualUrl) == getPort(originUrl)); + return isSameOrigin(request); } else { return allowedOrigins.contains(origin); } } + /** + * Check if the request is a same-origin one, based on {@code Origin}, {@code Host}, + * {@code Forwarded} and {@code X-Forwarded-Host} headers. + * @return {@code true} if the request is a same-origin one, {@code false} in case + * of cross-origin request. + * @since 4.2 + */ + public static boolean isSameOrigin(HttpRequest request) { + String origin = request.getHeaders().getOrigin(); + if (origin == null) { + return true; + } + UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build(); + UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); + return (actualUrl.getHost().equals(originUrl.getHost()) && getPort(actualUrl) == getPort(originUrl)); + } + private static int getPort(UriComponents component) { int port = component.getPort(); if (port == -1) { diff --git a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java index 29169d2a64..b82da3709b 100644 --- a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java @@ -106,37 +106,47 @@ public class WebUtilsTests { } @Test - public void isValidOriginSuccess() { - + public void isValidOrigin() { List allowed = Collections.emptyList(); - assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com", allowed)); - assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com:80", allowed)); - assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com", allowed)); - assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com:443", allowed)); - assertTrue(checkOrigin("mydomain1.com", 123, "http://mydomain1.com:123", allowed)); - assertTrue(checkOrigin("mydomain1.com", -1, "ws://mydomain1.com", allowed)); - assertTrue(checkOrigin("mydomain1.com", 443, "wss://mydomain1.com", allowed)); + assertTrue(checkValidOrigin("mydomain1.com", -1, "http://mydomain1.com", allowed)); + assertFalse(checkValidOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed)); allowed = Collections.singletonList("*"); - assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed)); + assertTrue(checkValidOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed)); allowed = Collections.singletonList("http://mydomain1.com"); - assertTrue(checkOrigin("mydomain2.com", -1, "http://mydomain1.com", allowed)); + assertTrue(checkValidOrigin("mydomain2.com", -1, "http://mydomain1.com", allowed)); + assertFalse(checkValidOrigin("mydomain2.com", -1, "http://mydomain3.com", allowed)); } @Test - public void isValidOriginFailure() { + public void isSameOrigin() { + assertTrue(checkSameOrigin("mydomain1.com", -1, "http://mydomain1.com")); + assertTrue(checkSameOrigin("mydomain1.com", -1, "http://mydomain1.com:80")); + assertTrue(checkSameOrigin("mydomain1.com", 443, "https://mydomain1.com")); + assertTrue(checkSameOrigin("mydomain1.com", 443, "https://mydomain1.com:443")); + assertTrue(checkSameOrigin("mydomain1.com", 123, "http://mydomain1.com:123")); + assertTrue(checkSameOrigin("mydomain1.com", -1, "ws://mydomain1.com")); + assertTrue(checkSameOrigin("mydomain1.com", 443, "wss://mydomain1.com")); + + assertFalse(checkSameOrigin("mydomain1.com", -1, "http://mydomain2.com")); + assertFalse(checkSameOrigin("mydomain1.com", -1, "https://mydomain1.com")); + assertFalse(checkSameOrigin("mydomain1.com", -1, "invalid-origin")); + } - List allowed = Collections.emptyList(); - assertFalse(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed)); - assertFalse(checkOrigin("mydomain1.com", -1, "https://mydomain1.com", allowed)); - assertFalse(checkOrigin("mydomain1.com", -1, "invalid-origin", allowed)); - allowed = Collections.singletonList("http://mydomain1.com"); - assertFalse(checkOrigin("mydomain2.com", -1, "http://mydomain3.com", allowed)); + private boolean checkValidOrigin(String serverName, int port, String originHeader, List allowed) { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + servletRequest.setServerName(serverName); + if (port != -1) { + servletRequest.setServerPort(port); + } + request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + return WebUtils.isValidOrigin(request, allowed); } - private boolean checkOrigin(String serverName, int port, String originHeader, List allowed) { + private boolean checkSameOrigin(String serverName, int port, String originHeader) { MockHttpServletRequest servletRequest = new MockHttpServletRequest(); ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); servletRequest.setServerName(serverName); @@ -144,7 +154,7 @@ public class WebUtilsTests { servletRequest.setServerPort(port); } request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); - return WebUtils.isValidOrigin(request, allowed); + return WebUtils.isSameOrigin(request); } } -- GitLab