提交 cdd7d7bd 编写于 作者: R Rossen Stoyanchev

Add javax.websocket.Endpoint configuration support

上级 4e67f809
......@@ -515,13 +515,24 @@ project("spring-websocket") {
compile(project(":spring-core"))
compile(project(":spring-context"))
compile(project(":spring-web"))
optional("javax.websocket:javax.websocket-api:1.0-b14")
optional("org.apache.tomcat:tomcat-servlet-api:8.0-SNAPSHOT") // TODO: replace with "javax.servlet:javax.servlet-api"
optional("org.apache.tomcat:tomcat-websocket-api:8.0-SNAPSHOT") // TODO: replace with "javax.websocket:javax.websocket-api"
optional("org.apache.tomcat:tomcat-websocket:8.0-SNAPSHOT") {
exclude group: "org.apache.tomcat", module: "tomcat-websocket-api"
exclude group: "org.apache.tomcat", module: "tomcat-servlet-api"
}
optional("org.eclipse.jetty:jetty-websocket:8.1.10.v20130312")
optional("org.glassfish.tyrus:tyrus-websocket-core:1.0-SNAPSHOT")
}
repositories {
maven { url "http://repo.springsource.org/libs-release" }
maven { url "https://repository.apache.org" } // tomcat-websocket snapshot
maven { url "https://maven.java.net/content/groups/public/" } // javax.websocket-*
maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket snapshots
maven { url "https://maven.java.net/content/repositories/snapshots" } // tyrus/glassfish snapshots
}
}
......
......@@ -17,14 +17,10 @@
package org.springframework.http;
import java.io.Serializable;
import java.net.URI;
import java.nio.charset.Charset;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
......@@ -40,6 +36,7 @@ import java.util.Set;
import java.util.TimeZone;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
......@@ -71,6 +68,8 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
private static final String CACHE_CONTROL = "Cache-Control";
private static final String CONNECTION = "Connection";
private static final String CONTENT_DISPOSITION = "Content-Disposition";
private static final String CONTENT_LENGTH = "Content-Length";
......@@ -91,8 +90,22 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
private static final String LOCATION = "Location";
private static final String ORIGIN = "Origin";
private static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept";
private static final String SEC_WEBSOCKET_EXTENSIONS = "Sec-WebSocket-Extensions";
private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key";
private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
private static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version";
private static final String PRAGMA = "Pragma";
private static final String UPGARDE = "Upgrade";
private static final String[] DATE_FORMATS = new String[] {
"EEE, dd MMM yyyy HH:mm:ss zzz",
......@@ -251,6 +264,30 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
return getFirst(CACHE_CONTROL);
}
/**
* Sets the (new) value of the {@code Connection} header.
* @param connection the value of the header
*/
public void setConnection(String connection) {
set(CONNECTION, connection);
}
/**
* Sets the (new) value of the {@code Connection} header.
* @param connection the value of the header
*/
public void setConnection(List<String> connection) {
set(CONNECTION, toCommaDelimitedString(connection));
}
/**
* Returns the value of the {@code Connection} header.
* @return the value of the header
*/
public List<String> getConnection() {
return getFirstValueAsList(CONNECTION);
}
/**
* Sets the (new) value of the {@code Content-Disposition} header for {@code form-data}.
* @param name the control name
......@@ -393,15 +430,19 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
* @param ifNoneMatchList the new value of the header
*/
public void setIfNoneMatch(List<String> ifNoneMatchList) {
set(IF_NONE_MATCH, toCommaDelimitedString(ifNoneMatchList));
}
private String toCommaDelimitedString(List<String> list) {
StringBuilder builder = new StringBuilder();
for (Iterator<String> iterator = ifNoneMatchList.iterator(); iterator.hasNext();) {
for (Iterator<String> iterator = list.iterator(); iterator.hasNext();) {
String ifNoneMatch = iterator.next();
builder.append(ifNoneMatch);
if (iterator.hasNext()) {
builder.append(", ");
}
}
set(IF_NONE_MATCH, builder.toString());
return builder.toString();
}
/**
......@@ -409,9 +450,13 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
* @return the header value
*/
public List<String> getIfNoneMatch() {
return getFirstValueAsList(IF_NONE_MATCH);
}
private List<String> getFirstValueAsList(String header) {
List<String> result = new ArrayList<String>();
String value = getFirst(IF_NONE_MATCH);
String value = getFirst(header);
if (value != null) {
String[] tokens = value.split(",\\s*");
for (String token : tokens) {
......@@ -457,6 +502,130 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
return (value != null ? URI.create(value) : null);
}
/**
* Sets the (new) value of the {@code Origin} header.
* @param origin the value of the header
*/
public void setOrigin(String origin) {
set(ORIGIN, origin);
}
/**
* Returns the value of the {@code Origin} header.
* @return the value of the header
*/
public String getOrigin() {
return getFirst(ORIGIN);
}
/**
* Sets the (new) value of the {@code Sec-WebSocket-Accept} header.
* @param secWebSocketAccept the value of the header
*/
public void setSecWebSocketAccept(String secWebSocketAccept) {
set(SEC_WEBSOCKET_ACCEPT, secWebSocketAccept);
}
/**
* Returns the value of the {@code Sec-WebSocket-Accept} header.
* @return the value of the header
*/
public String getSecWebSocketAccept() {
return getFirst(SEC_WEBSOCKET_ACCEPT);
}
/**
* Returns the value of the {@code Sec-WebSocket-Extensions} header.
* @return the value of the header
*/
public List<String> getSecWebSocketExtensions() {
List<String> values = get(SEC_WEBSOCKET_EXTENSIONS);
if (CollectionUtils.isEmpty(values)) {
return Collections.emptyList();
}
else if (values.size() == 1) {
return getFirstValueAsList(SEC_WEBSOCKET_EXTENSIONS);
}
else {
return values;
}
}
/**
* Sets the (new) value of the {@code Sec-WebSocket-Extensions} header.
* @param secWebSocketExtensions the value of the header
*/
public void setSecWebSocketExtensions(List<String> secWebSocketExtensions) {
set(SEC_WEBSOCKET_EXTENSIONS, toCommaDelimitedString(secWebSocketExtensions));
}
/**
* Sets the (new) value of the {@code Sec-WebSocket-Key} header.
* @param secWebSocketKey the value of the header
*/
public void setSecWebSocketKey(String secWebSocketKey) {
set(SEC_WEBSOCKET_KEY, secWebSocketKey);
}
/**
* Returns the value of the {@code Sec-WebSocket-Key} header.
* @return the value of the header
*/
public String getSecWebSocketKey() {
return getFirst(SEC_WEBSOCKET_KEY);
}
/**
* Sets the (new) value of the {@code Sec-WebSocket-Protocol} header.
* @param secWebSocketProtocol the value of the header
*/
public void setSecWebSocketProtocol(String secWebSocketProtocol) {
if (secWebSocketProtocol != null) {
set(SEC_WEBSOCKET_PROTOCOL, secWebSocketProtocol);
}
}
/**
* Sets the (new) value of the {@code Sec-WebSocket-Protocol} header.
* @param secWebSocketProtocols the value of the header
*/
public void setSecWebSocketProtocol(List<String> secWebSocketProtocols) {
set(SEC_WEBSOCKET_PROTOCOL, toCommaDelimitedString(secWebSocketProtocols));
}
/**
* Returns the value of the {@code Sec-WebSocket-Key} header.
* @return the value of the header
*/
public List<String> getSecWebSocketProtocol() {
List<String> values = get(SEC_WEBSOCKET_PROTOCOL);
if (CollectionUtils.isEmpty(values)) {
return Collections.emptyList();
}
else if (values.size() == 1) {
return getFirstValueAsList(SEC_WEBSOCKET_PROTOCOL);
}
else {
return values;
}
}
/**
* Sets the (new) value of the {@code Sec-WebSocket-Version} header.
* @param secWebSocketKey the value of the header
*/
public void setSecWebSocketVersion(String secWebSocketVersion) {
set(SEC_WEBSOCKET_VERSION, secWebSocketVersion);
}
/**
* Returns the value of the {@code Sec-WebSocket-Version} header.
* @return the value of the header
*/
public String getSecWebSocketVersion() {
return getFirst(SEC_WEBSOCKET_VERSION);
}
/**
* Sets the (new) value of the {@code Pragma} header.
* @param pragma the value of the header
......@@ -473,6 +642,22 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
return getFirst(PRAGMA);
}
/**
* Sets the (new) value of the {@code Upgrade} header.
* @param upgrade the value of the header
*/
public void setUpgrade(String upgrade) {
set(UPGARDE, upgrade);
}
/**
* Returns the value of the {@code Upgrade} header.
* @return the value of the header
*/
public String getUpgrade() {
return getFirst(UPGARDE);
}
// Utility methods
private long getFirstDate(String headerName) {
......
......@@ -16,17 +16,16 @@
package org.springframework.websocket;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
/**
*
* @author Rossen Stoyanchev
*/
public interface HandshakeRequestHandler {
public interface Session {
void sendText(String text) throws Exception;
boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response);
void close(int code, String reason) throws Exception;
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket;
import java.io.InputStream;
/**
*
* @author Rossen Stoyanchev
*/
public interface WebSocketHandler {
void newSession(Session session) throws Exception;
void handleTextMessage(Session session, String message) throws Exception;
void handleBinaryMessage(Session session, InputStream message) throws Exception;
void handleException(Session session, Throwable exception);
void sessionClosed(Session session, int statusCode, String reason) throws Exception;
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket;
import java.io.InputStream;
/**
*
* @author Rossen Stoyanchev
*/
public class WebSocketHandlerAdapter implements WebSocketHandler {
@Override
public void newSession(Session session) throws Exception {
}
@Override
public void handleTextMessage(Session session, String message) throws Exception {
}
@Override
public void handleBinaryMessage(Session session, InputStream message) throws Exception {
}
@Override
public void handleException(Session session, Throwable exception) {
}
@Override
public void sessionClosed(Session session, int statusCode, String reason) throws Exception {
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket.support;
import javax.servlet.ServletContext;
import javax.websocket.DeploymentException;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerContainerProvider;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.util.Assert;
import org.springframework.web.context.ServletContextAware;
/**
* BeanPostProcessor that registers {@link javax.websocket.server.ServerEndpointConfig}
* beans with a standard Java WebSocket runtime and also configures the underlying
* {@link javax.websocket.server.ServerContainer}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServerEndpointPostProcessor implements ServletContextAware, BeanPostProcessor, InitializingBean {
private static Log logger = LogFactory.getLog(ServerEndpointPostProcessor.class);
private Long maxSessionIdleTimeout;
private Integer maxTextMessageBufferSize;
private Integer maxBinaryMessageBufferSize;
private ServletContext servletContext;
/**
* If this property set it is in turn used to configure
* {@link ServerContainer#setDefaultMaxSessionIdleTimeout(long)}.
*/
public void setMaxSessionIdleTimeout(long maxSessionIdleTimeout) {
this.maxSessionIdleTimeout = maxSessionIdleTimeout;
}
public Long getMaxSessionIdleTimeout() {
return this.maxSessionIdleTimeout;
}
/**
* If this property set it is in turn used to configure
* {@link ServerContainer#setDefaultMaxTextMessageBufferSize(int)}
*/
public void setMaxTextMessageBufferSize(int maxTextMessageBufferSize) {
this.maxTextMessageBufferSize = maxTextMessageBufferSize;
}
public Integer getMaxTextMessageBufferSize() {
return this.maxTextMessageBufferSize;
}
/**
* If this property set it is in turn used to configure
* {@link ServerContainer#setDefaultMaxBinaryMessageBufferSize(int)}.
*/
public void setMaxBinaryMessageBufferSize(int maxBinaryMessageBufferSize) {
this.maxBinaryMessageBufferSize = maxBinaryMessageBufferSize;
}
public Integer getMaxBinaryMessageBufferSize() {
return this.maxBinaryMessageBufferSize;
}
@Override
public void setServletContext(ServletContext servletContext) {
this.servletContext = servletContext;
}
public ServletContext getServletContext() {
return servletContext;
}
@Override
public void afterPropertiesSet() throws Exception {
ServerContainer serverContainer = ServerContainerProvider.getServerContainer();
Assert.notNull(serverContainer, "javax.websocket.server.ServerContainer not available");
if (this.maxSessionIdleTimeout != null) {
serverContainer.setDefaultMaxSessionIdleTimeout(this.maxSessionIdleTimeout);
}
if (this.maxTextMessageBufferSize != null) {
serverContainer.setDefaultMaxTextMessageBufferSize(this.maxTextMessageBufferSize);
}
if (this.maxBinaryMessageBufferSize != null) {
serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize);
}
// TODO: this is necessary but only done on Tomcat
WsServerContainer sc = WsServerContainer.getServerContainer();
sc.setServletContext(this.servletContext);
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof ServerEndpointConfig) {
ServerEndpointConfig sec = (ServerEndpointConfig) bean;
ServerContainer serverContainer = ServerContainerProvider.getServerContainer();
try {
logger.debug("Registering javax.websocket.Endpoint for path " + sec.getPath());
serverContainer.addEndpoint(sec);
}
catch (DeploymentException e) {
throw new IllegalStateException("Failed to deploy Endpoint " + bean, e);
}
}
return bean;
}
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
return bean;
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket.support;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.websocket.Decoder;
import javax.websocket.Encoder;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.websocket.WebSocketHandler;
/**
* An implementation of {@link javax.websocket.server.ServerEndpointConfig} that also
* holds the target {@link javax.websocket.Endpoint} as a reference or a bean name.
* The target can also be {@link org.springframework.websocket.WebSocketHandler}, in
* which case it will be adapted via {@link StandardWebSocketHandlerAdapter}.
*
* <p>
* Beans of this type are detected by {@link ServerEndpointPostProcessor} and
* registered with a Java WebSocket runtime at startup.
*
* @author Rossen Stoyanchev
*/
public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFactoryAware {
private final String path;
private final Object bean;
private List<String> subprotocols = new ArrayList<String>();
private List<Extension> extensions = new ArrayList<Extension>();
private Map<String, Object> userProperties = new HashMap<String, Object>();
private BeanFactory beanFactory;
private final Configurator configurator = new Configurator() {};
public ServerEndpointRegistration(String path, String beanName) {
Assert.hasText(path, "path must not be empty");
Assert.notNull(beanName, "beanName is required");
this.path = path;
this.bean = beanName;
}
public ServerEndpointRegistration(String path, Object bean) {
Assert.hasText(path, "path must not be empty");
Assert.notNull(bean, "bean is required");
this.path = path;
this.bean = bean;
}
@Override
public String getPath() {
return this.path;
}
@SuppressWarnings("unchecked")
@Override
public Class<? extends Endpoint> getEndpointClass() {
Class<?> beanClass = this.bean.getClass();
if (beanClass.equals(String.class)) {
beanClass = this.beanFactory.getType((String) this.bean);
}
beanClass = ClassUtils.getUserClass(beanClass);
if (WebSocketHandler.class.isAssignableFrom(beanClass)) {
return StandardWebSocketHandlerAdapter.class;
}
else {
return (Class<? extends Endpoint>) beanClass;
}
}
protected Endpoint getEndpoint() {
Object bean = this.bean;
if (this.bean instanceof String) {
bean = this.beanFactory.getBean((String) this.bean);
}
if (bean instanceof WebSocketHandler) {
return new StandardWebSocketHandlerAdapter((WebSocketHandler) bean);
}
else {
return (Endpoint) bean;
}
}
@Override
public List<String> getSubprotocols() {
return this.subprotocols;
}
public void setSubprotocols(List<String> subprotocols) {
this.subprotocols = subprotocols;
}
@Override
public List<Extension> getExtensions() {
return this.extensions;
}
public void setExtensions(List<Extension> extensions) {
// TODO: verify against ServerContainer.getInstalledExtensions()
this.extensions = extensions;
}
@Override
public Map<String, Object> getUserProperties() {
return this.userProperties;
}
public void setUserProperties(Map<String, Object> userProperties) {
this.userProperties = userProperties;
}
@Override
public List<Class<? extends Encoder>> getEncoders() {
return Collections.emptyList();
}
@Override
public List<Class<? extends Decoder>> getDecoders() {
return Collections.emptyList();
}
@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
this.beanFactory = beanFactory;
}
@Override
public Configurator getConfigurator() {
return new Configurator() {
@SuppressWarnings("unchecked")
@Override
public <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
return (T) ServerEndpointRegistration.this.getEndpoint();
}
@Override
public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
ServerEndpointRegistration.this.modifyHandshake(request, response);
}
@Override
public boolean checkOrigin(String originHeaderValue) {
return ServerEndpointRegistration.this.checkOrigin(originHeaderValue);
}
@Override
public String getNegotiatedSubprotocol(List<String> supported, List<String> requested) {
return ServerEndpointRegistration.this.selectSubProtocol(requested);
}
@Override
public List<Extension> getNegotiatedExtensions(List<Extension> installed, List<Extension> requested) {
return ServerEndpointRegistration.this.selectExtensions(requested);
}
};
}
protected void modifyHandshake(HandshakeRequest request, HandshakeResponse response) {
this.configurator.modifyHandshake(this, request, response);
}
protected boolean checkOrigin(String originHeaderValue) {
return this.configurator.checkOrigin(originHeaderValue);
}
protected String selectSubProtocol(List<String> requested) {
return this.configurator.getNegotiatedSubprotocol(getSubprotocols(), requested);
}
protected List<Extension> selectExtensions(List<Extension> requested) {
return this.configurator.getNegotiatedExtensions(getExtensions(), requested);
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket.support;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.websocket.Session;
/**
*
* @author Rossen Stoyanchev
*/
public class StandardSessionAdapter implements Session {
private static Log logger = LogFactory.getLog(StandardSessionAdapter.class);
private javax.websocket.Session sourceSession;
public StandardSessionAdapter(javax.websocket.Session sourceSession) {
this.sourceSession = sourceSession;
}
@Override
public void sendText(String text) throws Exception {
logger.trace("Sending text message: " + text);
this.sourceSession.getBasicRemote().sendText(text);
}
@Override
public void close(int code, String reason) throws Exception {
this.sourceSession = null;
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket.support;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.util.Assert;
import org.springframework.websocket.Session;
import org.springframework.websocket.WebSocketHandler;
/**
*
* @author Rossen Stoyanchev
*/
public class StandardWebSocketHandlerAdapter extends Endpoint {
private static Log logger = LogFactory.getLog(StandardWebSocketHandlerAdapter.class);
private final WebSocketHandler webSocketHandler;
private final Map<String, Session> sessionMap = new ConcurrentHashMap<String, Session>();
public StandardWebSocketHandlerAdapter(WebSocketHandler webSocketHandler) {
this.webSocketHandler = webSocketHandler;
}
@Override
public void onOpen(javax.websocket.Session sourceSession, EndpointConfig config) {
logger.debug("New WebSocket session: " + sourceSession);
try {
Session session = new StandardSessionAdapter(sourceSession);
this.sessionMap.put(sourceSession.getId(), session);
sourceSession.addMessageHandler(new StandardMessageHandler(sourceSession.getId()));
this.webSocketHandler.newSession(session);
}
catch (Throwable ex) {
// TODO
logger.error("Error while processing new session", ex);
}
}
@Override
public void onClose(javax.websocket.Session sourceSession, CloseReason closeReason) {
String id = sourceSession.getId();
if (logger.isDebugEnabled()) {
logger.debug("Closing session: " + sourceSession + ", " + closeReason);
}
try {
Session session = getSession(id);
this.sessionMap.remove(id);
int code = closeReason.getCloseCode().getCode();
String reason = closeReason.getReasonPhrase();
session.close(code, reason);
this.webSocketHandler.sessionClosed(session, code, reason);
}
catch (Throwable ex) {
// TODO
logger.error("Error while processing session closing", ex);
}
}
@Override
public void onError(javax.websocket.Session sourceSession, Throwable exception) {
logger.error("Error for WebSocket session: " + sourceSession.getId(), exception);
try {
Session session = getSession(sourceSession.getId());
this.webSocketHandler.handleException(session, exception);
}
catch (Throwable ex) {
// TODO
logger.error("Failed to handle error", ex);
}
}
private Session getSession(String sourceSessionId) {
Session session = this.sessionMap.get(sourceSessionId);
Assert.notNull(session, "No session");
return session;
}
private class StandardMessageHandler implements MessageHandler.Whole<String> {
private final String sourceSessionId;
public StandardMessageHandler(String sourceSessionId) {
this.sourceSessionId = sourceSessionId;
}
@Override
public void onMessage(String message) {
if (logger.isTraceEnabled()) {
logger.trace("Message for session [" + this.sourceSessionId + "]: " + message);
}
try {
Session session = getSession(this.sourceSessionId);
StandardWebSocketHandlerAdapter.this.webSocketHandler.handleTextMessage(session, message);
}
catch (Throwable ex) {
// TODO
logger.error("Error while processing message", ex);
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册