提交 e21bbdd9 编写于 作者: R Rossen Stoyanchev

Polish WebSocket/STOMP Java config

Ensure configuration provided for WebSocketHandler's (eg interceptors,
or HandshakeHandler) are passed on to the SockJsService if congiured.

Better separate Servlet-specific parts of the configuration to make it
more obvious where non-Servlet alternatives could fit in.

Add more tests.

Improve WebSocket integration tests.
上级 5d697005
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.config;
import java.util.Set;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
/**
* A helper class for configuring STOMP protocol handling over WebSocket
* with optional SockJS fallback options.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractStompEndpointRegistration<M> implements StompEndpointRegistration {
private final String[] paths;
private final SubProtocolWebSocketHandler wsHandler;
private HandshakeHandler handshakeHandler;
private StompSockJsServiceRegistration sockJsServiceRegistration;
private final TaskScheduler defaultSockJsTaskScheduler;
public AbstractStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler webSocketHandler,
TaskScheduler defaultSockJsTaskScheduler) {
Assert.notEmpty(paths, "No paths specified");
this.paths = paths;
this.wsHandler = webSocketHandler;
this.defaultSockJsTaskScheduler = defaultSockJsTaskScheduler;
}
protected SubProtocolWebSocketHandler getWsHandler() {
return this.wsHandler;
}
@Override
public StompEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
this.handshakeHandler = handshakeHandler;
return this;
}
@Override
public SockJsServiceRegistration withSockJS() {
this.sockJsServiceRegistration = new StompSockJsServiceRegistration(this.defaultSockJsTaskScheduler);
if (this.handshakeHandler != null) {
WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
this.sockJsServiceRegistration.setTransportHandlerOverrides(transportHandler);
}
return this.sockJsServiceRegistration;
}
protected M getMappings() {
M mappings = createMappings();
if (this.sockJsServiceRegistration != null) {
SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService();
for (String path : this.paths) {
String pathPattern = path.endsWith("/") ? path + "**" : path + "/**";
addSockJsServiceMapping(mappings, sockJsService, this.wsHandler, pathPattern);
}
}
else {
HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
for (String path : this.paths) {
addWebSocketHandlerMapping(mappings, this.wsHandler, handshakeHandler, path);
}
}
return mappings;
}
protected abstract M createMappings();
private HandshakeHandler getOrCreateHandshakeHandler() {
HandshakeHandler handler = (this.handshakeHandler != null)
? this.handshakeHandler : new DefaultHandshakeHandler();
if (handler instanceof DefaultHandshakeHandler) {
DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handler;
if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) {
Set<String> protocols = this.wsHandler.getSupportedProtocols();
defaultHandshakeHandler.setSupportedProtocols(protocols.toArray(new String[protocols.size()]));
}
}
return handler;
}
protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService,
SubProtocolWebSocketHandler wsHandler, String pathPattern);
protected abstract void addWebSocketHandlerMapping(M mappings,
SubProtocolWebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path);
private class StompSockJsServiceRegistration extends SockJsServiceRegistration {
public StompSockJsServiceRegistration(TaskScheduler defaultTaskScheduler) {
super(defaultTaskScheduler);
}
protected SockJsService getSockJsService() {
return super.getSockJsService(paths);
}
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.config;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
/**
* A helper class for configuring STOMP protocol handling over WebSocket
* with optional SockJS fallback options.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServletStompEndpointRegistration
extends AbstractStompEndpointRegistration<MultiValueMap<HttpRequestHandler, String>> {
public ServletStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler wsHandler,
TaskScheduler sockJsTaskScheduler) {
super(paths, wsHandler, sockJsTaskScheduler);
}
@Override
protected MultiValueMap<HttpRequestHandler, String> createMappings() {
return new LinkedMultiValueMap<HttpRequestHandler, String>();
}
@Override
protected void addSockJsServiceMapping(MultiValueMap<HttpRequestHandler, String> mappings,
SockJsService sockJsService, SubProtocolWebSocketHandler wsHandler, String pathPattern) {
SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, wsHandler);
mappings.add(httpHandler, pathPattern);
}
@Override
protected void addWebSocketHandlerMapping(MultiValueMap<HttpRequestHandler, String> mappings,
SubProtocolWebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path) {
WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(wsHandler, handshakeHandler);
mappings.add(handler, path);
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.messaging.simp.handler.MutableUserQueueSuffixResolver;
import org.springframework.messaging.simp.stomp.StompProtocolHandler;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
/**
* A helper class for configuring STOMP protocol handling over WebSocket.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServletStompEndpointRegistry implements StompEndpointRegistry {
private final SubProtocolWebSocketHandler wsHandler;
private final StompProtocolHandler stompHandler;
private final List<ServletStompEndpointRegistration> registrations = new ArrayList<ServletStompEndpointRegistration>();
private final TaskScheduler sockJsScheduler;
public ServletStompEndpointRegistry(SubProtocolWebSocketHandler webSocketHandler,
MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) {
Assert.notNull(webSocketHandler);
Assert.notNull(userQueueSuffixResolver);
this.wsHandler = webSocketHandler;
this.stompHandler = new StompProtocolHandler();
this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver);
this.sockJsScheduler = defaultSockJsTaskScheduler;
}
@Override
public StompEndpointRegistration addEndpoint(String... paths) {
this.wsHandler.addProtocolHandler(this.stompHandler);
ServletStompEndpointRegistration r = new ServletStompEndpointRegistration(paths, this.wsHandler, this.sockJsScheduler);
this.registrations.add(r);
return r;
}
/**
* Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations.
*/
protected AbstractHandlerMapping getHandlerMapping() {
Map<String, Object> urlMap = new LinkedHashMap<String, Object>();
for (ServletStompEndpointRegistration registration : this.registrations) {
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
for (HttpRequestHandler httpHandler : mappings.keySet()) {
for (String pattern : mappings.get(httpHandler)) {
urlMap.put(pattern, httpHandler);
}
}
}
SimpleUrlHandlerMapping hm = new SimpleUrlHandlerMapping();
hm.setUrlMap(urlMap);
return hm;
}
}
......@@ -16,110 +16,26 @@
package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ObjectUtils;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
/**
* A helper class for configuring STOMP protocol handling over WebSocket
* with optional SockJS fallback options.
* Provides methods for configuring a STOMP protocol handler including enabling SockJS
* fallback options.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompEndpointRegistration {
private final List<String> paths;
private final SubProtocolWebSocketHandler wsHandler;
private HandshakeHandler handshakeHandler;
private StompSockJsServiceRegistration sockJsServiceRegistration;
private final TaskScheduler defaultSockJsTaskScheduler;
public StompEndpointRegistration(Collection<String> paths, SubProtocolWebSocketHandler webSocketHandler,
TaskScheduler defaultSockJsTaskScheduler) {
this.paths = new ArrayList<String>(paths);
this.wsHandler = webSocketHandler;
this.defaultSockJsTaskScheduler = defaultSockJsTaskScheduler;
}
public StompEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
this.handshakeHandler = handshakeHandler;
return this;
}
public SockJsServiceRegistration withSockJS() {
this.sockJsServiceRegistration = new StompSockJsServiceRegistration(this.defaultSockJsTaskScheduler);
return this.sockJsServiceRegistration;
}
protected MultiValueMap<HttpRequestHandler, String> getMappings() {
MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<HttpRequestHandler, String>();
if (this.sockJsServiceRegistration == null) {
HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
for (String path : this.paths) {
WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(this.wsHandler, handshakeHandler);
mappings.add(handler, path);
}
}
else {
SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService();
for (String path : this.paths) {
SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, this.wsHandler);
mappings.add(httpHandler, path.endsWith("/") ? path + "**" : path + "/**");
}
}
return mappings;
}
private HandshakeHandler getOrCreateHandshakeHandler() {
HandshakeHandler handler = (this.handshakeHandler != null)
? this.handshakeHandler : new DefaultHandshakeHandler();
if (handler instanceof DefaultHandshakeHandler) {
DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handler;
if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) {
Set<String> protocols = this.wsHandler.getSupportedProtocols();
defaultHandshakeHandler.setSupportedProtocols(protocols.toArray(new String[protocols.size()]));
}
}
return handler;
}
private class StompSockJsServiceRegistration extends SockJsServiceRegistration {
public interface StompEndpointRegistration {
public StompSockJsServiceRegistration(TaskScheduler defaultTaskScheduler) {
super(defaultTaskScheduler);
}
/**
* Configure the HandshakeHandler to use.
*/
StompEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler);
protected SockJsService getSockJsService() {
return super.getSockJsService(paths.toArray(new String[paths.size()]));
}
}
/**
* Enable SockJS fallback options.
*/
SockJsServiceRegistration withSockJS();
}
}
\ No newline at end of file
......@@ -16,90 +16,18 @@
package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.messaging.simp.handler.MutableUserQueueSuffixResolver;
import org.springframework.messaging.simp.stomp.StompProtocolHandler;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
/**
* A helper class for configuring STOMP protocol handling over WebSocket.
* Provides methods for configuring STOMP protocol handlers at specific URL paths.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompEndpointRegistry {
private final SubProtocolWebSocketHandler wsHandler;
private final StompProtocolHandler stompHandler;
private final List<StompEndpointRegistration> registrations = new ArrayList<StompEndpointRegistration>();
private int order = 1;
private final TaskScheduler defaultSockJsTaskScheduler;
public interface StompEndpointRegistry {
public StompEndpointRegistry(SubProtocolWebSocketHandler webSocketHandler,
MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) {
Assert.notNull(webSocketHandler);
Assert.notNull(userQueueSuffixResolver);
this.wsHandler = webSocketHandler;
this.stompHandler = new StompProtocolHandler();
this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver);
this.defaultSockJsTaskScheduler = defaultSockJsTaskScheduler;
}
public StompEndpointRegistration addEndpoint(String... paths) {
this.wsHandler.addProtocolHandler(this.stompHandler);
StompEndpointRegistration r = new StompEndpointRegistration(
Arrays.asList(paths), this.wsHandler, this.defaultSockJsTaskScheduler);
this.registrations.add(r);
return r;
}
/**
* Specify the order to use for the STOMP endpoint {@link HandlerMapping} relative to
* other handler mappings configured in the Spring MVC configuration. The default
* value is 1.
*/
public void setOrder(int order) {
this.order = order;
}
/**
* Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations.
* Expose a STOMP endpoint at the specified URL path (or paths_.
*/
protected AbstractHandlerMapping getHandlerMapping() {
Map<String, Object> urlMap = new LinkedHashMap<String, Object>();
for (StompEndpointRegistration registration : this.registrations) {
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
for (HttpRequestHandler httpHandler : mappings.keySet()) {
for (String pattern : mappings.get(httpHandler)) {
urlMap.put(pattern, httpHandler);
}
}
}
SimpleUrlHandlerMapping hm = new SimpleUrlHandlerMapping();
hm.setOrder(this.order);
hm.setUrlMap(urlMap);
return hm;
}
StompEndpointRegistration addEndpoint(String... paths);
}
}
\ No newline at end of file
......@@ -33,6 +33,7 @@ import org.springframework.messaging.support.converter.MessageConverter;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
......@@ -55,10 +56,12 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
@Bean
public HandlerMapping brokerWebSocketHandlerMapping() {
StompEndpointRegistry registry = new StompEndpointRegistry(
ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry(
subProtocolWebSocketHandler(), userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler());
registerStompEndpoints(registry);
return registry.getHandlerMapping();
AbstractHandlerMapping hm = registry.getHandlerMapping();
hm.setOrder(1);
return hm;
}
@Bean
......@@ -75,7 +78,20 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
/**
* The default TaskScheduler to use if none is configured via
* {@link SockJsServiceRegistration#setTaskScheduler()}
* {@link SockJsServiceRegistration#setTaskScheduler()}, i.e.
* <pre class="code">
* &#064;Configuration
* &#064;EnableWebSocketMessageBroker
* public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
*
* public void registerStompEndpoints(StompEndpointRegistry registry) {
* registry.addEndpoint("/stomp").withSockJS().setTaskScheduler(myScheduler());
* }
*
* // ...
*
* }
* </pre>
*/
@Bean
public ThreadPoolTaskScheduler brokerDefaultSockJsTaskScheduler() {
......
......@@ -25,6 +25,7 @@ import org.junit.runners.Parameterized.Parameter;
import org.springframework.context.Lifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
......@@ -52,14 +53,26 @@ public abstract class AbstractWebSocketIntegrationTests {
@Parameter(1)
public WebSocketClient webSocketClient;
protected AnnotationConfigWebApplicationContext wac;
@Before
public void setup() throws Exception {
this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(getAnnotatedConfigClasses());
this.wac.register(upgradeStrategyConfigTypes.get(this.server.getClass()));
if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).start();
}
this.server.init(this.wac);
this.server.start();
}
protected abstract Class<?>[] getAnnotatedConfigClasses();
@After
public void teardown() throws Exception {
try {
......@@ -76,10 +89,6 @@ public abstract class AbstractWebSocketIntegrationTests {
return "ws://localhost:" + this.server.getPort();
}
protected Class<?> getUpgradeStrategyConfigClass() {
return upgradeStrategyConfigTypes.get(this.server.getClass());
}
static abstract class AbstractRequestUpgradeStrategyConfig {
......
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.junit.Assert.*;
/**
* Test fixture for {@link AbstractStompEndpointRegistration}.
*
* @author Rossen Stoyanchev
*/
public class AbstractStompEndpointRegistrationTests {
private SubProtocolWebSocketHandler wsHandler;
private TaskScheduler scheduler;
@Before
public void setup() {
this.wsHandler = new SubProtocolWebSocketHandler(new ExecutorSubscribableChannel());
this.scheduler = Mockito.mock(TaskScheduler.class);
}
@Test
public void minimal() {
TestStompEndpointRegistration registration =
new TestStompEndpointRegistration(new String[] {"/foo"}, this.wsHandler, this.scheduler);
List<Mapping> mappings = registration.getMappings();
assertEquals(1, mappings.size());
Mapping m1 = mappings.get(0);
assertSame(this.wsHandler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
}
@Test
public void handshakeHandler() {
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
TestStompEndpointRegistration registration =
new TestStompEndpointRegistration(new String[] {"/foo"}, this.wsHandler, this.scheduler);
registration.setHandshakeHandler(handshakeHandler);
List<Mapping> mappings = registration.getMappings();
assertEquals(1, mappings.size());
Mapping m1 = mappings.get(0);
assertSame(this.wsHandler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
assertSame(handshakeHandler, m1.handshakeHandler);
}
@Test
public void handshakeHandlerPassedToSockJsService() {
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
TestStompEndpointRegistration registration =
new TestStompEndpointRegistration(new String[] {"/foo"}, this.wsHandler, this.scheduler);
registration.setHandshakeHandler(handshakeHandler);
registration.withSockJS();
List<Mapping> mappings = registration.getMappings();
assertEquals(1, mappings.size());
Mapping m1 = mappings.get(0);
assertSame(this.wsHandler, m1.webSocketHandler);
assertEquals("/foo/**", m1.path);
assertNotNull(m1.sockJsService);
WebSocketTransportHandler transportHandler =
(WebSocketTransportHandler) m1.sockJsService.getTransportHandlers().get(TransportType.WEBSOCKET);
assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
}
private static class TestStompEndpointRegistration extends AbstractStompEndpointRegistration<List<Mapping>> {
public TestStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler wsh, TaskScheduler scheduler) {
super(paths, wsh, scheduler);
}
@Override
protected List<Mapping> createMappings() {
return new ArrayList<>();
}
@Override
protected void addSockJsServiceMapping(List<Mapping> mappings, SockJsService sockJsService,
SubProtocolWebSocketHandler wsHandler, String pathPattern) {
mappings.add(new Mapping(wsHandler, pathPattern, sockJsService));
}
@Override
protected void addWebSocketHandlerMapping(List<Mapping> mappings, SubProtocolWebSocketHandler wsHandler,
HandshakeHandler handshakeHandler, String path) {
mappings.add(new Mapping(wsHandler, path, handshakeHandler));
}
}
private static class Mapping {
private final SubProtocolWebSocketHandler webSocketHandler;
private final String path;
private final HandshakeHandler handshakeHandler;
private final DefaultSockJsService sockJsService;
public Mapping(SubProtocolWebSocketHandler handler, String path, SockJsService sockJsService) {
this.webSocketHandler = handler;
this.path = path;
this.handshakeHandler = null;
this.sockJsService = (DefaultSockJsService) sockJsService;
}
public Mapping(SubProtocolWebSocketHandler h, String path, HandshakeHandler hh) {
this.webSocketHandler = h;
this.path = path;
this.handshakeHandler = hh;
this.sockJsService = null;
}
}
}
......@@ -35,7 +35,6 @@ import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompTextMessageBuilder;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.stereotype.Controller;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
......@@ -63,16 +62,14 @@ public class WebSocketMessageBrokerConfigurationTests extends AbstractWebSocketI
};
@Override
protected Class<?>[] getAnnotatedConfigClasses() {
return new Class<?>[] { TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class };
}
@Test
public void sendMessage() throws Exception {
AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext();
cxt.register(TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class);
cxt.register(getUpgradeStrategyConfigClass());
this.server.init(cxt);
this.server.start();
final TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND)
.headers("destination:/app/foo").build();
......@@ -83,7 +80,7 @@ public class WebSocketMessageBrokerConfigurationTests extends AbstractWebSocketI
}
};
TestController testController = cxt.getBean(TestController.class);
TestController testController = this.wac.getBean(TestController.class);
this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws");
assertTrue(testController.latch.await(2, TimeUnit.SECONDS));
......
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
/**
* Base class for {@link WebSocketHandlerRegistration}s that gathers all the configuration
* options but allows sub-classes to put together the actual HTTP request mappings.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSocketHandlerRegistration {
private MultiValueMap<WebSocketHandler, String> handlerMap = new LinkedMultiValueMap<WebSocketHandler, String>();
private HandshakeInterceptor[] interceptors;
private HandshakeHandler handshakeHandler;
private SockJsServiceRegistration sockJsServiceRegistration;
private final TaskScheduler sockJsTaskScheduler;
public AbstractWebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) {
this.sockJsTaskScheduler = defaultTaskScheduler;
}
@Override
public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) {
Assert.notNull(handler);
Assert.notEmpty(paths);
this.handlerMap.put(handler, Arrays.asList(paths));
return this;
}
@Override
public WebSocketHandlerRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
this.handshakeHandler = handshakeHandler;
return this;
}
public HandshakeHandler getHandshakeHandler() {
return handshakeHandler;
}
@Override
public WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors) {
this.interceptors = interceptors;
return this;
}
protected HandshakeInterceptor[] getInterceptors() {
return this.interceptors;
}
/**
* @param interceptors the interceptors to set
*/
public void setInterceptors(HandshakeInterceptor[] interceptors) {
this.interceptors = interceptors;
}
@Override
public SockJsServiceRegistration withSockJS() {
this.sockJsServiceRegistration = new SockJsServiceRegistration(this.sockJsTaskScheduler);
this.sockJsServiceRegistration.setInterceptors(this.interceptors);
if (this.handshakeHandler != null) {
WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
this.sockJsServiceRegistration.setTransportHandlerOverrides(transportHandler);
}
return this.sockJsServiceRegistration;
}
final M getMappings() {
M mappings = createMappings();
if (this.sockJsServiceRegistration != null) {
SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService(getAllPrefixes());
for (WebSocketHandler wsHandler : this.handlerMap.keySet()) {
for (String path : this.handlerMap.get(wsHandler)) {
String pathPattern = path.endsWith("/") ? path + "**" : path + "/**";
addSockJsServiceMapping(mappings, sockJsService, wsHandler, pathPattern);
}
}
}
else {
HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
for (WebSocketHandler wsHandler : this.handlerMap.keySet()) {
for (String path : this.handlerMap.get(wsHandler)) {
addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, this.interceptors, path);
}
}
}
return mappings;
}
private final String[] getAllPrefixes() {
List<String> all = new ArrayList<String>();
for (List<String> prefixes: this.handlerMap.values()) {
all.addAll(prefixes);
}
return all.toArray(new String[all.size()]);
}
private HandshakeHandler getOrCreateHandshakeHandler() {
return (this.handshakeHandler != null) ? this.handshakeHandler : new DefaultHandshakeHandler();
}
protected abstract M createMappings();
protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService,
WebSocketHandler handler, String pathPattern);
protected abstract void addWebSocketHandlerMapping(M mappings, WebSocketHandler wsHandler,
HandshakeHandler handshakeHandler, HandshakeInterceptor[] interceptors, String path);
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.config;
import java.util.Arrays;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ObjectUtils;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
/**
* A helper class for configuring {@link WebSocketHandler} request handling
* including SockJS fallback options.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServletWebSocketHandlerRegistration
extends AbstractWebSocketHandlerRegistration<MultiValueMap<HttpRequestHandler, String>> {
public ServletWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) {
super(sockJsTaskScheduler);
}
@Override
protected MultiValueMap<HttpRequestHandler, String> createMappings() {
return new LinkedMultiValueMap<HttpRequestHandler, String>();
}
@Override
protected void addSockJsServiceMapping(MultiValueMap<HttpRequestHandler, String> mappings,
SockJsService sockJsService, WebSocketHandler handler, String pathPattern) {
SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, handler);
mappings.add(httpHandler, pathPattern);
}
@Override
protected void addWebSocketHandlerMapping(MultiValueMap<HttpRequestHandler, String> mappings,
WebSocketHandler wsHandler, HandshakeHandler handshakeHandler,
HandshakeInterceptor[] interceptors, String path) {
WebSocketHttpRequestHandler httpHandler = new WebSocketHttpRequestHandler(wsHandler, handshakeHandler);
if (!ObjectUtils.isEmpty(interceptors)) {
httpHandler.setHandshakeInterceptors(Arrays.asList(interceptors));
}
mappings.add(httpHandler, path);
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.config;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
/**
* A {@link WebSocketHandlerRegistry} that maps {@link WebSocketHandler}s to URLs for use
* in a Servlet container.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry {
private final List<ServletWebSocketHandlerRegistration> registrations =
new ArrayList<ServletWebSocketHandlerRegistration>();
private TaskScheduler sockJsTaskScheduler;
public ServletWebSocketHandlerRegistry(ThreadPoolTaskScheduler sockJsTaskScheduler) {
this.sockJsTaskScheduler = sockJsTaskScheduler;
}
@Override
public WebSocketHandlerRegistration addHandler(WebSocketHandler webSocketHandler, String... paths) {
ServletWebSocketHandlerRegistration r = new ServletWebSocketHandlerRegistration(this.sockJsTaskScheduler);
r.addHandler(webSocketHandler, paths);
this.registrations.add(r);
return r;
}
/**
* Returns a {@link HandlerMapping} with mapped {@link HttpRequestHandler}s.
*/
AbstractHandlerMapping getHandlerMapping() {
Map<String, Object> urlMap = new LinkedHashMap<String, Object>();
for (ServletWebSocketHandlerRegistration registration : this.registrations) {
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
for (HttpRequestHandler httpHandler : mappings.keySet()) {
for (String pattern : mappings.get(httpHandler)) {
urlMap.put(pattern, httpHandler);
}
}
}
SimpleUrlHandlerMapping hm = new SimpleUrlHandlerMapping();
hm.setUrlMap(urlMap);
return hm;
}
}
......@@ -19,6 +19,7 @@ package org.springframework.web.socket.server.config;
import org.springframework.context.annotation.Bean;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
/**
......@@ -32,9 +33,11 @@ public class WebSocketConfigurationSupport {
@Bean
public HandlerMapping webSocketHandlerMapping() {
WebSocketHandlerRegistry registry = new WebSocketHandlerRegistry(defaultSockJsTaskScheduler());
ServletWebSocketHandlerRegistry registry = new ServletWebSocketHandlerRegistry(defaultSockJsTaskScheduler());
registerWebSocketHandlers(registry);
return registry.getHandlerMapping();
AbstractHandlerMapping hm = registry.getHandlerMapping();
hm.setOrder(1);
return hm;
}
protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
......@@ -42,7 +45,20 @@ public class WebSocketConfigurationSupport {
/**
* The default TaskScheduler to use if none is configured via
* {@link SockJsServiceRegistration#setTaskScheduler()}
* {@link SockJsServiceRegistration#setTaskScheduler()}, i.e.
* <pre class="code">
* &#064;Configuration
* &#064;EnableWebSocket
* public class WebSocketConfig implements WebSocketConfigurer {
*
* public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
* registry.addHandler(myWsHandler(), "/echo").withSockJS().setTaskScheduler(myScheduler());
* }
*
* // ...
*
* }
* </pre>
*/
@Bean
public ThreadPoolTaskScheduler defaultSockJsTaskScheduler() {
......
......@@ -16,110 +16,37 @@
package org.springframework.web.socket.server.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
/**
* A helper class for configuring {@link WebSocketHandler} request handling
* including SockJS fallback options.
* Provides methods for configuring a WebSocket handler.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class WebSocketHandlerRegistration {
private MultiValueMap<WebSocketHandler, String> handlerMap =
new LinkedMultiValueMap<WebSocketHandler, String>();
private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
private HandshakeHandler handshakeHandler;
private SockJsServiceRegistration sockJsServiceRegistration;
private final TaskScheduler defaultTaskScheduler;
public WebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) {
this.defaultTaskScheduler = defaultTaskScheduler;
}
public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) {
Assert.notNull(handler);
Assert.notEmpty(paths);
this.handlerMap.put(handler, Arrays.asList(paths));
return this;
}
public WebSocketHandlerRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
this.handshakeHandler = handshakeHandler;
return this;
}
public HandshakeHandler getHandshakeHandler() {
return handshakeHandler;
}
public void addInterceptors(HandshakeInterceptor... interceptors) {
this.interceptors.addAll(Arrays.asList(interceptors));
}
public SockJsServiceRegistration withSockJS() {
this.sockJsServiceRegistration = new SockJsServiceRegistration(this.defaultTaskScheduler);
this.sockJsServiceRegistration.setInterceptors(
this.interceptors.toArray(new HandshakeInterceptor[this.interceptors.size()]));
return this.sockJsServiceRegistration;
}
MultiValueMap<HttpRequestHandler, String> getMappings() {
MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<HttpRequestHandler, String>();
if (this.sockJsServiceRegistration == null) {
HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
for (WebSocketHandler handler : this.handlerMap.keySet()) {
for (String path : this.handlerMap.get(handler)) {
WebSocketHttpRequestHandler httpHandler = new WebSocketHttpRequestHandler(handler, handshakeHandler);
httpHandler.setHandshakeInterceptors(this.interceptors);
mappings.add(httpHandler, path);
}
}
}
else {
SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService(getAllPrefixes());
for (WebSocketHandler handler : this.handlerMap.keySet()) {
for (String path : this.handlerMap.get(handler)) {
SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, handler);
mappings.add(httpHandler, path.endsWith("/") ? path + "**" : path + "/**");
}
}
}
return mappings;
}
private HandshakeHandler getOrCreateHandshakeHandler() {
return (this.handshakeHandler != null) ? this.handshakeHandler : new DefaultHandshakeHandler();
}
private final String[] getAllPrefixes() {
List<String> all = new ArrayList<String>();
for (List<String> prefixes: this.handlerMap.values()) {
all.addAll(prefixes);
}
return all.toArray(new String[all.size()]);
}
}
public interface WebSocketHandlerRegistration {
/**
* Add more handlers that will share the same configuration (interceptors, SockJS
* config, etc)
*/
WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths);
/**
* Configure interceptors for the handshake request.
*/
WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors);
/**
* Configure the HandshakeHandler to use.
*/
WebSocketHandlerRegistration setHandshakeHandler(HandshakeHandler handshakeHandler);
/**
* Enable SockJS fallback options.
*/
SockJsServiceRegistration withSockJS();
}
\ No newline at end of file
......@@ -16,76 +16,19 @@
package org.springframework.web.socket.server.config;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
/**
* A helper class for configuring {@link WebSocketHandler} request handling.
* Provides methods for configuring {@link WebSocketHandler} request mappings.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class WebSocketHandlerRegistry {
private final List<WebSocketHandlerRegistration> registrations = new ArrayList<WebSocketHandlerRegistration>();
private int order = 1;
private TaskScheduler defaultTaskScheduler;
public WebSocketHandlerRegistry(ThreadPoolTaskScheduler defaultSockJsTaskScheduler) {
this.defaultTaskScheduler = defaultSockJsTaskScheduler;
}
public WebSocketHandlerRegistration addHandler(WebSocketHandler wsHandler, String... paths) {
WebSocketHandlerRegistration r = new WebSocketHandlerRegistration(this.defaultTaskScheduler);
r.addHandler(wsHandler, paths);
this.registrations.add(r);
return r;
}
protected List<WebSocketHandlerRegistration> getRegistrations() {
return this.registrations;
}
/**
* Specify the order to use for WebSocket {@link HandlerMapping} relative to other
* handler mappings configured in the Spring MVC configuration. The default value is 1.
*/
public void setOrder(int order) {
this.order = order;
}
public interface WebSocketHandlerRegistry {
/**
* Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations.
* Configure a WebSocketHandler at the specified URL paths.
*/
AbstractHandlerMapping getHandlerMapping() {
Map<String, Object> urlMap = new LinkedHashMap<String, Object>();
for (WebSocketHandlerRegistration registration : this.registrations) {
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
for (HttpRequestHandler httpHandler : mappings.keySet()) {
for (String pattern : mappings.get(httpHandler)) {
urlMap.put(pattern, httpHandler);
}
}
}
SimpleUrlHandlerMapping hm = new SimpleUrlHandlerMapping();
hm.setOrder(this.order);
hm.setUrlMap(urlMap);
return hm;
}
WebSocketHandlerRegistration addHandler(WebSocketHandler webSocketHandler, String... paths);
}
}
\ No newline at end of file
......@@ -61,6 +61,10 @@ public class WebSocketTransportHandler extends TransportHandlerSupport
return TransportType.WEBSOCKET;
}
public HandshakeHandler getHandshakeHandler() {
return this.handshakeHandler;
}
@Override
public AbstractSockJsSession createSession(String sessionId, WebSocketHandler wsHandler,
Map<String, Object> attributes) {
......
......@@ -25,6 +25,7 @@ import org.junit.runners.Parameterized.Parameter;
import org.springframework.context.Lifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
......@@ -52,14 +53,26 @@ public abstract class AbstractWebSocketIntegrationTests {
@Parameter(1)
public WebSocketClient webSocketClient;
protected AnnotationConfigWebApplicationContext wac;
@Before
public void setup() throws Exception {
this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(getAnnotatedConfigClasses());
this.wac.register(upgradeStrategyConfigTypes.get(this.server.getClass()));
if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).start();
}
this.server.init(this.wac);
this.server.start();
}
protected abstract Class<?>[] getAnnotatedConfigClasses();
@After
public void teardown() throws Exception {
try {
......@@ -76,10 +89,6 @@ public abstract class AbstractWebSocketIntegrationTests {
return "ws://localhost:" + this.server.getPort();
}
protected Class<?> getUpgradeStrategyConfigClass() {
return upgradeStrategyConfigTypes.get(this.server.getClass());
}
static abstract class AbstractRequestUpgradeStrategyConfig {
......
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.junit.Assert.*;
/**
* Test fixture for {@link AbstractWebSocketHandlerRegistration}.
*
* @author Rossen Stoyanchev
*/
public class AbstractWebSocketHandlerRegistrationTests {
private TestWebSocketHandlerRegistration registration;
private TaskScheduler taskScheduler;
@Before
public void setup() {
this.taskScheduler = Mockito.mock(TaskScheduler.class);
this.registration = new TestWebSocketHandlerRegistration(taskScheduler);
}
@Test
public void minimal() {
WebSocketHandler wsHandler = new TextWebSocketHandlerAdapter();
this.registration.addHandler(wsHandler, "/foo", "/bar");
List<Mapping> mappings = this.registration.getMappings();
assertEquals(2, mappings.size());
Mapping m1 = mappings.get(0);
assertEquals(wsHandler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
Mapping m2 = mappings.get(1);
assertEquals(wsHandler, m2.webSocketHandler);
assertEquals("/bar", m2.path);
}
@Test
public void interceptors() {
WebSocketHandler wsHandler = new TextWebSocketHandlerAdapter();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(wsHandler, "/foo").addInterceptors(interceptor);
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping m1 = mappings.get(0);
assertEquals(wsHandler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
assertArrayEquals(new HandshakeInterceptor[] { interceptor }, m1.interceptors);
}
@Test
public void interceptorsPassedToSockJsRegistration() {
WebSocketHandler wsHandler = new TextWebSocketHandlerAdapter();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(wsHandler, "/foo").addInterceptors(interceptor).withSockJS();
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping m1 = mappings.get(0);
assertEquals(wsHandler, m1.webSocketHandler);
assertEquals("/foo/**", m1.path);
assertNotNull(m1.sockJsService);
assertEquals(Arrays.asList(interceptor), m1.sockJsService.getHandshakeInterceptors());
}
@Test
public void handshakeHandler() {
WebSocketHandler wsHandler = new TextWebSocketHandlerAdapter();
HandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
this.registration.addHandler(wsHandler, "/foo").setHandshakeHandler(handshakeHandler);
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping m1 = mappings.get(0);
assertEquals(wsHandler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
assertSame(handshakeHandler, m1.handshakeHandler);
}
@Test
public void handshakeHandlerPassedToSockJsRegistration() {
WebSocketHandler wsHandler = new TextWebSocketHandlerAdapter();
HandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
this.registration.addHandler(wsHandler, "/foo").setHandshakeHandler(handshakeHandler).withSockJS();
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping m1 = mappings.get(0);
assertEquals(wsHandler, m1.webSocketHandler);
assertEquals("/foo/**", m1.path);
assertNotNull(m1.sockJsService);
WebSocketTransportHandler transportHandler =
(WebSocketTransportHandler) m1.sockJsService.getTransportHandlers().get(TransportType.WEBSOCKET);
assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
}
private static class TestWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration<List<Mapping>> {
public TestWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) {
super(sockJsTaskScheduler);
}
@Override
protected List<Mapping> createMappings() {
return new ArrayList<>();
}
@Override
protected void addSockJsServiceMapping(List<Mapping> mappings, SockJsService sockJsService,
WebSocketHandler wsHandler, String pathPattern) {
mappings.add(new Mapping(wsHandler, pathPattern, sockJsService));
}
@Override
protected void addWebSocketHandlerMapping(List<Mapping> mappings,
WebSocketHandler wsHandler, HandshakeHandler handshakeHandler,
HandshakeInterceptor[] interceptors, String path) {
mappings.add(new Mapping(wsHandler, path, handshakeHandler, interceptors));
}
}
private static class Mapping {
private final WebSocketHandler webSocketHandler;
private final String path;
private final HandshakeHandler handshakeHandler;
private final HandshakeInterceptor[] interceptors;
private final DefaultSockJsService sockJsService;
public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) {
this.webSocketHandler = handler;
this.path = path;
this.handshakeHandler = null;
this.interceptors = null;
this.sockJsService = (DefaultSockJsService) sockJsService;
}
public Mapping(WebSocketHandler h, String path, HandshakeHandler hh, HandshakeInterceptor[] interceptors) {
this.webSocketHandler = h;
this.path = path;
this.handshakeHandler = hh;
this.interceptors = interceptors;
this.sockJsService = null;
}
}
}
......@@ -27,7 +27,6 @@ import org.junit.runners.Parameterized.Parameters;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
import org.springframework.web.socket.JettyTestServer;
import org.springframework.web.socket.WebSocketSession;
......@@ -54,33 +53,26 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
};
@Override
protected Class<?>[] getAnnotatedConfigClasses() {
return new Class<?>[] { TestWebSocketConfigurer.class };
}
@Test
public void registerWebSocketHandler() throws Exception {
AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext();
cxt.register(TestWebSocketConfigurer.class, getUpgradeStrategyConfigClass());
this.server.init(cxt);
this.server.start();
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws");
TestWebSocketHandler serverHandler = cxt.getBean(TestWebSocketHandler.class);
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
}
@Test
public void registerWebSocketHandlerWithSockJS() throws Exception {
AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext();
cxt.register(TestWebSocketConfigurer.class, getUpgradeStrategyConfigClass());
this.server.init(cxt);
this.server.start();
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket");
TestWebSocketHandler serverHandler = cxt.getBean(TestWebSocketHandler.class);
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册