From 68ecb92d1f3f9599d54dd48a1feff7e7af07544f Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Sun, 3 May 2015 10:23:13 +0200 Subject: [PATCH] Allow "ws" and "wss" for isValidCorsOrigin checks Issue: SPR-12956 --- .../web/util/UriComponentsBuilder.java | 23 +++++ .../springframework/web/util/WebUtils.java | 25 ++--- .../web/util/WebUtilsTests.java | 91 ++++++++----------- .../support/OriginHandshakeInterceptor.java | 17 ++-- .../sockjs/support/AbstractSockJsService.java | 33 +++---- .../OriginHandshakeInterceptorTests.java | 37 +++----- .../handler/DefaultSockJsServiceTests.java | 39 +++----- 7 files changed, 117 insertions(+), 148 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java index e2251a3ece..95ec54bdb6 100644 --- a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java @@ -317,6 +317,29 @@ public class UriComponentsBuilder implements Cloneable { } + /** + * Create an instance by parsing the "origin" header of an HTTP request. + */ + public static UriComponentsBuilder fromOriginHeader(String origin) { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + if (StringUtils.hasText(origin)) { + int schemaIdx = origin.indexOf("://"); + String schema = (schemaIdx != -1 ? origin.substring(0, schemaIdx) : "http"); + builder.scheme(schema); + String hostString = (schemaIdx != -1 ? origin.substring(schemaIdx + 3) : origin); + if (hostString.contains(":")) { + String[] hostAndPort = StringUtils.split(hostString, ":"); + builder.host(hostAndPort[0]); + builder.port(Integer.parseInt(hostAndPort[1])); + } + else { + builder.host(hostString); + } + } + return builder; + } + + // build methods /** diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java index 572e457ae3..f2e2a76755 100644 --- a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -23,6 +23,7 @@ import java.util.Enumeration; import java.util.Map; import java.util.StringTokenizer; import java.util.TreeMap; + import javax.servlet.ServletContext; import javax.servlet.ServletRequest; import javax.servlet.ServletRequestWrapper; @@ -38,6 +39,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpRequest; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -790,21 +792,10 @@ public abstract class WebUtils { if (origin == null || allowedOrigins.contains("*")) { return true; } - else if (allowedOrigins.isEmpty()) { - UriComponents originComponents; - try { - originComponents = UriComponentsBuilder.fromHttpUrl(origin).build(); - } - catch (IllegalArgumentException ex) { - if (logger.isWarnEnabled()) { - logger.warn("Failed to parse Origin header value [" + origin + "]"); - } - return false; - } - UriComponents requestComponents = UriComponentsBuilder.fromHttpRequest(request).build(); - int originPort = getPort(originComponents); - int requestPort = getPort(requestComponents); - return (originComponents.getHost().equals(requestComponents.getHost()) && originPort == requestPort); + else if (CollectionUtils.isEmpty(allowedOrigins)) { + UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build(); + UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); + return (actualUrl.getHost().equals(originUrl.getHost()) && getPort(actualUrl) == getPort(originUrl)); } else { return allowedOrigins.contains(origin); @@ -814,10 +805,10 @@ public abstract class WebUtils { private static int getPort(UriComponents component) { int port = component.getPort(); if (port == -1) { - if ("http".equals(component.getScheme())) { + if ("http".equals(component.getScheme()) || "ws".equals(component.getScheme())) { port = 80; } - else if ("https".equals(component.getScheme())) { + else if ("https".equals(component.getScheme()) || "wss".equals(component.getScheme())) { port = 443; } } diff --git a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java index be8a2cec53..29169d2a64 100644 --- a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java @@ -16,8 +16,8 @@ package org.springframework.web.util; -import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -106,60 +106,45 @@ public class WebUtilsTests { } @Test - public void isValidOrigin() { - List allowedOrigins = new ArrayList<>(); + public void isValidOriginSuccess() { + + List allowed = Collections.emptyList(); + assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com", allowed)); + assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com:80", allowed)); + assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com", allowed)); + assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com:443", allowed)); + assertTrue(checkOrigin("mydomain1.com", 123, "http://mydomain1.com:123", allowed)); + assertTrue(checkOrigin("mydomain1.com", -1, "ws://mydomain1.com", allowed)); + assertTrue(checkOrigin("mydomain1.com", 443, "wss://mydomain1.com", allowed)); + + allowed = Collections.singletonList("*"); + assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed)); + + allowed = Collections.singletonList("http://mydomain1.com"); + assertTrue(checkOrigin("mydomain2.com", -1, "http://mydomain1.com", allowed)); + } + + @Test + public void isValidOriginFailure() { + + List allowed = Collections.emptyList(); + assertFalse(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed)); + assertFalse(checkOrigin("mydomain1.com", -1, "https://mydomain1.com", allowed)); + assertFalse(checkOrigin("mydomain1.com", -1, "invalid-origin", allowed)); + + allowed = Collections.singletonList("http://mydomain1.com"); + assertFalse(checkOrigin("mydomain2.com", -1, "http://mydomain3.com", allowed)); + } + + private boolean checkOrigin(String serverName, int port, String originHeader, List allowed) { MockHttpServletRequest servletRequest = new MockHttpServletRequest(); ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); - - servletRequest.setServerName("mydomain1.com"); - request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com"); - assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); - - servletRequest.setServerName("mydomain1.com"); - request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:80"); - assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); - - servletRequest.setServerName("mydomain1.com"); - servletRequest.setServerPort(443); - request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com"); - assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); - - servletRequest.setServerName("mydomain1.com"); - servletRequest.setServerPort(443); - request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com:443"); - assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); - - servletRequest.setServerName("mydomain1.com"); - servletRequest.setServerPort(123); - request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:123"); - assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); - - servletRequest.setServerName("mydomain1.com"); - request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com"); - assertFalse(WebUtils.isValidOrigin(request, allowedOrigins)); - - servletRequest.setServerName("mydomain1.com"); - request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com"); - assertFalse(WebUtils.isValidOrigin(request, allowedOrigins)); - - servletRequest.setServerName("invalid-origin"); - request.getHeaders().set(HttpHeaders.ORIGIN, "invalid-origin"); - assertFalse(WebUtils.isValidOrigin(request, allowedOrigins)); - - allowedOrigins = Arrays.asList("*"); - servletRequest.setServerName("mydomain1.com"); - request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com"); - assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); - - allowedOrigins = Arrays.asList("http://mydomain1.com"); - servletRequest.setServerName("mydomain2.com"); - request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com"); - assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); - - allowedOrigins = Arrays.asList("http://mydomain1.com"); - servletRequest.setServerName("mydomain2.com"); - request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain3.com"); - assertFalse(WebUtils.isValidOrigin(request, allowedOrigins)); + servletRequest.setServerName(serverName); + if (port != -1) { + servletRequest.setServerPort(port); + } + request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + return WebUtils.isValidOrigin(request, allowed); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java index 178fe69882..60b3c34f78 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java @@ -65,22 +65,18 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { } /** - * Configure allowed {@code Origin} header values. This check is mostly designed for - * browser clients. There is nothing preventing other types of client to modify the - * {@code Origin} header value. + * Configure allowed {@code Origin} header values. This check is mostly + * designed for browsers. There is nothing preventing other types of client + * to modify the {@code Origin} header value. * - *

