From e083683f4fe9206609201bb39a60bbd8ee0c8a0f Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 22 Jun 2015 22:28:29 -0400 Subject: [PATCH] Update WebSocket support for Jetty 9.3 Issue: SPR-13140 --- build.gradle | 2 +- .../config/HandlersBeanDefinitionParser.java | 4 +- .../ServletWebSocketHandlerRegistry.java | 3 +- .../jetty/JettyRequestUpgradeStrategy.java | 25 +- .../support/AbstractHandshakeHandler.java | 399 ++++++++++++++++++ .../support/DefaultHandshakeHandler.java | 342 +-------------- .../support/WebSocketHandlerMapping.java | 12 + .../support/WebSocketHttpRequestHandler.java | 12 +- .../support/SockJsHttpRequestHandler.java | 11 +- .../handler/DefaultSockJsService.java | 13 +- .../handler/WebSocketTransportHandler.java | 12 +- .../AbstractWebSocketIntegrationTests.java | 1 - .../web/socket/TomcatWebSocketTestServer.java | 8 + .../web/socket/WebSocketIntegrationTests.java | 6 +- ...essageBrokerBeanDefinitionParserTests.java | 2 + 15 files changed, 504 insertions(+), 348 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java diff --git a/build.gradle b/build.gradle index 057e3bd83c..a0a60eb8bb 100644 --- a/build.gradle +++ b/build.gradle @@ -46,7 +46,7 @@ configure(allprojects) { project -> ext.jackson2Version = "2.6.0-rc2" // to be upgraded to 2.6 final in time for Spring Framework 4.2 GA ext.jasperreportsVersion = "6.1.0" ext.javamailVersion = "1.5.3" - ext.jettyVersion = "9.2.11.v20150529" + ext.jettyVersion = "9.3.0.v20150612" ext.jodaVersion = "2.8.1" ext.jrubyVersion = "1.7.20" ext.jtaVersion = "1.2" diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java index 1f4e7f6587..deced0d005 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java @@ -33,8 +33,8 @@ import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; +import org.springframework.web.socket.server.support.WebSocketHandlerMapping; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; @@ -64,7 +64,7 @@ class HandlersBeanDefinitionParser implements BeanDefinitionParser { String orderAttribute = element.getAttribute("order"); int order = orderAttribute.isEmpty() ? DEFAULT_MAPPING_ORDER : Integer.valueOf(orderAttribute); - RootBeanDefinition handlerMappingDef = new RootBeanDefinition(SimpleUrlHandlerMapping.class); + RootBeanDefinition handlerMappingDef = new RootBeanDefinition(WebSocketHandlerMapping.class); handlerMappingDef.setSource(source); handlerMappingDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); handlerMappingDef.getPropertyValues().add("order", order); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java index 0485f9ce1c..b14341b265 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java @@ -29,6 +29,7 @@ import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.handler.AbstractHandlerMapping; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.support.WebSocketHandlerMapping; import org.springframework.web.util.UrlPathHelper; /** @@ -101,7 +102,7 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry } } } - SimpleUrlHandlerMapping hm = new SimpleUrlHandlerMapping(); + WebSocketHandlerMapping hm = new WebSocketHandlerMapping(); hm.setUrlMap(urlMap); hm.setOrder(this.order); if (this.urlPathHelper != null) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java index d9d5c17523..0b94597ffb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java @@ -17,11 +17,13 @@ package org.springframework.web.socket.server.jetty; import java.io.IOException; +import java.lang.reflect.Method; import java.security.Principal; import java.util.ArrayList; import java.util.List; import java.util.Map; +import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -42,7 +44,9 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; +import org.springframework.web.context.ServletContextAware; import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.jetty.JettyWebSocketHandlerAdapter; @@ -59,7 +63,11 @@ import org.springframework.web.socket.server.RequestUpgradeStrategy; * @author Rossen Stoyanchev * @since 4.0 */ -public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle { +public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle, ServletContextAware { + + // Pre-Jetty 9.3 init method without ServletContext + private static final Method webSocketFactoryInitMethod = + ClassUtils.getMethodIfAvailable(WebSocketServerFactory.class, "init"); private static final ThreadLocal wsContainerHolder = new NamedThreadLocal("WebSocket Handler Container"); @@ -69,6 +77,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life private volatile List supportedExtensions; + private ServletContext servletContext; + private volatile boolean running = false; @@ -94,7 +104,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life // Cast to avoid infinite recursion return createWebSocket((UpgradeRequest) request, (UpgradeResponse) response); } - // For Jetty 9.0.x public Object createWebSocket(UpgradeRequest request, UpgradeResponse response) { WebSocketHandlerContainer container = wsContainerHolder.get(); @@ -128,6 +137,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life return result; } + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + @Override public boolean isRunning() { return this.running; @@ -139,7 +153,12 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life if (!isRunning()) { this.running = true; try { - this.factory.init(); + if (webSocketFactoryInitMethod != null) { + webSocketFactoryInitMethod.invoke(this.factory); + } + else { + this.factory.init(this.servletContext); + } } catch (Exception ex) { throw new IllegalStateException("Unable to initialize Jetty WebSocketServerFactory", ex); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java new file mode 100644 index 0000000000..85096911ff --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java @@ -0,0 +1,399 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.context.Lifecycle; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.socket.SubProtocolCapable; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketHttpHeaders; +import org.springframework.web.socket.handler.WebSocketHandlerDecorator; +import org.springframework.web.socket.server.HandshakeFailureException; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.RequestUpgradeStrategy; + +/** + * A base class to use for {@link HandshakeHandler} implementations. + * Performs initial validation of the WebSocket handshake request -- possibly rejecting it + * through the appropriate HTTP status code -- while also allowing sub-classes to override + * various parts of the negotiation process (e.g. origin validation, sub-protocol negotiation, + * extensions negotiation, etc). + * + *

If the negotiation succeeds, the actual upgrade is delegated to a server-specific + * {@link RequestUpgradeStrategy}, which will update + * the response as necessary and initialize the WebSocket. Currently supported servers are + * Tomcat 7 and 8, Jetty 9, and GlassFish 4. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle { + + private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + + + private static final boolean jettyWsPresent = ClassUtils.isPresent( + "org.eclipse.jetty.websocket.server.WebSocketServerFactory", AbstractHandshakeHandler.class.getClassLoader()); + + private static final boolean tomcatWsPresent = ClassUtils.isPresent( + "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", AbstractHandshakeHandler.class.getClassLoader()); + + private static final boolean undertowWsPresent = ClassUtils.isPresent( + "io.undertow.websockets.jsr.ServerWebSocketContainer", AbstractHandshakeHandler.class.getClassLoader()); + + private static final boolean glassFishWsPresent = ClassUtils.isPresent( + "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", AbstractHandshakeHandler.class.getClassLoader()); + + private static final boolean webLogicWsPresent = ClassUtils.isPresent( + "weblogic.websocket.tyrus.TyrusServletWriter", AbstractHandshakeHandler.class.getClassLoader()); + + + protected final Log logger = LogFactory.getLog(getClass()); + + private final RequestUpgradeStrategy requestUpgradeStrategy; + + private final List supportedProtocols = new ArrayList(); + + private volatile boolean running = false; + + + /** + * Default constructor that auto-detects and instantiates a + * {@link RequestUpgradeStrategy} suitable for the runtime container. + * @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found. + */ + protected AbstractHandshakeHandler() { + this(initRequestUpgradeStrategy()); + } + + /** + * A constructor that accepts a runtime-specific {@link RequestUpgradeStrategy}. + * @param requestUpgradeStrategy the upgrade strategy to use + */ + protected AbstractHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) { + Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy must not be null"); + this.requestUpgradeStrategy = requestUpgradeStrategy; + } + + + private static RequestUpgradeStrategy initRequestUpgradeStrategy() { + String className; + if (tomcatWsPresent) { + className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy"; + } + else if (jettyWsPresent) { + className = "org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy"; + } + else if (undertowWsPresent) { + className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy"; + } + else if (glassFishWsPresent) { + className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy"; + } + else if (webLogicWsPresent) { + className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy"; + } + else { + throw new IllegalStateException("No suitable default RequestUpgradeStrategy found"); + } + try { + Class clazz = ClassUtils.forName(className, AbstractHandshakeHandler.class.getClassLoader()); + return (RequestUpgradeStrategy) clazz.newInstance(); + } + catch (Throwable ex) { + throw new IllegalStateException("Failed to instantiate RequestUpgradeStrategy: " + className, ex); + } + } + + + /** + * Return the {@link RequestUpgradeStrategy} for WebSocket requests. + */ + public RequestUpgradeStrategy getRequestUpgradeStrategy() { + return this.requestUpgradeStrategy; + } + + /** + * Use this property to configure the list of supported sub-protocols. + * The first configured sub-protocol that matches a client-requested sub-protocol + * is accepted. If there are no matches the response will not contain a + * {@literal Sec-WebSocket-Protocol} header. + *

Note that if the WebSocketHandler passed in at runtime is an instance of + * {@link SubProtocolCapable} then there is not need to explicitly configure + * this property. That is certainly the case with the built-in STOMP over + * WebSocket support. Therefore this property should be configured explicitly + * only if the WebSocketHandler does not implement {@code SubProtocolCapable}. + */ + public void setSupportedProtocols(String... protocols) { + this.supportedProtocols.clear(); + for (String protocol : protocols) { + this.supportedProtocols.add(protocol.toLowerCase()); + } + } + + /** + * Return the list of supported sub-protocols. + */ + public String[] getSupportedProtocols() { + return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]); + } + + @Override + public boolean isRunning() { + return this.running; + } + + @Override + public void start() { + if (!isRunning()) { + this.running = true; + doStart(); + } + } + + protected void doStart() { + if (this.requestUpgradeStrategy instanceof Lifecycle) { + ((Lifecycle) this.requestUpgradeStrategy).start(); + } + } + + @Override + public void stop() { + if (isRunning()) { + this.running = false; + doStop(); + } + } + + protected void doStop() { + if (this.requestUpgradeStrategy instanceof Lifecycle) { + ((Lifecycle) this.requestUpgradeStrategy).stop(); + } + } + + + @Override + public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders()); + if (logger.isTraceEnabled()) { + logger.trace("Processing request " + request.getURI() + " with headers=" + headers); + } + try { + if (!HttpMethod.GET.equals(request.getMethod())) { + response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); + response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET)); + if (logger.isErrorEnabled()) { + logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod()); + } + return false; + } + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { + handleInvalidUpgradeHeader(request, response); + return false; + } + if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { + handleInvalidConnectHeader(request, response); + return false; + } + if (!isWebSocketVersionSupported(headers)) { + handleWebSocketVersionNotSupported(request, response); + return false; + } + if (!isValidOrigin(request)) { + response.setStatusCode(HttpStatus.FORBIDDEN); + return false; + } + String wsKey = headers.getSecWebSocketKey(); + if (wsKey == null) { + if (logger.isErrorEnabled()) { + logger.error("Missing \"Sec-WebSocket-Key\" header"); + } + response.setStatusCode(HttpStatus.BAD_REQUEST); + return false; + } + } + catch (IOException ex) { + throw new HandshakeFailureException( + "Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex); + } + + String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler); + List requested = headers.getSecWebSocketExtensions(); + List supported = this.requestUpgradeStrategy.getSupportedExtensions(request); + List extensions = filterRequestedExtensions(request, requested, supported); + Principal user = determineUser(request, wsHandler, attributes); + + if (logger.isTraceEnabled()) { + logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions); + } + this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes); + return true; + } + + protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { + if (logger.isErrorEnabled()) { + logger.error("Handshake failed due to invalid Upgrade header: " + request.getHeaders().getUpgrade()); + } + response.setStatusCode(HttpStatus.BAD_REQUEST); + response.getBody().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(UTF8_CHARSET)); + } + + protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { + if (logger.isErrorEnabled()) { + logger.error("Handshake failed due to invalid Connection header " + request.getHeaders().getConnection()); + } + response.setStatusCode(HttpStatus.BAD_REQUEST); + response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes(UTF8_CHARSET)); + } + + protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) { + String version = httpHeaders.getSecWebSocketVersion(); + String[] supportedVersions = getSupportedVersions(); + for (String supportedVersion : supportedVersions) { + if (supportedVersion.trim().equals(version)) { + return true; + } + } + return false; + } + + protected String[] getSupportedVersions() { + return this.requestUpgradeStrategy.getSupportedVersions(); + } + + protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) { + if (logger.isErrorEnabled()) { + String version = request.getHeaders().getFirst("Sec-WebSocket-Version"); + logger.error("Handshake failed due to unsupported WebSocket version: " + version + + ". Supported versions: " + Arrays.toString(getSupportedVersions())); + } + response.setStatusCode(HttpStatus.UPGRADE_REQUIRED); + response.getHeaders().put(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION, + Arrays.asList(StringUtils.arrayToCommaDelimitedString(getSupportedVersions()))); + } + + /** + * Return whether the request {@code Origin} header value is valid or not. + * By default, all origins as considered as valid. Consider using an + * {@link OriginHandshakeInterceptor} for filtering origins if needed. + */ + protected boolean isValidOrigin(ServerHttpRequest request) { + return true; + } + + /** + * Perform the sub-protocol negotiation based on requested and supported sub-protocols. + * For the list of supported sub-protocols, this method first checks if the target + * WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any + * sub-protocols have been explicitly configured with + * {@link #setSupportedProtocols(String...)}. + * @param requestedProtocols the requested sub-protocols + * @param webSocketHandler the WebSocketHandler that will be used + * @return the selected protocols or {@code null} + * @see #determineHandlerSupportedProtocols(WebSocketHandler) + */ + protected String selectProtocol(List requestedProtocols, WebSocketHandler webSocketHandler) { + if (requestedProtocols != null) { + List handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler); + for (String protocol : requestedProtocols) { + if (handlerProtocols.contains(protocol.toLowerCase())) { + return protocol; + } + if (this.supportedProtocols.contains(protocol.toLowerCase())) { + return protocol; + } + } + } + return null; + } + + /** + * Determine the sub-protocols supported by the given WebSocketHandler by + * checking whether it is an instance of {@link SubProtocolCapable}. + * @param handler the handler to check + * @return a list of supported protocols, or an empty list if none available + */ + protected final List determineHandlerSupportedProtocols(WebSocketHandler handler) { + WebSocketHandler handlerToCheck = WebSocketHandlerDecorator.unwrap(handler); + List subProtocols = null; + if (handlerToCheck instanceof SubProtocolCapable) { + subProtocols = ((SubProtocolCapable) handlerToCheck).getSubProtocols(); + } + return (subProtocols != null ? subProtocols : Collections.emptyList()); + } + + /** + * Filter the list of requested WebSocket extensions. + *

As of 4.1 the default implementation of this method filters the list to + * leave only extensions that are both requested and supported. + * @param request the current request + * @param requestedExtensions the list of extensions requested by the client + * @param supportedExtensions the list of extensions supported by the server + * @return the selected extensions or an empty list + */ + protected List filterRequestedExtensions(ServerHttpRequest request, + List requestedExtensions, List supportedExtensions) { + + List result = new ArrayList(requestedExtensions.size()); + for (WebSocketExtension extension : requestedExtensions) { + if (supportedExtensions.contains(extension)) { + result.add(extension); + } + } + return result; + } + + /** + * A method that can be used to associate a user with the WebSocket session + * in the process of being established. The default implementation calls + * {@link ServerHttpRequest#getPrincipal()} + *

Subclasses can provide custom logic for associating a user with a session, + * for example for assigning a name to anonymous users (i.e. not fully + * authenticated). + * @param request the handshake request + * @param wsHandler the WebSocket handler that will handle messages + * @param attributes handshake attributes to pass to the WebSocket session + * @return the user for the WebSocket session, or {@code null} if not available + */ + protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler, + Map attributes) { + + return request.getPrincipal(); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java index 1af82bc9e7..f4fd9358d1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java @@ -16,33 +16,9 @@ package org.springframework.web.socket.server.support; -import java.io.IOException; -import java.nio.charset.Charset; -import java.security.Principal; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import javax.servlet.ServletContext; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import org.springframework.context.Lifecycle; -import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; -import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; -import org.springframework.util.StringUtils; -import org.springframework.web.socket.SubProtocolCapable; -import org.springframework.web.socket.WebSocketExtension; -import org.springframework.web.socket.WebSocketHandler; -import org.springframework.web.socket.WebSocketHttpHeaders; -import org.springframework.web.socket.handler.WebSocketHandlerDecorator; -import org.springframework.web.socket.server.HandshakeFailureException; -import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.context.ServletContextAware; import org.springframework.web.socket.server.RequestUpgradeStrategy; /** @@ -60,325 +36,23 @@ import org.springframework.web.socket.server.RequestUpgradeStrategy; * @author Rossen Stoyanchev * @since 4.0 */ -public class DefaultHandshakeHandler implements HandshakeHandler, Lifecycle { - - private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); - - - private static final boolean jettyWsPresent = ClassUtils.isPresent( - "org.eclipse.jetty.websocket.server.WebSocketServerFactory", DefaultHandshakeHandler.class.getClassLoader()); - - private static final boolean tomcatWsPresent = ClassUtils.isPresent( - "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader()); - - private static final boolean undertowWsPresent = ClassUtils.isPresent( - "io.undertow.websockets.jsr.ServerWebSocketContainer", DefaultHandshakeHandler.class.getClassLoader()); - - private static final boolean glassFishWsPresent = ClassUtils.isPresent( - "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader()); - - private static final boolean webLogicWsPresent = ClassUtils.isPresent( - "weblogic.websocket.tyrus.TyrusServletWriter", DefaultHandshakeHandler.class.getClassLoader()); - - - protected final Log logger = LogFactory.getLog(getClass()); +public class DefaultHandshakeHandler extends AbstractHandshakeHandler implements ServletContextAware { - private final RequestUpgradeStrategy requestUpgradeStrategy; - private final List supportedProtocols = new ArrayList(); - - private volatile boolean running = false; - - - /** - * Default constructor that autodetects and instantiates a - * {@link RequestUpgradeStrategy} suitable for the runtime container. - * @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found. - */ public DefaultHandshakeHandler() { - this(initRequestUpgradeStrategy()); } - /** - * A constructor that accepts a runtime-specific {@link RequestUpgradeStrategy}. - * @param requestUpgradeStrategy the upgrade strategy to use - */ public DefaultHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) { - Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy must not be null"); - this.requestUpgradeStrategy = requestUpgradeStrategy; - } - - - private static RequestUpgradeStrategy initRequestUpgradeStrategy() { - String className; - if (jettyWsPresent) { - className = "org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy"; - } - else if (tomcatWsPresent) { - className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy"; - } - else if (undertowWsPresent) { - className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy"; - } - else if (glassFishWsPresent) { - className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy"; - } - else if (webLogicWsPresent) { - className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy"; - } - else { - throw new IllegalStateException("No suitable default RequestUpgradeStrategy found"); - } - try { - Class clazz = ClassUtils.forName(className, DefaultHandshakeHandler.class.getClassLoader()); - return (RequestUpgradeStrategy) clazz.newInstance(); - } - catch (Throwable ex) { - throw new IllegalStateException("Failed to instantiate RequestUpgradeStrategy: " + className, ex); - } - } - - - /** - * Use this property to configure the list of supported sub-protocols. - * The first configured sub-protocol that matches a client-requested sub-protocol - * is accepted. If there are no matches the response will not contain a - * {@literal Sec-WebSocket-Protocol} header. - *

Note that if the WebSocketHandler passed in at runtime is an instance of - * {@link SubProtocolCapable} then there is not need to explicitly configure - * this property. That is certainly the case with the built-in STOMP over - * WebSocket support. Therefore this property should be configured explicitly - * only if the WebSocketHandler does not implement {@code SubProtocolCapable}. - */ - public void setSupportedProtocols(String... protocols) { - this.supportedProtocols.clear(); - for (String protocol : protocols) { - this.supportedProtocols.add(protocol.toLowerCase()); - } - } - - /** - * Return the list of supported sub-protocols. - */ - public String[] getSupportedProtocols() { - return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]); - } - - @Override - public boolean isRunning() { - return this.running; - } - - @Override - public void start() { - if (!isRunning()) { - this.running = true; - if (this.requestUpgradeStrategy instanceof Lifecycle) { - ((Lifecycle) this.requestUpgradeStrategy).start(); - } - } - } - - @Override - public void stop() { - if (isRunning()) { - this.running = false; - if (this.requestUpgradeStrategy instanceof Lifecycle) { - ((Lifecycle) this.requestUpgradeStrategy).stop(); - } - } + super(requestUpgradeStrategy); } @Override - public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { - - WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders()); - if (logger.isTraceEnabled()) { - logger.trace("Processing request " + request.getURI() + " with headers=" + headers); - } - try { - if (!HttpMethod.GET.equals(request.getMethod())) { - response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); - response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET)); - if (logger.isErrorEnabled()) { - logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod()); - } - return false; - } - if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { - handleInvalidUpgradeHeader(request, response); - return false; - } - if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { - handleInvalidConnectHeader(request, response); - return false; - } - if (!isWebSocketVersionSupported(headers)) { - handleWebSocketVersionNotSupported(request, response); - return false; - } - if (!isValidOrigin(request)) { - response.setStatusCode(HttpStatus.FORBIDDEN); - return false; - } - String wsKey = headers.getSecWebSocketKey(); - if (wsKey == null) { - if (logger.isErrorEnabled()) { - logger.error("Missing \"Sec-WebSocket-Key\" header"); - } - response.setStatusCode(HttpStatus.BAD_REQUEST); - return false; - } - } - catch (IOException ex) { - throw new HandshakeFailureException( - "Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex); - } - - String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler); - List requested = headers.getSecWebSocketExtensions(); - List supported = this.requestUpgradeStrategy.getSupportedExtensions(request); - List extensions = filterRequestedExtensions(request, requested, supported); - Principal user = determineUser(request, wsHandler, attributes); - - if (logger.isTraceEnabled()) { - logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions); - } - this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes); - return true; - } - - protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { - if (logger.isErrorEnabled()) { - logger.error("Handshake failed due to invalid Upgrade header: " + request.getHeaders().getUpgrade()); - } - response.setStatusCode(HttpStatus.BAD_REQUEST); - response.getBody().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(UTF8_CHARSET)); - } - - protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { - if (logger.isErrorEnabled()) { - logger.error("Handshake failed due to invalid Connection header " + request.getHeaders().getConnection()); - } - response.setStatusCode(HttpStatus.BAD_REQUEST); - response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes(UTF8_CHARSET)); - } - - protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) { - String version = httpHeaders.getSecWebSocketVersion(); - String[] supportedVersions = getSupportedVersions(); - for (String supportedVersion : supportedVersions) { - if (supportedVersion.trim().equals(version)) { - return true; - } - } - return false; - } - - protected String[] getSupportedVersions() { - return this.requestUpgradeStrategy.getSupportedVersions(); - } - - protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) { - if (logger.isErrorEnabled()) { - String version = request.getHeaders().getFirst("Sec-WebSocket-Version"); - logger.error("Handshake failed due to unsupported WebSocket version: " + version + - ". Supported versions: " + Arrays.toString(getSupportedVersions())); - } - response.setStatusCode(HttpStatus.UPGRADE_REQUIRED); - response.getHeaders().put(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION, - Arrays.asList(StringUtils.arrayToCommaDelimitedString(getSupportedVersions()))); - } - - /** - * Return whether the request {@code Origin} header value is valid or not. - * By default, all origins as considered as valid. Consider using an - * {@link OriginHandshakeInterceptor} for filtering origins if needed. - */ - protected boolean isValidOrigin(ServerHttpRequest request) { - return true; - } - - /** - * Perform the sub-protocol negotiation based on requested and supported sub-protocols. - * For the list of supported sub-protocols, this method first checks if the target - * WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any - * sub-protocols have been explicitly configured with - * {@link #setSupportedProtocols(String...)}. - * @param requestedProtocols the requested sub-protocols - * @param webSocketHandler the WebSocketHandler that will be used - * @return the selected protocols or {@code null} - * @see #determineHandlerSupportedProtocols(org.springframework.web.socket.WebSocketHandler) - */ - protected String selectProtocol(List requestedProtocols, WebSocketHandler webSocketHandler) { - if (requestedProtocols != null) { - List handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler); - for (String protocol : requestedProtocols) { - if (handlerProtocols.contains(protocol.toLowerCase())) { - return protocol; - } - if (this.supportedProtocols.contains(protocol.toLowerCase())) { - return protocol; - } - } + public void setServletContext(ServletContext servletContext) { + RequestUpgradeStrategy strategy = getRequestUpgradeStrategy(); + if (strategy instanceof ServletContextAware) { + ((ServletContextAware) strategy).setServletContext(servletContext); } - return null; - } - - /** - * Determine the sub-protocols supported by the given WebSocketHandler by - * checking whether it is an instance of {@link SubProtocolCapable}. - * @param handler the handler to check - * @return a list of supported protocols, or an empty list if none available - */ - protected final List determineHandlerSupportedProtocols(WebSocketHandler handler) { - WebSocketHandler handlerToCheck = WebSocketHandlerDecorator.unwrap(handler); - List subProtocols = null; - if (handlerToCheck instanceof SubProtocolCapable) { - subProtocols = ((SubProtocolCapable) handlerToCheck).getSubProtocols(); - } - return (subProtocols != null ? subProtocols : Collections.emptyList()); - } - - /** - * Filter the list of requested WebSocket extensions. - *

As of 4.1 the default implementation of this method filters the list to - * leave only extensions that are both requested and supported. - * @param request the current request - * @param requestedExtensions the list of extensions requested by the client - * @param supportedExtensions the list of extensions supported by the server - * @return the selected extensions or an empty list - */ - protected List filterRequestedExtensions(ServerHttpRequest request, - List requestedExtensions, List supportedExtensions) { - - List result = new ArrayList(requestedExtensions.size()); - for (WebSocketExtension extension : requestedExtensions) { - if (supportedExtensions.contains(extension)) { - result.add(extension); - } - } - return result; - } - - /** - * A method that can be used to associate a user with the WebSocket session - * in the process of being established. The default implementation calls - * {@link org.springframework.http.server.ServerHttpRequest#getPrincipal()} - *

Subclasses can provide custom logic for associating a user with a session, - * for example for assigning a name to anonymous users (i.e. not fully - * authenticated). - * @param request the handshake request - * @param wsHandler the WebSocket handler that will handle messages - * @param attributes handshake attributes to pass to the WebSocket session - * @return the user for the WebSocket session, or {@code null} if not available - */ - protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler, - Map attributes) { - - return request.getPrincipal(); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHandlerMapping.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHandlerMapping.java index 5b6bbb693b..d5c9937c3d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHandlerMapping.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHandlerMapping.java @@ -15,8 +15,11 @@ */ package org.springframework.web.socket.server.support; +import javax.servlet.ServletContext; + import org.springframework.context.Lifecycle; import org.springframework.context.SmartLifecycle; +import org.springframework.web.context.ServletContextAware; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; /** @@ -33,6 +36,15 @@ public class WebSocketHandlerMapping extends SimpleUrlHandlerMapping implements private volatile boolean running = false; + @Override + protected void initServletContext(ServletContext servletContext) { + for (Object handler : getUrlMap().values()) { + if (handler instanceof ServletContextAware) { + ((ServletContextAware) handler).setServletContext(servletContext); + } + } + } + @Override public boolean isAutoStartup() { return true; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java index f2f560c35e..ae814653d9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java @@ -21,6 +21,8 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; + +import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -35,6 +37,7 @@ import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.HttpRequestHandler; +import org.springframework.web.context.ServletContextAware; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator; import org.springframework.web.socket.handler.LoggingWebSocketHandlerDecorator; @@ -53,7 +56,7 @@ import org.springframework.web.socket.server.HandshakeInterceptor; * @author Rossen Stoyanchev * @since 4.0 */ -public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycle { +public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycle, ServletContextAware { private final Log logger = LogFactory.getLog(WebSocketHttpRequestHandler.class); @@ -109,6 +112,13 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl return this.interceptors; } + @Override + public void setServletContext(ServletContext servletContext) { + if (this.handshakeHandler instanceof ServletContextAware) { + ((ServletContextAware) this.handshakeHandler).setServletContext(servletContext); + } + } + @Override public boolean isRunning() { return this.running; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/SockJsHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/SockJsHttpRequestHandler.java index 6c8fec4074..d2ba85ec16 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/SockJsHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/SockJsHttpRequestHandler.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.sockjs.support; import java.io.IOException; +import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -29,6 +30,7 @@ import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.HttpRequestHandler; +import org.springframework.web.context.ServletContextAware; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.web.servlet.HandlerMapping; @@ -47,7 +49,7 @@ import org.springframework.web.socket.sockjs.SockJsService; * @since 4.0 */ public class SockJsHttpRequestHandler - implements HttpRequestHandler, CorsConfigurationSource, Lifecycle { + implements HttpRequestHandler, CorsConfigurationSource, Lifecycle, ServletContextAware { // No logging: HTTP transports too verbose and we don't know enough to log anything of value @@ -86,6 +88,13 @@ public class SockJsHttpRequestHandler return this.webSocketHandler; } + @Override + public void setServletContext(ServletContext servletContext) { + if (this.sockJsService instanceof ServletContextAware) { + ((ServletContextAware) this.sockJsService).setServletContext(servletContext); + } + } + @Override public boolean isRunning() { return this.running; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java index 6a1f1ed522..5f2200e05a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java @@ -21,10 +21,13 @@ import java.util.Collection; import java.util.LinkedHashSet; import java.util.Set; +import javax.servlet.ServletContext; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.scheduling.TaskScheduler; +import org.springframework.web.context.ServletContextAware; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService; @@ -37,7 +40,7 @@ import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsSe * @author Juergen Hoeller * @since 4.0 */ -public class DefaultSockJsService extends TransportHandlingSockJsService { +public class DefaultSockJsService extends TransportHandlingSockJsService implements ServletContextAware { /** * Create a DefaultSockJsService with default {@link TransportHandler handler} types. @@ -99,4 +102,12 @@ public class DefaultSockJsService extends TransportHandlingSockJsService { return result; } + @Override + public void setServletContext(ServletContext servletContext) { + for (TransportHandler handler : getTransportHandlers().values()) { + if (handler instanceof ServletContextAware) { + ((ServletContextAware) handler).setServletContext(servletContext); + } + } + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java index ba70fdacba..d7316f4786 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java @@ -18,10 +18,13 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.util.Map; +import javax.servlet.ServletContext; + import org.springframework.context.Lifecycle; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.web.context.ServletContextAware; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeFailureException; @@ -46,7 +49,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo * @since 4.0 */ public class WebSocketTransportHandler extends AbstractTransportHandler - implements SockJsSessionFactory, HandshakeHandler, Lifecycle { + implements SockJsSessionFactory, HandshakeHandler, Lifecycle, ServletContextAware { private final HandshakeHandler handshakeHandler; @@ -68,6 +71,13 @@ public class WebSocketTransportHandler extends AbstractTransportHandler return this.handshakeHandler; } + @Override + public void setServletContext(ServletContext servletContext) { + if (this.handshakeHandler instanceof ServletContextAware) { + ((ServletContextAware) this.handshakeHandler).setServletContext(servletContext); + } + } + @Override public boolean isRunning() { return this.running; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java index 123d6c07a1..8b6d0b9316 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java @@ -78,7 +78,6 @@ public abstract class AbstractWebSocketIntegrationTests { this.wac = new AnnotationConfigWebApplicationContext(); this.wac.register(getAnnotatedConfigClasses()); this.wac.register(upgradeStrategyConfigTypes.get(this.server.getClass())); - this.wac.refresh(); if (this.webSocketClient instanceof Lifecycle) { ((Lifecycle) this.webSocketClient).start(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java index 7c3c47cd12..4adc6aae77 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java @@ -22,6 +22,8 @@ import java.io.IOException; import javax.servlet.Filter; import org.apache.catalina.Context; +import org.apache.catalina.LifecycleEvent; +import org.apache.catalina.LifecycleListener; import org.apache.catalina.connector.Connector; import org.apache.catalina.startup.Tomcat; import org.apache.coyote.http11.Http11NioProtocol; @@ -115,6 +117,12 @@ public class TomcatWebSocketTestServer implements WebSocketTestServer { @Override public void start() throws Exception { this.tomcatServer.start(); + this.context.addLifecycleListener(new LifecycleListener() { + @Override + public void lifecycleEvent(LifecycleEvent event) { + System.out.println(event.getType()); + } + }); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java index 76ef4a63a5..1d86998dae 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java @@ -79,11 +79,13 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest @Test public void unsolicitedPongWithEmptyPayload() throws Exception { - TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); - serverHandler.setWaitMessageCount(1); String url = getWsBaseUrl() + "/ws"; WebSocketSession session = this.webSocketClient.doHandshake(new AbstractWebSocketHandler() {}, url).get(); + + TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); + serverHandler.setWaitMessageCount(1); + session.sendMessage(new PongMessage()); serverHandler.await(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 1c7f876214..e3c8847762 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -60,6 +60,7 @@ import org.springframework.messaging.simp.user.UserRegistryMessageHandler; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; +import org.springframework.mock.web.test.MockServletContext; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.MimeTypeUtils; @@ -453,6 +454,7 @@ public class MessageBrokerBeanDefinitionParserTests { XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext); ClassPathResource resource = new ClassPathResource(fileName, MessageBrokerBeanDefinitionParserTests.class); reader.loadBeanDefinitions(resource); + this.appContext.setServletContext(new MockServletContext()); this.appContext.refresh(); } -- GitLab