/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.SessionLimitExceededException;
/**
* An implementation of {@link WebSocketHandler} that delegates incoming WebSocket
* messages to a {@link SubProtocolHandler} along with a {@link MessageChannel} to
* which the sub-protocol handler can send messages from WebSocket clients to
* the application.
*
* Also an implementation of {@link MessageHandler} that finds the WebSocket
* session associated with the {@link Message} and passes it, along with the message,
* to the sub-protocol handler to send messages from the application back to the
* client.
*
* @author Rossen Stoyanchev
* @author Andy Wilkinson
* @since 4.0
*/
public class SubProtocolWebSocketHandler
implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle {
private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
private final MessageChannel clientInboundChannel;
private final SubscribableChannel clientOutboundChannel;
private final Map protocolHandlers =
new TreeMap(String.CASE_INSENSITIVE_ORDER);
private SubProtocolHandler defaultProtocolHandler;
private final Map sessions = new ConcurrentHashMap();
private int sendTimeLimit = 10 * 1000;
private int sendBufferSizeLimit = 64 * 1024;
private Object lifecycleMonitor = new Object();
private volatile boolean running = false;
public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) {
Assert.notNull(clientInboundChannel, "ClientInboundChannel must not be null");
Assert.notNull(clientOutboundChannel, "ClientOutboundChannel must not be null");
this.clientInboundChannel = clientInboundChannel;
this.clientOutboundChannel = clientOutboundChannel;
}
/**
* Configure one or more handlers to use depending on the sub-protocol requested by
* the client in the WebSocket handshake request.
* @param protocolHandlers the sub-protocol handlers to use
*/
public void setProtocolHandlers(List protocolHandlers) {
this.protocolHandlers.clear();
for (SubProtocolHandler handler: protocolHandlers) {
addProtocolHandler(handler);
}
}
public List getProtocolHandlers() {
return new ArrayList(protocolHandlers.values());
}
/**
* Register a sub-protocol handler.
*/
public void addProtocolHandler(SubProtocolHandler handler) {
List protocols = handler.getSupportedProtocols();
if (CollectionUtils.isEmpty(protocols)) {
logger.warn("No sub-protocols, ignoring handler " + handler);
return;
}
for (String protocol: protocols) {
SubProtocolHandler replaced = this.protocolHandlers.put(protocol, handler);
if ((replaced != null) && (replaced != handler) ) {
throw new IllegalStateException("Failed to map handler " + handler
+ " to protocol '" + protocol + "', it is already mapped to handler " + replaced);
}
}
}
/**
* Return the sub-protocols keyed by protocol name.
*/
public Map getProtocolHandlerMap() {
return this.protocolHandlers;
}
/**
* Set the {@link SubProtocolHandler} to use when the client did not request a
* sub-protocol.
* @param defaultProtocolHandler the default handler
*/
public void setDefaultProtocolHandler(SubProtocolHandler defaultProtocolHandler) {
this.defaultProtocolHandler = defaultProtocolHandler;
if (this.protocolHandlers.isEmpty()) {
setProtocolHandlers(Arrays.asList(defaultProtocolHandler));
}
}
/**
* @return the default sub-protocol handler to use
*/
public SubProtocolHandler getDefaultProtocolHandler() {
return this.defaultProtocolHandler;
}
/**
* Return all supported protocols.
*/
public List getSubProtocols() {
return new ArrayList(this.protocolHandlers.keySet());
}
public void setSendTimeLimit(int sendTimeLimit) {
this.sendTimeLimit = sendTimeLimit;
}
public int getSendTimeLimit() {
return this.sendTimeLimit;
}
public void setSendBufferSizeLimit(int sendBufferSizeLimit) {
this.sendBufferSizeLimit = sendBufferSizeLimit;
}
public int getSendBufferSizeLimit() {
return sendBufferSizeLimit;
}
@Override
public boolean isAutoStartup() {
return true;
}
@Override
public int getPhase() {
return Integer.MAX_VALUE;
}
@Override
public final boolean isRunning() {
synchronized (this.lifecycleMonitor) {
return this.running;
}
}
@Override
public final void start() {
synchronized (this.lifecycleMonitor) {
this.clientOutboundChannel.subscribe(this);
this.running = true;
}
}
@Override
public final void stop() {
synchronized (this.lifecycleMonitor) {
this.running = false;
this.clientOutboundChannel.unsubscribe(this);
}
}
@Override
public final void stop(Runnable callback) {
synchronized (this.lifecycleMonitor) {
stop();
callback.run();
}
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
this.sessions.put(session.getId(), session);
if (logger.isDebugEnabled()) {
logger.debug("Started WebSocket session=" + session.getId() +
", number of sessions=" + this.sessions.size());
}
findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
}
protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) {
String protocol = null;
try {
protocol = session.getAcceptedProtocol();
}
catch (Exception ex) {
logger.warn("Ignoring protocol in WebSocket session after failure to obtain it: " + ex.toString());
}
SubProtocolHandler handler;
if (!StringUtils.isEmpty(protocol)) {
handler = this.protocolHandlers.get(protocol);
Assert.state(handler != null,
"No handler for sub-protocol '" + protocol + "', handlers=" + this.protocolHandlers);
}
else {
if (this.defaultProtocolHandler != null) {
handler = this.defaultProtocolHandler;
}
else {
Set handlers = new HashSet(this.protocolHandlers.values());
if (handlers.size() == 1) {
handler = handlers.iterator().next();
}
else {
throw new IllegalStateException(
"No sub-protocol was requested and a default sub-protocol handler was not configured");
}
}
}
return handler;
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage> message) throws Exception {
findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel);
}
@Override
public void handleMessage(Message> message) throws MessagingException {
String sessionId = resolveSessionId(message);
if (sessionId == null) {
logger.error("sessionId not found in message " + message);
return;
}
WebSocketSession session = this.sessions.get(sessionId);
if (session == null) {
logger.error("Session not found for session with id " + sessionId);
return;
}
try {
findProtocolHandler(session).handleMessageToClient(session, message);
}
catch (SessionLimitExceededException ex) {
try {
logger.error("Terminating session id '" + sessionId + "'", ex);
// Session may be unresponsive so clear first
clearSession(session, ex.getStatus());
session.close(ex.getStatus());
}
catch (Exception secondException) {
logger.error("Exception terminating session id '" + sessionId + "'", secondException);
}
}
catch (Exception e) {
logger.error("Failed to send message to client " + message, e);
}
}
private String resolveSessionId(Message> message) {
for (SubProtocolHandler handler : this.protocolHandlers.values()) {
String sessionId = handler.resolveSessionId(message);
if (sessionId != null) {
return sessionId;
}
}
if (this.defaultProtocolHandler != null) {
String sessionId = this.defaultProtocolHandler.resolveSessionId(message);
if (sessionId != null) {
return sessionId;
}
}
return null;
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
clearSession(session, closeStatus);
}
private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception {
this.sessions.remove(session.getId());
findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel);
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}