From 9e20a25607aeb54dc0d636727fb67b92e00dab1e Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Tue, 30 Jul 2013 11:29:57 +0100 Subject: [PATCH] Introduce SubProtocolHandler abstraction Add SubProtocolHandler to encapsulate the logic for using a sub-protocol. A SubProtocolWebSocketHandler is also provided to delegate to the appropriate SubProtocolHandler based on the negotiated sub-protocol value at handshake. StompSubProtocolHandler provides handling for STOMP messages. Issue: SPR-10786 --- .../handler/websocket/SubProtocolHandler.java | 94 ++++++++ .../SubProtocolWebSocketHandler.java | 208 ++++++++++++++++++ .../messaging/simp/package-info.java | 2 +- .../stomp/StompBrokerRelayMessageHandler.java | 15 +- ...Handler.java => StompProtocolHandler.java} | 183 +++++++-------- .../messaging/simp/stomp/package-info.java | 4 + .../SubProtocolWebSocketHandlerTests.java | 127 +++++++++++ .../web/socket/WebSocketHandler.java | 9 +- .../MultiProtocolWebSocketHandler.java | 125 ----------- .../MultiProtocolWebSocketHandlerTests.java | 106 --------- 10 files changed, 532 insertions(+), 341 deletions(-) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java rename spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/{StompWebSocketHandler.java => StompProtocolHandler.java} (77%) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/package-info.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java delete mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java delete mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java new file mode 100644 index 0000000000..1d3dcb30a0 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.handler.websocket; + +import java.util.List; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; + + +/** + * A contract for handling WebSocket messages as part of a higher level protocol, referred + * to as "sub-protocol" in the WebSocket RFC specification. Handles both + * {@link WebSocketMessage}s from a client as well as {@link Message}s to a client. + *

+ * Implementations of this interface can be configured on a + * {@link SubProtocolWebSocketHandler} which selects a sub-protocol handler to delegate + * messages to based on the sub-protocol requested by the client through the + * {@code Sec-WebSocket-Protocol} request header. + * + * @author Andy Wilkinson + * @author Rossen Stoyanchev + * + * @since 4.0 + */ +public interface SubProtocolHandler { + + /** + * Return the list of sub-protocols supported by this handler, never {@code null}. + */ + List getSupportedProtocols(); + + /** + * Handle the given {@link WebSocketMessage} received from a client. + * + * @param session the client session + * @param message the client message + * @param outputChannel an output channel to send messages to + */ + void handleMessageFromClient(WebSocketSession session, WebSocketMessage message, + MessageChannel outputChannel) throws Exception; + + /** + * Handle the given {@link Message} to the client associated with the given WebSocket + * session. + * + * @param session the client session + * @param message the client message + */ + void handleMessageToClient(WebSocketSession session, Message message) throws Exception; + + /** + * Resolve the session id from the given message or return {@code null}. + * + * @param the message to resolve the session id from + */ + String resolveSessionId(Message message); + + /** + * Invoked after a {@link WebSocketSession} has started. + * + * @param session the client session + * @param outputChannel a channel + */ + void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) throws Exception; + + /** + * Invoked after a {@link WebSocketSession} has ended. + * + * @param session the client session + * @param closeStatus the reason why the session was closed + * @param outputChannel a channel + */ + void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, + MessageChannel outputChannel) throws Exception; + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java new file mode 100644 index 0000000000..42125d78d2 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java @@ -0,0 +1,208 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.handler.websocket; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; + + +/** + * A {@link WebSocketHandler} that delegates messages to a {@link SubProtocolHandler} + * based on the sub-protocol value requested by the client through the + * {@code Sec-WebSocket-Protocol} request header A default handler can also be configured + * to use if the client does not request a specific sub-protocol. + * + * @author Rossen Stoyanchev + * @author Andy Wilkinson + * + * @since 4.0 + */ +public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHandler { + + private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); + + private final MessageChannel outputChannel; + + private final Map protocolHandlers = + new TreeMap(String.CASE_INSENSITIVE_ORDER); + + private SubProtocolHandler defaultProtocolHandler; + + private final Map sessions = new ConcurrentHashMap(); + + + /** + * @param outputChannel + */ + public SubProtocolWebSocketHandler(MessageChannel outputChannel) { + Assert.notNull(outputChannel, "outputChannel is required"); + this.outputChannel = outputChannel; + } + + /** + * Configure one or more handlers to use depending on the sub-protocol requested by + * the client in the WebSocket handshake request. + * + * @param protocolHandlers the sub-protocol handlers to use + */ + public void setProtocolHandlers(List protocolHandlers) { + this.protocolHandlers.clear(); + for (SubProtocolHandler handler: protocolHandlers) { + List protocols = handler.getSupportedProtocols(); + if (CollectionUtils.isEmpty(protocols)) { + logger.warn("No sub-protocols, ignoring handler " + handler); + continue; + } + for (String protocol: protocols) { + SubProtocolHandler replaced = this.protocolHandlers.put(protocol, handler); + if (replaced != null) { + throw new IllegalStateException("Failed to map handler " + handler + + " to protocol '" + protocol + "', it is already mapped to handler " + replaced); + } + } + } + if ((this.protocolHandlers.size() == 1) &&(this.defaultProtocolHandler == null)) { + this.defaultProtocolHandler = this.protocolHandlers.values().iterator().next(); + } + } + + /** + * @return the configured sub-protocol handlers + */ + public Map getProtocolHandlers() { + return this.protocolHandlers; + } + + /** + * Set the {@link SubProtocolHandler} to use when the client did not request a + * sub-protocol. + * + * @param defaultProtocolHandler the default handler + */ + public void setDefaultProtocolHandler(SubProtocolHandler defaultProtocolHandler) { + this.defaultProtocolHandler = defaultProtocolHandler; + if (this.protocolHandlers.isEmpty()) { + setProtocolHandlers(Arrays.asList(defaultProtocolHandler)); + } + } + + /** + * @return the default sub-protocol handler to use + */ + public SubProtocolHandler getDefaultProtocolHandler() { + return this.defaultProtocolHandler; + } + + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + this.sessions.put(session.getId(), session); + getProtocolHandler(session).afterSessionStarted(session, this.outputChannel); + } + + protected final SubProtocolHandler getProtocolHandler(WebSocketSession session) { + SubProtocolHandler handler; + String protocol = session.getAcceptedProtocol(); + if (protocol != null) { + handler = this.protocolHandlers.get(protocol); + Assert.state(handler != null, + "No handler for sub-protocol '" + protocol + "', handlers=" + this.protocolHandlers); + } + else { + handler = this.defaultProtocolHandler; + Assert.state(handler != null, + "No sub-protocol was requested and a default sub-protocol handler was not configured"); + } + return handler; + } + + @Override + public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { + getProtocolHandler(session).handleMessageFromClient(session, message, this.outputChannel); + } + + @Override + public void handleMessage(Message message) throws MessagingException { + + String sessionId = resolveSessionId(message); + if (sessionId == null) { + logger.error("sessionId not found in message " + message); + return; + } + + WebSocketSession session = this.sessions.get(sessionId); + if (session == null) { + logger.error("Session not found for session with id " + sessionId); + return; + } + + try { + getProtocolHandler(session).handleMessageToClient(session, message); + } + catch (Exception e) { + logger.error("Failed to send message to client " + message, e); + } + } + + private String resolveSessionId(Message message) { + for (SubProtocolHandler handler : this.protocolHandlers.values()) { + String sessionId = handler.resolveSessionId(message); + if (sessionId != null) { + return sessionId; + } + } + if (this.defaultProtocolHandler != null) { + String sessionId = this.defaultProtocolHandler.resolveSessionId(message); + if (sessionId != null) { + return sessionId; + } + } + return null; + } + + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { + this.sessions.remove(session.getId()); + getProtocolHandler(session).afterSessionEnded(session, closeStatus, this.outputChannel); + } + + @Override + public boolean supportsPartialMessages() { + return false; + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/package-info.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/package-info.java index 0440189e7a..4073d52cb1 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/package-info.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/package-info.java @@ -1,4 +1,4 @@ /** - * Generic support for simple messaging protocols (like STOMP). + * Generic support for SImple Messaging Protocols such as STOMP. */ package org.springframework.messaging.simp; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 0f45bde028..3fda85e5d8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -197,7 +197,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife /** * Open a "system" session for sending messages from parts of the application - * not assoicated with a client STOMP session. + * not associated with a client STOMP session. */ private void openSystemSession() { @@ -449,12 +449,21 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife } } - private boolean forwardInternal(Message message, TcpConnection connection) { + private boolean forwardInternal(final Message message, TcpConnection connection) { if (logger.isTraceEnabled()) { logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); } byte[] bytes = stompMessageConverter.fromMessage(message); - connection.send(new String(bytes, Charset.forName("UTF-8"))); + connection.send(new String(bytes, Charset.forName("UTF-8")), new Consumer() { + @Override + public void accept(Boolean success) { + if (!success) { + String sessionId = StompHeaderAccessor.wrap(message).getSessionId(); + relaySessions.remove(sessionId); + sendError(sessionId, "Failed to relay message to broker"); + } + } + }); // TODO: detect if send fails and send ERROR downstream (except on DISCONNECT) return true; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java similarity index 77% rename from spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java rename to spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index 260de8687f..3efea2df31 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,36 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.messaging.simp.stomp; import java.io.IOException; import java.nio.charset.Charset; import java.security.Principal; -import java.util.Map; +import java.util.Arrays; +import java.util.List; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.handler.websocket.SubProtocolHandler; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.handler.MutableUserQueueSuffixResolver; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; - -import reactor.util.Assert; - /** + * A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2 of the + * STOMP specification. + * * @author Rossen Stoyanchev - * @since 4.0 + * @author Andy Wilkinson */ -public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implements MessageHandler { +public class StompProtocolHandler implements SubProtocolHandler { /** * The name of the header set on the CONNECTED frame indicating the name of the user @@ -58,25 +60,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement */ public static final String QUEUE_SUFFIX_HEADER = "queue-suffix"; - - private static Log logger = LogFactory.getLog(StompWebSocketHandler.class); - - private MessageChannel dispatchChannel; - - private MutableUserQueueSuffixResolver queueSuffixResolver; + private final Log logger = LogFactory.getLog(StompProtocolHandler.class); private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); - private final Map sessions = new ConcurrentHashMap(); - - - /** - * @param dispatchChannel the channel to send client STOMP/WebSocket messages to - */ - public StompWebSocketHandler(MessageChannel dispatchChannel) { - Assert.notNull(dispatchChannel, "dispatchChannel is required"); - this.dispatchChannel = dispatchChannel; - } + private MutableUserQueueSuffixResolver queueSuffixResolver; /** @@ -94,23 +82,20 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement return this.queueSuffixResolver; } - public StompMessageConverter getStompMessageConverter() { - return this.stompMessageConverter; - } - - @Override - public void afterConnectionEstablished(WebSocketSession session) throws Exception { - this.sessions.put(session.getId(), session); + public List getSupportedProtocols() { + return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"); } /** * Handle incoming WebSocket messages from clients. */ - @Override - protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) { + public void handleMessageFromClient(WebSocketSession session, WebSocketMessage webSocketMessage, + MessageChannel outputChannel) { + try { - String payload = textMessage.getPayload(); + Assert.isInstanceOf(TextMessage.class, webSocketMessage); + String payload = ((TextMessage)webSocketMessage).getPayload(); Message message = this.stompMessageConverter.toMessage(payload); // TODO: validate size limits @@ -124,13 +109,14 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); headers.setSessionId(session.getId()); headers.setUser(session.getPrincipal()); + message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); if (SimpMessageType.CONNECT.equals(headers.getMessageType())) { handleConnect(session, message); } - this.dispatchChannel.send(message); + outputChannel.send(message); } catch (Throwable t) { @@ -147,6 +133,51 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement } } + /** + * Handle STOMP messages going back out to WebSocket clients. + */ + @Override + public void handleMessageToClient(WebSocketSession session, Message message) { + + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + headers.setCommandIfNotSet(StompCommand.MESSAGE); + + if (StompCommand.CONNECTED.equals(headers.getCommand())) { + // Ignore for now since we already sent it + return; + } + + if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) { + // TODO: failed message delivery mechanism + logger.error("Ignoring message, no subscriptionId header: " + message); + return; + } + + if (!(message.getPayload() instanceof byte[])) { + // TODO: failed message delivery mechanism + logger.error("Ignoring message, expected byte[] content: " + message); + return; + } + + try { + message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); + byte[] bytes = this.stompMessageConverter.fromMessage(message); + session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); + } + catch (Throwable t) { + sendErrorMessage(session, t); + } + finally { + if (StompCommand.ERROR.equals(headers.getCommand())) { + try { + session.close(CloseStatus.PROTOCOL_ERROR); + } + catch (IOException e) { + } + } + } + } + protected void handleConnect(WebSocketSession session, Message message) throws IOException { StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(message); @@ -200,78 +231,26 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement } @Override - public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { - - String sessionId = session.getId(); - this.sessions.remove(sessionId); - - if ((this.queueSuffixResolver != null) && (session.getPrincipal() != null)) { - this.queueSuffixResolver.removeQueueSuffix(session.getPrincipal().getName(), sessionId); - } - - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); - headers.setSessionId(sessionId); - Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - this.dispatchChannel.send(message); + public String resolveSessionId(Message message) { + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + return headers.getSessionId(); } - /** - * Handle STOMP messages going back out to WebSocket clients. - */ @Override - public void handleMessage(Message message) { - - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - headers.setCommandIfNotSet(StompCommand.MESSAGE); - - if (StompCommand.CONNECTED.equals(headers.getCommand())) { - // Ignore for now since we already sent it - return; - } - - String sessionId = headers.getSessionId(); - if (sessionId == null) { - // TODO: failed message delivery mechanism - logger.error("Ignoring message, no sessionId header: " + message); - return; - } - - WebSocketSession session = this.sessions.get(sessionId); - if (session == null) { - // TODO: failed message delivery mechanism - logger.error("Ignoring message, sessionId not found: " + message); - return; - } + public void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) { + } - if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) { - // TODO: failed message delivery mechanism - logger.error("Ignoring message, no subscriptionId header: " + message); - return; - } + @Override + public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) { - if (!(message.getPayload() instanceof byte[])) { - // TODO: failed message delivery mechanism - logger.error("Ignoring message, expected byte[] content: " + message); - return; + if ((this.queueSuffixResolver != null) && (session.getPrincipal() != null)) { + this.queueSuffixResolver.removeQueueSuffix(session.getPrincipal().getName(), session.getId()); } - try { - message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); - byte[] bytes = this.stompMessageConverter.fromMessage(message); - session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); - } - catch (Throwable t) { - sendErrorMessage(session, t); - } - finally { - if (StompCommand.ERROR.equals(headers.getCommand())) { - try { - session.close(CloseStatus.PROTOCOL_ERROR); - } - catch (IOException e) { - } - } - } + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.setSessionId(session.getId()); + Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + outputChannel.send(message); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/package-info.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/package-info.java new file mode 100644 index 0000000000..df73ffeb65 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/package-info.java @@ -0,0 +1,4 @@ +/** + * Generic support for simple messaging protocols (like STOMP). + */ +package org.springframework.messaging.simp.stomp; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java new file mode 100644 index 0000000000..ccc4d53e4f --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIOsNS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.handler.websocket; + +import java.util.Arrays; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.messaging.MessageChannel; +import org.springframework.web.socket.support.TestWebSocketSession; + +import static org.mockito.Mockito.*; + + +/** + * Test fixture for {@link SubProtocolWebSocketHandler}. + * + * @author Rossen Stoyanchev + * @author Andy Wilkinson + */ +public class SubProtocolWebSocketHandlerTests { + + private SubProtocolWebSocketHandler webSocketHandler; + + private TestWebSocketSession session; + + @Mock + SubProtocolHandler stompHandler; + + @Mock + SubProtocolHandler mqttHandler; + + @Mock + SubProtocolHandler defaultHandler; + + @Mock + MessageChannel channel; + + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + this.webSocketHandler = new SubProtocolWebSocketHandler(this.channel); + when(stompHandler.getSupportedProtocols()).thenReturn(Arrays.asList("STOMP")); + when(mqttHandler.getSupportedProtocols()).thenReturn(Arrays.asList("MQTT")); + + this.session = new TestWebSocketSession(); + this.session.setId("1"); + } + + + @Test + public void subProtocolMatch() throws Exception { + this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler, mqttHandler)); + this.session.setAcceptedProtocol("sToMp"); + this.webSocketHandler.afterConnectionEstablished(session); + + verify(this.stompHandler).afterSessionStarted(session, this.channel); + verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel); + } + + @Test + public void subProtocolDefaultHandlerOnly() throws Exception { + this.webSocketHandler.setDefaultProtocolHandler(stompHandler); + this.session.setAcceptedProtocol("sToMp"); + this.webSocketHandler.afterConnectionEstablished(session); + + verify(this.stompHandler).afterSessionStarted(session, this.channel); + } + + @Test(expected=IllegalStateException.class) + public void subProtocolNoMatch() throws Exception { + this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); + this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler, mqttHandler)); + this.session.setAcceptedProtocol("wamp"); + + this.webSocketHandler.afterConnectionEstablished(session); + } + + @Test + public void noSubProtocol() throws Exception { + this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); + this.webSocketHandler.afterConnectionEstablished(session); + + verify(this.defaultHandler).afterSessionStarted(session, this.channel); + verify(this.stompHandler, times(0)).afterSessionStarted(session, this.channel); + verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel); + } + + @Test + public void noSubProtocolOneHandler() throws Exception { + this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler)); + this.webSocketHandler.afterConnectionEstablished(session); + + verify(this.stompHandler).afterSessionStarted(session, this.channel); + } + + @Test(expected=IllegalStateException.class) + public void noSubProtocolTwoHandlers() throws Exception { + this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler, mqttHandler)); + this.webSocketHandler.afterConnectionEstablished(session); + } + + @Test(expected=IllegalStateException.class) + public void noSubProtocolNoDefaultHandler() throws Exception { + this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler, mqttHandler)); + this.webSocketHandler.afterConnectionEstablished(session); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHandler.java index 309bb41abc..07f5499f2c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHandler.java @@ -20,12 +20,12 @@ package org.springframework.web.socket; * A handler for WebSocket messages and lifecycle events. * *

Implementations of this interface are encouraged to handle exceptions locally where - * it makes sense or alternatively let the exception bubble up in which case the exception - * is logged and the session closed with - * {@link CloseStatus#SERVER_ERROR SERVER_ERROR(1011)} by default. The exception handling + * it makes sense or alternatively let the exception bubble up in which case by default + * the exception is logged and the session closed with + * {@link CloseStatus#SERVER_ERROR SERVER_ERROR(1011)}. The exception handling * strategy is provided by * {@link org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator - * ExceptionWebSocketHandlerDecorator}, which can be customized or replaced by decorating + * ExceptionWebSocketHandlerDecorator} and it can be customized or replaced by decorating * the {@link WebSocketHandler} with a different decorator. * * @author Rossen Stoyanchev @@ -61,6 +61,7 @@ public interface WebSocketHandler { * transport error has occurred. Although the session may technically still be open, * depending on the underlying implementation, sending messages at this point is * discouraged and most likely will not succeed. + * * @throws Exception this method can handle or propagate exceptions; see class-level * Javadoc for details. */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java deleted file mode 100644 index c579e34f51..0000000000 --- a/spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.socket.support; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import org.springframework.util.Assert; -import org.springframework.web.socket.CloseStatus; -import org.springframework.web.socket.WebSocketHandler; -import org.springframework.web.socket.WebSocketMessage; -import org.springframework.web.socket.WebSocketSession; - - -/** - * A {@link WebSocketHandler} that delegates to other {@link WebSocketHandler} instances - * based on the sub-protocol value accepted at the handshake. A default handler can also - * be configured for use by default when a sub-protocol value if the WebSocket session - * does not have a sub-protocol value associated with it. - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class MultiProtocolWebSocketHandler implements WebSocketHandler { - - private WebSocketHandler defaultHandler; - - private Map handlers = new HashMap(); - - - /** - * Configure {@link WebSocketHandler}'s to use by sub-protocol. The values for - * sub-protocols are case insensitive. - */ - public void setProtocolHandlers(Map protocolHandlers) { - this.handlers.clear(); - for (String protocol : protocolHandlers.keySet()) { - this.handlers.put(protocol.toLowerCase(), protocolHandlers.get(protocol)); - } - } - - /** - * Return a read-only copy of the sub-protocol handler map. - */ - public Map getProtocolHandlers() { - return Collections.unmodifiableMap(this.handlers); - } - - /** - * Set the default {@link WebSocketHandler} to use if a sub-protocol was not - * requested. - */ - public void setDefaultProtocolHandler(WebSocketHandler defaultHandler) { - this.defaultHandler = defaultHandler; - } - - /** - * Return the default {@link WebSocketHandler} to be used. - */ - public WebSocketHandler getDefaultProtocolHandler() { - return this.defaultHandler; - } - - - @Override - public void afterConnectionEstablished(WebSocketSession session) throws Exception { - WebSocketHandler handler = getHandlerForSession(session); - handler.afterConnectionEstablished(session); - } - - private WebSocketHandler getHandlerForSession(WebSocketSession session) { - WebSocketHandler handler = null; - String protocol = session.getAcceptedProtocol(); - if (protocol != null) { - handler = this.handlers.get(protocol.toLowerCase()); - Assert.state(handler != null, - "No WebSocketHandler for sub-protocol '" + protocol + "', handlers=" + this.handlers); - } - else { - handler = this.defaultHandler; - Assert.state(handler != null, - "No sub-protocol was requested and no default WebSocketHandler was configured"); - } - return handler; - } - - @Override - public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { - WebSocketHandler handler = getHandlerForSession(session); - handler.handleMessage(session, message); - } - - @Override - public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { - WebSocketHandler handler = getHandlerForSession(session); - handler.handleTransportError(session, exception); - } - - @Override - public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { - WebSocketHandler handler = getHandlerForSession(session); - handler.afterConnectionClosed(session, closeStatus); - } - - @Override - public boolean supportsPartialMessages() { - return false; - } - -} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java deleted file mode 100644 index 533fce706d..0000000000 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIOsNS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.socket.support; - -import java.util.HashMap; -import java.util.Map; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.springframework.web.socket.WebSocketHandler; - -import static org.mockito.Mockito.*; - - -/** - * Test fixture for {@link MultiProtocolWebSocketHandler}. - * - * @author Rossen Stoyanchev - */ -public class MultiProtocolWebSocketHandlerTests { - - private MultiProtocolWebSocketHandler multiProtocolHandler; - - @Mock - WebSocketHandler stompHandler; - - @Mock - WebSocketHandler mqttHandler; - - @Mock - WebSocketHandler defaultHandler; - - - @Before - public void setup() { - - MockitoAnnotations.initMocks(this); - - Map handlers = new HashMap(); - handlers.put("STOMP", this.stompHandler); - handlers.put("MQTT", this.mqttHandler); - - this.multiProtocolHandler = new MultiProtocolWebSocketHandler(); - this.multiProtocolHandler.setProtocolHandlers(handlers); - this.multiProtocolHandler.setDefaultProtocolHandler(this.defaultHandler); - } - - - @Test - public void subProtocol() throws Exception { - - TestWebSocketSession session = new TestWebSocketSession(); - session.setAcceptedProtocol("sToMp"); - - this.multiProtocolHandler.afterConnectionEstablished(session); - - verify(this.stompHandler).afterConnectionEstablished(session); - verifyZeroInteractions(this.mqttHandler); - } - - @Test(expected=IllegalStateException.class) - public void subProtocolNoMatch() throws Exception { - - TestWebSocketSession session = new TestWebSocketSession(); - session.setAcceptedProtocol("wamp"); - - this.multiProtocolHandler.afterConnectionEstablished(session); - } - - @Test - public void noSubProtocol() throws Exception { - - TestWebSocketSession session = new TestWebSocketSession(); - - this.multiProtocolHandler.afterConnectionEstablished(session); - - verify(this.defaultHandler).afterConnectionEstablished(session); - verifyZeroInteractions(this.stompHandler, this.mqttHandler); - } - - @Test(expected=IllegalStateException.class) - public void noSubProtocolNoDefaultHandler() throws Exception { - - TestWebSocketSession session = new TestWebSocketSession(); - - this.multiProtocolHandler.setDefaultProtocolHandler(null); - this.multiProtocolHandler.afterConnectionEstablished(session); - } - -} -- GitLab