提交 e083683f 编写于 作者: R Rossen Stoyanchev

Update WebSocket support for Jetty 9.3

Issue: SPR-13140
上级 25ff34f3
......@@ -46,7 +46,7 @@ configure(allprojects) { project ->
ext.jackson2Version = "2.6.0-rc2" // to be upgraded to 2.6 final in time for Spring Framework 4.2 GA
ext.jasperreportsVersion = "6.1.0"
ext.javamailVersion = "1.5.3"
ext.jettyVersion = "9.2.11.v20150529"
ext.jettyVersion = "9.3.0.v20150612"
ext.jodaVersion = "2.8.1"
ext.jrubyVersion = "1.7.20"
ext.jtaVersion = "1.2"
......
......@@ -33,8 +33,8 @@ import org.springframework.beans.factory.xml.BeanDefinitionParser;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHandlerMapping;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
......@@ -64,7 +64,7 @@ class HandlersBeanDefinitionParser implements BeanDefinitionParser {
String orderAttribute = element.getAttribute("order");
int order = orderAttribute.isEmpty() ? DEFAULT_MAPPING_ORDER : Integer.valueOf(orderAttribute);
RootBeanDefinition handlerMappingDef = new RootBeanDefinition(SimpleUrlHandlerMapping.class);
RootBeanDefinition handlerMappingDef = new RootBeanDefinition(WebSocketHandlerMapping.class);
handlerMappingDef.setSource(source);
handlerMappingDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
handlerMappingDef.getPropertyValues().add("order", order);
......
......@@ -29,6 +29,7 @@ import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.WebSocketHandlerMapping;
import org.springframework.web.util.UrlPathHelper;
/**
......@@ -101,7 +102,7 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry
}
}
}
SimpleUrlHandlerMapping hm = new SimpleUrlHandlerMapping();
WebSocketHandlerMapping hm = new WebSocketHandlerMapping();
hm.setUrlMap(urlMap);
hm.setOrder(this.order);
if (this.urlPathHelper != null) {
......
......@@ -17,11 +17,13 @@
package org.springframework.web.socket.server.jetty;
import java.io.IOException;
import java.lang.reflect.Method;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
......@@ -42,7 +44,9 @@ import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.jetty.JettyWebSocketHandlerAdapter;
......@@ -59,7 +63,11 @@ import org.springframework.web.socket.server.RequestUpgradeStrategy;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle {
public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle, ServletContextAware {
// Pre-Jetty 9.3 init method without ServletContext
private static final Method webSocketFactoryInitMethod =
ClassUtils.getMethodIfAvailable(WebSocketServerFactory.class, "init");
private static final ThreadLocal<WebSocketHandlerContainer> wsContainerHolder =
new NamedThreadLocal<WebSocketHandlerContainer>("WebSocket Handler Container");
......@@ -69,6 +77,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
private volatile List<WebSocketExtension> supportedExtensions;
private ServletContext servletContext;
private volatile boolean running = false;
......@@ -94,7 +104,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
// Cast to avoid infinite recursion
return createWebSocket((UpgradeRequest) request, (UpgradeResponse) response);
}
// For Jetty 9.0.x
public Object createWebSocket(UpgradeRequest request, UpgradeResponse response) {
WebSocketHandlerContainer container = wsContainerHolder.get();
......@@ -128,6 +137,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
return result;
}
@Override
public void setServletContext(ServletContext servletContext) {
this.servletContext = servletContext;
}
@Override
public boolean isRunning() {
return this.running;
......@@ -139,7 +153,12 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
if (!isRunning()) {
this.running = true;
try {
this.factory.init();
if (webSocketFactoryInitMethod != null) {
webSocketFactoryInitMethod.invoke(this.factory);
}
else {
this.factory.init(this.servletContext);
}
}
catch (Exception ex) {
throw new IllegalStateException("Unable to initialize Jetty WebSocketServerFactory", ex);
......
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.Lifecycle;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
/**
* A base class to use for {@link HandshakeHandler} implementations.
* Performs initial validation of the WebSocket handshake request -- possibly rejecting it
* through the appropriate HTTP status code -- while also allowing sub-classes to override
* various parts of the negotiation process (e.g. origin validation, sub-protocol negotiation,
* extensions negotiation, etc).
*
* <p>If the negotiation succeeds, the actual upgrade is delegated to a server-specific
* {@link RequestUpgradeStrategy}, which will update
* the response as necessary and initialize the WebSocket. Currently supported servers are
* Tomcat 7 and 8, Jetty 9, and GlassFish 4.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle {
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
private static final boolean jettyWsPresent = ClassUtils.isPresent(
"org.eclipse.jetty.websocket.server.WebSocketServerFactory", AbstractHandshakeHandler.class.getClassLoader());
private static final boolean tomcatWsPresent = ClassUtils.isPresent(
"org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", AbstractHandshakeHandler.class.getClassLoader());
private static final boolean undertowWsPresent = ClassUtils.isPresent(
"io.undertow.websockets.jsr.ServerWebSocketContainer", AbstractHandshakeHandler.class.getClassLoader());
private static final boolean glassFishWsPresent = ClassUtils.isPresent(
"org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", AbstractHandshakeHandler.class.getClassLoader());
private static final boolean webLogicWsPresent = ClassUtils.isPresent(
"weblogic.websocket.tyrus.TyrusServletWriter", AbstractHandshakeHandler.class.getClassLoader());
protected final Log logger = LogFactory.getLog(getClass());
private final RequestUpgradeStrategy requestUpgradeStrategy;
private final List<String> supportedProtocols = new ArrayList<String>();
private volatile boolean running = false;
/**
* Default constructor that auto-detects and instantiates a
* {@link RequestUpgradeStrategy} suitable for the runtime container.
* @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found.
*/
protected AbstractHandshakeHandler() {
this(initRequestUpgradeStrategy());
}
/**
* A constructor that accepts a runtime-specific {@link RequestUpgradeStrategy}.
* @param requestUpgradeStrategy the upgrade strategy to use
*/
protected AbstractHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) {
Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy must not be null");
this.requestUpgradeStrategy = requestUpgradeStrategy;
}
private static RequestUpgradeStrategy initRequestUpgradeStrategy() {
String className;
if (tomcatWsPresent) {
className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy";
}
else if (jettyWsPresent) {
className = "org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy";
}
else if (undertowWsPresent) {
className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy";
}
else if (glassFishWsPresent) {
className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy";
}
else if (webLogicWsPresent) {
className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy";
}
else {
throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
}
try {
Class<?> clazz = ClassUtils.forName(className, AbstractHandshakeHandler.class.getClassLoader());
return (RequestUpgradeStrategy) clazz.newInstance();
}
catch (Throwable ex) {
throw new IllegalStateException("Failed to instantiate RequestUpgradeStrategy: " + className, ex);
}
}
/**
* Return the {@link RequestUpgradeStrategy} for WebSocket requests.
*/
public RequestUpgradeStrategy getRequestUpgradeStrategy() {
return this.requestUpgradeStrategy;
}
/**
* Use this property to configure the list of supported sub-protocols.
* The first configured sub-protocol that matches a client-requested sub-protocol
* is accepted. If there are no matches the response will not contain a
* {@literal Sec-WebSocket-Protocol} header.
* <p>Note that if the WebSocketHandler passed in at runtime is an instance of
* {@link SubProtocolCapable} then there is not need to explicitly configure
* this property. That is certainly the case with the built-in STOMP over
* WebSocket support. Therefore this property should be configured explicitly
* only if the WebSocketHandler does not implement {@code SubProtocolCapable}.
*/
public void setSupportedProtocols(String... protocols) {
this.supportedProtocols.clear();
for (String protocol : protocols) {
this.supportedProtocols.add(protocol.toLowerCase());
}
}
/**
* Return the list of supported sub-protocols.
*/
public String[] getSupportedProtocols() {
return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]);
}
@Override
public boolean isRunning() {
return this.running;
}
@Override
public void start() {
if (!isRunning()) {
this.running = true;
doStart();
}
}
protected void doStart() {
if (this.requestUpgradeStrategy instanceof Lifecycle) {
((Lifecycle) this.requestUpgradeStrategy).start();
}
}
@Override
public void stop() {
if (isRunning()) {
this.running = false;
doStop();
}
}
protected void doStop() {
if (this.requestUpgradeStrategy instanceof Lifecycle) {
((Lifecycle) this.requestUpgradeStrategy).stop();
}
}
@Override
public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Processing request " + request.getURI() + " with headers=" + headers);
}
try {
if (!HttpMethod.GET.equals(request.getMethod())) {
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET));
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod());
}
return false;
}
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
handleInvalidUpgradeHeader(request, response);
return false;
}
if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
handleInvalidConnectHeader(request, response);
return false;
}
if (!isWebSocketVersionSupported(headers)) {
handleWebSocketVersionNotSupported(request, response);
return false;
}
if (!isValidOrigin(request)) {
response.setStatusCode(HttpStatus.FORBIDDEN);
return false;
}
String wsKey = headers.getSecWebSocketKey();
if (wsKey == null) {
if (logger.isErrorEnabled()) {
logger.error("Missing \"Sec-WebSocket-Key\" header");
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
}
}
catch (IOException ex) {
throw new HandshakeFailureException(
"Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex);
}
String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler);
List<WebSocketExtension> requested = headers.getSecWebSocketExtensions();
List<WebSocketExtension> supported = this.requestUpgradeStrategy.getSupportedExtensions(request);
List<WebSocketExtension> extensions = filterRequestedExtensions(request, requested, supported);
Principal user = determineUser(request, wsHandler, attributes);
if (logger.isTraceEnabled()) {
logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions);
}
this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes);
return true;
}
protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to invalid Upgrade header: " + request.getHeaders().getUpgrade());
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(UTF8_CHARSET));
}
protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to invalid Connection header " + request.getHeaders().getConnection());
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes(UTF8_CHARSET));
}
protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) {
String version = httpHeaders.getSecWebSocketVersion();
String[] supportedVersions = getSupportedVersions();
for (String supportedVersion : supportedVersions) {
if (supportedVersion.trim().equals(version)) {
return true;
}
}
return false;
}
protected String[] getSupportedVersions() {
return this.requestUpgradeStrategy.getSupportedVersions();
}
protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) {
if (logger.isErrorEnabled()) {
String version = request.getHeaders().getFirst("Sec-WebSocket-Version");
logger.error("Handshake failed due to unsupported WebSocket version: " + version +
". Supported versions: " + Arrays.toString(getSupportedVersions()));
}
response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
response.getHeaders().put(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION,
Arrays.asList(StringUtils.arrayToCommaDelimitedString(getSupportedVersions())));
}
/**
* Return whether the request {@code Origin} header value is valid or not.
* By default, all origins as considered as valid. Consider using an
* {@link OriginHandshakeInterceptor} for filtering origins if needed.
*/
protected boolean isValidOrigin(ServerHttpRequest request) {
return true;
}
/**
* Perform the sub-protocol negotiation based on requested and supported sub-protocols.
* For the list of supported sub-protocols, this method first checks if the target
* WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any
* sub-protocols have been explicitly configured with
* {@link #setSupportedProtocols(String...)}.
* @param requestedProtocols the requested sub-protocols
* @param webSocketHandler the WebSocketHandler that will be used
* @return the selected protocols or {@code null}
* @see #determineHandlerSupportedProtocols(WebSocketHandler)
*/
protected String selectProtocol(List<String> requestedProtocols, WebSocketHandler webSocketHandler) {
if (requestedProtocols != null) {
List<String> handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler);
for (String protocol : requestedProtocols) {
if (handlerProtocols.contains(protocol.toLowerCase())) {
return protocol;
}
if (this.supportedProtocols.contains(protocol.toLowerCase())) {
return protocol;
}
}
}
return null;
}
/**
* Determine the sub-protocols supported by the given WebSocketHandler by
* checking whether it is an instance of {@link SubProtocolCapable}.
* @param handler the handler to check
* @return a list of supported protocols, or an empty list if none available
*/
protected final List<String> determineHandlerSupportedProtocols(WebSocketHandler handler) {
WebSocketHandler handlerToCheck = WebSocketHandlerDecorator.unwrap(handler);
List<String> subProtocols = null;
if (handlerToCheck instanceof SubProtocolCapable) {
subProtocols = ((SubProtocolCapable) handlerToCheck).getSubProtocols();
}
return (subProtocols != null ? subProtocols : Collections.<String>emptyList());
}
/**
* Filter the list of requested WebSocket extensions.
* <p>As of 4.1 the default implementation of this method filters the list to
* leave only extensions that are both requested and supported.
* @param request the current request
* @param requestedExtensions the list of extensions requested by the client
* @param supportedExtensions the list of extensions supported by the server
* @return the selected extensions or an empty list
*/
protected List<WebSocketExtension> filterRequestedExtensions(ServerHttpRequest request,
List<WebSocketExtension> requestedExtensions, List<WebSocketExtension> supportedExtensions) {
List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(requestedExtensions.size());
for (WebSocketExtension extension : requestedExtensions) {
if (supportedExtensions.contains(extension)) {
result.add(extension);
}
}
return result;
}
/**
* A method that can be used to associate a user with the WebSocket session
* in the process of being established. The default implementation calls
* {@link ServerHttpRequest#getPrincipal()}
* <p>Subclasses can provide custom logic for associating a user with a session,
* for example for assigning a name to anonymous users (i.e. not fully
* authenticated).
* @param request the handshake request
* @param wsHandler the WebSocket handler that will handle messages
* @param attributes handshake attributes to pass to the WebSocket session
* @return the user for the WebSocket session, or {@code null} if not available
*/
protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler,
Map<String, Object> attributes) {
return request.getPrincipal();
}
}
......@@ -16,33 +16,9 @@
package org.springframework.web.socket.server.support;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletContext;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.Lifecycle;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
/**
......@@ -60,325 +36,23 @@ import org.springframework.web.socket.server.RequestUpgradeStrategy;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class DefaultHandshakeHandler implements HandshakeHandler, Lifecycle {
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
private static final boolean jettyWsPresent = ClassUtils.isPresent(
"org.eclipse.jetty.websocket.server.WebSocketServerFactory", DefaultHandshakeHandler.class.getClassLoader());
private static final boolean tomcatWsPresent = ClassUtils.isPresent(
"org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader());
private static final boolean undertowWsPresent = ClassUtils.isPresent(
"io.undertow.websockets.jsr.ServerWebSocketContainer", DefaultHandshakeHandler.class.getClassLoader());
private static final boolean glassFishWsPresent = ClassUtils.isPresent(
"org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader());
private static final boolean webLogicWsPresent = ClassUtils.isPresent(
"weblogic.websocket.tyrus.TyrusServletWriter", DefaultHandshakeHandler.class.getClassLoader());
protected final Log logger = LogFactory.getLog(getClass());
public class DefaultHandshakeHandler extends AbstractHandshakeHandler implements ServletContextAware {
private final RequestUpgradeStrategy requestUpgradeStrategy;
private final List<String> supportedProtocols = new ArrayList<String>();
private volatile boolean running = false;
/**
* Default constructor that autodetects and instantiates a
* {@link RequestUpgradeStrategy} suitable for the runtime container.
* @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found.
*/
public DefaultHandshakeHandler() {
this(initRequestUpgradeStrategy());
}
/**
* A constructor that accepts a runtime-specific {@link RequestUpgradeStrategy}.
* @param requestUpgradeStrategy the upgrade strategy to use
*/
public DefaultHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) {
Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy must not be null");
this.requestUpgradeStrategy = requestUpgradeStrategy;
}
private static RequestUpgradeStrategy initRequestUpgradeStrategy() {
String className;
if (jettyWsPresent) {
className = "org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy";
}
else if (tomcatWsPresent) {
className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy";
}
else if (undertowWsPresent) {
className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy";
}
else if (glassFishWsPresent) {
className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy";
}
else if (webLogicWsPresent) {
className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy";
}
else {
throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
}
try {
Class<?> clazz = ClassUtils.forName(className, DefaultHandshakeHandler.class.getClassLoader());
return (RequestUpgradeStrategy) clazz.newInstance();
}
catch (Throwable ex) {
throw new IllegalStateException("Failed to instantiate RequestUpgradeStrategy: " + className, ex);
}
}
/**
* Use this property to configure the list of supported sub-protocols.
* The first configured sub-protocol that matches a client-requested sub-protocol
* is accepted. If there are no matches the response will not contain a
* {@literal Sec-WebSocket-Protocol} header.
* <p>Note that if the WebSocketHandler passed in at runtime is an instance of
* {@link SubProtocolCapable} then there is not need to explicitly configure
* this property. That is certainly the case with the built-in STOMP over
* WebSocket support. Therefore this property should be configured explicitly
* only if the WebSocketHandler does not implement {@code SubProtocolCapable}.
*/
public void setSupportedProtocols(String... protocols) {
this.supportedProtocols.clear();
for (String protocol : protocols) {
this.supportedProtocols.add(protocol.toLowerCase());
}
}
/**
* Return the list of supported sub-protocols.
*/
public String[] getSupportedProtocols() {
return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]);
}
@Override
public boolean isRunning() {
return this.running;
}
@Override
public void start() {
if (!isRunning()) {
this.running = true;
if (this.requestUpgradeStrategy instanceof Lifecycle) {
((Lifecycle) this.requestUpgradeStrategy).start();
}
}
}
@Override
public void stop() {
if (isRunning()) {
this.running = false;
if (this.requestUpgradeStrategy instanceof Lifecycle) {
((Lifecycle) this.requestUpgradeStrategy).stop();
}
}
super(requestUpgradeStrategy);
}
@Override
public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Processing request " + request.getURI() + " with headers=" + headers);
}
try {
if (!HttpMethod.GET.equals(request.getMethod())) {
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET));
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod());
}
return false;
}
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
handleInvalidUpgradeHeader(request, response);
return false;
}
if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
handleInvalidConnectHeader(request, response);
return false;
}
if (!isWebSocketVersionSupported(headers)) {
handleWebSocketVersionNotSupported(request, response);
return false;
}
if (!isValidOrigin(request)) {
response.setStatusCode(HttpStatus.FORBIDDEN);
return false;
}
String wsKey = headers.getSecWebSocketKey();
if (wsKey == null) {
if (logger.isErrorEnabled()) {
logger.error("Missing \"Sec-WebSocket-Key\" header");
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
}
}
catch (IOException ex) {
throw new HandshakeFailureException(
"Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex);
}
String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler);
List<WebSocketExtension> requested = headers.getSecWebSocketExtensions();
List<WebSocketExtension> supported = this.requestUpgradeStrategy.getSupportedExtensions(request);
List<WebSocketExtension> extensions = filterRequestedExtensions(request, requested, supported);
Principal user = determineUser(request, wsHandler, attributes);
if (logger.isTraceEnabled()) {
logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions);
}
this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes);
return true;
}
protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to invalid Upgrade header: " + request.getHeaders().getUpgrade());
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(UTF8_CHARSET));
}
protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to invalid Connection header " + request.getHeaders().getConnection());
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes(UTF8_CHARSET));
}
protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) {
String version = httpHeaders.getSecWebSocketVersion();
String[] supportedVersions = getSupportedVersions();
for (String supportedVersion : supportedVersions) {
if (supportedVersion.trim().equals(version)) {
return true;
}
}
return false;
}
protected String[] getSupportedVersions() {
return this.requestUpgradeStrategy.getSupportedVersions();
}
protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) {
if (logger.isErrorEnabled()) {
String version = request.getHeaders().getFirst("Sec-WebSocket-Version");
logger.error("Handshake failed due to unsupported WebSocket version: " + version +
". Supported versions: " + Arrays.toString(getSupportedVersions()));
}
response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
response.getHeaders().put(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION,
Arrays.asList(StringUtils.arrayToCommaDelimitedString(getSupportedVersions())));
}
/**
* Return whether the request {@code Origin} header value is valid or not.
* By default, all origins as considered as valid. Consider using an
* {@link OriginHandshakeInterceptor} for filtering origins if needed.
*/
protected boolean isValidOrigin(ServerHttpRequest request) {
return true;
}
/**
* Perform the sub-protocol negotiation based on requested and supported sub-protocols.
* For the list of supported sub-protocols, this method first checks if the target
* WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any
* sub-protocols have been explicitly configured with
* {@link #setSupportedProtocols(String...)}.
* @param requestedProtocols the requested sub-protocols
* @param webSocketHandler the WebSocketHandler that will be used
* @return the selected protocols or {@code null}
* @see #determineHandlerSupportedProtocols(org.springframework.web.socket.WebSocketHandler)
*/
protected String selectProtocol(List<String> requestedProtocols, WebSocketHandler webSocketHandler) {
if (requestedProtocols != null) {
List<String> handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler);
for (String protocol : requestedProtocols) {
if (handlerProtocols.contains(protocol.toLowerCase())) {
return protocol;
}
if (this.supportedProtocols.contains(protocol.toLowerCase())) {
return protocol;
}
}
public void setServletContext(ServletContext servletContext) {
RequestUpgradeStrategy strategy = getRequestUpgradeStrategy();
if (strategy instanceof ServletContextAware) {
((ServletContextAware) strategy).setServletContext(servletContext);
}
return null;
}
/**
* Determine the sub-protocols supported by the given WebSocketHandler by
* checking whether it is an instance of {@link SubProtocolCapable}.
* @param handler the handler to check
* @return a list of supported protocols, or an empty list if none available
*/
protected final List<String> determineHandlerSupportedProtocols(WebSocketHandler handler) {
WebSocketHandler handlerToCheck = WebSocketHandlerDecorator.unwrap(handler);
List<String> subProtocols = null;
if (handlerToCheck instanceof SubProtocolCapable) {
subProtocols = ((SubProtocolCapable) handlerToCheck).getSubProtocols();
}
return (subProtocols != null ? subProtocols : Collections.<String>emptyList());
}
/**
* Filter the list of requested WebSocket extensions.
* <p>As of 4.1 the default implementation of this method filters the list to
* leave only extensions that are both requested and supported.
* @param request the current request
* @param requestedExtensions the list of extensions requested by the client
* @param supportedExtensions the list of extensions supported by the server
* @return the selected extensions or an empty list
*/
protected List<WebSocketExtension> filterRequestedExtensions(ServerHttpRequest request,
List<WebSocketExtension> requestedExtensions, List<WebSocketExtension> supportedExtensions) {
List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(requestedExtensions.size());
for (WebSocketExtension extension : requestedExtensions) {
if (supportedExtensions.contains(extension)) {
result.add(extension);
}
}
return result;
}
/**
* A method that can be used to associate a user with the WebSocket session
* in the process of being established. The default implementation calls
* {@link org.springframework.http.server.ServerHttpRequest#getPrincipal()}
* <p>Subclasses can provide custom logic for associating a user with a session,
* for example for assigning a name to anonymous users (i.e. not fully
* authenticated).
* @param request the handshake request
* @param wsHandler the WebSocket handler that will handle messages
* @param attributes handshake attributes to pass to the WebSocket session
* @return the user for the WebSocket session, or {@code null} if not available
*/
protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler,
Map<String, Object> attributes) {
return request.getPrincipal();
}
}
......@@ -15,8 +15,11 @@
*/
package org.springframework.web.socket.server.support;
import javax.servlet.ServletContext;
import org.springframework.context.Lifecycle;
import org.springframework.context.SmartLifecycle;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
/**
......@@ -33,6 +36,15 @@ public class WebSocketHandlerMapping extends SimpleUrlHandlerMapping implements
private volatile boolean running = false;
@Override
protected void initServletContext(ServletContext servletContext) {
for (Object handler : getUrlMap().values()) {
if (handler instanceof ServletContextAware) {
((ServletContextAware) handler).setServletContext(servletContext);
}
}
}
@Override
public boolean isAutoStartup() {
return true;
......
......@@ -21,6 +21,8 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
......@@ -35,6 +37,7 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
import org.springframework.web.socket.handler.LoggingWebSocketHandlerDecorator;
......@@ -53,7 +56,7 @@ import org.springframework.web.socket.server.HandshakeInterceptor;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycle {
public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycle, ServletContextAware {
private final Log logger = LogFactory.getLog(WebSocketHttpRequestHandler.class);
......@@ -109,6 +112,13 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl
return this.interceptors;
}
@Override
public void setServletContext(ServletContext servletContext) {
if (this.handshakeHandler instanceof ServletContextAware) {
((ServletContextAware) this.handshakeHandler).setServletContext(servletContext);
}
}
@Override
public boolean isRunning() {
return this.running;
......
......@@ -18,6 +18,7 @@ package org.springframework.web.socket.sockjs.support;
import java.io.IOException;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
......@@ -29,6 +30,7 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.servlet.HandlerMapping;
......@@ -47,7 +49,7 @@ import org.springframework.web.socket.sockjs.SockJsService;
* @since 4.0
*/
public class SockJsHttpRequestHandler
implements HttpRequestHandler, CorsConfigurationSource, Lifecycle {
implements HttpRequestHandler, CorsConfigurationSource, Lifecycle, ServletContextAware {
// No logging: HTTP transports too verbose and we don't know enough to log anything of value
......@@ -86,6 +88,13 @@ public class SockJsHttpRequestHandler
return this.webSocketHandler;
}
@Override
public void setServletContext(ServletContext servletContext) {
if (this.sockJsService instanceof ServletContextAware) {
((ServletContextAware) this.sockJsService).setServletContext(servletContext);
}
}
@Override
public boolean isRunning() {
return this.running;
......
......@@ -21,10 +21,13 @@ import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.servlet.ServletContext;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;
......@@ -37,7 +40,7 @@ import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsSe
* @author Juergen Hoeller
* @since 4.0
*/
public class DefaultSockJsService extends TransportHandlingSockJsService {
public class DefaultSockJsService extends TransportHandlingSockJsService implements ServletContextAware {
/**
* Create a DefaultSockJsService with default {@link TransportHandler handler} types.
......@@ -99,4 +102,12 @@ public class DefaultSockJsService extends TransportHandlingSockJsService {
return result;
}
@Override
public void setServletContext(ServletContext servletContext) {
for (TransportHandler handler : getTransportHandlers().values()) {
if (handler instanceof ServletContextAware) {
((ServletContextAware) handler).setServletContext(servletContext);
}
}
}
}
......@@ -18,10 +18,13 @@ package org.springframework.web.socket.sockjs.transport.handler;
import java.util.Map;
import javax.servlet.ServletContext;
import org.springframework.context.Lifecycle;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeFailureException;
......@@ -46,7 +49,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo
* @since 4.0
*/
public class WebSocketTransportHandler extends AbstractTransportHandler
implements SockJsSessionFactory, HandshakeHandler, Lifecycle {
implements SockJsSessionFactory, HandshakeHandler, Lifecycle, ServletContextAware {
private final HandshakeHandler handshakeHandler;
......@@ -68,6 +71,13 @@ public class WebSocketTransportHandler extends AbstractTransportHandler
return this.handshakeHandler;
}
@Override
public void setServletContext(ServletContext servletContext) {
if (this.handshakeHandler instanceof ServletContextAware) {
((ServletContextAware) this.handshakeHandler).setServletContext(servletContext);
}
}
@Override
public boolean isRunning() {
return this.running;
......
......@@ -78,7 +78,6 @@ public abstract class AbstractWebSocketIntegrationTests {
this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(getAnnotatedConfigClasses());
this.wac.register(upgradeStrategyConfigTypes.get(this.server.getClass()));
this.wac.refresh();
if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).start();
......
......@@ -22,6 +22,8 @@ import java.io.IOException;
import javax.servlet.Filter;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleEvent;
import org.apache.catalina.LifecycleListener;
import org.apache.catalina.connector.Connector;
import org.apache.catalina.startup.Tomcat;
import org.apache.coyote.http11.Http11NioProtocol;
......@@ -115,6 +117,12 @@ public class TomcatWebSocketTestServer implements WebSocketTestServer {
@Override
public void start() throws Exception {
this.tomcatServer.start();
this.context.addLifecycleListener(new LifecycleListener() {
@Override
public void lifecycleEvent(LifecycleEvent event) {
System.out.println(event.getType());
}
});
}
@Override
......
......@@ -79,11 +79,13 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest
@Test
public void unsolicitedPongWithEmptyPayload() throws Exception {
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
serverHandler.setWaitMessageCount(1);
String url = getWsBaseUrl() + "/ws";
WebSocketSession session = this.webSocketClient.doHandshake(new AbstractWebSocketHandler() {}, url).get();
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
serverHandler.setWaitMessageCount(1);
session.sendMessage(new PongMessage());
serverHandler.await();
......
......@@ -60,6 +60,7 @@ import org.springframework.messaging.simp.user.UserRegistryMessageHandler;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.mock.web.test.MockServletContext;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.MimeTypeUtils;
......@@ -453,6 +454,7 @@ public class MessageBrokerBeanDefinitionParserTests {
XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext);
ClassPathResource resource = new ClassPathResource(fileName, MessageBrokerBeanDefinitionParserTests.class);
reader.loadBeanDefinitions(resource);
this.appContext.setServletContext(new MockServletContext());
this.appContext.refresh();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册