Each provided allowed origin must start by "http://", "https://" or be "*" - * (means that all origins are allowed). + *

Each provided allowed origin must have a scheme, and optionally a port + * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin + * string may also be "*" in which case all origins are allowed. * * @see RFC 6454: The Web Origin Concept */ public void setAllowedOrigins(Collection allowedOrigins) { Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null"); - for (String allowedOrigin : allowedOrigins) { - Assert.isTrue(allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") || - allowedOrigin.startsWith("https://"), "Invalid allowed origin provided: \"" + - allowedOrigin + "\". It must start with \"http://\", \"https://\" or be \"*\""); - } this.allowedOrigins.clear(); this.allowedOrigins.addAll(allowedOrigins); } @@ -93,6 +89,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { return Collections.unmodifiableList(this.allowedOrigins); } + @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 2e7532c947..e58878fb70 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -276,16 +276,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig } /** - * Configure allowed {@code Origin} header values. This check is mostly designed for - * browser clients. There is nothing preventing other types of client to modify the - * {@code Origin} header value. + * Configure allowed {@code Origin} header values. This check is mostly + * designed for browsers. There is nothing preventing other types of client + * to modify the {@code Origin} header value. * - *

When SockJS is enabled and origins are restricted, transport types that do not - * allow to check request origin (JSONP and Iframe based transports) are disabled. - * As a consequence, IE 6 to 9 are not supported when origins are restricted. + *

When SockJS is enabled and origins are restricted, transport types + * that do not allow to check request origin (JSONP and Iframe based + * transports) are disabled. As a consequence, IE 6 to 9 are not supported + * when origins are restricted. * - *

Each provided allowed origin must start by "http://", "https://" or be "*" - * (means that all origins are allowed). + *

Each provided allowed origin must have a scheme, and optionally a port + * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin + * string may also be "*" in which case all origins are allowed. * * @since 4.1.2 * @see RFC 6454: The Web Origin Concept @@ -293,14 +295,6 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig */ public void setAllowedOrigins(List allowedOrigins) { Assert.notNull(allowedOrigins, "Allowed origin List must not be null"); - for (String allowedOrigin : allowedOrigins) { - Assert.isTrue( - allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") || - allowedOrigin.startsWith("https://"), - "Invalid allowed origin provided: \"" + - allowedOrigin + - "\". It must start with \"http://\", \"https://\" or be \"*\""); - } this.allowedOrigins.clear(); this.allowedOrigins.addAll(allowedOrigins); } @@ -451,7 +445,9 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException; - protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) throws IOException { + protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, + HttpMethod... httpMethods) throws IOException { + String origin = request.getHeaders().getOrigin(); if (origin == null) { @@ -514,7 +510,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig addNoCacheHeaders(response); if (checkOrigin(request, response)) { response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET)); - String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled()); + String content = String.format(INFO_CONTENT, random.nextInt(), + isSessionCookieNeeded(), isWebSocketEnabled()); response.getBody().write(content.getBytes()); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java index 14e132bcf3..ec87b68154 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java @@ -17,7 +17,9 @@ package org.springframework.web.socket.server.support; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentSkipListSet; @@ -39,31 +41,17 @@ import org.springframework.web.socket.WebSocketHandler; public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { @Test(expected = IllegalArgumentException.class) - public void nullAllowedOriginList() { + public void invalidInput() { new OriginHandshakeInterceptor(null); } - @Test(expected = IllegalArgumentException.class) - public void invalidAllowedOrigin() { - new OriginHandshakeInterceptor(Arrays.asList("domain.com")); - } - - @Test - public void emtpyAllowedOriginList() { - new OriginHandshakeInterceptor(Arrays.asList()); - } - - @Test - public void validAllowedOrigins() { - new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*")); - } - @Test public void originValueMatch() throws Exception { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com")); + List allowed = Collections.singletonList("http://mydomain1.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -73,7 +61,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com")); + List allowed = Collections.singletonList("http://mydomain2.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -83,7 +72,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); + List allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -93,7 +83,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); + List allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -117,7 +108,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com"); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); - interceptor.setAllowedOrigins(Arrays.asList("*")); + interceptor.setAllowedOrigins(Collections.singletonList("*")); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -128,7 +119,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); this.servletRequest.setServerName("mydomain2.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList()); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList()); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } @@ -139,7 +130,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com"); this.servletRequest.setServerName("mydomain2.com"); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList()); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList()); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index bf2c43cd35..4b40baa42a 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -16,11 +16,14 @@ package org.springframework.web.socket.sockjs.transport.handler; +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; -import org.hamcrest.Matchers; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -41,9 +44,6 @@ import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig; import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession; -import static org.junit.Assert.*; -import static org.mockito.BDDMockito.*; - /** * Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}. * @@ -125,26 +125,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { } @Test(expected = IllegalArgumentException.class) - public void nullAllowedOriginList() { + public void invalidAllowedOrigins() { this.service.setAllowedOrigins(null); } - @Test - public void emptyAllowedOriginList() { - this.service.setAllowedOrigins(Arrays.asList()); - assertThat(this.service.getAllowedOrigins(), Matchers.empty()); - } - - @Test(expected = IllegalArgumentException.class) - public void invalidAllowedOrigin() { - this.service.setAllowedOrigins(Arrays.asList("domain.com")); - } - - @Test - public void validAllowedOrigins() { - this.service.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*")); - } - @Test public void customizedTransportHandlerList() { TransportHandlingSockJsService service = new TransportHandlingSockJsService( @@ -268,13 +252,13 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { assertEquals(404, this.servletResponse.getStatus()); resetRequestAndResponse(); - jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + jsonpService.setAllowedOrigins(Collections.singletonList("http://mydomain1.com")); setRequest("GET", sockJsPrefix + sockJsPath); jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertEquals(404, this.servletResponse.getStatus()); resetRequestAndResponse(); - jsonpService.setAllowedOrigins(Arrays.asList("*")); + jsonpService.setAllowedOrigins(Collections.singletonList("*")); setRequest("GET", sockJsPrefix + sockJsPath); jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertNotEquals(404, this.servletResponse.getStatus()); @@ -289,8 +273,9 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { assertNotEquals(403, this.servletResponse.getStatus()); resetRequestAndResponse(); - OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com")); - wsService.setHandshakeInterceptors(Arrays.asList(interceptor)); + List allowed = Collections.singletonList("http://mydomain1.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed); + wsService.setHandshakeInterceptors(Collections.singletonList(interceptor)); setRequest("GET", sockJsPrefix + sockJsPath); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com"); wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); @@ -313,14 +298,14 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { resetRequestAndResponse(); setRequest("GET", sockJsPrefix + sockJsPath); - this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + this.service.setAllowedOrigins(Collections.singletonList("http://mydomain1.com")); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertEquals(404, this.servletResponse.getStatus()); assertNull(this.servletResponse.getHeader("X-Frame-Options")); resetRequestAndResponse(); setRequest("GET", sockJsPrefix + sockJsPath); - this.service.setAllowedOrigins(Arrays.asList("*")); + this.service.setAllowedOrigins(Collections.singletonList("*")); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertNotEquals(404, this.servletResponse.getStatus()); assertNull(this.servletResponse.getHeader("X-Frame-Options")); -- GitLab