提交 39ff1e2c 编写于 作者: R Rossen Stoyanchev

Add StompProtocolHandler tests

上级 364bc357
......@@ -30,6 +30,7 @@ import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.handler.websocket.SubProtocolHandler;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.MutableUserQueueSuffixResolver;
import org.springframework.messaging.simp.handler.SimpleUserQueueSuffixResolver;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
......@@ -64,7 +65,7 @@ public class StompProtocolHandler implements SubProtocolHandler {
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private MutableUserQueueSuffixResolver queueSuffixResolver;
private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver();
/**
......
......@@ -27,15 +27,12 @@ import org.junit.runners.Parameterized.Parameters;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.AbstractWebSocketIntegrationTests;
import org.springframework.messaging.simp.JettyTestServer;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompMessageConverter;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.simp.stomp.StompTextMessageBuilder;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.stereotype.Controller;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
......@@ -76,16 +73,13 @@ public class WebSocketMessageBrokerConfigurationTests extends AbstractWebSocketI
this.server.init(cxt);
this.server.start();
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setDestination("/app/foo");
Message<byte[]> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
byte[] bytes = new StompMessageConverter().fromMessage(message);
final TextMessage webSocketMessage = new TextMessage(new String(bytes));
final TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND)
.headers("destination:/app/foo").build();
WebSocketHandler clientHandler = new TextWebSocketHandlerAdapter() {
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
session.sendMessage(webSocketMessage);
session.sendMessage(textMessage);
}
};
......
......@@ -24,6 +24,7 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.web.socket.TextMessage;
import static org.junit.Assert.*;
......@@ -41,14 +42,17 @@ public class StompMessageConverterTests {
this.converter = new StompMessageConverter();
}
@SuppressWarnings("unchecked")
@Test
public void connectFrame() throws Exception {
String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
String frame = "\n\n\nCONNECT\n" + accept + host + "\n";
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"));
String accept = "accept-version:1.1";
String host = "host:github.org";
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT)
.headers(accept, host).build();
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(textMessage.getPayload());
assertEquals(0, message.getPayload().length);
......@@ -80,11 +84,14 @@ public class StompMessageConverterTests {
@Test
public void connectWithEscapes() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String frame = "CONNECT\n" + accept + host + "\n";
String accept = "accept-version:1.1";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org";
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT)
.headers(accept, host).build();
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"));
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(textMessage.getPayload());
assertEquals(0, message.getPayload().length);
......
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.util.Arrays;
import java.util.HashSet;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.support.TestPrincipal;
import org.springframework.web.socket.support.TestWebSocketSession;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link StompProtocolHandler} tests.
*
* @author Rossen Stoyanchev
*/
public class StompProtocolHandlerTests {
private StompProtocolHandler stompHandler;
private TestWebSocketSession session;
private MessageChannel channel;
private ArgumentCaptor<Message> messageCaptor;
@Before
public void setup() {
this.stompHandler = new StompProtocolHandler();
this.channel = Mockito.mock(MessageChannel.class);
this.messageCaptor = ArgumentCaptor.forClass(Message.class);
this.session = new TestWebSocketSession();
this.session.setId("s1");
this.session.setPrincipal(new TestPrincipal("joe"));
}
@Test
public void handleConnect() {
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
"login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();
this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel);
verify(this.channel).send(this.messageCaptor.capture());
Message<?> actual = this.messageCaptor.getValue();
assertNotNull(actual);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual);
assertEquals(StompCommand.CONNECT, headers.getCommand());
assertEquals("s1", headers.getSessionId());
assertEquals("joe", headers.getUser().getName());
assertEquals("guest", headers.getLogin());
assertEquals("PROTECTED", headers.getPasscode());
assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat());
assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion());
// Check CONNECTED reply
assertEquals(1, this.session.getSentMessages().size());
textMessage = (TextMessage) this.session.getSentMessages().get(0);
Message<?> message = new StompMessageConverter().toMessage(textMessage.getPayload());
StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message);
assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());
assertEquals("1.1", replyHeaders.getVersion());
assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat());
assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0));
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.springframework.web.socket.TextMessage;
/**
* A builder for creating WebSocket messages with STOMP frame content.
*
* @author Rossen Stoyanchev
*/
public class StompTextMessageBuilder {
private StompCommand command;
private final List<String> headerLines = new ArrayList<String>();
private String body;
private StompTextMessageBuilder(StompCommand command) {
this.command = command;
}
public static StompTextMessageBuilder create(StompCommand command) {
return new StompTextMessageBuilder(command);
}
public StompTextMessageBuilder headers(String... headerLines) {
this.headerLines.addAll(Arrays.asList(headerLines));
return this;
}
public StompTextMessageBuilder body(String body) {
this.body = body;
return this;
}
public TextMessage build() {
StringBuilder sb = new StringBuilder(this.command.name()).append("\n");
for (String line : this.headerLines) {
sb.append(line).append("\n");
}
sb.append("\n");
if (this.body != null) {
sb.append(this.body);
}
sb.append("\u0000");
return new TextMessage(sb.toString());
}
}
......@@ -17,26 +17,26 @@
package org.springframework.web.socket.server.config;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.mockito.Mockito;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
import org.springframework.web.socket.JettyTestServer;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.WebSocketHandlerAdapter;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;
import static org.junit.Assert.*;
/**
......@@ -63,13 +63,10 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
this.server.init(cxt);
this.server.start();
WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class);
WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class);
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws");
this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws");
verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class));
verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class));
TestWebSocketHandler serverHandler = cxt.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
}
@Test
......@@ -81,13 +78,10 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
this.server.init(cxt);
this.server.start();
WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class);
WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class);
this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket");
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket");
verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class));
verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class));
TestWebSocketHandler serverHandler = cxt.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
}
......@@ -110,8 +104,18 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
}
@Bean
public WebSocketHandler serverHandler() {
return Mockito.mock(WebSocketHandler.class);
public TestWebSocketHandler serverHandler() {
return new TestWebSocketHandler();
}
}
private static class TestWebSocketHandler extends WebSocketHandlerAdapter {
private CountDownLatch latch = new CountDownLatch(1);
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.latch.countDown();
}
}
......
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.support;
import java.security.Principal;
/**
* An implementation of Prinicipal for testing.
* @author Rossen Stoyanchev
*/
public class TestPrincipal implements Principal {
private String name;
public TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return this.name;
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (!(obj instanceof TestPrincipal)) {
return false;
}
TestPrincipal p = (TestPrincipal) obj;
return this.name.equals(p.name);
}
@Override
public int hashCode() {
return this.name.hashCode();
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册