From 23fa37b08b8cc3fe3a716cdd1d9b612c2eaecb50 Mon Sep 17 00:00:00 2001 From: Sebastien Deleuze Date: Fri, 13 Feb 2015 16:56:09 +0100 Subject: [PATCH] Change SockJS and Websocket default allowedOrigins to same origin This commit adds support for a same origin check that compares Origin header to Host header. It also changes the default setting from all origins allowed to only same origin allowed. Issues: SPR-12697, SPR-12685 (cherry picked from commit 6062e15) --- .../springframework/web/util/WebUtils.java | 48 ++++++++++- .../web/util/WebUtilsTests.java | 62 ++++++++++++++- .../config/HandlersBeanDefinitionParser.java | 8 +- .../MessageBrokerBeanDefinitionParser.java | 8 +- .../config/WebSocketNamespaceUtils.java | 10 +-- .../AbstractWebSocketHandlerRegistration.java | 13 +-- .../annotation/SockJsServiceRegistration.java | 28 +++---- .../StompWebSocketEndpointRegistration.java | 4 +- ...MvcStompWebSocketEndpointRegistration.java | 13 ++- .../WebSocketHandlerRegistration.java | 4 +- .../support/OriginHandshakeInterceptor.java | 26 +++--- .../sockjs/support/AbstractSockJsService.java | 19 +++-- .../sockjs/transport/TransportType.java | 6 +- .../HandlersBeanDefinitionParserTests.java | 10 ++- ...ompWebSocketEndpointRegistrationTests.java | 24 ++++-- .../WebSocketHandlerRegistrationTests.java | 29 +++++-- .../OriginHandshakeInterceptorTests.java | 48 +++++++---- .../sockjs/support/SockJsServiceTests.java | 18 +++++ .../handler/DefaultSockJsServiceTests.java | 29 ++++--- src/asciidoc/index.adoc | 79 +++++++++++++++++-- 20 files changed, 363 insertions(+), 123 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java index fbb626bfd1..61cd603c6e 100644 --- a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.web.util; import java.io.File; import java.io.FileNotFoundException; import java.util.Enumeration; +import java.util.List; import java.util.Map; import java.util.StringTokenizer; import java.util.TreeMap; @@ -32,6 +33,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.springframework.http.server.ServerHttpRequest; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -43,6 +45,7 @@ import org.springframework.util.StringUtils; * * @author Rod Johnson * @author Juergen Hoeller + * @author Sebastien Deleuze */ public abstract class WebUtils { @@ -765,4 +768,47 @@ public abstract class WebUtils { } return result; } + + /** + * Check the given request origin against a list of allowed origins. + * A list containing "*" means that all origins are allowed. + * An empty list means only same origin is allowed. + * + * @return true if the request origin is valid, false otherwise + * @since 4.1.5 + * @see RFC 6454: The Web Origin Concept + */ + public static boolean isValidOrigin(ServerHttpRequest request, List allowedOrigins) { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(allowedOrigins, "Allowed origins must not be null"); + + String origin = request.getHeaders().getOrigin(); + if (origin == null || allowedOrigins.contains("*")) { + return true; + } + else if (allowedOrigins.isEmpty()) { + UriComponents originComponents = UriComponentsBuilder.fromHttpUrl(origin).build(); + UriComponents requestComponents = UriComponentsBuilder.fromHttpRequest(request).build(); + int originPort = getPort(originComponents); + int requestPort = getPort(requestComponents); + return originComponents.getHost().equals(requestComponents.getHost()) && (originPort == requestPort); + } + else { + return allowedOrigins.contains(origin); + } + } + + private static int getPort(UriComponents component) { + int port = component.getPort(); + if (port == -1) { + if ("http".equals(component.getScheme())) { + port = 80; + } + else if ("https".equals(component.getScheme())) { + port = 443; + } + } + return port; + } + } diff --git a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java index de6ff52dc3..e0a66ce26e 100644 --- a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2008 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,18 @@ package org.springframework.web.util; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.util.MultiValueMap; import static org.junit.Assert.*; @@ -30,6 +36,7 @@ import static org.junit.Assert.*; * @author Juergen Hoeller * @author Arjen Poutsma * @author Rossen Stoyanchev + * @author Sebastien Deleuze */ public class WebUtilsTests { @@ -98,4 +105,57 @@ public class WebUtilsTests { assertEquals(Arrays.asList("red", "blue", "green"), variables.get("colors")); } + @Test + public void isValidOrigin() { + List allowedOrigins = new ArrayList<>(); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + + servletRequest.setServerName("mydomain1.com"); + request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com"); + assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); + + servletRequest.setServerName("mydomain1.com"); + request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:80"); + assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); + + servletRequest.setServerName("mydomain1.com"); + servletRequest.setServerPort(443); + request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com"); + assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); + + servletRequest.setServerName("mydomain1.com"); + servletRequest.setServerPort(443); + request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com:443"); + assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); + + servletRequest.setServerName("mydomain1.com"); + servletRequest.setServerPort(123); + request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:123"); + assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); + + servletRequest.setServerName("mydomain1.com"); + request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com"); + assertFalse(WebUtils.isValidOrigin(request, allowedOrigins)); + + servletRequest.setServerName("mydomain1.com"); + request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com"); + assertFalse(WebUtils.isValidOrigin(request, allowedOrigins)); + + allowedOrigins = Arrays.asList("*"); + servletRequest.setServerName("mydomain1.com"); + request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com"); + assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); + + allowedOrigins = Arrays.asList("http://mydomain1.com"); + servletRequest.setServerName("mydomain2.com"); + request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com"); + assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); + + allowedOrigins = Arrays.asList("http://mydomain1.com"); + servletRequest.setServerName("mydomain2.com"); + request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain3.com"); + assertFalse(WebUtils.isValidOrigin(request, allowedOrigins)); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java index c8747da141..1f4e7f6587 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -83,11 +83,7 @@ class HandlersBeanDefinitionParser implements BeanDefinitionParser { ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); String allowedOriginsAttribute = element.getAttribute("allowed-origins"); List allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); - if (!allowedOrigins.isEmpty()) { - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(allowedOrigins); - interceptors.add(interceptor); - } + interceptors.add(new OriginHandshakeInterceptor(allowedOrigins)); strategy = new WebSocketHandlerMappingStrategy(handshakeHandler, interceptors); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index 068e75a778..cbb9353d98 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -288,11 +288,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); String allowedOriginsAttribute = element.getAttribute("allowed-origins"); List allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); - if (!allowedOrigins.isEmpty()) { - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(allowedOrigins); - interceptors.add(interceptor); - } + interceptors.add(new OriginHandshakeInterceptor(allowedOrigins)); ConstructorArgumentValues cavs = new ConstructorArgumentValues(); cavs.addIndexedArgumentValue(0, subProtoHandler); if (handshakeHandler != null) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java index 8756126ceb..d8867b07d1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,12 +105,8 @@ class WebSocketNamespaceUtils { ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); String allowedOriginsAttribute = element.getAttribute("allowed-origins"); List allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); - if (!allowedOrigins.isEmpty()) { - sockJsServiceDef.getPropertyValues().add("allowedOrigins", allowedOrigins); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(allowedOrigins); - interceptors.add(interceptor); - } + sockJsServiceDef.getPropertyValues().add("allowedOrigins", allowedOrigins); + interceptors.add(new OriginHandshakeInterceptor(allowedOrigins)); sockJsServiceDef.getPropertyValues().add("handshakeInterceptors", interceptors); String attrValue = sockJsElement.getAttribute("name"); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java index b326590d4f..58c4ef1f49 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java @@ -88,11 +88,10 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock } @Override - public WebSocketHandlerRegistration setAllowedOrigins(String... origins) { - Assert.notEmpty(origins, "No allowed origin specified"); + public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins) { this.allowedOrigins.clear(); - if (!ObjectUtils.isEmpty(origins)) { - this.allowedOrigins.addAll(Arrays.asList(origins)); + if (!ObjectUtils.isEmpty(allowedOrigins)) { + this.allowedOrigins.addAll(Arrays.asList(allowedOrigins)); } return this; } @@ -117,11 +116,7 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock protected HandshakeInterceptor[] getInterceptors() { List interceptors = new ArrayList(); interceptors.addAll(this.interceptors); - if (!this.allowedOrigins.isEmpty()) { - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(this.allowedOrigins); - interceptors.add(interceptor); - } + interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins)); return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java index 901db2235c..101296b2a0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java @@ -206,6 +206,17 @@ public class SockJsServiceRegistration { return this; } + /** + * @since 4.1.2 + */ + protected SockJsServiceRegistration setAllowedOrigins(String... allowedOrigins) { + this.allowedOrigins.clear(); + if (!ObjectUtils.isEmpty(allowedOrigins)) { + this.allowedOrigins.addAll(Arrays.asList(allowedOrigins)); + } + return this; + } + /** * This option can be used to disable automatic addition of CORS headers for * SockJS requests. @@ -229,17 +240,6 @@ public class SockJsServiceRegistration { return this; } - /** - * @since 4.1.2 - */ - protected SockJsServiceRegistration setAllowedOrigins(String... origins) { - this.allowedOrigins.clear(); - if (!ObjectUtils.isEmpty(origins)) { - this.allowedOrigins.addAll(Arrays.asList(origins)); - } - return this; - } - protected SockJsService getSockJsService() { TransportHandlingSockJsService service = createSockJsService(); service.setHandshakeInterceptors(this.interceptors); @@ -264,12 +264,12 @@ public class SockJsServiceRegistration { if (this.webSocketEnabled != null) { service.setWebSocketEnabled(this.webSocketEnabled); } + if (this.allowedOrigins != null) { + service.setAllowedOrigins(this.allowedOrigins); + } if (this.suppressCors != null) { service.setSuppressCors(this.suppressCors); } - if (!this.allowedOrigins.isEmpty()) { - service.setAllowedOrigins(this.allowedOrigins); - } if (this.messageCodec != null) { service.setMessageCodec(this.messageCodec); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java index 877eedb814..497a3842a4 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java @@ -52,8 +52,8 @@ public interface StompWebSocketEndpointRegistration { * As a consequence, IE 6 to 9 are not supported when origins are restricted. * *

