提交 9e20a256 编写于 作者: A Andy Wilkinson 提交者: Rossen Stoyanchev

Introduce SubProtocolHandler abstraction

Add SubProtocolHandler to encapsulate the logic for using a
sub-protocol.

A SubProtocolWebSocketHandler is also provided to
delegate to the appropriate SubProtocolHandler based on the
negotiated sub-protocol value at handshake.

StompSubProtocolHandler provides handling for STOMP messages.

Issue: SPR-10786
上级 e4d83bbe
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.handler.websocket;
import java.util.List;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* A contract for handling WebSocket messages as part of a higher level protocol, referred
* to as "sub-protocol" in the WebSocket RFC specification. Handles both
* {@link WebSocketMessage}s from a client as well as {@link Message}s to a client.
* <p>
* Implementations of this interface can be configured on a
* {@link SubProtocolWebSocketHandler} which selects a sub-protocol handler to delegate
* messages to based on the sub-protocol requested by the client through the
* {@code Sec-WebSocket-Protocol} request header.
*
* @author Andy Wilkinson
* @author Rossen Stoyanchev
*
* @since 4.0
*/
public interface SubProtocolHandler {
/**
* Return the list of sub-protocols supported by this handler, never {@code null}.
*/
List<String> getSupportedProtocols();
/**
* Handle the given {@link WebSocketMessage} received from a client.
*
* @param session the client session
* @param message the client message
* @param outputChannel an output channel to send messages to
*/
void handleMessageFromClient(WebSocketSession session, WebSocketMessage message,
MessageChannel outputChannel) throws Exception;
/**
* Handle the given {@link Message} to the client associated with the given WebSocket
* session.
*
* @param session the client session
* @param message the client message
*/
void handleMessageToClient(WebSocketSession session, Message<?> message) throws Exception;
/**
* Resolve the session id from the given message or return {@code null}.
*
* @param the message to resolve the session id from
*/
String resolveSessionId(Message<?> message);
/**
* Invoked after a {@link WebSocketSession} has started.
*
* @param session the client session
* @param outputChannel a channel
*/
void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) throws Exception;
/**
* Invoked after a {@link WebSocketSession} has ended.
*
* @param session the client session
* @param closeStatus the reason why the session was closed
* @param outputChannel a channel
*/
void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
MessageChannel outputChannel) throws Exception;
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.handler.websocket;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* A {@link WebSocketHandler} that delegates messages to a {@link SubProtocolHandler}
* based on the sub-protocol value requested by the client through the
* {@code Sec-WebSocket-Protocol} request header A default handler can also be configured
* to use if the client does not request a specific sub-protocol.
*
* @author Rossen Stoyanchev
* @author Andy Wilkinson
*
* @since 4.0
*/
public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHandler {
private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
private final MessageChannel outputChannel;
private final Map<String, SubProtocolHandler> protocolHandlers =
new TreeMap<String, SubProtocolHandler>(String.CASE_INSENSITIVE_ORDER);
private SubProtocolHandler defaultProtocolHandler;
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
/**
* @param outputChannel
*/
public SubProtocolWebSocketHandler(MessageChannel outputChannel) {
Assert.notNull(outputChannel, "outputChannel is required");
this.outputChannel = outputChannel;
}
/**
* Configure one or more handlers to use depending on the sub-protocol requested by
* the client in the WebSocket handshake request.
*
* @param protocolHandlers the sub-protocol handlers to use
*/
public void setProtocolHandlers(List<SubProtocolHandler> protocolHandlers) {
this.protocolHandlers.clear();
for (SubProtocolHandler handler: protocolHandlers) {
List<String> protocols = handler.getSupportedProtocols();
if (CollectionUtils.isEmpty(protocols)) {
logger.warn("No sub-protocols, ignoring handler " + handler);
continue;
}
for (String protocol: protocols) {
SubProtocolHandler replaced = this.protocolHandlers.put(protocol, handler);
if (replaced != null) {
throw new IllegalStateException("Failed to map handler " + handler
+ " to protocol '" + protocol + "', it is already mapped to handler " + replaced);
}
}
}
if ((this.protocolHandlers.size() == 1) &&(this.defaultProtocolHandler == null)) {
this.defaultProtocolHandler = this.protocolHandlers.values().iterator().next();
}
}
/**
* @return the configured sub-protocol handlers
*/
public Map<String, SubProtocolHandler> getProtocolHandlers() {
return this.protocolHandlers;
}
/**
* Set the {@link SubProtocolHandler} to use when the client did not request a
* sub-protocol.
*
* @param defaultProtocolHandler the default handler
*/
public void setDefaultProtocolHandler(SubProtocolHandler defaultProtocolHandler) {
this.defaultProtocolHandler = defaultProtocolHandler;
if (this.protocolHandlers.isEmpty()) {
setProtocolHandlers(Arrays.asList(defaultProtocolHandler));
}
}
/**
* @return the default sub-protocol handler to use
*/
public SubProtocolHandler getDefaultProtocolHandler() {
return this.defaultProtocolHandler;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.sessions.put(session.getId(), session);
getProtocolHandler(session).afterSessionStarted(session, this.outputChannel);
}
protected final SubProtocolHandler getProtocolHandler(WebSocketSession session) {
SubProtocolHandler handler;
String protocol = session.getAcceptedProtocol();
if (protocol != null) {
handler = this.protocolHandlers.get(protocol);
Assert.state(handler != null,
"No handler for sub-protocol '" + protocol + "', handlers=" + this.protocolHandlers);
}
else {
handler = this.defaultProtocolHandler;
Assert.state(handler != null,
"No sub-protocol was requested and a default sub-protocol handler was not configured");
}
return handler;
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
getProtocolHandler(session).handleMessageFromClient(session, message, this.outputChannel);
}
@Override
public void handleMessage(Message<?> message) throws MessagingException {
String sessionId = resolveSessionId(message);
if (sessionId == null) {
logger.error("sessionId not found in message " + message);
return;
}
WebSocketSession session = this.sessions.get(sessionId);
if (session == null) {
logger.error("Session not found for session with id " + sessionId);
return;
}
try {
getProtocolHandler(session).handleMessageToClient(session, message);
}
catch (Exception e) {
logger.error("Failed to send message to client " + message, e);
}
}
private String resolveSessionId(Message<?> message) {
for (SubProtocolHandler handler : this.protocolHandlers.values()) {
String sessionId = handler.resolveSessionId(message);
if (sessionId != null) {
return sessionId;
}
}
if (this.defaultProtocolHandler != null) {
String sessionId = this.defaultProtocolHandler.resolveSessionId(message);
if (sessionId != null) {
return sessionId;
}
}
return null;
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
this.sessions.remove(session.getId());
getProtocolHandler(session).afterSessionEnded(session, closeStatus, this.outputChannel);
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}
/**
* Generic support for simple messaging protocols (like STOMP).
* Generic support for SImple Messaging Protocols such as STOMP.
*/
package org.springframework.messaging.simp;
......@@ -197,7 +197,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
/**
* Open a "system" session for sending messages from parts of the application
* not assoicated with a client STOMP session.
* not associated with a client STOMP session.
*/
private void openSystemSession() {
......@@ -449,12 +449,21 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
}
}
private boolean forwardInternal(Message<?> message, TcpConnection<String, String> connection) {
private boolean forwardInternal(final Message<?> message, TcpConnection<String, String> connection) {
if (logger.isTraceEnabled()) {
logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId());
}
byte[] bytes = stompMessageConverter.fromMessage(message);
connection.send(new String(bytes, Charset.forName("UTF-8")));
connection.send(new String(bytes, Charset.forName("UTF-8")), new Consumer<Boolean>() {
@Override
public void accept(Boolean success) {
if (!success) {
String sessionId = StompHeaderAccessor.wrap(message).getSessionId();
relaySessions.remove(sessionId);
sendError(sessionId, "Failed to relay message to broker");
}
}
});
// TODO: detect if send fails and send ERROR downstream (except on DISCONNECT)
return true;
......
......@@ -5,7 +5,7 @@
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
......@@ -13,36 +13,38 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.Principal;
import java.util.Map;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.handler.websocket.SubProtocolHandler;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.MutableUserQueueSuffixResolver;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
import reactor.util.Assert;
/**
* A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2 of the
* STOMP specification.
*
* @author Rossen Stoyanchev
* @since 4.0
* @author Andy Wilkinson
*/
public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implements MessageHandler {
public class StompProtocolHandler implements SubProtocolHandler {
/**
* The name of the header set on the CONNECTED frame indicating the name of the user
......@@ -58,25 +60,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
*/
public static final String QUEUE_SUFFIX_HEADER = "queue-suffix";
private static Log logger = LogFactory.getLog(StompWebSocketHandler.class);
private MessageChannel dispatchChannel;
private MutableUserQueueSuffixResolver queueSuffixResolver;
private final Log logger = LogFactory.getLog(StompProtocolHandler.class);
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
/**
* @param dispatchChannel the channel to send client STOMP/WebSocket messages to
*/
public StompWebSocketHandler(MessageChannel dispatchChannel) {
Assert.notNull(dispatchChannel, "dispatchChannel is required");
this.dispatchChannel = dispatchChannel;
}
private MutableUserQueueSuffixResolver queueSuffixResolver;
/**
......@@ -94,23 +82,20 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
return this.queueSuffixResolver;
}
public StompMessageConverter getStompMessageConverter() {
return this.stompMessageConverter;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.sessions.put(session.getId(), session);
public List<String> getSupportedProtocols() {
return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
}
/**
* Handle incoming WebSocket messages from clients.
*/
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) {
public void handleMessageFromClient(WebSocketSession session, WebSocketMessage webSocketMessage,
MessageChannel outputChannel) {
try {
String payload = textMessage.getPayload();
Assert.isInstanceOf(TextMessage.class, webSocketMessage);
String payload = ((TextMessage)webSocketMessage).getPayload();
Message<?> message = this.stompMessageConverter.toMessage(payload);
// TODO: validate size limits
......@@ -124,13 +109,14 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setSessionId(session.getId());
headers.setUser(session.getPrincipal());
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
if (SimpMessageType.CONNECT.equals(headers.getMessageType())) {
handleConnect(session, message);
}
this.dispatchChannel.send(message);
outputChannel.send(message);
}
catch (Throwable t) {
......@@ -147,6 +133,51 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
}
}
/**
* Handle STOMP messages going back out to WebSocket clients.
*/
@Override
public void handleMessageToClient(WebSocketSession session, Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setCommandIfNotSet(StompCommand.MESSAGE);
if (StompCommand.CONNECTED.equals(headers.getCommand())) {
// Ignore for now since we already sent it
return;
}
if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) {
// TODO: failed message delivery mechanism
logger.error("Ignoring message, no subscriptionId header: " + message);
return;
}
if (!(message.getPayload() instanceof byte[])) {
// TODO: failed message delivery mechanism
logger.error("Ignoring message, expected byte[] content: " + message);
return;
}
try {
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
catch (Throwable t) {
sendErrorMessage(session, t);
}
finally {
if (StompCommand.ERROR.equals(headers.getCommand())) {
try {
session.close(CloseStatus.PROTOCOL_ERROR);
}
catch (IOException e) {
}
}
}
}
protected void handleConnect(WebSocketSession session, Message<?> message) throws IOException {
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(message);
......@@ -200,78 +231,26 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
String sessionId = session.getId();
this.sessions.remove(sessionId);
if ((this.queueSuffixResolver != null) && (session.getPrincipal() != null)) {
this.queueSuffixResolver.removeQueueSuffix(session.getPrincipal().getName(), sessionId);
}
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(sessionId);
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
this.dispatchChannel.send(message);
public String resolveSessionId(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
return headers.getSessionId();
}
/**
* Handle STOMP messages going back out to WebSocket clients.
*/
@Override
public void handleMessage(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setCommandIfNotSet(StompCommand.MESSAGE);
if (StompCommand.CONNECTED.equals(headers.getCommand())) {
// Ignore for now since we already sent it
return;
}
String sessionId = headers.getSessionId();
if (sessionId == null) {
// TODO: failed message delivery mechanism
logger.error("Ignoring message, no sessionId header: " + message);
return;
}
WebSocketSession session = this.sessions.get(sessionId);
if (session == null) {
// TODO: failed message delivery mechanism
logger.error("Ignoring message, sessionId not found: " + message);
return;
}
public void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) {
}
if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) {
// TODO: failed message delivery mechanism
logger.error("Ignoring message, no subscriptionId header: " + message);
return;
}
@Override
public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
if (!(message.getPayload() instanceof byte[])) {
// TODO: failed message delivery mechanism
logger.error("Ignoring message, expected byte[] content: " + message);
return;
if ((this.queueSuffixResolver != null) && (session.getPrincipal() != null)) {
this.queueSuffixResolver.removeQueueSuffix(session.getPrincipal().getName(), session.getId());
}
try {
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
catch (Throwable t) {
sendErrorMessage(session, t);
}
finally {
if (StompCommand.ERROR.equals(headers.getCommand())) {
try {
session.close(CloseStatus.PROTOCOL_ERROR);
}
catch (IOException e) {
}
}
}
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(session.getId());
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
outputChannel.send(message);
}
}
/**
* Generic support for simple messaging protocols (like STOMP).
*/
package org.springframework.messaging.simp.stomp;
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIOsNS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.handler.websocket;
import java.util.Arrays;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.MessageChannel;
import org.springframework.web.socket.support.TestWebSocketSession;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link SubProtocolWebSocketHandler}.
*
* @author Rossen Stoyanchev
* @author Andy Wilkinson
*/
public class SubProtocolWebSocketHandlerTests {
private SubProtocolWebSocketHandler webSocketHandler;
private TestWebSocketSession session;
@Mock
SubProtocolHandler stompHandler;
@Mock
SubProtocolHandler mqttHandler;
@Mock
SubProtocolHandler defaultHandler;
@Mock
MessageChannel channel;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
this.webSocketHandler = new SubProtocolWebSocketHandler(this.channel);
when(stompHandler.getSupportedProtocols()).thenReturn(Arrays.asList("STOMP"));
when(mqttHandler.getSupportedProtocols()).thenReturn(Arrays.asList("MQTT"));
this.session = new TestWebSocketSession();
this.session.setId("1");
}
@Test
public void subProtocolMatch() throws Exception {
this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler, mqttHandler));
this.session.setAcceptedProtocol("sToMp");
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.stompHandler).afterSessionStarted(session, this.channel);
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel);
}
@Test
public void subProtocolDefaultHandlerOnly() throws Exception {
this.webSocketHandler.setDefaultProtocolHandler(stompHandler);
this.session.setAcceptedProtocol("sToMp");
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.stompHandler).afterSessionStarted(session, this.channel);
}
@Test(expected=IllegalStateException.class)
public void subProtocolNoMatch() throws Exception {
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler, mqttHandler));
this.session.setAcceptedProtocol("wamp");
this.webSocketHandler.afterConnectionEstablished(session);
}
@Test
public void noSubProtocol() throws Exception {
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.defaultHandler).afterSessionStarted(session, this.channel);
verify(this.stompHandler, times(0)).afterSessionStarted(session, this.channel);
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel);
}
@Test
public void noSubProtocolOneHandler() throws Exception {
this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler));
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.stompHandler).afterSessionStarted(session, this.channel);
}
@Test(expected=IllegalStateException.class)
public void noSubProtocolTwoHandlers() throws Exception {
this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler, mqttHandler));
this.webSocketHandler.afterConnectionEstablished(session);
}
@Test(expected=IllegalStateException.class)
public void noSubProtocolNoDefaultHandler() throws Exception {
this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler, mqttHandler));
this.webSocketHandler.afterConnectionEstablished(session);
}
}
......@@ -20,12 +20,12 @@ package org.springframework.web.socket;
* A handler for WebSocket messages and lifecycle events.
*
* <p>Implementations of this interface are encouraged to handle exceptions locally where
* it makes sense or alternatively let the exception bubble up in which case the exception
* is logged and the session closed with
* {@link CloseStatus#SERVER_ERROR SERVER_ERROR(1011)} by default. The exception handling
* it makes sense or alternatively let the exception bubble up in which case by default
* the exception is logged and the session closed with
* {@link CloseStatus#SERVER_ERROR SERVER_ERROR(1011)}. The exception handling
* strategy is provided by
* {@link org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator
* ExceptionWebSocketHandlerDecorator}, which can be customized or replaced by decorating
* ExceptionWebSocketHandlerDecorator} and it can be customized or replaced by decorating
* the {@link WebSocketHandler} with a different decorator.
*
* @author Rossen Stoyanchev
......@@ -61,6 +61,7 @@ public interface WebSocketHandler {
* transport error has occurred. Although the session may technically still be open,
* depending on the underlying implementation, sending messages at this point is
* discouraged and most likely will not succeed.
*
* @throws Exception this method can handle or propagate exceptions; see class-level
* Javadoc for details.
*/
......
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.support;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* A {@link WebSocketHandler} that delegates to other {@link WebSocketHandler} instances
* based on the sub-protocol value accepted at the handshake. A default handler can also
* be configured for use by default when a sub-protocol value if the WebSocket session
* does not have a sub-protocol value associated with it.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class MultiProtocolWebSocketHandler implements WebSocketHandler {
private WebSocketHandler defaultHandler;
private Map<String, WebSocketHandler> handlers = new HashMap<String, WebSocketHandler>();
/**
* Configure {@link WebSocketHandler}'s to use by sub-protocol. The values for
* sub-protocols are case insensitive.
*/
public void setProtocolHandlers(Map<String, WebSocketHandler> protocolHandlers) {
this.handlers.clear();
for (String protocol : protocolHandlers.keySet()) {
this.handlers.put(protocol.toLowerCase(), protocolHandlers.get(protocol));
}
}
/**
* Return a read-only copy of the sub-protocol handler map.
*/
public Map<String, WebSocketHandler> getProtocolHandlers() {
return Collections.unmodifiableMap(this.handlers);
}
/**
* Set the default {@link WebSocketHandler} to use if a sub-protocol was not
* requested.
*/
public void setDefaultProtocolHandler(WebSocketHandler defaultHandler) {
this.defaultHandler = defaultHandler;
}
/**
* Return the default {@link WebSocketHandler} to be used.
*/
public WebSocketHandler getDefaultProtocolHandler() {
return this.defaultHandler;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
WebSocketHandler handler = getHandlerForSession(session);
handler.afterConnectionEstablished(session);
}
private WebSocketHandler getHandlerForSession(WebSocketSession session) {
WebSocketHandler handler = null;
String protocol = session.getAcceptedProtocol();
if (protocol != null) {
handler = this.handlers.get(protocol.toLowerCase());
Assert.state(handler != null,
"No WebSocketHandler for sub-protocol '" + protocol + "', handlers=" + this.handlers);
}
else {
handler = this.defaultHandler;
Assert.state(handler != null,
"No sub-protocol was requested and no default WebSocketHandler was configured");
}
return handler;
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
WebSocketHandler handler = getHandlerForSession(session);
handler.handleMessage(session, message);
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
WebSocketHandler handler = getHandlerForSession(session);
handler.handleTransportError(session, exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
WebSocketHandler handler = getHandlerForSession(session);
handler.afterConnectionClosed(session, closeStatus);
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIOsNS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.support;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.web.socket.WebSocketHandler;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link MultiProtocolWebSocketHandler}.
*
* @author Rossen Stoyanchev
*/
public class MultiProtocolWebSocketHandlerTests {
private MultiProtocolWebSocketHandler multiProtocolHandler;
@Mock
WebSocketHandler stompHandler;
@Mock
WebSocketHandler mqttHandler;
@Mock
WebSocketHandler defaultHandler;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
Map<String, WebSocketHandler> handlers = new HashMap<String, WebSocketHandler>();
handlers.put("STOMP", this.stompHandler);
handlers.put("MQTT", this.mqttHandler);
this.multiProtocolHandler = new MultiProtocolWebSocketHandler();
this.multiProtocolHandler.setProtocolHandlers(handlers);
this.multiProtocolHandler.setDefaultProtocolHandler(this.defaultHandler);
}
@Test
public void subProtocol() throws Exception {
TestWebSocketSession session = new TestWebSocketSession();
session.setAcceptedProtocol("sToMp");
this.multiProtocolHandler.afterConnectionEstablished(session);
verify(this.stompHandler).afterConnectionEstablished(session);
verifyZeroInteractions(this.mqttHandler);
}
@Test(expected=IllegalStateException.class)
public void subProtocolNoMatch() throws Exception {
TestWebSocketSession session = new TestWebSocketSession();
session.setAcceptedProtocol("wamp");
this.multiProtocolHandler.afterConnectionEstablished(session);
}
@Test
public void noSubProtocol() throws Exception {
TestWebSocketSession session = new TestWebSocketSession();
this.multiProtocolHandler.afterConnectionEstablished(session);
verify(this.defaultHandler).afterConnectionEstablished(session);
verifyZeroInteractions(this.stompHandler, this.mqttHandler);
}
@Test(expected=IllegalStateException.class)
public void noSubProtocolNoDefaultHandler() throws Exception {
TestWebSocketSession session = new TestWebSocketSession();
this.multiProtocolHandler.setDefaultProtocolHandler(null);
this.multiProtocolHandler.afterConnectionEstablished(session);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册