提交 68ecb92d 编写于 作者: R Rossen Stoyanchev

Allow "ws" and "wss" for isValidCorsOrigin checks

Issue: SPR-12956
上级 222f6998
......@@ -317,6 +317,29 @@ public class UriComponentsBuilder implements Cloneable {
}
/**
* Create an instance by parsing the "origin" header of an HTTP request.
*/
public static UriComponentsBuilder fromOriginHeader(String origin) {
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
if (StringUtils.hasText(origin)) {
int schemaIdx = origin.indexOf("://");
String schema = (schemaIdx != -1 ? origin.substring(0, schemaIdx) : "http");
builder.scheme(schema);
String hostString = (schemaIdx != -1 ? origin.substring(schemaIdx + 3) : origin);
if (hostString.contains(":")) {
String[] hostAndPort = StringUtils.split(hostString, ":");
builder.host(hostAndPort[0]);
builder.port(Integer.parseInt(hostAndPort[1]));
}
else {
builder.host(hostString);
}
}
return builder;
}
// build methods
/**
......
......@@ -23,6 +23,7 @@ import java.util.Enumeration;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.TreeMap;
import javax.servlet.ServletContext;
import javax.servlet.ServletRequest;
import javax.servlet.ServletRequestWrapper;
......@@ -38,6 +39,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpRequest;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
......@@ -790,21 +792,10 @@ public abstract class WebUtils {
if (origin == null || allowedOrigins.contains("*")) {
return true;
}
else if (allowedOrigins.isEmpty()) {
UriComponents originComponents;
try {
originComponents = UriComponentsBuilder.fromHttpUrl(origin).build();
}
catch (IllegalArgumentException ex) {
if (logger.isWarnEnabled()) {
logger.warn("Failed to parse Origin header value [" + origin + "]");
}
return false;
}
UriComponents requestComponents = UriComponentsBuilder.fromHttpRequest(request).build();
int originPort = getPort(originComponents);
int requestPort = getPort(requestComponents);
return (originComponents.getHost().equals(requestComponents.getHost()) && originPort == requestPort);
else if (CollectionUtils.isEmpty(allowedOrigins)) {
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
return (actualUrl.getHost().equals(originUrl.getHost()) && getPort(actualUrl) == getPort(originUrl));
}
else {
return allowedOrigins.contains(origin);
......@@ -814,10 +805,10 @@ public abstract class WebUtils {
private static int getPort(UriComponents component) {
int port = component.getPort();
if (port == -1) {
if ("http".equals(component.getScheme())) {
if ("http".equals(component.getScheme()) || "ws".equals(component.getScheme())) {
port = 80;
}
else if ("https".equals(component.getScheme())) {
else if ("https".equals(component.getScheme()) || "wss".equals(component.getScheme())) {
port = 443;
}
}
......
......@@ -16,8 +16,8 @@
package org.springframework.web.util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
......@@ -106,60 +106,45 @@ public class WebUtilsTests {
}
@Test
public void isValidOrigin() {
List<String> allowedOrigins = new ArrayList<>();
public void isValidOriginSuccess() {
List<String> allowed = Collections.emptyList();
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com", allowed));
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com:80", allowed));
assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com", allowed));
assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com:443", allowed));
assertTrue(checkOrigin("mydomain1.com", 123, "http://mydomain1.com:123", allowed));
assertTrue(checkOrigin("mydomain1.com", -1, "ws://mydomain1.com", allowed));
assertTrue(checkOrigin("mydomain1.com", 443, "wss://mydomain1.com", allowed));
allowed = Collections.singletonList("*");
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed));
allowed = Collections.singletonList("http://mydomain1.com");
assertTrue(checkOrigin("mydomain2.com", -1, "http://mydomain1.com", allowed));
}
@Test
public void isValidOriginFailure() {
List<String> allowed = Collections.emptyList();
assertFalse(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed));
assertFalse(checkOrigin("mydomain1.com", -1, "https://mydomain1.com", allowed));
assertFalse(checkOrigin("mydomain1.com", -1, "invalid-origin", allowed));
allowed = Collections.singletonList("http://mydomain1.com");
assertFalse(checkOrigin("mydomain2.com", -1, "http://mydomain3.com", allowed));
}
private boolean checkOrigin(String serverName, int port, String originHeader, List<String> allowed) {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:80");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(443);
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(443);
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com:443");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(123);
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:123");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("invalid-origin");
request.getHeaders().set(HttpHeaders.ORIGIN, "invalid-origin");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("*");
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("http://mydomain1.com");
servletRequest.setServerName("mydomain2.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("http://mydomain1.com");
servletRequest.setServerName("mydomain2.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain3.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName(serverName);
if (port != -1) {
servletRequest.setServerPort(port);
}
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isValidOrigin(request, allowed);
}
}
......@@ -65,22 +65,18 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
}
/**
* Configure allowed {@code Origin} header values. This check is mostly designed for
* browser clients. There is nothing preventing other types of client to modify the
* {@code Origin} header value.
* Configure allowed {@code Origin} header values. This check is mostly
* designed for browsers. There is nothing preventing other types of client
* to modify the {@code Origin} header value.
*
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed).
* <p>Each provided allowed origin must have a scheme, and optionally a port
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
*
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
*/
public void setAllowedOrigins(Collection<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null");
for (String allowedOrigin : allowedOrigins) {
Assert.isTrue(allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
allowedOrigin.startsWith("https://"), "Invalid allowed origin provided: \"" +
allowedOrigin + "\". It must start with \"http://\", \"https://\" or be \"*\"");
}
this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins);
}
......@@ -93,6 +89,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
return Collections.unmodifiableList(this.allowedOrigins);
}
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
......
......@@ -276,16 +276,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
}
/**
* Configure allowed {@code Origin} header values. This check is mostly designed for
* browser clients. There is nothing preventing other types of client to modify the
* {@code Origin} header value.
* Configure allowed {@code Origin} header values. This check is mostly
* designed for browsers. There is nothing preventing other types of client
* to modify the {@code Origin} header value.
*
* <p>When SockJS is enabled and origins are restricted, transport types that do not
* allow to check request origin (JSONP and Iframe based transports) are disabled.
* As a consequence, IE 6 to 9 are not supported when origins are restricted.
* <p>When SockJS is enabled and origins are restricted, transport types
* that do not allow to check request origin (JSONP and Iframe based
* transports) are disabled. As a consequence, IE 6 to 9 are not supported
* when origins are restricted.
*
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed).
* <p>Each provided allowed origin must have a scheme, and optionally a port
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
*
* @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
......@@ -293,14 +295,6 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
*/
public void setAllowedOrigins(List<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
for (String allowedOrigin : allowedOrigins) {
Assert.isTrue(
allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
allowedOrigin.startsWith("https://"),
"Invalid allowed origin provided: \"" +
allowedOrigin +
"\". It must start with \"http://\", \"https://\" or be \"*\"");
}
this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins);
}
......@@ -451,7 +445,9 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) throws IOException {
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response,
HttpMethod... httpMethods) throws IOException {
String origin = request.getHeaders().getOrigin();
if (origin == null) {
......@@ -514,7 +510,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
addNoCacheHeaders(response);
if (checkOrigin(request, response)) {
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
String content = String.format(INFO_CONTENT, random.nextInt(),
isSessionCookieNeeded(), isWebSocketEnabled());
response.getBody().write(content.getBytes());
}
......
......@@ -17,7 +17,9 @@
package org.springframework.web.socket.server.support;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
......@@ -39,31 +41,17 @@ import org.springframework.web.socket.WebSocketHandler;
public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
@Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() {
public void invalidInput() {
new OriginHandshakeInterceptor(null);
}
@Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() {
new OriginHandshakeInterceptor(Arrays.asList("domain.com"));
}
@Test
public void emtpyAllowedOriginList() {
new OriginHandshakeInterceptor(Arrays.asList());
}
@Test
public void validAllowedOrigins() {
new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*"));
}
@Test
public void originValueMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
List<String> allowed = Collections.singletonList("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -73,7 +61,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com"));
List<String> allowed = Collections.singletonList("http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -83,7 +72,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -93,7 +83,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -117,7 +108,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("*"));
interceptor.setAllowedOrigins(Collections.singletonList("*"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -128,7 +119,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -139,7 +130,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......
......@@ -16,11 +16,14 @@
package org.springframework.web.socket.sockjs.transport.handler;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.hamcrest.Matchers;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
......@@ -41,9 +44,6 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig;
import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
/**
* Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}.
*
......@@ -125,26 +125,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
}
@Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() {
public void invalidAllowedOrigins() {
this.service.setAllowedOrigins(null);
}
@Test
public void emptyAllowedOriginList() {
this.service.setAllowedOrigins(Arrays.asList());
assertThat(this.service.getAllowedOrigins(), Matchers.empty());
}
@Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() {
this.service.setAllowedOrigins(Arrays.asList("domain.com"));
}
@Test
public void validAllowedOrigins() {
this.service.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*"));
}
@Test
public void customizedTransportHandlerList() {
TransportHandlingSockJsService service = new TransportHandlingSockJsService(
......@@ -268,13 +252,13 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
jsonpService.setAllowedOrigins(Collections.singletonList("http://mydomain1.com"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("*"));
jsonpService.setAllowedOrigins(Collections.singletonList("*"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
......@@ -289,8 +273,9 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
List<String> allowed = Collections.singletonList("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
wsService.setHandshakeInterceptors(Collections.singletonList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
......@@ -313,14 +298,14 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
this.service.setAllowedOrigins(Collections.singletonList("http://mydomain1.com"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("*"));
this.service.setAllowedOrigins(Collections.singletonList("*"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册