提交 23fa37b0 编写于 作者: S Sebastien Deleuze 提交者: Juergen Hoeller

Change SockJS and Websocket default allowedOrigins to same origin

This commit adds support for a same origin check that compares
Origin header to Host header. It also changes the default setting
from all origins allowed to only same origin allowed.

Issues: SPR-12697, SPR-12685
(cherry picked from commit 6062e155)
上级 cc78d40c
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -19,6 +19,7 @@ package org.springframework.web.util;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.TreeMap;
......@@ -32,6 +33,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
......@@ -43,6 +45,7 @@ import org.springframework.util.StringUtils;
*
* @author Rod Johnson
* @author Juergen Hoeller
* @author Sebastien Deleuze
*/
public abstract class WebUtils {
......@@ -765,4 +768,47 @@ public abstract class WebUtils {
}
return result;
}
/**
* Check the given request origin against a list of allowed origins.
* A list containing "*" means that all origins are allowed.
* An empty list means only same origin is allowed.
*
* @return true if the request origin is valid, false otherwise
* @since 4.1.5
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
*/
public static boolean isValidOrigin(ServerHttpRequest request, List<String> allowedOrigins) {
Assert.notNull(request, "Request must not be null");
Assert.notNull(allowedOrigins, "Allowed origins must not be null");
String origin = request.getHeaders().getOrigin();
if (origin == null || allowedOrigins.contains("*")) {
return true;
}
else if (allowedOrigins.isEmpty()) {
UriComponents originComponents = UriComponentsBuilder.fromHttpUrl(origin).build();
UriComponents requestComponents = UriComponentsBuilder.fromHttpRequest(request).build();
int originPort = getPort(originComponents);
int requestPort = getPort(requestComponents);
return originComponents.getHost().equals(requestComponents.getHost()) && (originPort == requestPort);
}
else {
return allowedOrigins.contains(origin);
}
}
private static int getPort(UriComponents component) {
int port = component.getPort();
if (port == -1) {
if ("http".equals(component.getScheme())) {
port = 80;
}
else if ("https".equals(component.getScheme())) {
port = 443;
}
}
return port;
}
}
/*
* Copyright 2002-2008 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -16,12 +16,18 @@
package org.springframework.web.util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.util.MultiValueMap;
import static org.junit.Assert.*;
......@@ -30,6 +36,7 @@ import static org.junit.Assert.*;
* @author Juergen Hoeller
* @author Arjen Poutsma
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
*/
public class WebUtilsTests {
......@@ -98,4 +105,57 @@ public class WebUtilsTests {
assertEquals(Arrays.asList("red", "blue", "green"), variables.get("colors"));
}
@Test
public void isValidOrigin() {
List<String> allowedOrigins = new ArrayList<>();
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:80");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(443);
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(443);
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com:443");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(123);
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:123");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("*");
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("http://mydomain1.com");
servletRequest.setServerName("mydomain2.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("http://mydomain1.com");
servletRequest.setServerName("mydomain2.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain3.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
}
}
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -83,11 +83,7 @@ class HandlersBeanDefinitionParser implements BeanDefinitionParser {
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if (!allowedOrigins.isEmpty()) {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
interceptors.add(new OriginHandshakeInterceptor(allowedOrigins));
strategy = new WebSocketHandlerMappingStrategy(handshakeHandler, interceptors);
}
......
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -288,11 +288,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if (!allowedOrigins.isEmpty()) {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
interceptors.add(new OriginHandshakeInterceptor(allowedOrigins));
ConstructorArgumentValues cavs = new ConstructorArgumentValues();
cavs.addIndexedArgumentValue(0, subProtoHandler);
if (handshakeHandler != null) {
......
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -105,12 +105,8 @@ class WebSocketNamespaceUtils {
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if (!allowedOrigins.isEmpty()) {
sockJsServiceDef.getPropertyValues().add("allowedOrigins", allowedOrigins);
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
sockJsServiceDef.getPropertyValues().add("allowedOrigins", allowedOrigins);
interceptors.add(new OriginHandshakeInterceptor(allowedOrigins));
sockJsServiceDef.getPropertyValues().add("handshakeInterceptors", interceptors);
String attrValue = sockJsElement.getAttribute("name");
......
......@@ -88,11 +88,10 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
}
@Override
public WebSocketHandlerRegistration setAllowedOrigins(String... origins) {
Assert.notEmpty(origins, "No allowed origin specified");
public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins) {
this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(origins)) {
this.allowedOrigins.addAll(Arrays.asList(origins));
if (!ObjectUtils.isEmpty(allowedOrigins)) {
this.allowedOrigins.addAll(Arrays.asList(allowedOrigins));
}
return this;
}
......@@ -117,11 +116,7 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
protected HandshakeInterceptor[] getInterceptors() {
List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
interceptors.addAll(this.interceptors);
if (!this.allowedOrigins.isEmpty()) {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(this.allowedOrigins);
interceptors.add(interceptor);
}
interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
}
......
......@@ -206,6 +206,17 @@ public class SockJsServiceRegistration {
return this;
}
/**
* @since 4.1.2
*/
protected SockJsServiceRegistration setAllowedOrigins(String... allowedOrigins) {
this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(allowedOrigins)) {
this.allowedOrigins.addAll(Arrays.asList(allowedOrigins));
}
return this;
}
/**
* This option can be used to disable automatic addition of CORS headers for
* SockJS requests.
......@@ -229,17 +240,6 @@ public class SockJsServiceRegistration {
return this;
}
/**
* @since 4.1.2
*/
protected SockJsServiceRegistration setAllowedOrigins(String... origins) {
this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(origins)) {
this.allowedOrigins.addAll(Arrays.asList(origins));
}
return this;
}
protected SockJsService getSockJsService() {
TransportHandlingSockJsService service = createSockJsService();
service.setHandshakeInterceptors(this.interceptors);
......@@ -264,12 +264,12 @@ public class SockJsServiceRegistration {
if (this.webSocketEnabled != null) {
service.setWebSocketEnabled(this.webSocketEnabled);
}
if (this.allowedOrigins != null) {
service.setAllowedOrigins(this.allowedOrigins);
}
if (this.suppressCors != null) {
service.setSuppressCors(this.suppressCors);
}
if (!this.allowedOrigins.isEmpty()) {
service.setAllowedOrigins(this.allowedOrigins);
}
if (this.messageCodec != null) {
service.setMessageCodec(this.messageCodec);
}
......
......@@ -52,8 +52,8 @@ public interface StompWebSocketEndpointRegistration {
* As a consequence, IE 6 to 9 are not supported when origins are restricted.
*
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed). Empty allowed origin list is not supported.
* By default, all origins are allowed.
* (means that all origins are allowed). By default, only same origin requests are
* allowed (empty list).
*
* @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
......
......@@ -85,10 +85,11 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
}
@Override
public StompWebSocketEndpointRegistration setAllowedOrigins(String... origins) {
Assert.notEmpty(origins, "No allowed origin specified");
public StompWebSocketEndpointRegistration setAllowedOrigins(String... allowedOrigins) {
this.allowedOrigins.clear();
this.allowedOrigins.addAll(Arrays.asList(origins));
if (!ObjectUtils.isEmpty(allowedOrigins)) {
this.allowedOrigins.addAll(Arrays.asList(allowedOrigins));
}
return this;
}
......@@ -112,11 +113,7 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
protected HandshakeInterceptor[] getInterceptors() {
List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
interceptors.addAll(this.interceptors);
if (!this.allowedOrigins.isEmpty()) {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(this.allowedOrigins);
interceptors.add(interceptor);
}
interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
}
......
......@@ -54,8 +54,8 @@ public interface WebSocketHandlerRegistration {
* As a consequence, IE 6 to 9 are not supported when origins are restricted.
*
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed). Empty allowed origin list is not supported.
* By default, all origins are allowed.
* (means that all origins are allowed). By default, only same origin requests are
* allowed (empty list).
*
* @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
......
......@@ -31,6 +31,7 @@ import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.WebUtils;
/**
* An interceptor to check request {@code Origin} header value against a collection of
......@@ -47,12 +48,22 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
/**
* Default constructor with no origin allowed.
* Default constructor with only same origin requests allowed.
*/
public OriginHandshakeInterceptor() {
this.allowedOrigins = new ArrayList<String>();
}
/**
* Constructor using the specified allowed origin values.
*
* @see #setAllowedOrigins(Collection)
*/
public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
this();
setAllowedOrigins(allowedOrigins);
}
/**
* Configure allowed {@code Origin} header values. This check is mostly designed for
* browser clients. There is nothing preventing other types of client to modify the
......@@ -85,7 +96,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (!isValidOrigin(request)) {
if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) {
response.setStatusCode(HttpStatus.FORBIDDEN);
if (logger.isDebugEnabled()) {
logger.debug("Handshake request rejected, Origin header value "
......@@ -96,17 +107,6 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
return true;
}
protected boolean isValidOrigin(ServerHttpRequest request) {
String origin = request.getHeaders().getOrigin();
if (origin == null) {
return true;
}
if (this.allowedOrigins.contains("*")) {
return true;
}
return this.allowedOrigins.contains(origin);
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) {
......
......@@ -46,12 +46,16 @@ import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.util.WebUtils;
/**
* An abstract base class for {@link SockJsService} implementations that provides SockJS
* path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html",
* etc). Sub-classes must handle session URLs (i.e. transport-specific requests).
*
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)}
* to specify a list of allowed origins (a list containing "*" will allow all origins).
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
* @since 4.0
......@@ -64,6 +68,8 @@ public abstract class AbstractSockJsService implements SockJsService {
private static final Random random = new Random();
private static final String XFRAME_OPTIONS_HEADER = "X-Frame-Options";
protected final Log logger = LogFactory.getLog(getClass());
......@@ -85,7 +91,7 @@ public abstract class AbstractSockJsService implements SockJsService {
private boolean webSocketEnabled = true;
private final List<String> allowedOrigins = new ArrayList<String>(Arrays.asList("*"));
private final List<String> allowedOrigins = new ArrayList<String>();
private boolean suppressCors = false;
......@@ -275,15 +281,14 @@ public abstract class AbstractSockJsService implements SockJsService {
* As a consequence, IE 6 to 9 are not supported when origins are restricted.
*
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed). Empty allowed origin list is not supported.
* By default, all origins are allowed.
* (means that all origins are allowed).
*
* @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
*/
public void setAllowedOrigins(List<String> allowedOrigins) {
Assert.notEmpty(allowedOrigins, "Allowed origin List must not be empty");
Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
for (String allowedOrigin : allowedOrigins) {
Assert.isTrue(
allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
......@@ -360,6 +365,9 @@ public abstract class AbstractSockJsService implements SockJsService {
response.setStatusCode(HttpStatus.NOT_FOUND);
return;
}
if (this.allowedOrigins.isEmpty()) {
response.getHeaders().add(XFRAME_OPTIONS_HEADER, "SAMEORIGIN");
}
logger.debug(requestInfo);
this.iframeHandler.handle(request, response);
}
......@@ -438,13 +446,12 @@ public abstract class AbstractSockJsService implements SockJsService {
HttpHeaders requestHeaders = request.getHeaders();
HttpHeaders responseHeaders = response.getHeaders();
String origin = requestHeaders.getOrigin();
String host = requestHeaders.getFirst(HttpHeaders.HOST);
if (origin == null) {
return true;
}
if (!this.allowedOrigins.contains("*") && !this.allowedOrigins.contains(origin)) {
if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) {
logger.debug("Request rejected, Origin header value " + origin + " not allowed");
response.setStatusCode(HttpStatus.FORBIDDEN);
return false;
......
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -45,9 +45,9 @@ public enum TransportType {
XHR_STREAMING("xhr_streaming", HttpMethod.POST, "cors", "jsessionid", "no_cache"),
EVENT_SOURCE("eventsource", HttpMethod.GET, "jsessionid", "no_cache"),
EVENT_SOURCE("eventsource", HttpMethod.GET, "origin", "jsessionid", "no_cache"),
HTML_FILE("htmlfile", HttpMethod.GET, "jsessionid", "no_cache");
HTML_FILE("htmlfile", HttpMethod.GET, "cors", "jsessionid", "no_cache");
private final String value;
......
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -106,7 +106,8 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty());
assertFalse(handler.getHandshakeInterceptors().isEmpty());
assertTrue(handler.getHandshakeInterceptors().get(0) instanceof OriginHandshakeInterceptor);
}
else {
assertThat(shm.getUrlMap().keySet(), contains("/test"));
......@@ -116,7 +117,8 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty());
assertFalse(handler.getHandshakeInterceptors().isEmpty());
assertTrue(handler.getHandshakeInterceptors().get(0) instanceof OriginHandshakeInterceptor);
}
}
}
......@@ -196,7 +198,7 @@ public class HandlersBeanDefinitionParserTests {
assertEquals(TestHandshakeHandler.class, handler.getHandshakeHandler().getClass());
List<HandshakeInterceptor> interceptors = defaultSockJsService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
}
@Test
......
......@@ -71,7 +71,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler());
assertTrue(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().isEmpty());
assertEquals(1, ((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().size());
assertEquals(Arrays.asList("/foo"), entry.getValue());
}
......@@ -80,7 +80,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins("http://mydomain.com");
registration.setAllowedOrigins();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
......@@ -90,10 +90,18 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
}
@Test(expected = IllegalArgumentException.class)
public void noAllowedOrigin() {
@Test
public void sameOrigin() {
WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertEquals(1, requestHandler.getHandshakeInterceptors().size());
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
}
@Test
......@@ -158,7 +166,9 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
assertEquals(Arrays.asList(interceptor), requestHandler.getHandshakeInterceptors());
assertEquals(2, requestHandler.getHandshakeInterceptors().size());
assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
}
@Test
......@@ -210,7 +220,9 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors());
assertEquals(2, sockJsService.getHandshakeInterceptors().size());
assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, sockJsService.getHandshakeInterceptors().get(1).getClass());
}
@Test
......
......@@ -69,12 +69,14 @@ public class WebSocketHandlerRegistrationTests {
Mapping m1 = mappings.get(0);
assertEquals(handler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
assertEquals(0, m1.interceptors.length);
assertEquals(1, m1.interceptors.length);
assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass());
Mapping m2 = mappings.get(1);
assertEquals(handler, m2.webSocketHandler);
assertEquals("/bar", m2.path);
assertEquals(0, m2.interceptors.length);
assertEquals(1, m2.interceptors.length);
assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass());
}
@Test
......@@ -90,12 +92,27 @@ public class WebSocketHandlerRegistrationTests {
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertArrayEquals(new HandshakeInterceptor[] {interceptor}, mapping.interceptors);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
}
@Test(expected = IllegalArgumentException.class)
public void noAllowedOrigin() {
this.registration.addHandler(Mockito.mock(WebSocketHandler.class), "/foo").setAllowedOrigins();
@Test
public void emptyAllowedOrigin() {
WebSocketHandler handler = new TextWebSocketHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins();
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
}
@Test
......
......@@ -39,20 +39,22 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
@Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(null);
new OriginHandshakeInterceptor(null);
}
@Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("domain.com"));
new OriginHandshakeInterceptor(Arrays.asList("domain.com"));
}
@Test
public void emtpyAllowedOriginList() {
new OriginHandshakeInterceptor(Arrays.asList());
}
@Test
public void validAllowedOrigins() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*"));
new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*"));
}
@Test
......@@ -60,8 +62,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -71,8 +72,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain2.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -82,8 +82,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -93,8 +92,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
......@@ -123,4 +121,26 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void sameOriginMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void sameOriginNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain3.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
}
......@@ -110,6 +110,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
@Test // SPR-12226 and SPR-12660
public void handleInfoGetWithOrigin() throws Exception {
this.servletRequest.setServerName("mydomain2.com");
setOrigin("http://mydomain2.com");
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
......@@ -135,6 +136,12 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("*"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
}
@Test // SPR-11443
......@@ -186,6 +193,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
@Test // SPR-12226 and SPR-12660
public void handleInfoOptionsWithOrigin() throws Exception {
this.servletRequest.setServerName("mydomain2.com");
setOrigin("http://mydomain2.com");
this.request.getHeaders().add("Access-Control-Request-Headers", "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
......@@ -216,6 +224,16 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("*"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
}
@Test // SPR-12283
......
......@@ -122,19 +122,15 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertSame(xhrHandler, handlers.get(xhrHandler.getTransportType()));
}
@Test
public void defaultAllowedOrigin() {
assertThat(this.service.getAllowedOrigins(), Matchers.contains("*"));
}
@Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() {
this.service.setAllowedOrigins(null);
}
@Test(expected = IllegalArgumentException.class)
@Test
public void emptyAllowedOriginList() {
this.service.setAllowedOrigins(Arrays.asList());
assertThat(this.service.getAllowedOrigins(), Matchers.empty());
}
@Test(expected = IllegalArgumentException.class)
......@@ -271,13 +267,19 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
String sockJsPath = sessionUrlPrefix+ "jsonp";
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("*"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
}
@Test
......@@ -289,8 +291,7 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain1.com");
......@@ -310,13 +311,21 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
assertEquals("SAMEORIGIN", this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("*"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
}
......
......@@ -39465,7 +39465,76 @@ or WebSocket XML namespace:
</beans>
----
[[websocket-server-allowed-origins]]
==== Configuring allowed origins
As of Spring Framework 4.1.5, Websocket and SockJS default behavior is to accept only same
origin requests. It is also possible to allow all or a specified list of origins.
This check is mostly designed for browser clients. There is nothing preventing other types
of client to modify the `Origin` header value (see
https://tools.ietf.org/html/rfc6454[RFC 6454: The Web Origin Concept] for more details).
The 3 possible behaviors are:
* Allow only same origin requests (default): in this mode, when SockJS is enabled, the
Iframe HTTP response header `X-Frame-Options` is set to `SAMEORIGIN`, and JSONP
transport is disabled since it does not allow to check the origin of a request.
As a consequence, IE6 and IE7 are not supported when this mode is enabled.
* Allow a specified list of origins: each provided allowed origin must start by `http://`
or `https://`. In this mode, when SockJS is enabled, both IFrame and JSONP based
transports are disabled. As a consequence, IE6 up to IE9 are not supported when this
mode is enabled.
* Allow all origins: to enable this mode, you should provide `*` as allowed origin. In this
mode, all transports are available.
Websocket and SockJS allowed origins can be configured as shown bellow:
[source,java,indent=0]
[subs="verbatim,quotes"]
----
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(myHandler(), "/myHandler").setAllowedOrigins("http://mydomain.com");
}
@Bean
public WebSocketHandler myHandler() {
return new MyHandler();
}
}
----
XML configuration equivalent:
[source,xml,indent=0]
[subs="verbatim,quotes,attributes"]
----
<beans xmlns="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:websocket="http://www.springframework.org/schema/websocket"
xsi:schemaLocation="
http://www.springframework.org/schema/beans
http://www.springframework.org/schema/beans/spring-beans.xsd
http://www.springframework.org/schema/websocket
http://www.springframework.org/schema/websocket/spring-websocket.xsd">
<websocket:handlers allowed-origins="http://mydomain.com">
<websocket:mapping path="/myHandler" handler="myHandler" />
</websocket:handlers>
<bean id="myHandler" class="org.springframework.samples.MyHandler"/>
</beans>
----
[[websocket-fallback]]
......@@ -39732,11 +39801,11 @@ log category to TRACE.
[[websocket-fallback-cors]]
==== CORS Headers for SockJS
The SockJS protocol uses CORS for cross-domain support in the XHR streaming and
polling transports. Therefore CORS headers are added automatically unless the
presence of CORS headers in the response is detected. So if an application is
already configured to provide CORS support, e.g. through a Servlet Filter,
Spring's SockJsService will skip this part.
If you allow cross-origin requests (see <<websocket-server-allowed-origins>>), the SockJS protocol
uses CORS for cross-domain support in the XHR streaming and polling transports. Therefore
CORS headers are added automatically unless the presence of CORS headers in the response
is detected. So if an application is already configured to provide CORS support, e.g.
through a Servlet Filter, Spring's SockJsService will skip this part.
It is also possible to disable the addition of these CORS headers thanks to the
`suppressCors` property in Spring's SockJsService.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册