提交 3d1ae9c6 编写于 作者: J Juergen Hoeller

Efficient and consistent setAllowedOrigins collection type

Issue: SPR-13761
上级 cd4ce872
......@@ -16,11 +16,11 @@
package org.springframework.web.socket.server.support;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
......@@ -34,8 +34,8 @@ import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.WebUtils;
/**
* An interceptor to check request {@code Origin} header value against a collection of
* allowed origins.
* An interceptor to check request {@code Origin} header value against a
* collection of allowed origins.
*
* @author Sebastien Deleuze
* @since 4.1.2
......@@ -44,60 +44,57 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
protected Log logger = LogFactory.getLog(getClass());
private final List<String> allowedOrigins;
private final Set<String> allowedOrigins = new LinkedHashSet<String>();
/**
* Default constructor with only same origin requests allowed.
*/
public OriginHandshakeInterceptor() {
this.allowedOrigins = new ArrayList<String>();
}
/**
* Constructor using the specified allowed origin values.
*
* @see #setAllowedOrigins(Collection)
*/
public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
this();
setAllowedOrigins(allowedOrigins);
}
/**
* Configure allowed {@code Origin} header values. This check is mostly
* designed for browsers. There is nothing preventing other types of client
* to modify the {@code Origin} header value.
*
* <p>Each provided allowed origin must have a scheme, and optionally a port
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
*
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
*/
public void setAllowedOrigins(Collection<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null");
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins);
}
/**
* @see #setAllowedOrigins(Collection)
* @since 4.1.5
* @see #setAllowedOrigins
*/
public Collection<String> getAllowedOrigins() {
return Collections.unmodifiableList(this.allowedOrigins);
return Collections.unmodifiableSet(this.allowedOrigins);
}
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) {
response.setStatusCode(HttpStatus.FORBIDDEN);
if (logger.isDebugEnabled()) {
logger.debug("Handshake request rejected, Origin header value "
+ request.getHeaders().getOrigin() + " not allowed");
logger.debug("Handshake request rejected, Origin header value " +
request.getHeaders().getOrigin() + " not allowed");
}
return false;
}
......
......@@ -18,13 +18,15 @@ package org.springframework.web.socket.sockjs.support;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
......@@ -56,7 +58,7 @@ import org.springframework.web.util.WebUtils;
* path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html",
* etc). Sub-classes must handle session URLs (i.e. transport-specific requests).
*
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)}
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins}
* to specify a list of allowed origins (a list containing "*" will allow all origins).
*
* @author Rossen Stoyanchev
......@@ -94,10 +96,10 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
private boolean webSocketEnabled = true;
private final List<String> allowedOrigins = new ArrayList<String>();
private boolean suppressCors = false;
protected final Set<String> allowedOrigins = new LinkedHashSet<String>();
public AbstractSockJsService(TaskScheduler scheduler) {
Assert.notNull(scheduler, "TaskScheduler must not be null");
......@@ -274,6 +276,24 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
return this.webSocketEnabled;
}
/**
* This option can be used to disable automatic addition of CORS headers for
* SockJS requests.
* <p>The default value is "false".
* @since 4.1.2
*/
public void setSuppressCors(boolean suppressCors) {
this.suppressCors = suppressCors;
}
/**
* @since 4.1.2
* @see #setSuppressCors(boolean)
*/
public boolean shouldSuppressCors() {
return this.suppressCors;
}
/**
* Configure allowed {@code Origin} header values. This check is mostly
* designed for browsers. There is nothing preventing other types of client
......@@ -289,36 +309,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
*/
public void setAllowedOrigins(List<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
public void setAllowedOrigins(Collection<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins);
}
/**
* @since 4.1.2
* @see #setAllowedOrigins(List)
*/
public List<String> getAllowedOrigins() {
return Collections.unmodifiableList(this.allowedOrigins);
}
/**
* This option can be used to disable automatic addition of CORS headers for
* SockJS requests.
* <p>The default value is "false".
* @since 4.1.2
*/
public void setSuppressCors(boolean suppressCors) {
this.suppressCors = suppressCors;
}
/**
* @since 4.1.2
* @see #setSuppressCors(boolean)
* @see #setAllowedOrigins
*/
public boolean shouldSuppressCors() {
return this.suppressCors;
public Collection<String> getAllowedOrigins() {
return Collections.unmodifiableSet(this.allowedOrigins);
}
......@@ -465,24 +467,11 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
String path = request.getURI().getPath();
int index = path.lastIndexOf('/') + 1;
String filename = path.substring(index);
return filename.indexOf(';') == -1;
return (filename.indexOf(';') == -1);
}
/**
* Handle request for raw WebSocket communication, i.e. without any SockJS message framing.
*/
protected abstract void handleRawWebSocketRequest(ServerHttpRequest request,
ServerHttpResponse response, WebSocketHandler webSocketHandler) throws IOException;
/**
* Handle a SockJS session URL (i.e. transport-specific request).
*/
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response,
HttpMethod... httpMethods) throws IOException {
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods)
throws IOException {
if (WebUtils.isSameOrigin(request)) {
return true;
......@@ -529,6 +518,19 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
}
/**
* Handle request for raw WebSocket communication, i.e. without any SockJS message framing.
*/
protected abstract void handleRawWebSocketRequest(ServerHttpRequest request,
ServerHttpResponse response, WebSocketHandler webSocketHandler) throws IOException;
/**
* Handle a SockJS session URL (i.e. transport-specific request).
*/
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
private interface SockJsRequestHandler {
void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException;
......@@ -546,8 +548,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
addNoCacheHeaders(response);
if (checkOrigin(request, response)) {
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
String content = String.format(INFO_CONTENT, random.nextInt(),
isSessionCookieNeeded(), isWebSocketEnabled());
String content = String.format(
INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
response.getBody().write(content.getBytes());
}
......
......@@ -326,7 +326,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem
return false;
}
if (!getAllowedOrigins().contains("*")) {
if (!this.allowedOrigins.contains("*")) {
TransportType transportType = TransportType.fromValue(transport);
if (transportType == null || !transportType.supportsOrigin()) {
if (logger.isWarnEnabled()) {
......
......@@ -16,18 +16,13 @@
package org.springframework.web.socket.config;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledFuture;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
......@@ -67,6 +62,9 @@ import org.springframework.web.socket.sockjs.transport.handler.XhrPollingTranspo
import org.springframework.web.socket.sockjs.transport.handler.XhrReceivingTransportHandler;
import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTransportHandler;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
/**
* Test fixture for HandlersBeanDefinitionParser.
* See test configuration files websocket-config-handlers-*.xml.
......@@ -76,13 +74,7 @@ import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTrans
*/
public class HandlersBeanDefinitionParserTests {
private GenericWebApplicationContext appContext;
@Before
public void setup() {
this.appContext = new GenericWebApplicationContext();
}
private GenericWebApplicationContext appContext = new GenericWebApplicationContext();
@Test
......@@ -234,10 +226,12 @@ public class HandlersBeanDefinitionParserTests {
List<HandshakeInterceptor> interceptors = transportService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins());
assertTrue(transportService.shouldSuppressCors());
assertTrue(transportService.getAllowedOrigins().contains("http://mydomain1.com"));
assertTrue(transportService.getAllowedOrigins().contains("http://mydomain2.com"));
}
private void loadBeanDefinitions(String fileName) {
XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext);
ClassPathResource resource = new ClassPathResource(fileName, HandlersBeanDefinitionParserTests.class);
......@@ -278,9 +272,11 @@ class TestWebSocketHandler implements WebSocketHandler {
}
}
class FooWebSocketHandler extends TestWebSocketHandler {
}
class TestHandshakeHandler implements HandshakeHandler {
@Override
......@@ -291,9 +287,11 @@ class TestHandshakeHandler implements HandshakeHandler {
}
}
class TestChannelInterceptor extends ChannelInterceptorAdapter {
}
class FooTestInterceptor implements HandshakeInterceptor {
@Override
......@@ -309,9 +307,11 @@ class FooTestInterceptor implements HandshakeInterceptor {
}
}
class BarTestInterceptor extends FooTestInterceptor {
}
@SuppressWarnings({ "unchecked", "rawtypes" })
class TestTaskScheduler implements TaskScheduler {
......@@ -344,9 +344,9 @@ class TestTaskScheduler implements TaskScheduler {
public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) {
return null;
}
}
class TestMessageCodec implements SockJsMessageCodec {
@Override
......@@ -363,4 +363,4 @@ class TestMessageCodec implements SockJsMessageCodec {
public String[] decodeInputStream(InputStream content) throws IOException {
return new String[0];
}
}
\ No newline at end of file
}
......@@ -86,16 +86,8 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
/**
* Test fixture for MessageBrokerBeanDefinitionParser.
......@@ -192,7 +184,8 @@ public class MessageBrokerBeanDefinitionParserTests {
interceptors = defaultSockJsService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins());
assertTrue(defaultSockJsService.getAllowedOrigins().contains("http://mydomain3.com"));
assertTrue(defaultSockJsService.getAllowedOrigins().contains("http://mydomain4.com"));
SimpUserRegistry userRegistry = this.appContext.getBean(SimpUserRegistry.class);
assertNotNull(userRegistry);
......@@ -478,9 +471,9 @@ public class MessageBrokerBeanDefinitionParserTests {
return (handler instanceof WebSocketHandlerDecorator) ?
((WebSocketHandlerDecorator) handler).getLastHandler() : handler;
}
}
class CustomArgumentResolver implements HandlerMethodArgumentResolver {
@Override
......@@ -494,6 +487,7 @@ class CustomArgumentResolver implements HandlerMethodArgumentResolver {
}
}
class CustomReturnValueHandler implements HandlerMethodReturnValueHandler {
@Override
......@@ -507,6 +501,7 @@ class CustomReturnValueHandler implements HandlerMethodReturnValueHandler {
}
}
class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory {
@Override
......@@ -515,6 +510,7 @@ class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorF
}
}
class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator {
public TestWebSocketHandlerDecorator(WebSocketHandler delegate) {
......@@ -528,6 +524,6 @@ class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator {
}
}
class TestStompErrorHandler extends StompSubProtocolErrorHandler {
}
\ No newline at end of file
class TestStompErrorHandler extends StompSubProtocolErrorHandler {
}
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -29,9 +29,9 @@ import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandler;
......@@ -117,7 +117,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getSockJsService());
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
assertTrue(sockJsService.getAllowedOrigins().contains(origin));
assertFalse(sockJsService.shouldSuppressCors());
registration =
......@@ -128,7 +128,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getSockJsService());
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
assertTrue(sockJsService.getAllowedOrigins().contains(origin));
assertFalse(sockJsService.shouldSuppressCors());
}
......@@ -255,7 +255,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class,
sockJsService.getHandshakeInterceptors().get(1).getClass());
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
assertTrue(sockJsService.getAllowedOrigins().contains(origin));
}
}
......@@ -17,7 +17,6 @@
package org.springframework.web.socket.config.annotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Before;
......@@ -29,9 +28,9 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
......@@ -148,8 +147,7 @@ public class WebSocketHandlerRegistrationTests {
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo/**", mapping.path);
assertNotNull(mapping.sockJsService);
assertEquals(Arrays.asList("http://mydomain1.com"),
mapping.sockJsService.getAllowedOrigins());
assertTrue(mapping.sockJsService.getAllowedOrigins().contains("http://mydomain1.com"));
List<HandshakeInterceptor> interceptors = mapping.sockJsService.getHandshakeInterceptors();
assertEquals(interceptor, interceptors.get(0));
assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass());
......@@ -218,6 +216,7 @@ public class WebSocketHandlerRegistrationTests {
}
}
private static class Mapping {
private final WebSocketHandler webSocketHandler;
......@@ -230,7 +229,6 @@ public class WebSocketHandlerRegistrationTests {
private final DefaultSockJsService sockJsService;
public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) {
this.webSocketHandler = handler;
this.path = path;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册