Each provided allowed origin must start by "http://", "https://" or be "*" - * (means that all origins are allowed). Empty allowed origin list is not supported. - * By default, all origins are allowed. + * (means that all origins are allowed). By default, only same origin requests are + * allowed (empty list). * * @since 4.1.2 * @see RFC 6454: The Web Origin Concept diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java index 467dee2d24..ecac159632 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java @@ -85,10 +85,11 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE } @Override - public StompWebSocketEndpointRegistration setAllowedOrigins(String... origins) { - Assert.notEmpty(origins, "No allowed origin specified"); + public StompWebSocketEndpointRegistration setAllowedOrigins(String... allowedOrigins) { this.allowedOrigins.clear(); - this.allowedOrigins.addAll(Arrays.asList(origins)); + if (!ObjectUtils.isEmpty(allowedOrigins)) { + this.allowedOrigins.addAll(Arrays.asList(allowedOrigins)); + } return this; } @@ -112,11 +113,7 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE protected HandshakeInterceptor[] getInterceptors() { List interceptors = new ArrayList(); interceptors.addAll(this.interceptors); - if (!this.allowedOrigins.isEmpty()) { - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(this.allowedOrigins); - interceptors.add(interceptor); - } + interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins)); return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java index eb7eee0eb1..c5acb0e3a0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java @@ -54,8 +54,8 @@ public interface WebSocketHandlerRegistration { * As a consequence, IE 6 to 9 are not supported when origins are restricted. * *

Each provided allowed origin must start by "http://", "https://" or be "*" - * (means that all origins are allowed). Empty allowed origin list is not supported. - * By default, all origins are allowed. + * (means that all origins are allowed). By default, only same origin requests are + * allowed (empty list). * * @since 4.1.2 * @see RFC 6454: The Web Origin Concept diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java index 0f9ee0c829..178fe69882 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java @@ -31,6 +31,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.util.WebUtils; /** * An interceptor to check request {@code Origin} header value against a collection of @@ -47,12 +48,22 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { /** - * Default constructor with no origin allowed. + * Default constructor with only same origin requests allowed. */ public OriginHandshakeInterceptor() { this.allowedOrigins = new ArrayList(); } + /** + * Constructor using the specified allowed origin values. + * + * @see #setAllowedOrigins(Collection) + */ + public OriginHandshakeInterceptor(Collection allowedOrigins) { + this(); + setAllowedOrigins(allowedOrigins); + } + /** * Configure allowed {@code Origin} header values. This check is mostly designed for * browser clients. There is nothing preventing other types of client to modify the @@ -85,7 +96,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { - if (!isValidOrigin(request)) { + if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) { response.setStatusCode(HttpStatus.FORBIDDEN); if (logger.isDebugEnabled()) { logger.debug("Handshake request rejected, Origin header value " @@ -96,17 +107,6 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { return true; } - protected boolean isValidOrigin(ServerHttpRequest request) { - String origin = request.getHeaders().getOrigin(); - if (origin == null) { - return true; - } - if (this.allowedOrigins.contains("*")) { - return true; - } - return this.allowedOrigins.contains(origin); - } - @Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 0b01c24563..1fef45cd23 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -46,12 +46,16 @@ import org.springframework.util.StringUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsService; +import org.springframework.web.util.WebUtils; /** * An abstract base class for {@link SockJsService} implementations that provides SockJS * path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html", * etc). Sub-classes must handle session URLs (i.e. transport-specific requests). * + * By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)} + * to specify a list of allowed origins (a list containing "*" will allow all origins). + * * @author Rossen Stoyanchev * @author Sebastien Deleuze * @since 4.0 @@ -64,6 +68,8 @@ public abstract class AbstractSockJsService implements SockJsService { private static final Random random = new Random(); + private static final String XFRAME_OPTIONS_HEADER = "X-Frame-Options"; + protected final Log logger = LogFactory.getLog(getClass()); @@ -85,7 +91,7 @@ public abstract class AbstractSockJsService implements SockJsService { private boolean webSocketEnabled = true; - private final List allowedOrigins = new ArrayList(Arrays.asList("*")); + private final List allowedOrigins = new ArrayList(); private boolean suppressCors = false; @@ -275,15 +281,14 @@ public abstract class AbstractSockJsService implements SockJsService { * As a consequence, IE 6 to 9 are not supported when origins are restricted. * *

Each provided allowed origin must start by "http://", "https://" or be "*" - * (means that all origins are allowed). Empty allowed origin list is not supported. - * By default, all origins are allowed. + * (means that all origins are allowed). * * @since 4.1.2 * @see RFC 6454: The Web Origin Concept * @see SockJS supported transports by browser */ public void setAllowedOrigins(List allowedOrigins) { - Assert.notEmpty(allowedOrigins, "Allowed origin List must not be empty"); + Assert.notNull(allowedOrigins, "Allowed origin List must not be null"); for (String allowedOrigin : allowedOrigins) { Assert.isTrue( allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") || @@ -360,6 +365,9 @@ public abstract class AbstractSockJsService implements SockJsService { response.setStatusCode(HttpStatus.NOT_FOUND); return; } + if (this.allowedOrigins.isEmpty()) { + response.getHeaders().add(XFRAME_OPTIONS_HEADER, "SAMEORIGIN"); + } logger.debug(requestInfo); this.iframeHandler.handle(request, response); } @@ -438,13 +446,12 @@ public abstract class AbstractSockJsService implements SockJsService { HttpHeaders requestHeaders = request.getHeaders(); HttpHeaders responseHeaders = response.getHeaders(); String origin = requestHeaders.getOrigin(); - String host = requestHeaders.getFirst(HttpHeaders.HOST); if (origin == null) { return true; } - if (!this.allowedOrigins.contains("*") && !this.allowedOrigins.contains(origin)) { + if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) { logger.debug("Request rejected, Origin header value " + origin + " not allowed"); response.setStatusCode(HttpStatus.FORBIDDEN); return false; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java index 04e7f30f6b..e3da359f91 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,9 +45,9 @@ public enum TransportType { XHR_STREAMING("xhr_streaming", HttpMethod.POST, "cors", "jsessionid", "no_cache"), - EVENT_SOURCE("eventsource", HttpMethod.GET, "jsessionid", "no_cache"), + EVENT_SOURCE("eventsource", HttpMethod.GET, "origin", "jsessionid", "no_cache"), - HTML_FILE("htmlfile", HttpMethod.GET, "jsessionid", "no_cache"); + HTML_FILE("htmlfile", HttpMethod.GET, "cors", "jsessionid", "no_cache"); private final String value; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java index 97c0e1326d..811cb21a93 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -106,7 +106,8 @@ public class HandlersBeanDefinitionParserTests { HandshakeHandler handshakeHandler = handler.getHandshakeHandler(); assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof DefaultHandshakeHandler); - assertTrue(handler.getHandshakeInterceptors().isEmpty()); + assertFalse(handler.getHandshakeInterceptors().isEmpty()); + assertTrue(handler.getHandshakeInterceptors().get(0) instanceof OriginHandshakeInterceptor); } else { assertThat(shm.getUrlMap().keySet(), contains("/test")); @@ -116,7 +117,8 @@ public class HandlersBeanDefinitionParserTests { HandshakeHandler handshakeHandler = handler.getHandshakeHandler(); assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof DefaultHandshakeHandler); - assertTrue(handler.getHandshakeInterceptors().isEmpty()); + assertFalse(handler.getHandshakeInterceptors().isEmpty()); + assertTrue(handler.getHandshakeInterceptors().get(0) instanceof OriginHandshakeInterceptor); } } } @@ -196,7 +198,7 @@ public class HandlersBeanDefinitionParserTests { assertEquals(TestHandshakeHandler.class, handler.getHandshakeHandler().getClass()); List interceptors = defaultSockJsService.getHandshakeInterceptors(); - assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); + assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java index a149176893..d323e2b843 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java @@ -71,7 +71,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { Map.Entry> entry = mappings.entrySet().iterator().next(); assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler()); - assertTrue(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().isEmpty()); + assertEquals(1, ((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().size()); assertEquals(Arrays.asList("/foo"), entry.getValue()); } @@ -80,7 +80,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); - registration.setAllowedOrigins("http://mydomain.com"); + registration.setAllowedOrigins(); MultiValueMap mappings = registration.getMappings(); assertEquals(1, mappings.size()); @@ -90,10 +90,18 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass()); } - @Test(expected = IllegalArgumentException.class) - public void noAllowedOrigin() { + @Test + public void sameOrigin() { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); + registration.setAllowedOrigins(); + + MultiValueMap mappings = registration.getMappings(); + assertEquals(1, mappings.size()); + WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); + assertNotNull(requestHandler.getWebSocketHandler()); + assertEquals(1, requestHandler.getHandshakeInterceptors().size()); + assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass()); } @Test @@ -158,7 +166,9 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey(); assertNotNull(requestHandler.getWebSocketHandler()); assertSame(handshakeHandler, requestHandler.getHandshakeHandler()); - assertEquals(Arrays.asList(interceptor), requestHandler.getHandshakeInterceptors()); + assertEquals(2, requestHandler.getHandshakeInterceptors().size()); + assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0)); + assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass()); } @Test @@ -210,7 +220,9 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { Map handlers = sockJsService.getTransportHandlers(); WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET); assertSame(handshakeHandler, transportHandler.getHandshakeHandler()); - assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors()); + assertEquals(2, sockJsService.getHandshakeInterceptors().size()); + assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0)); + assertEquals(OriginHandshakeInterceptor.class, sockJsService.getHandshakeInterceptors().get(1).getClass()); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java index 751b268b6e..7591558e10 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java @@ -69,12 +69,14 @@ public class WebSocketHandlerRegistrationTests { Mapping m1 = mappings.get(0); assertEquals(handler, m1.webSocketHandler); assertEquals("/foo", m1.path); - assertEquals(0, m1.interceptors.length); + assertEquals(1, m1.interceptors.length); + assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass()); Mapping m2 = mappings.get(1); assertEquals(handler, m2.webSocketHandler); assertEquals("/bar", m2.path); - assertEquals(0, m2.interceptors.length); + assertEquals(1, m2.interceptors.length); + assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass()); } @Test @@ -90,12 +92,27 @@ public class WebSocketHandlerRegistrationTests { Mapping mapping = mappings.get(0); assertEquals(handler, mapping.webSocketHandler); assertEquals("/foo", mapping.path); - assertArrayEquals(new HandshakeInterceptor[] {interceptor}, mapping.interceptors); + assertEquals(2, mapping.interceptors.length); + assertEquals(interceptor, mapping.interceptors[0]); + assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass()); } - @Test(expected = IllegalArgumentException.class) - public void noAllowedOrigin() { - this.registration.addHandler(Mockito.mock(WebSocketHandler.class), "/foo").setAllowedOrigins(); + @Test + public void emptyAllowedOrigin() { + WebSocketHandler handler = new TextWebSocketHandler(); + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + + this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins(); + + List mappings = this.registration.getMappings(); + assertEquals(1, mappings.size()); + + Mapping mapping = mappings.get(0); + assertEquals(handler, mapping.webSocketHandler); + assertEquals("/foo", mapping.path); + assertEquals(2, mapping.interceptors.length); + assertEquals(interceptor, mapping.interceptors[0]); + assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass()); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java index a970949820..866e9569dc 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java @@ -39,20 +39,22 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { @Test(expected = IllegalArgumentException.class) public void nullAllowedOriginList() { - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(null); + new OriginHandshakeInterceptor(null); } @Test(expected = IllegalArgumentException.class) public void invalidAllowedOrigin() { - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(Arrays.asList("domain.com")); + new OriginHandshakeInterceptor(Arrays.asList("domain.com")); + } + + @Test + public void emtpyAllowedOriginList() { + new OriginHandshakeInterceptor(Arrays.asList()); } @Test public void validAllowedOrigins() { - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*")); + new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*")); } @Test @@ -60,8 +62,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); setOrigin("http://mydomain1.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com")); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -71,8 +72,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); setOrigin("http://mydomain1.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(Arrays.asList("http://mydomain2.com")); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com")); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -82,8 +82,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); setOrigin("http://mydomain2.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -93,8 +92,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); setOrigin("http://mydomain4.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -123,4 +121,26 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } + @Test + public void sameOriginMatch() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + setOrigin("http://mydomain2.com"); + this.servletRequest.setServerName("mydomain2.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList()); + assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + + @Test + public void sameOriginNoMatch() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + setOrigin("http://mydomain3.com"); + this.servletRequest.setServerName("mydomain2.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList()); + assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java index f2a1126d6f..c29a5fe2d4 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java @@ -110,6 +110,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { @Test // SPR-12226 and SPR-12660 public void handleInfoGetWithOrigin() throws Exception { + this.servletRequest.setServerName("mydomain2.com"); setOrigin("http://mydomain2.com"); resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); @@ -135,6 +136,12 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin")); assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); assertEquals("Origin", this.servletResponse.getHeader("Vary")); + + this.service.setAllowedOrigins(Arrays.asList("*")); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); + assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertEquals("Origin", this.servletResponse.getHeader("Vary")); } @Test // SPR-11443 @@ -186,6 +193,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { @Test // SPR-12226 and SPR-12660 public void handleInfoOptionsWithOrigin() throws Exception { + this.servletRequest.setServerName("mydomain2.com"); setOrigin("http://mydomain2.com"); this.request.getHeaders().add("Access-Control-Request-Headers", "Last-Modified"); resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); @@ -216,6 +224,16 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods")); assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age")); assertEquals("Origin", this.servletResponse.getHeader("Vary")); + + this.service.setAllowedOrigins(Arrays.asList("*")); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); + this.response.flush(); + assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers")); + assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods")); + assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age")); + assertEquals("Origin", this.servletResponse.getHeader("Vary")); } @Test // SPR-12283 diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index 35c049d143..dc1dce9609 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -122,19 +122,15 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { assertSame(xhrHandler, handlers.get(xhrHandler.getTransportType())); } - @Test - public void defaultAllowedOrigin() { - assertThat(this.service.getAllowedOrigins(), Matchers.contains("*")); - } - @Test(expected = IllegalArgumentException.class) public void nullAllowedOriginList() { this.service.setAllowedOrigins(null); } - @Test(expected = IllegalArgumentException.class) + @Test public void emptyAllowedOriginList() { this.service.setAllowedOrigins(Arrays.asList()); + assertThat(this.service.getAllowedOrigins(), Matchers.empty()); } @Test(expected = IllegalArgumentException.class) @@ -271,13 +267,19 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { String sockJsPath = sessionUrlPrefix+ "jsonp"; setRequest("GET", sockJsPrefix + sockJsPath); jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); - assertNotEquals(404, this.servletResponse.getStatus()); + assertEquals(404, this.servletResponse.getStatus()); resetRequestAndResponse(); jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); setRequest("GET", sockJsPrefix + sockJsPath); jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertEquals(404, this.servletResponse.getStatus()); + + resetRequestAndResponse(); + jsonpService.setAllowedOrigins(Arrays.asList("*")); + setRequest("GET", sockJsPrefix + sockJsPath); + jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + assertNotEquals(404, this.servletResponse.getStatus()); } @Test @@ -289,8 +291,7 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { assertNotEquals(403, this.servletResponse.getStatus()); resetRequestAndResponse(); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com")); wsService.setHandshakeInterceptors(Arrays.asList(interceptor)); setRequest("GET", sockJsPrefix + sockJsPath); setOrigin("http://mydomain1.com"); @@ -310,13 +311,21 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { setRequest("GET", sockJsPrefix + sockJsPath); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertNotEquals(404, this.servletResponse.getStatus()); - assertNull(this.servletResponse.getHeader("X-Frame-Options")); + assertEquals("SAMEORIGIN", this.servletResponse.getHeader("X-Frame-Options")); resetRequestAndResponse(); setRequest("GET", sockJsPrefix + sockJsPath); this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertEquals(404, this.servletResponse.getStatus()); + assertNull(this.servletResponse.getHeader("X-Frame-Options")); + + resetRequestAndResponse(); + setRequest("GET", sockJsPrefix + sockJsPath); + this.service.setAllowedOrigins(Arrays.asList("*")); + this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + assertNotEquals(404, this.servletResponse.getStatus()); + assertNull(this.servletResponse.getHeader("X-Frame-Options")); } diff --git a/src/asciidoc/index.adoc b/src/asciidoc/index.adoc index e8b86991b8..89ca09acf0 100644 --- a/src/asciidoc/index.adoc +++ b/src/asciidoc/index.adoc @@ -39465,7 +39465,76 @@ or WebSocket XML namespace: ---- +[[websocket-server-allowed-origins]] +==== Configuring allowed origins +As of Spring Framework 4.1.5, Websocket and SockJS default behavior is to accept only same +origin requests. It is also possible to allow all or a specified list of origins. +This check is mostly designed for browser clients. There is nothing preventing other types +of client to modify the `Origin` header value (see +https://tools.ietf.org/html/rfc6454[RFC 6454: The Web Origin Concept] for more details). + +The 3 possible behaviors are: + + * Allow only same origin requests (default): in this mode, when SockJS is enabled, the + Iframe HTTP response header `X-Frame-Options` is set to `SAMEORIGIN`, and JSONP + transport is disabled since it does not allow to check the origin of a request. + As a consequence, IE6 and IE7 are not supported when this mode is enabled. + * Allow a specified list of origins: each provided allowed origin must start by `http://` + or `https://`. In this mode, when SockJS is enabled, both IFrame and JSONP based + transports are disabled. As a consequence, IE6 up to IE9 are not supported when this + mode is enabled. + * Allow all origins: to enable this mode, you should provide `*` as allowed origin. In this + mode, all transports are available. + +Websocket and SockJS allowed origins can be configured as shown bellow: + +[source,java,indent=0] +[subs="verbatim,quotes"] +---- + import org.springframework.web.socket.config.annotation.EnableWebSocket; + import org.springframework.web.socket.config.annotation.WebSocketConfigurer; + import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; + + @Configuration + @EnableWebSocket + public class WebSocketConfig implements WebSocketConfigurer { + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + registry.addHandler(myHandler(), "/myHandler").setAllowedOrigins("http://mydomain.com"); + } + + @Bean + public WebSocketHandler myHandler() { + return new MyHandler(); + } + + } +---- + +XML configuration equivalent: + +[source,xml,indent=0] +[subs="verbatim,quotes,attributes"] +---- + + + + + + + + + +---- [[websocket-fallback]] @@ -39732,11 +39801,11 @@ log category to TRACE. [[websocket-fallback-cors]] ==== CORS Headers for SockJS -The SockJS protocol uses CORS for cross-domain support in the XHR streaming and -polling transports. Therefore CORS headers are added automatically unless the -presence of CORS headers in the response is detected. So if an application is -already configured to provide CORS support, e.g. through a Servlet Filter, -Spring's SockJsService will skip this part. +If you allow cross-origin requests (see <>), the SockJS protocol +uses CORS for cross-domain support in the XHR streaming and polling transports. Therefore +CORS headers are added automatically unless the presence of CORS headers in the response +is detected. So if an application is already configured to provide CORS support, e.g. +through a Servlet Filter, Spring's SockJsService will skip this part. It is also possible to disable the addition of these CORS headers thanks to the `suppressCors` property in Spring's SockJsService. -- GitLab