提交 743356fa 编写于 作者: S Sebastien Deleuze

Add an option to set an Origin whitelist for Websocket and SockJS

This commit introduces a new OriginHandshakeInterceptor. It filters
Origin header value against a list of allowed origins.

AbstractSockJsService as been modified to:
 - Reject CORS requests with forbidden origins
 - Disable transport types that does not support CORS when an origin
   check is required
 - Use the Origin request header value instead of "*" for
   Access-Control-Allow-Origin response header value
   (mandatory when  Access-Control-Allow-Credentials=true)
 - Return CORS header only if the request contains an Origin header

It is possible to configure easily this behavior thanks to JavaConfig API
WebSocketHandlerRegistration#addAllowedOrigins(String...) and
StompWebSocketEndpointRegistration#addAllowedOrigins(String...).
It is also possible to configure it using the websocket XML namespace.

Please notice that this commit does not change the default behavior:
cross origin requests are still enabled by default.

Issues: SPR-12226
上级 28a3cd50
......@@ -34,6 +34,7 @@ import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
......@@ -79,7 +80,14 @@ class HandlersBeanDefinitionParser implements BeanDefinitionParser {
else {
RuntimeBeanReference handshakeHandler = WebSocketNamespaceUtils.registerHandshakeHandler(element, context, source);
Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors");
ManagedList<?> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if(!allowedOrigins.isEmpty()) {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
strategy = new WebSocketHandlerMappingStrategy(handshakeHandler, interceptors);
}
......
......@@ -62,6 +62,7 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
......@@ -282,7 +283,14 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
else {
RuntimeBeanReference handshakeHandler = WebSocketNamespaceUtils.registerHandshakeHandler(element, context, source);
Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors");
ManagedList<?> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if(!allowedOrigins.isEmpty()) {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
ConstructorArgumentValues cavs = new ConstructorArgumentValues();
cavs.addIndexedArgumentValue(0, subProtoHandler);
if (handshakeHandler != null) {
......
......@@ -16,6 +16,9 @@
package org.springframework.web.socket.config;
import java.util.Arrays;
import java.util.List;
import org.w3c.dom.Element;
import org.springframework.beans.factory.config.BeanDefinition;
......@@ -25,8 +28,10 @@ import org.springframework.beans.factory.support.ManagedList;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
......@@ -97,7 +102,15 @@ class WebSocketNamespaceUtils {
}
Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors");
ManagedList<?> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if(!allowedOrigins.isEmpty()) {
sockJsServiceDef.getPropertyValues().add("allowedOrigins", allowedOrigins);
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
sockJsServiceDef.getPropertyValues().add("handshakeInterceptors", interceptors);
String attrValue = sockJsElement.getAttribute("name");
......
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -16,15 +16,19 @@
package org.springframework.web.socket.config.annotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
......@@ -34,6 +38,7 @@ import org.springframework.web.socket.sockjs.transport.handler.WebSocketTranspor
* options but allows sub-classes to put together the actual HTTP request mappings.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
* @since 4.0
*/
public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSocketHandlerRegistration {
......@@ -44,7 +49,9 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
private HandshakeHandler handshakeHandler;
private HandshakeInterceptor[] interceptors;
private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
private final List<String> allowedOrigins = new ArrayList<String>();
private SockJsServiceRegistration sockJsServiceRegistration;
......@@ -74,27 +81,49 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
@Override
public WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors) {
this.interceptors = interceptors;
if (!ObjectUtils.isEmpty(interceptors)) {
this.interceptors.addAll(Arrays.asList(interceptors));
}
return this;
}
protected HandshakeInterceptor[] getInterceptors() {
return this.interceptors;
@Override
public WebSocketHandlerRegistration setAllowedOrigins(String... origins) {
this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(origins)) {
this.allowedOrigins.addAll(Arrays.asList(origins));
}
return this;
}
@Override
public SockJsServiceRegistration withSockJS() {
this.sockJsServiceRegistration = new SockJsServiceRegistration(this.sockJsTaskScheduler);
if (this.interceptors != null) {
this.sockJsServiceRegistration.setInterceptors(this.interceptors);
HandshakeInterceptor[] interceptors = getInterceptors();
if (interceptors.length > 0) {
this.sockJsServiceRegistration.setInterceptors(interceptors);
}
if (this.handshakeHandler != null) {
WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
this.sockJsServiceRegistration.setTransportHandlerOverrides(transportHandler);
}
if (!this.allowedOrigins.isEmpty()) {
this.sockJsServiceRegistration.setAllowedOrigins(this.allowedOrigins.toArray(new String[this.allowedOrigins.size()]));
}
return this.sockJsServiceRegistration;
}
protected HandshakeInterceptor[] getInterceptors() {
List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
interceptors.addAll(this.interceptors);
if(!this.allowedOrigins.isEmpty()) {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(this.allowedOrigins);
interceptors.add(interceptor);
}
return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
}
protected final M getMappings() {
M mappings = createMappings();
if (this.sockJsServiceRegistration != null) {
......@@ -108,9 +137,10 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
}
else {
HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
HandshakeInterceptor[] interceptors = getInterceptors();
for (WebSocketHandler wsHandler : this.handlerMap.keySet()) {
for (String path : this.handlerMap.get(wsHandler)) {
addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, this.interceptors, path);
addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, interceptors, path);
}
}
}
......
......@@ -62,6 +62,8 @@ public class SockJsServiceRegistration {
private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
private final List<String> allowedOrigins = new ArrayList<String>();
private SockJsMessageCodec messageCodec;
......@@ -195,6 +197,7 @@ public class SockJsServiceRegistration {
}
public SockJsServiceRegistration setInterceptors(HandshakeInterceptor... interceptors) {
this.interceptors.clear();
if (!ObjectUtils.isEmpty(interceptors)) {
this.interceptors.addAll(Arrays.asList(interceptors));
}
......@@ -213,6 +216,17 @@ public class SockJsServiceRegistration {
return this;
}
/**
* @since 4.1.2
*/
protected SockJsServiceRegistration setAllowedOrigins(String... origins) {
this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(origins)) {
this.allowedOrigins.addAll(Arrays.asList(origins));
}
return this;
}
protected SockJsService getSockJsService() {
TransportHandlingSockJsService service = createSockJsService();
service.setHandshakeInterceptors(this.interceptors);
......@@ -237,6 +251,9 @@ public class SockJsServiceRegistration {
if (this.webSocketEnabled != null) {
service.setWebSocketEnabled(this.webSocketEnabled);
}
if (!this.allowedOrigins.isEmpty()) {
service.setAllowedOrigins(this.allowedOrigins);
}
if (this.messageCodec != null) {
service.setMessageCodec(this.messageCodec);
}
......
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -42,4 +42,19 @@ public interface StompWebSocketEndpointRegistration {
*/
StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors);
/**
* Configure allowed {@code Origin} header values. This check is mostly designed for browser
* clients. There is noting preventing other types of client to modify the Origin header value.
*
* <p>When SockJS is enabled and allowed origins are restricted, transport types that do not
* use {@code Origin} headers for cross origin requests (jsonp-polling, iframe-xhr-polling,
* iframe-eventsource and iframe-htmlfile) are disabled. As a consequence, IE6/IE7 won't be
* supported anymore and IE8/IE9 will only be supported without cookies.
*
* <p>By default, all origins are allowed.
* @since 4.1.2
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
*/
StompWebSocketEndpointRegistration setAllowedOrigins(String... origins);
}
\ No newline at end of file
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -22,15 +22,19 @@ import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ObjectUtils;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import java.util.ArrayList;
import java.util.List;
/**
* An abstract base class class for configuring STOMP over WebSocket/SockJS endpoints.
*
......@@ -47,7 +51,9 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
private HandshakeHandler handshakeHandler;
private HandshakeInterceptor[] interceptors;
private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
private final List<String> allowedOrigins = new ArrayList<String>();
private StompSockJsServiceRegistration registration;
......@@ -72,27 +78,49 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
@Override
public StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors) {
this.interceptors = interceptors;
if (!ObjectUtils.isEmpty(interceptors)) {
this.interceptors.addAll(Arrays.asList(interceptors));
}
return this;
}
protected HandshakeInterceptor[] getInterceptors() {
return this.interceptors;
@Override
public StompWebSocketEndpointRegistration setAllowedOrigins(String... origins) {
this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(origins)) {
this.allowedOrigins.addAll(Arrays.asList(origins));
}
return this;
}
@Override
public SockJsServiceRegistration withSockJS() {
this.registration = new StompSockJsServiceRegistration(this.sockJsTaskScheduler);
if (this.interceptors != null) {
this.registration.setInterceptors(this.interceptors);
HandshakeInterceptor[] interceptors = getInterceptors();
if (interceptors.length > 0) {
this.registration.setInterceptors(interceptors);
}
if (this.handshakeHandler != null) {
WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
this.registration.setTransportHandlerOverrides(transportHandler);
}
if (!this.allowedOrigins.isEmpty()) {
this.registration.setAllowedOrigins(this.allowedOrigins.toArray(new String[this.allowedOrigins.size()]));
}
return this.registration;
}
protected HandshakeInterceptor[] getInterceptors() {
List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
interceptors.addAll(this.interceptors);
if(!this.allowedOrigins.isEmpty()) {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(this.allowedOrigins);
interceptors.add(interceptor);
}
return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
}
public final MultiValueMap<HttpRequestHandler, String> getMappings() {
MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<HttpRequestHandler, String>();
if (this.registration != null) {
......@@ -112,8 +140,9 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
else {
handler = new WebSocketHttpRequestHandler(this.webSocketHandler);
}
if (this.interceptors != null) {
handler.setHandshakeInterceptors(Arrays.asList(this.interceptors));
HandshakeInterceptor[] interceptors = getInterceptors();
if (interceptors.length > 0) {
handler.setHandshakeInterceptors(Arrays.asList(interceptors));
}
mappings.add(handler, path);
}
......
......@@ -44,9 +44,27 @@ public interface WebSocketHandlerRegistration {
*/
WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors);
/**
* Configure allowed {@code Origin} header values. This check is mostly designed for browser
* clients. There is noting preventing other types of client to modify the Origin header value.
*
* <p>When SockJS is enabled and allowed origins are restricted, transport types that do not
* use {@code Origin} headers for cross origin requests (jsonp-polling, iframe-xhr-polling,
* iframe-eventsource and iframe-htmlfile) are disabled. As a consequence, IE6/IE7 won't be
* supported anymore and IE8/IE9 will only be supported without cookies.
*
* <p>By default, all origins are allowed.
*
* @since 4.1.2
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
*/
WebSocketHandlerRegistration setAllowedOrigins(String... origins);
/**
* Enable SockJS fallback options.
*/
SockJsServiceRegistration withSockJS();
}
\ No newline at end of file
......@@ -260,6 +260,11 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
Arrays.asList(StringUtils.arrayToCommaDelimitedString(getSupportedVersions())));
}
/**
* Return whether the request {@code Origin} header value is valid or not.
* By default, all origins as considered as valid. Consider using an
* {@link OriginHandshakeInterceptor} for filtering origins if needed.
*/
protected boolean isValidOrigin(ServerHttpRequest request) {
return true;
}
......
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
* An interceptor to check request {@code Origin} header value against a collection of
* allowed origins.
*
* @author Sebastien Deleuze
* @since 4.1.2
*/
public class OriginHandshakeInterceptor implements HandshakeInterceptor {
protected Log logger = LogFactory.getLog(getClass());
private final List<String> allowedOrigins;
/**
* Default constructor with no origin allowed.
*/
public OriginHandshakeInterceptor() {
this.allowedOrigins = new ArrayList<String>();
}
/**
* Use this property to define a collection of allowed origins.
*/
public void setAllowedOrigins(Collection<String> allowedOrigins) {
this.allowedOrigins.clear();
if (allowedOrigins != null) {
this.allowedOrigins.addAll(allowedOrigins);
}
}
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if(!isValidOrigin(request)) {
response.setStatusCode(HttpStatus.FORBIDDEN);
if (logger.isDebugEnabled()) {
logger.debug("Handshake request rejected, Origin header value "
+ request.getHeaders().getOrigin() + " not allowed");
}
return false;
}
return true;
}
protected boolean isValidOrigin(ServerHttpRequest request) {
return this.allowedOrigins.contains(request.getHeaders().getOrigin());
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) {
}
}
......@@ -18,7 +18,9 @@ package org.springframework.web.socket.sockjs.support;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
......@@ -44,6 +46,7 @@ import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.TransportType;
/**
* An abstract base class for {@link SockJsService} implementations that provides SockJS
......@@ -51,6 +54,7 @@ import org.springframework.web.socket.sockjs.SockJsService;
* etc). Sub-classes must handle session URLs (i.e. transport-specific requests).
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
* @since 4.0
*/
public abstract class AbstractSockJsService implements SockJsService {
......@@ -82,6 +86,8 @@ public abstract class AbstractSockJsService implements SockJsService {
private boolean webSocketEnabled = true;
private final List<String> allowedOrigins = new ArrayList<String>(Arrays.asList("*"));
public AbstractSockJsService(TaskScheduler scheduler) {
Assert.notNull(scheduler, "TaskScheduler must not be null");
......@@ -258,6 +264,34 @@ public abstract class AbstractSockJsService implements SockJsService {
return this.webSocketEnabled;
}
/**
* Configure allowed {@code Origin} header values. This check is mostly designed for browser
* clients. There is noting preventing other types of client to modify the Origin header value.
*
* <p>When SockJS is enabled and allowed origins are restricted, transport types that do not
* use {@code Origin} headers for cross origin requests (jsonp-polling, iframe-xhr-polling,
* iframe-eventsource and iframe-htmlfile) are disabled. As a consequence, IE6/IE7 won't be
* supported anymore and IE8/IE9 will only be supported without cookies.
*
* <p>By default, all origins are allowed.
*
* @since 4.1.2
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
*/
public void setAllowedOrigins(List<String> allowedOrigins) {
this.allowedOrigins.clear();
if (allowedOrigins != null) {
this.allowedOrigins.addAll(allowedOrigins);
}
}
/**
* @since 4.1.2
* @see #setAllowedOrigins(List)
*/
public List<String> getAllowedOrigins() {
return Collections.unmodifiableList(allowedOrigins);
}
/**
* This method determines the SockJS path and handles SockJS static URLs.
......@@ -325,6 +359,12 @@ public abstract class AbstractSockJsService implements SockJsService {
response.setStatusCode(HttpStatus.NOT_FOUND);
return;
}
else if(!this.allowedOrigins.contains("*") && !TransportType.fromValue(transport).supportsOrigin()) {
logger.debug("Origin check has been enabled, but this transport does not support it, ignoring "
+ requestInfo);
response.setStatusCode(HttpStatus.NOT_FOUND);
return;
}
handleTransportRequest(request, response, wsHandler, sessionId, transport);
}
response.close();
......@@ -360,23 +400,43 @@ public abstract class AbstractSockJsService implements SockJsService {
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) {
/**
* Check the {@code Origin} header value and eventually call {@link #addCorsHeaders(ServerHttpRequest, ServerHttpResponse, HttpMethod...)}.
* If the request origin is not allowed, the request is rejected.
* @return false if the request is rejected, else true
* @since 4.1.2
*/
protected boolean checkAndAddCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) {
HttpHeaders requestHeaders = request.getHeaders();
HttpHeaders responseHeaders = response.getHeaders();
String origin = requestHeaders.getOrigin();
if(!this.allowedOrigins.contains("*") && (origin == null || !this.allowedOrigins.contains(origin))) {
logger.debug("Request rejected, Origin header value " + origin + " not allowed");
response.setStatusCode(HttpStatus.FORBIDDEN);
return false;
}
boolean hasCorsResponseHeaders = false;
try {
// Perhaps a CORS Filter has already added this?
if (!CollectionUtils.isEmpty(responseHeaders.get("Access-Control-Allow-Origin"))) {
return;
}
hasCorsResponseHeaders = !CollectionUtils.isEmpty(responseHeaders.get("Access-Control-Allow-Origin"));
}
catch (NullPointerException npe) {
// See SPR-11919 and https://issues.jboss.org/browse/WFLY-3474
}
String origin = requestHeaders.getFirst("origin");
origin = (origin == null || origin.equals("null") ? "*" : origin);
responseHeaders.add("Access-Control-Allow-Origin", origin);
if(origin != null && !hasCorsResponseHeaders) {
addCorsHeaders(request, response, httpMethods);
}
return true;
}
protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) {
HttpHeaders requestHeaders = request.getHeaders();
HttpHeaders responseHeaders = response.getHeaders();
responseHeaders.add("Access-Control-Allow-Origin", requestHeaders.getFirst("Origin"));
responseHeaders.add("Access-Control-Allow-Credentials", "true");
List<String> accessControllerHeaders = requestHeaders.get("Access-Control-Request-Headers");
......@@ -424,16 +484,19 @@ public abstract class AbstractSockJsService implements SockJsService {
@Override
public void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (HttpMethod.GET.equals(request.getMethod())) {
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
addCorsHeaders(request, response);
addNoCacheHeaders(response);
String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
response.getBody().write(content.getBytes());
if(checkAndAddCorsHeaders(request, response)) {
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
response.getBody().write(content.getBytes());
}
}
else if (HttpMethod.OPTIONS.equals(request.getMethod())) {
response.setStatusCode(HttpStatus.NO_CONTENT);
addCorsHeaders(request, response, HttpMethod.OPTIONS, HttpMethod.GET);
addCacheHeaders(response);
if(checkAndAddCorsHeaders(request, response, HttpMethod.OPTIONS,
HttpMethod.GET)) {
addCacheHeaders(response);
response.setStatusCode(HttpStatus.NO_CONTENT);
}
}
else {
sendMethodNotAllowed(response, HttpMethod.OPTIONS, HttpMethod.GET);
......
......@@ -207,9 +207,10 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem
HttpMethod supportedMethod = transportType.getHttpMethod();
if (!supportedMethod.equals(request.getMethod())) {
if (HttpMethod.OPTIONS.equals(request.getMethod()) && transportType.supportsCors()) {
response.setStatusCode(HttpStatus.NO_CONTENT);
addCorsHeaders(request, response, HttpMethod.OPTIONS, supportedMethod);
addCacheHeaders(response);
if(checkAndAddCorsHeaders(request, response, HttpMethod.OPTIONS, supportedMethod)) {
response.setStatusCode(HttpStatus.NO_CONTENT);
addCacheHeaders(response);
}
}
else if (transportType.supportsCors()) {
sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS);
......@@ -250,7 +251,9 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem
}
if (transportType.supportsCors()) {
addCorsHeaders(request, response);
if(!checkAndAddCorsHeaders(request, response)) {
return;
}
}
transportHandler.handleRequest(request, response, handler, session);
......
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -28,11 +28,12 @@ import org.springframework.http.HttpMethod;
* SockJS transport types.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
* @since 4.0
*/
public enum TransportType {
WEBSOCKET("websocket", HttpMethod.GET),
WEBSOCKET("websocket", HttpMethod.GET, "origin"),
XHR("xhr", HttpMethod.POST, "cors", "jsessionid", "no_cache"),
......@@ -91,6 +92,10 @@ public enum TransportType {
return this.headerHints.contains("cors");
}
public boolean supportsOrigin() {
return this.headerHints.contains("cors") || this.headerHints.contains("origin");
}
public boolean sendsSessionCookie() {
return this.headerHints.contains("jsessionid");
}
......
......@@ -37,7 +37,7 @@ import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSe
import org.springframework.web.util.JavaScriptUtils;
/**
* An HTTP {@link TransportHandler} that uses a famous browsder document.domain technique:
* An HTTP {@link TransportHandler} that uses a famous browser document.domain technique:
* <a href="http://stackoverflow.com/questions/1481251/what-does-document-domain-document-domain-do">
* http://stackoverflow.com/questions/1481251/what-does-document-domain-document-domain-do</a>
*
......
......@@ -474,6 +474,24 @@
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="allowed-origins" type="xsd:string">
<xsd:annotation>
<xsd:documentation><![CDATA[
Configure allowed {@code Origin} header values. Multiple origins may be specified
as a comma-separated list.
This check is mostly designed for browser clients. There is noting preventing other
types of client to modify the Origin header value.
When SockJS is enabled and allowed origins are restricted, transport types that do not
use {@code Origin} headers for cross origin requests (jsonp-polling, iframe-xhr-polling,
iframe-eventsource and iframe-htmlfile) are disabled. As a consequence, IE6/IE7 won't be
supported anymore and IE8/IE9 will only be supported without cookies.
By default, all origins are allowed.
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>
</xsd:complexType>
</xsd:element>
......@@ -641,6 +659,24 @@
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="allowed-origins" type="xsd:string">
<xsd:annotation>
<xsd:documentation><![CDATA[
Configure allowed {@code Origin} header values. Multiple origins may be specified
as a comma-separated list.
This check is mostly designed for browser clients. There is noting preventing other
types of client to modify the Origin header value.
When SockJS is enabled and allowed origins are restricted, transport types that do not
use {@code Origin} headers for cross origin requests (jsonp-polling, iframe-xhr-polling,
iframe-eventsource and iframe-htmlfile) are disabled. As a consequence, IE6/IE7 won't be
supported anymore and IE8/IE9 will only be supported without cookies.
By default, all origins are allowed.
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>
</xsd:complexType>
</xsd:element>
<xsd:choice>
......
......@@ -54,6 +54,10 @@ public abstract class AbstractHttpRequestTests {
this.servletRequest.setRequestURI(requestUri);
}
protected void setOrigin(String origin) {
this.servletRequest.addHeader("Origin", origin);
}
protected void resetRequestAndResponse() {
resetRequest();
resetResponse();
......
......@@ -18,11 +18,13 @@ package org.springframework.web.socket.config;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledFuture;
import static org.junit.Assert.assertEquals;
import org.junit.Before;
import org.junit.Test;
......@@ -45,6 +47,7 @@ import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
......@@ -103,6 +106,7 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty());
}
else {
assertThat(shm.getUrlMap().keySet(), contains("/test"));
......@@ -112,6 +116,7 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty());
}
}
}
......@@ -135,7 +140,8 @@ public class HandlersBeanDefinitionParserTests {
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof TestHandshakeHandler);
List<HandshakeInterceptor> interceptors = handler.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
handler = (WebSocketHttpRequestHandler) urlHandlerMapping.getUrlMap().get("/test");
assertNotNull(handler);
......@@ -144,8 +150,8 @@ public class HandlersBeanDefinitionParserTests {
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof TestHandshakeHandler);
interceptors = handler.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
}
@Test
......@@ -222,6 +228,10 @@ public class HandlersBeanDefinitionParserTests {
assertEquals(1024, transportService.getHttpMessageCacheSize());
assertEquals(20, transportService.getHeartbeatTime());
assertEquals(TestMessageCodec.class, transportService.getMessageCodec().getClass());
List<HandshakeInterceptor> interceptors = transportService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins());
}
private void loadBeanDefinitions(String fileName) {
......
......@@ -68,6 +68,7 @@ import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportType;
......@@ -115,7 +116,8 @@ public class MessageBrokerBeanDefinitionParserTests {
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof TestHandshakeHandler);
List<HandshakeInterceptor> interceptors = wsHttpRequestHandler.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
WebSocketSession session = new TestWebSocketSession("id");
wsHttpRequestHandler.getWebSocketHandler().afterConnectionEstablished(session);
......@@ -158,7 +160,9 @@ public class MessageBrokerBeanDefinitionParserTests {
assertTrue(scheduler.getScheduledThreadPoolExecutor().getRemoveOnCancelPolicy());
interceptors = defaultSockJsService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins());
UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class);
assertNotNull(userSessionRegistry);
......
......@@ -29,6 +29,7 @@ import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
......@@ -70,19 +71,60 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler());
assertTrue(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().isEmpty());
assertEquals(Arrays.asList("/foo"), entry.getValue());
}
@Test
public void handshakeHandlerAndInterceptors() {
public void allowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins("http://mydomain.com");
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertEquals(1, requestHandler.getHandshakeInterceptors().size());
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
}
@Test
public void allowedOriginsWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
String origin = "http://mydomain.com";
registration.setAllowedOrigins(origin).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getSockJsService());
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.withSockJS().setAllowedOrigins(origin);
mappings = registration.getMappings();
assertEquals(1, mappings.size());
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getSockJsService());
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
}
@Test
public void handshakeHandlerAndInterceptor() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
registration.setHandshakeHandler(handshakeHandler);
registration.addInterceptors(interceptor);
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
......@@ -97,16 +139,38 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
}
@Test
public void handshakeHandlerAndInterceptorsWithSockJsService() {
public void handshakeHandlerAndInterceptorWithAllowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
String origin = "http://mydomain.com";
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
registration.setHandshakeHandler(handshakeHandler);
registration.addInterceptors(interceptor);
registration.withSockJS();
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertEquals(Arrays.asList("/foo"), entry.getValue());
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
assertEquals(2, requestHandler.getHandshakeInterceptors().size());
assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
}
@Test
public void handshakeHandlerInterceptorWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
......@@ -126,4 +190,37 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors());
}
@Test
public void handshakeHandlerInterceptorWithSockJsServiceAndAllowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
String origin = "http://mydomain.com";
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertEquals(Arrays.asList("/foo/**"), entry.getValue());
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey();
assertNotNull(requestHandler.getWebSocketHandler());
DefaultSockJsService sockJsService = (DefaultSockJsService) requestHandler.getSockJsService();
assertNotNull(sockJsService);
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
assertEquals(2, sockJsService.getHandshakeInterceptors().size());
assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class,
sockJsService.getHandshakeInterceptors().get(1).getClass());
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
}
}
......@@ -29,6 +29,7 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.sockjs.SockJsService;
......@@ -68,10 +69,12 @@ public class WebSocketHandlerRegistrationTests {
Mapping m1 = mappings.get(0);
assertEquals(handler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
assertEquals(0, m1.interceptors.length);
Mapping m2 = mappings.get(1);
assertEquals(handler, m2.webSocketHandler);
assertEquals("/bar", m2.path);
assertEquals(0, m2.interceptors.length);
}
@Test
......@@ -90,12 +93,31 @@ public class WebSocketHandlerRegistrationTests {
assertArrayEquals(new HandshakeInterceptor[] {interceptor}, mapping.interceptors);
}
@Test
public void interceptorsWithAllowedOrigins() {
WebSocketHandler handler = new TextWebSocketHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins("http://mydomain1.com");
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
}
@Test
public void interceptorsPassedToSockJsRegistration() {
WebSocketHandler handler = new TextWebSocketHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).withSockJS();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor)
.setAllowedOrigins("http://mydomain1.com").withSockJS();
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
......@@ -104,7 +126,11 @@ public class WebSocketHandlerRegistrationTests {
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo/**", mapping.path);
assertNotNull(mapping.sockJsService);
assertEquals(Arrays.asList(interceptor), mapping.sockJsService.getHandshakeInterceptors());
assertEquals(Arrays.asList("http://mydomain1.com"),
mapping.sockJsService.getAllowedOrigins());
List<HandshakeInterceptor> interceptors = mapping.sockJsService.getHandshakeInterceptors();
assertEquals(interceptor, interceptors.get(0));
assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass());
}
@Test
......
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import static org.junit.Assert.*;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.http.HttpStatus;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
/**
* Test fixture for {@link OriginHandshakeInterceptor}.
*
* @author Sebastien Deleuze
*/
public class AllowedOriginsInterceptorTests extends AbstractHttpRequestTests {
@Test
public void originValueMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void originValueNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain2.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void originListMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void originListNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void noOriginNoMatchWithNullHostileCollection() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
Set<String> allowedOrigins = new ConcurrentSkipListSet<String>();
allowedOrigins.add("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void noOriginNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
}
......@@ -17,9 +17,12 @@
package org.springframework.web.socket.sockjs.support;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import static org.junit.Assert.assertEquals;
import org.junit.Before;
import org.junit.Test;
......@@ -40,9 +43,12 @@ import static org.mockito.BDDMockito.*;
* Test fixture for {@link AbstractSockJsService}.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
*/
public class SockJsServiceTests extends AbstractHttpRequestTests {
private static final List<String> origins = Arrays.asList("http://mydomain1.com", "http://mydomain2.com");
private TestSockJsService service;
private WebSocketHandler handler;
......@@ -80,10 +86,10 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType());
assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader("Cache-Control"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
String body = this.servletResponse.getContentAsString();
assertEquals("{\"entropy\"", body.substring(0, body.indexOf(':')));
......@@ -97,6 +103,47 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
body = this.servletResponse.getContentAsString();
assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":false,\"websocket\":false}",
body.substring(body.indexOf(',')));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
}
@Test // SPR-12226
public void handleInfoGetWithOrigin() throws Exception {
setOrigin("http://mydomain2.com");
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType());
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader("Cache-Control"));
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
String body = this.servletResponse.getContentAsString();
assertEquals("{\"entropy\"", body.substring(0, body.indexOf(':')));
assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":true}",
body.substring(body.indexOf(',')));
this.service.setAllowedOrigins(null);
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
}
@Test // SPR-11443
......@@ -129,7 +176,60 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertNull(this.servletResponse.getHeader("Vary"));
}
@Test // SPR-12226
public void handleInfoOptionsWithOrigin() throws Exception {
setOrigin("http://mydomain2.com");
this.servletRequest.addHeader("Access-Control-Request-Headers", "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(null);
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
......
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -16,7 +16,9 @@
package org.springframework.web.socket.sockjs.transport.handler;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.junit.Before;
......@@ -27,6 +29,8 @@ import org.mockito.MockitoAnnotations;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.sockjs.transport.SockJsSessionFactory;
import org.springframework.web.socket.sockjs.transport.TransportHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;
......@@ -41,6 +45,7 @@ import static org.mockito.BDDMockito.*;
* Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
*/
public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
......@@ -50,11 +55,19 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
private static final String sessionUrlPrefix = "/server1/" + sessionId + "/";
private static final List<String> origins = Arrays.asList("http://mydomain1.com", "http://mydomain2.com");
@Mock private SessionCreatingTransportHandler xhrHandler;
@Mock private TransportHandler xhrSendHandler;
@Mock private SessionCreatingTransportHandler jsonpHandler;
@Mock private TransportHandler jsonpSendHandler;
@Mock private HandshakeTransportHandler wsTransportHandler;
@Mock private WebSocketHandler wsHandler;
@Mock private TaskScheduler taskScheduler;
......@@ -75,6 +88,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
given(this.xhrHandler.getTransportType()).willReturn(TransportType.XHR);
given(this.xhrHandler.createSession(sessionId, this.wsHandler, attributes)).willReturn(this.session);
given(this.xhrSendHandler.getTransportType()).willReturn(TransportType.XHR_SEND);
given(this.jsonpHandler.getTransportType()).willReturn(TransportType.JSONP);
given(this.jsonpHandler.createSession(sessionId, this.wsHandler, attributes)).willReturn(this.session);
given(this.jsonpSendHandler.getTransportType()).willReturn(TransportType.JSONP_SEND);
given(this.wsTransportHandler.getTransportType()).willReturn(TransportType.WEBSOCKET);
this.service = new TransportHandlingSockJsService(this.taskScheduler, this.xhrHandler, this.xhrSendHandler);
}
......@@ -126,10 +143,47 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
verify(taskScheduler).scheduleAtFixedRate(any(Runnable.class), eq(service.getDisconnectDelay()));
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.response.getHeaders().getCacheControl());
assertEquals("*", this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test // SPR-12226
public void handleTransportRequestXhrAllowNullOrigin() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("POST", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(null);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test // SPR-12226
public void handleTransportRequestXhrAllowedOriginsMatch() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("POST", sockJsPrefix + sockJsPath);
setOrigin(origins.get(0));
this.service.setAllowedOrigins(origins);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(200, this.servletResponse.getStatus());
assertEquals(origins.get(0), this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertEquals("true", this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test // SPR-12226
public void handleTransportRequestXhrAllowedOriginsNoMatch() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("POST", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain3.com");
this.service.setAllowedOrigins(origins);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(403, this.servletResponse.getStatus());
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test
public void handleTransportRequestXhrOptions() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
......@@ -137,9 +191,22 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(204, this.servletResponse.getStatus());
assertEquals("*", this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertEquals("true", this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
assertEquals("OPTIONS, POST", this.response.getHeaders().getFirst("Access-Control-Allow-Methods"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Methods"));
}
@Test // SPR-12226
public void handleTransportRequestXhrOptionsAllowNullOrigin() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("OPTIONS", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(null);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(403, this.servletResponse.getStatus());
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Methods"));
}
@Test
......@@ -176,8 +243,56 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
verify(this.xhrSendHandler).handleRequest(this.request, this.response, this.wsHandler, this.session);
}
@Test
public void handleTransportRequestJsonp() throws Exception {
TransportHandlingSockJsService jsonpService = new TransportHandlingSockJsService(this.taskScheduler, this.jsonpHandler, this.jsonpSendHandler);
String sockJsPath = sessionUrlPrefix+ "jsonp";
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(null);
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
}
@Test
public void handleTransportRequestWebsocket() throws Exception {
TransportHandlingSockJsService wsService = new TransportHandlingSockJsService(this.taskScheduler, this.wsTransportHandler);
String sockJsPath = "/websocket";
setRequest("GET", sockJsPrefix + sockJsPath);
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain1.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain2.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(403, this.servletResponse.getStatus());
}
interface SessionCreatingTransportHandler extends TransportHandler, SockJsSessionFactory {
}
interface HandshakeTransportHandler extends TransportHandler, HandshakeHandler {
}
}
......@@ -15,7 +15,7 @@
</websocket:decorator-factories>
</websocket:transport>
<websocket:stomp-endpoint path=" /foo,/bar">
<websocket:stomp-endpoint path=" /foo,/bar" allowed-origins="http://mydomain1.com,http://mydomain2.com">
<websocket:handshake-handler ref="myHandler"/>
<websocket:handshake-interceptors>
<bean class="org.springframework.web.socket.config.FooTestInterceptor"/>
......@@ -23,7 +23,7 @@
</websocket:handshake-interceptors>
</websocket:stomp-endpoint>
<websocket:stomp-endpoint path="/test,/sockjs">
<websocket:stomp-endpoint path="/test,/sockjs" allowed-origins="http://mydomain3.com,http://mydomain4.com">
<websocket:handshake-handler ref="myHandler"/>
<websocket:handshake-interceptors>
<bean class="org.springframework.web.socket.config.FooTestInterceptor"/>
......
......@@ -5,7 +5,7 @@
http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd
http://www.springframework.org/schema/websocket http://www.springframework.org/schema/websocket/spring-websocket.xsd">
<websocket:handlers order="2">
<websocket:handlers order="2" allowed-origins="http://mydomain1.com, http://mydomain2.com">
<websocket:mapping path="/foo" handler="fooHandler"/>
<websocket:mapping path="/test" handler="testHandler"/>
<websocket:handshake-handler ref="testHandshakeHandler"/>
......
......@@ -5,7 +5,7 @@
http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd
http://www.springframework.org/schema/websocket http://www.springframework.org/schema/websocket/spring-websocket.xsd">
<websocket:handlers >
<websocket:handlers allowed-origins="http://mydomain1.com, http://mydomain2.com">
<websocket:mapping path="/test" handler="testHandler"/>
<websocket:sockjs name="testSockJsService" scheduler="testTaskScheduler" websocket-enabled="false"
session-cookie-needed="false" stream-bytes-limit="2048" disconnect-delay="256"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册