提交 d30b3eaf 编写于 作者: R Rossen Stoyanchev

Add STOMP client

WebSocketStompClient can be used with any implementation of
org.springframework.web.socket.client.WebSocketClient, which includes
org.springframework.web.socket.sockjs.client.SockJsClient.

Reactor11TcpStompClient can be used with reactor-net and provides STOMP
over TCP. It's also possible to adapt other WebSocket and TCP client
libraries (see StompClientSupport for more details).

For example usage see WebSocketStompClientIntegrationTests.

Issue: SPR-11588
上级 8a5e47a0
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import org.springframework.messaging.tcp.TcpConnectionHandler;
import org.springframework.util.concurrent.ListenableFuture;
/**
* A {@link StompSession} that implements
* {@link org.springframework.messaging.tcp.TcpConnectionHandler
* TcpConnectionHandler} in order to send and receive messages.
*
* <p>A ConnectionHandlingStompSession can be used with any TCP or WebSocket
* library that is adapted to the {@code TcpConnectionHandler} contract.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public interface ConnectionHandlingStompSession extends StompSession, TcpConnectionHandler<byte[]> {
/**
* Return a future that will complete when the session is ready for use.
*/
ListenableFuture<StompSession> getSessionFuture();
}
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
/**
* Raised when the connection for a STOMP session is lost rather than closed.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
@SuppressWarnings("serial")
public class ConnectionLostException extends Exception {
public ConnectionLostException(String message) {
super(message);
}
}
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.util.Arrays;
import java.util.Properties;
import reactor.core.Environment;
import reactor.core.configuration.ConfigurationReader;
import reactor.core.configuration.DispatcherConfiguration;
import reactor.core.configuration.DispatcherType;
import reactor.core.configuration.ReactorConfiguration;
import reactor.net.netty.tcp.NettyTcpClient;
import reactor.net.tcp.TcpClient;
import reactor.net.tcp.spec.TcpClientSpec;
import org.springframework.messaging.Message;
import org.springframework.messaging.tcp.TcpOperations;
import org.springframework.messaging.tcp.reactor.Reactor11TcpClient;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.SettableListenableFuture;
/**
* A STOMP over TCP client that uses
* {@link org.springframework.messaging.tcp.reactor.Reactor11TcpClient
* Reactor11TcpClient}.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public class Reactor11TcpStompClient extends StompClientSupport {
private final TcpOperations<byte[]> tcpClient;
/**
* Create an instance with host "127.0.0.1" and port 61613.
*/
public Reactor11TcpStompClient() {
this("127.0.0.1", 61613);
}
/**
* Create an instance with the given host and port.
* @param host the host
* @param port the port
*/
public Reactor11TcpStompClient(String host, int port) {
this.tcpClient = new Reactor11TcpClient<byte[]>(createNettyTcpClient(host, port));
}
private TcpClient<Message<byte[]>, Message<byte[]>> createNettyTcpClient(String host, int port) {
return new TcpClientSpec<Message<byte[]>, Message<byte[]>>(NettyTcpClient.class)
.env(new Environment(new StompClientDispatcherConfigReader()))
.codec(new Reactor11StompCodec(new StompEncoder(), new StompDecoder()))
.connect(host, port)
.get();
}
/**
* Create an instance with a pre-configured TCP client.
* @param tcpClient the client to use
*/
public Reactor11TcpStompClient(TcpOperations<byte[]> tcpClient) {
this.tcpClient = tcpClient;
}
/**
* Connect and notify the given {@link StompSessionHandler} when connected
* on the STOMP level,
* @param handler the handler for the STOMP session
* @return ListenableFuture for access to the session when ready for use
*/
public ListenableFuture<StompSession> connect(StompSessionHandler handler) {
return connect(null, handler);
}
/**
* An overloaded version of {@link #connect(StompSessionHandler)} that
* accepts headers to use for the STOMP CONNECT frame.
* @param connectHeaders headers to add to the CONNECT frame
* @param handler the handler for the STOMP session
* @return ListenableFuture for access to the session when ready for use
*/
public ListenableFuture<StompSession> connect(StompHeaders connectHeaders, StompSessionHandler handler) {
ConnectionHandlingStompSession session = createSession(connectHeaders, handler);
this.tcpClient.connect(session);
return session.getSessionFuture();
}
/**
* Shut down the client and release resources.
*/
public void shutdown() {
this.tcpClient.shutdown();
}
/**
* A ConfigurationReader with a thread pool-based dispatcher.
*/
private static class StompClientDispatcherConfigReader implements ConfigurationReader {
@Override
public ReactorConfiguration read() {
String dispatcherName = "StompClient";
DispatcherType dispatcherType = DispatcherType.THREAD_POOL_EXECUTOR;
DispatcherConfiguration config = new DispatcherConfiguration(dispatcherName, dispatcherType, 128, 0);
return new ReactorConfiguration(Arrays.<DispatcherConfiguration>asList(config), dispatcherName, new Properties());
}
}
}
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.util.Arrays;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.converter.StringMessageConverter;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
/**
* Base class for STOMP client implementations.
*
* <p>Sub-classes can connect over WebSocket or TCP using any library.
* When creating a new connection a sub-class can create an instance of
* {@link DefaultStompSession} which extends
* {@link org.springframework.messaging.tcp.TcpConnectionHandler
* TcpConnectionHandler} whose lifecycle methods the sub-class must then invoke.
*
* <p>In effect {@code TcpConnectionHandler} and {@code TcpConnection} are the
* contracts any sub-class must adapt to while using {@link StompEncoder} and
* {@link StompDecoder} to encode and decode STOMP messages.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public abstract class StompClientSupport {
private MessageConverter messageConverter = new StringMessageConverter();
private TaskScheduler taskScheduler;
private long[] defaultHeartbeat = new long[] {10000, 10000};
private long receiptTimeLimit = 15 * 1000;
/**
* Set the {@link MessageConverter} to use to convert the payload of incoming
* and outgoing messages to and from {@code byte[]} based on object type
* and the "content-type" header.
* <p>By default, {@link StringMessageConverter} is configured.
* @param messageConverter the message converter to use
*/
public void setMessageConverter(MessageConverter messageConverter) {
Assert.notNull(messageConverter, "'messageConverter' must not be null");
this.messageConverter = messageConverter;
}
/**
* Return the configured {@link MessageConverter}.
*/
public MessageConverter getMessageConverter() {
return this.messageConverter;
}
/**
* Configure a scheduler to use for heartbeats and for receipt tracking.
*
* <p><strong>Note:</strong> some transports have built-in support to work
* with heartbeats and therefore do not require a TaskScheduler.
* Receipts however, if needed, do require a TaskScheduler to be configured.
*
* <p>By default this is not set.
*/
public void setTaskScheduler(TaskScheduler taskScheduler) {
this.taskScheduler = taskScheduler;
}
/**
* The configured TaskScheduler.
*/
public TaskScheduler getTaskScheduler() {
return this.taskScheduler;
}
/**
* Configure the default value for the "heart-beat" header of the STOMP
* CONNECT frame. The first number represents how often the client will write
* or send a heart-beat. The second is how often the server should write.
* A value of 0 means no heart-beats.
* <p>By default this is set to "10000,10000" but sub-classes may override
* that default and for example set it to "0,0" if they require a
* TaskScheduler to be configured first.
* @param heartbeat the value for the CONNECT "heart-beat" header
* @see <a href="http://stomp.github.io/stomp-specification-1.2.html#Heart-beating">
* http://stomp.github.io/stomp-specification-1.2.html#Heart-beating</a>
*/
public void setDefaultHeartbeat(long[] heartbeat) {
Assert.notNull(heartbeat);
Assert.isTrue(heartbeat[0] >= 0 && heartbeat[1] >=0 , "Invalid heart-beat: " + Arrays.toString(heartbeat));
this.defaultHeartbeat = heartbeat;
}
/**
* Return the configured default heart-beat value, never {@code null}.
*/
public long[] getDefaultHeartbeat() {
return this.defaultHeartbeat;
}
/**
* Whether heartbeats are enabled. Returns {@code false} if
* {@link #setDefaultHeartbeat defaultHeartbeat} is set to "0,0", and
* {@code true} otherwise.
*/
public boolean isDefaultHeartbeatEnabled() {
return (getDefaultHeartbeat() != null && getDefaultHeartbeat()[0] != 0 && getDefaultHeartbeat()[1] != 0);
}
/**
* Configure the number of milliseconds before a receipt is considered expired.
* <p>By default set to 15,000 (15 seconds).
*/
public void setReceiptTimeLimit(long receiptTimeLimit) {
Assert.isTrue(receiptTimeLimit > 0);
this.receiptTimeLimit = receiptTimeLimit;
}
/**
* Return the configured receipt time limit.
*/
public long getReceiptTimeLimit() {
return receiptTimeLimit;
}
/**
* Factory method for create and configure a new session.
* @param connectHeaders headers for the STOMP CONNECT frame
* @param handler the handler for the STOMP session
* @return the created session
*/
protected ConnectionHandlingStompSession createSession(StompHeaders connectHeaders, StompSessionHandler handler) {
connectHeaders = processConnectHeaders(connectHeaders);
DefaultStompSession session = new DefaultStompSession(handler, connectHeaders);
session.setMessageConverter(getMessageConverter());
session.setTaskScheduler(getTaskScheduler());
session.setReceiptTimeLimit(getReceiptTimeLimit());
return session;
}
/**
* Further initialize the StompHeaders, for example setting the heart-beat
* header if necessary.
* @param connectHeaders the headers to modify
* @return the modified headers
*/
protected StompHeaders processConnectHeaders(StompHeaders connectHeaders) {
connectHeaders = (connectHeaders != null ? connectHeaders : new StompHeaders());
if (connectHeaders.getHeartbeat() == null) {
connectHeaders.setHeartbeat(getDefaultHeartbeat());
}
return connectHeaders;
}
}
......@@ -73,6 +73,9 @@ public final class StompEncoder {
DataOutputStream output = new DataOutputStream(baos);
if (SimpMessageType.HEARTBEAT.equals(SimpMessageHeaderAccessor.getMessageType(headers))) {
if (logger.isTraceEnabled()) {
logger.trace("Encoding heartbeat");
}
output.write(StompDecoder.HEARTBEAT_PAYLOAD);
}
else {
......
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.lang.reflect.Type;
/**
* Contract to handle a STOMP frame.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public interface StompFrameHandler {
/**
* Invoked before {@link #handleFrame(StompHeaders, Object)} to determine the
* type of Object the payload should be converted to.
* @param headers the headers of a message
*/
Type getPayloadType(StompHeaders headers);
/**
* Handle a STOMP frame with the payload converted to the target type returned
* from {@link #getPayloadType(StompHeaders)}.
*
* @param headers the headers of the frame
* @param payload the payload or {@code null} if there was no payload
*/
void handleFrame(StompHeaders headers, Object payload);
}
......@@ -162,7 +162,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
}
}
private void updateStompHeadersFromSimpMessageHeaders() {
void updateStompHeadersFromSimpMessageHeaders() {
if (getDestination() != null) {
setNativeHeader(STOMP_DESTINATION_HEADER, getDestination());
}
......
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* Represents STOMP frame headers.
*
* <p>In addition to the normal methods defined by {@link Map}, this class offers
* the following convenience methods:
* <ul>
* <li>{@link #getFirst(String)} return the first value for a header name</li>
* <li>{@link #add(String, String)} add to the list of values for a header name</li>
* <li>{@link #set(String, String)} set a header name to a single string value</li>
* </ul>
*
* @author Rossen Stoyanchev
* @since 4.2
* @see <a href="http://stomp.github.io/stomp-specification-1.2.html#Frames_and_Headers">
* http://stomp.github.io/stomp-specification-1.2.html#Frames_and_Headers</a>
*/
public class StompHeaders implements MultiValueMap<String, String>, Serializable {
private static final long serialVersionUID = 7514642206528452544L;
// Standard headers (as defined in the spec)
public static final String CONTENT_TYPE = "content-type"; // SEND, MESSAGE, ERROR
public static final String CONTENT_LENGTH = "content-length"; // SEND, MESSAGE, ERROR
public static final String RECEIPT = "receipt"; // any client frame other than CONNECT
// CONNECT
public static final String HOST = "host";
public static final String LOGIN = "login";
public static final String PASSCODE = "passcode";
public static final String HEARTBEAT = "heart-beat";
// CONNECTED
public static final String SESSION = "session";
public static final String SERVER = "server";
// SEND
public static final String DESTINATION = "destination";
// SUBSCRIBE, UNSUBSCRIBE
public static final String ID = "id";
public static final String ACK = "ack";
// MESSAGE
public static final String SUBSCRIPTION = "subscription";
public static final String MESSAGE_ID = "message-id";
// RECEIPT
public static final String RECEIPT_ID = "receipt-id";
private final Map<String, List<String>> headers;
/**
* Create a new instance to be populated with new header values.
*/
public StompHeaders() {
this(new LinkedMultiValueMap<String, String>(4), false);
}
private StompHeaders(Map<String, List<String>> headers, boolean readOnly) {
Assert.notNull(headers, "'headers' must not be null");
if (readOnly) {
Map<String, List<String>> map = new LinkedMultiValueMap<String, String>(headers.size());
for (Entry<String, List<String>> entry : headers.entrySet()) {
List<String> values = Collections.unmodifiableList(entry.getValue());
map.put(entry.getKey(), values);
}
this.headers = Collections.unmodifiableMap(map);
}
else {
this.headers = headers;
}
}
/**
* Set the content-type header.
* Applies to the SEND, MESSAGE, and ERROR frames.
*/
public void setContentType(MimeType mimeType) {
Assert.isTrue(!mimeType.isWildcardType(), "'Content-Type' cannot contain wildcard type '*'");
Assert.isTrue(!mimeType.isWildcardSubtype(), "'Content-Type' cannot contain wildcard subtype '*'");
set(CONTENT_TYPE, mimeType.toString());
}
/**
* Return the content-type header value.
*/
public MimeType getContentType() {
String value = getFirst(CONTENT_TYPE);
return (StringUtils.hasLength(value) ? MimeTypeUtils.parseMimeType(value) : null);
}
/**
* Set the content-length header.
* Applies to the SEND, MESSAGE, and ERROR frames.
*/
public void setContentLength(long contentLength) {
set(CONTENT_LENGTH, Long.toString(contentLength));
}
/**
* Return the content-length header or -1 if unknown.
*/
public long getContentLength() {
String value = getFirst(CONTENT_LENGTH);
return (value != null ? Long.parseLong(value) : -1);
}
/**
* Set the receipt header.
* Applies to any client frame other than CONNECT.
*/
public void setReceipt(String receipt) {
set(RECEIPT, receipt);
}
/**
* Get the receipt header.
*/
public String getReceipt() {
return getFirst(RECEIPT);
}
/**
* Set the host header.
* Applies to the CONNECT frame.
*/
public void setHost(String host) {
set(HOST, host);
}
/**
* Get the host header.
*/
public String getHost() {
return getFirst(HOST);
}
/**
* Set the login header.
* Applies to the CONNECT frame.
*/
public void setLogin(String login) {
set(LOGIN, login);
}
/**
* Get the login header.
*/
public String getLogin() {
return getFirst(LOGIN);
}
/**
* Set the passcode header.
* Applies to the CONNECT frame.
*/
public void setPasscode(String passcode) {
set(PASSCODE, passcode);
}
/**
* Get the passcode header.
*/
public String getPasscode() {
return getFirst(PASSCODE);
}
/**
* Set the heartbeat header.
* Applies to the CONNECT and CONNECTED frames.
*/
public void setHeartbeat(long[] heartbeat) {
Assert.notNull(heartbeat);
String value = heartbeat[0] + "," + heartbeat[1];
Assert.isTrue(heartbeat[0] >= 0 && heartbeat[1] >= 0, "Heart-beat values cannot be negative: " + value);
set(HEARTBEAT, value);
}
/**
* Get the heartbeat header.
*/
public long[] getHeartbeat() {
String rawValue = getFirst(HEARTBEAT);
if (!StringUtils.hasText(rawValue)) {
return null;
}
String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue);
return new long[] {Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])};
}
/**
* Whether heartbeats are enabled. Returns {@code false} if
* {@link #setHeartbeat} is set to "0,0", and {@code true} otherwise.
*/
public boolean isHeartbeatEnabled() {
long[] heartbeat = getHeartbeat();
return (heartbeat != null && heartbeat[0] != 0 && heartbeat[1] != 0);
}
/**
* Set the session header.
* Applies to the CONNECTED frame.
*/
public void setSession(String session) {
set(SESSION, session);
}
/**
* Get the session header.
*/
public String getSession() {
return getFirst(SESSION);
}
/**
* Set the server header.
* Applies to the CONNECTED frame.
*/
public void setServer(String server) {
set(SERVER, server);
}
/**
* Get the server header.
* Applies to the CONNECTED frame.
*/
public String getServer() {
return getFirst(SERVER);
}
/**
* Set the destination header.
*/
public void setDestination(String destination) {
set(DESTINATION, destination);
}
/**
* Get the destination header.
* Applies to the SEND, SUBSCRIBE, and MESSAGE frames.
*/
public String getDestination() {
return getFirst(DESTINATION);
}
/**
* Set the id header.
* Applies to the SUBSCR0BE, UNSUBSCRIBE, and ACK or NACK frames.
*/
public void setId(String id) {
set(ID, id);
}
/**
* Get the id header.
*/
public String getId() {
return getFirst(ID);
}
/**
* Set the ack header to one of "auto", "client", or "client-individual".
* Applies to the SUBSCRIBE and MESSAGE frames.
*/
public void setAck(String ack) {
set(ACK, ack);
}
/**
* Get the ack header.
*/
public String getAck() {
return getFirst(ACK);
}
/**
* Set the login header.
* Applies to the MESSAGE frame.
*/
public void setSubscription(String subscription) {
set(SUBSCRIPTION, subscription);
}
/**
* Get the subscription header.
*/
public String getSubscription() {
return getFirst(SUBSCRIPTION);
}
/**
* Set the message-id header.
* Applies to the MESSAGE frame.
*/
public void setMessageId(String messageId) {
set(MESSAGE_ID, messageId);
}
/**
* Get the message-id header.
*/
public String getMessageId() {
return getFirst(MESSAGE_ID);
}
/**
* Set the receipt-id header.
* Applies to the RECEIPT frame.
*/
public void setReceiptId(String receiptId) {
set(RECEIPT_ID, receiptId);
}
/**
* Get the receipt header.
*/
public String getReceiptId() {
return getFirst(RECEIPT_ID);
}
/**
* Return the first header value for the given header name, if any.
* @param headerName the header name
* @return the first header value, or {@code null} if none
*/
@Override
public String getFirst(String headerName) {
List<String> headerValues = headers.get(headerName);
return headerValues != null ? headerValues.get(0) : null;
}
/**
* Add the given, single header value under the given name.
* @param headerName the header name
* @param headerValue the header value
* @throws UnsupportedOperationException if adding headers is not supported
* @see #put(String, List)
* @see #set(String, String)
*/
@Override
public void add(String headerName, String headerValue) {
List<String> headerValues = headers.get(headerName);
if (headerValues == null) {
headerValues = new LinkedList<String>();
this.headers.put(headerName, headerValues);
}
headerValues.add(headerValue);
}
/**
* Set the given, single header value under the given name.
* @param headerName the header name
* @param headerValue the header value
* @throws UnsupportedOperationException if adding headers is not supported
* @see #put(String, List)
* @see #add(String, String)
*/
@Override
public void set(String headerName, String headerValue) {
List<String> headerValues = new LinkedList<String>();
headerValues.add(headerValue);
headers.put(headerName, headerValues);
}
@Override
public void setAll(Map<String, String> values) {
for (Entry<String, String> entry : values.entrySet()) {
set(entry.getKey(), entry.getValue());
}
}
@Override
public Map<String, String> toSingleValueMap() {
LinkedHashMap<String, String> singleValueMap = new LinkedHashMap<String,String>(this.headers.size());
for (Entry<String, List<String>> entry : headers.entrySet()) {
singleValueMap.put(entry.getKey(), entry.getValue().get(0));
}
return singleValueMap;
}
// Map implementation
@Override
public int size() {
return this.headers.size();
}
@Override
public boolean isEmpty() {
return this.headers.isEmpty();
}
@Override
public boolean containsKey(Object key) {
return this.headers.containsKey(key);
}
@Override
public boolean containsValue(Object value) {
return this.headers.containsValue(value);
}
@Override
public List<String> get(Object key) {
return this.headers.get(key);
}
@Override
public List<String> put(String key, List<String> value) {
return this.headers.put(key, value);
}
@Override
public List<String> remove(Object key) {
return this.headers.remove(key);
}
@Override
public void putAll(Map<? extends String, ? extends List<String>> map) {
this.headers.putAll(map);
}
@Override
public void clear() {
this.headers.clear();
}
@Override
public Set<String> keySet() {
return this.headers.keySet();
}
@Override
public Collection<List<String>> values() {
return this.headers.values();
}
@Override
public Set<Entry<String, List<String>>> entrySet() {
return this.headers.entrySet();
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof StompHeaders)) {
return false;
}
StompHeaders otherHeaders = (StompHeaders) other;
return this.headers.equals(otherHeaders.headers);
}
@Override
public int hashCode() {
return this.headers.hashCode();
}
@Override
public String toString() {
return this.headers.toString();
}
/**
* Return a {@code StompHeaders} object that can only be read, not written to.
*/
public static StompHeaders readOnlyStompHeaders(Map<String, List<String>> headers) {
return new StompHeaders(headers, true);
}
}
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
/**
* Represents a STOMP session with operations to send messages, create
* subscriptions and receive messages on those subscriptions.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public interface StompSession {
/**
* Return the id for the session.
*/
String getSessionId();
/**
* Whether the session is connected.
*/
boolean isConnected();
/**
* When enabled, a receipt header is automatically added to future
* {@code send} and {@code subscribe} operations on this session, which causes
* the server to return a RECEIPT. An application can then use the
* {@link StompSession.Receiptable
* Receiptable} returned from the operation to track the receipt.
*
* <p>A receipt header can also be added manually through the overloaded
* methods that accept {@code StompHeaders}.
*/
void setAutoReceipt(boolean enabled);
/**
* Send a message to the specified destination, converting the payload to a
* {@code byte[]} with the help of a
* {@link org.springframework.messaging.converter.MessageConverter MessageConverter}.
* @param destination the destination to send a message to
* @param payload the message payload
* @return a Receiptable for tracking receipts
*/
Receiptable send(String destination, Object payload);
/**
* An overloaded version of {@link #send(String, Object)} that accepts
* full {@link StompHeaders} instead of a destination. The headers must
* contain a destination and may also have other headers such as
* "content-type" or custom headers for the broker to propagate to subscribers,
* or broker-specific, non-standard headers..
* @param headers the message headers
* @param payload the message payload
* @return a Receiptable for tracking receipts
*/
Receiptable send(StompHeaders headers, Object payload);
/**
* Subscribe to the given destination by sending a SUBSCRIBE frame and handle
* received messages with the specified {@link StompFrameHandler}.
* @param destination the destination to subscribe to
* @param handler the handler for received messages
* @return a handle to use to unsubscribe and/or track receipts
*/
Subscription subscribe(String destination, StompFrameHandler handler);
/**
* An overloaded version of {@link #subscribe(String, StompFrameHandler)}
* that accepts full {@link StompHeaders} rather instead of a destination.
* @param headers the headers for the subscribe message frame
* @param handler the handler for received messages
* @return a handle to use to unsubscribe and/or track receipts
*/
Subscription subscribe(StompHeaders headers, StompFrameHandler handler);
/**
* Disconnect the session by sending a DISCONNECT frame.
*/
void disconnect();
/**
* A handle to use to track receipts.
* @see #setAutoReceipt(boolean)
*/
interface Receiptable {
/**
* Return the receipt id, or {@code null} if the STOMP frame for which
* the handle was returned did not have a "receipt" header.
*/
String getReceiptId();
/**
* Task to invoke when a receipt is received.
* @throws java.lang.IllegalArgumentException if the receiptId is {@code null}
*/
void addReceiptTask(Runnable runnable);
/**
* Task to invoke when a receipt is not received in the configured time.
* @throws java.lang.IllegalArgumentException if the receiptId is {@code null}
* @see org.springframework.messaging.simp.stomp.StompClientSupport#setReceiptTimeLimit(long)
*/
void addReceiptLostTask(Runnable runnable);
}
/**
* A handle to use to unsubscribe or to track a receipt.
*/
interface Subscription extends Receiptable {
/**
* Return the id for the subscription.
*/
String getSubscriptionId();
/**
* Remove the subscription by sending an UNSUBSCRIBE frame.
*/
void unsubscribe();
}
}
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
/**
* A contract for client STOMP session lifecycle events including a callback
* when the session is established and notifications of transport or message
* handling failures.
*
* <p>This contract also extends {@link StompFrameHandler} in order to handle
* STOMP ERROR frames received from the broker.
*
* <p>Implementations of this interface should consider extending
* {@link StompSessionHandlerAdapter}.
*
* @author Rossen Stoyanchev
* @since 4.2
* @see StompSessionHandlerAdapter
*/
public interface StompSessionHandler extends StompFrameHandler {
/**
* Invoked when the session is ready to use, i.e. after the underlying
* transport (TCP, WebSocket) is connected and a STOMP CONNECTED frame is
* received from the broker.
* @param session the client STOMP session
* @param connectedHeaders the STOMP CONNECTED frame headers
*/
void afterConnected(StompSession session, StompHeaders connectedHeaders);
/**
* Handle any exception arising while processing a STOMP frame such as a
* failure to convert the payload or an unhandled exception in the
* application {@code StompFrameHandler}.
* @param session the client STOMP session
* @param command the STOMP command of the frame
* @param headers the headers
* @param payload the raw payload
* @param exception the exception
*/
void handleException(StompSession session, StompCommand command, StompHeaders headers,
byte[] payload, Throwable exception);
/**
* Handle a low level transport error which could be an I/O error or a
* failure to encode or decode a STOMP message.
* <p>Note that
* {@link org.springframework.messaging.simp.stomp.ConnectionLostException
* ConnectionLostException} will be passed into this method when the
* connection is lost rather than closed normally via
* {@link StompSession#disconnect()}.
* @param session the client STOMP session
* @param exception the exception that occurred
*/
void handleTransportError(StompSession session, Throwable exception);
}
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.lang.reflect.Type;
/**
* Abstract adapter class for {@link StompSessionHandler} with mostly empty
* implementation methods except for {@link #getPayloadType} which returns String
* as the default Object type expected for STOMP ERROR frame payloads.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public abstract class StompSessionHandlerAdapter implements StompSessionHandler {
/**
* This implementation is empty.
*/
@Override
public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
}
/**
* This implementation returns String as the expected payload type
* for STOMP ERROR frames.
*/
@Override
public Type getPayloadType(StompHeaders headers) {
return String.class;
}
/**
* This implementation is empty.
*/
@Override
public void handleFrame(StompHeaders headers, Object payload) {
}
/**
* This implementation is empty.
*/
@Override
public void handleException(StompSession session, StompCommand command, StompHeaders headers,
byte[] payload, Throwable exception) {
}
/**
* This implementation is empty.
*/
@Override
public void handleTransportError(StompSession session, Throwable exception) {
}
}
......@@ -25,6 +25,7 @@ import org.springframework.messaging.Message;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ObjectUtils;
/**
......@@ -186,6 +187,17 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor {
setModified(true);
}
public void addNativeHeaders(MultiValueMap<String, String> headers) {
if (headers == null) {
return;
}
for (String header : headers.keySet()) {
for (String value : headers.get(header)) {
addNativeHeader(header, value);
}
}
}
public List<String> removeNativeHeader(String name) {
Assert.state(isMutable(), "Already immutable");
Map<String, List<String>> nativeHeaders = getNativeHeaders();
......
......@@ -196,7 +196,9 @@ public class Reactor11TcpClient<P> implements TcpOperations<P> {
};
}
finally {
this.environment.shutdown();
if (this.environment != null) {
this.environment.shutdown();
}
}
}
......
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.apache.activemq.broker.BrokerService;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.springframework.messaging.converter.StringMessageConverter;
import org.springframework.messaging.simp.stomp.StompSession.Subscription;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.SocketUtils;
import org.springframework.util.concurrent.ListenableFuture;
/**
* Integration tests for {@link Reactor11TcpStompClient}.
*
* @author Rossen Stoyanchev
*/
public class Reactor11TcpStompClientTests {
private static final Log logger = LogFactory.getLog(Reactor11TcpStompClientTests.class);
@Rule
public final TestName testName = new TestName();
private BrokerService activeMQBroker;
private Reactor11TcpStompClient client;
@Before
public void setUp() throws Exception {
logger.debug("Setting up before '" + this.testName.getMethodName() + "'");
int port = SocketUtils.findAvailableTcpPort(61613);
this.activeMQBroker = new BrokerService();
this.activeMQBroker.addConnector("stomp://127.0.0.1:" + port);
this.activeMQBroker.setStartAsync(false);
this.activeMQBroker.setPersistent(false);
this.activeMQBroker.setUseJmx(false);
this.activeMQBroker.getSystemUsage().getMemoryUsage().setLimit(1024 * 1024 * 5);
this.activeMQBroker.getSystemUsage().getTempUsage().setLimit(1024 * 1024 * 5);
this.activeMQBroker.start();
ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler();
taskScheduler.afterPropertiesSet();
this.client = new Reactor11TcpStompClient("127.0.0.1", port);
this.client.setMessageConverter(new StringMessageConverter());
this.client.setTaskScheduler(taskScheduler);
}
@After
public void tearDown() throws Exception {
try {
this.client.shutdown();
}
catch (Throwable ex) {
logger.error("Failed to shut client", ex);
}
final CountDownLatch latch = new CountDownLatch(1);
this.activeMQBroker.addShutdownHook(latch::countDown);
logger.debug("Stopping ActiveMQ broker and will await shutdown");
this.activeMQBroker.stop();
if (!latch.await(5, TimeUnit.SECONDS)) {
logger.debug("ActiveMQ broker did not shut in the expected time.");
}
}
@Test
public void publishSubscribe() throws Exception {
String destination = "/topic/foo";
ConsumingHandler consumingHandler1 = new ConsumingHandler(destination);
ListenableFuture<StompSession> consumerFuture1 = this.client.connect(consumingHandler1);
ConsumingHandler consumingHandler2 = new ConsumingHandler(destination);
ListenableFuture<StompSession> consumerFuture2 = this.client.connect(consumingHandler2);
assertTrue(consumingHandler1.awaitForSubscriptions(5000));
assertTrue(consumingHandler2.awaitForSubscriptions(5000));
ProducingHandler producingHandler = new ProducingHandler();
producingHandler.addToSend(destination, "foo1");
producingHandler.addToSend(destination, "foo2");
ListenableFuture<StompSession> producerFuture = this.client.connect(producingHandler);
assertTrue(consumingHandler1.awaitForMessageCount(2, 5000));
assertThat(consumingHandler1.getReceived(), containsInAnyOrder("foo1", "foo2"));
assertTrue(consumingHandler2.awaitForMessageCount(2, 5000));
assertThat(consumingHandler2.getReceived(), containsInAnyOrder("foo1", "foo2"));
consumerFuture1.get().disconnect();
consumerFuture2.get().disconnect();
producerFuture.get().disconnect();
}
private static class LoggingSessionHandler extends StompSessionHandlerAdapter {
@Override
public void handleException(StompSession session, StompCommand command,
StompHeaders headers, byte[] payload, Throwable ex) {
logger.error(command + " " + headers, ex);
}
@Override
public void handleFrame(StompHeaders headers, Object payload) {
logger.error("STOMP error frame " + headers + " payload=" + payload);
}
@Override
public void handleTransportError(StompSession session, Throwable exception) {
logger.error(exception);
}
}
private static class ConsumingHandler extends LoggingSessionHandler {
private final List<String> topics;
private final CountDownLatch subscriptionLatch;
private final List<String> received = new ArrayList<>();
public ConsumingHandler(String... topics) {
Assert.notEmpty(topics);
this.topics = Arrays.asList(topics);
this.subscriptionLatch = new CountDownLatch(this.topics.size());
}
public List<String> getReceived() {
return this.received;
}
@Override
public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
for (String topic : this.topics) {
session.setAutoReceipt(true);
Subscription subscription = session.subscribe(topic, new StompFrameHandler() {
@Override
public Type getPayloadType(StompHeaders headers) {
return String.class;
}
@Override
public void handleFrame(StompHeaders headers, Object payload) {
received.add((String) payload);
}
});
subscription.addReceiptTask(subscriptionLatch::countDown);
}
}
public boolean awaitForSubscriptions(long millisToWait) throws InterruptedException {
if (logger.isDebugEnabled()) {
logger.debug("Awaiting for subscription receipts");
}
return this.subscriptionLatch.await(millisToWait, TimeUnit.MILLISECONDS);
}
public boolean awaitForMessageCount(int expected, long millisToWait) throws InterruptedException {
if (logger.isDebugEnabled()) {
logger.debug("Awaiting for message count: " + expected);
}
long startTime = System.currentTimeMillis();
while (this.received.size() < expected) {
Thread.sleep(500);
if ((System.currentTimeMillis() - startTime) > millisToWait) {
return false;
}
}
return true;
}
}
private static class ProducingHandler extends LoggingSessionHandler {
private final List<String> topics = new ArrayList<>();
private final List<Object> payloads = new ArrayList<>();
public ProducingHandler addToSend(String topic, Object payload) {
this.topics.add(topic);
this.payloads.add(payload);
return this;
}
@Override
public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
for (int i=0; i < this.topics.size(); i++) {
session.send(this.topics.get(i), this.payloads.get(i));
}
}
}
}
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import static org.junit.Assert.*;
import static org.mockito.Matchers.same;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.MockitoAnnotations;
/**
* Unit tests for {@@link StompClientSupport}.
* @author Rossen Stoyanchev
*/
public class StompClientSupportTests {
private StompClientSupport stompClient;
@Before
public void setUp() throws Exception {
this.stompClient = new StompClientSupport() {};
}
@Test
public void defaultHearbeatValidation() throws Exception {
trySetDefaultHeartbeat(null);
trySetDefaultHeartbeat(new long[] {-1, 0});
trySetDefaultHeartbeat(new long[] {0, -1});
}
private void trySetDefaultHeartbeat(long[] heartbeat) {
try {
this.stompClient.setDefaultHeartbeat(heartbeat);
fail("Expected exception");
}
catch (IllegalArgumentException ex) {
// Ignore
}
}
@Test
public void defaultHeartbeatValue() throws Exception {
assertArrayEquals(new long[] {10000, 10000}, this.stompClient.getDefaultHeartbeat());
}
@Test
public void isDefaultHeartbeatEnabled() throws Exception {
assertArrayEquals(new long[] {10000, 10000}, this.stompClient.getDefaultHeartbeat());
assertTrue(this.stompClient.isDefaultHeartbeatEnabled());
this.stompClient.setDefaultHeartbeat(new long[] {0, 0});
assertFalse(this.stompClient.isDefaultHeartbeatEnabled());
}
@Test
public void processConnectHeadersDefault() throws Exception {
StompHeaders connectHeaders = this.stompClient.processConnectHeaders(null);
assertNotNull(connectHeaders);
assertArrayEquals(new long[] {10000, 10000}, connectHeaders.getHeartbeat());
}
@Test
public void processConnectHeadersWithExplicitHeartbeat() throws Exception {
StompHeaders connectHeaders = new StompHeaders();
connectHeaders.setHeartbeat(new long[] {15000, 15000});
connectHeaders = this.stompClient.processConnectHeaders(connectHeaders);
assertNotNull(connectHeaders);
assertArrayEquals(new long[] {15000, 15000}, connectHeaders.getHeartbeat());
}
}
......@@ -4,13 +4,10 @@ log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c] - %m%n
log4j.rootCategory=WARN, console
log4j.logger.org.springframework.messaging=DEBUG
log4j.logger.org.springframework.web=DEBUG
log4j.logger.org.apache.activemq=TRACE
# Enable TRACE level to chase integration test issues on CI servers
log4j.logger.org.springframework.messaging.simp.stomp=TRACE
log4j.logger.reactor.net=TRACE
log4j.logger.io.netty=TRACE
#log4j.logger.org.springframework.messaging=TRACE
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ScheduledFuture;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.Lifecycle;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession;
import org.springframework.messaging.simp.stomp.StompClientSupport;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.messaging.simp.stomp.StompSession;
import org.springframework.messaging.simp.stomp.StompSessionHandler;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.messaging.tcp.TcpConnectionHandler;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.sockjs.transport.SockJsSession;
import org.springframework.web.util.UriComponentsBuilder;
/**
* A STOMP over WebSocket client that connects using an implementation of
* {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}
* including {@link org.springframework.web.socket.sockjs.client.SockJsClient
* SockJsClient}.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public class WebSocketStompClient extends StompClientSupport implements SmartLifecycle {
private static Log logger = LogFactory.getLog(WebSocketStompClient.class);
private final WebSocketClient webSocketClient;
private int inboundMessageSizeLimit = 64 * 1024;
private boolean autoStartup = true;
private boolean running = false;
private int phase = Integer.MAX_VALUE;
/**
* Class constructor. Sets {@link #setDefaultHeartbeat} to "0,0" but will
* reset it back to the preferred "10000,10000" when a
* {@link #setTaskScheduler} is configured.
*
* @param webSocketClient the WebSocket client to connect with
*/
public WebSocketStompClient(WebSocketClient webSocketClient) {
Assert.notNull(webSocketClient, "'webSocketClient' is required.");
this.webSocketClient = webSocketClient;
setDefaultHeartbeat(new long[] {0, 0});
}
/**
* Return the configured WebSocketClient.
*/
public WebSocketClient getWebSocketClient() {
return this.webSocketClient;
}
/**
* {@inheritDoc}
* <p>Also automatically sets the {@link #setDefaultHeartbeat defaultHeartbeat}
* property to "10000,10000" if it is currently set to "0,0".
*/
@Override
public void setTaskScheduler(TaskScheduler taskScheduler) {
if (taskScheduler != null && !isDefaultHeartbeatEnabled()) {
setDefaultHeartbeat(new long[] {10000, 10000});
}
super.setTaskScheduler(taskScheduler);
}
/**
* Configure the maximum size allowed for inbound STOMP message.
* Since a STOMP message can be received in multiple WebSocket messages,
* buffering may be required and this property determines the maximum buffer
* size per message.
* <p>By default this is set to 64 * 1024 (64K).
*/
public void setInboundMessageSizeLimit(int inboundMessageSizeLimit) {
this.inboundMessageSizeLimit = inboundMessageSizeLimit;
}
/**
* Get the configured inbound message buffer size in bytes.
*/
public int getInboundMessageSizeLimit() {
return this.inboundMessageSizeLimit;
}
/**
* Set whether to auto-start the contained WebSocketClient when the Spring
* context has been refreshed.
* <p>Default is "true".
*/
public void setAutoStartup(boolean autoStartup) {
this.autoStartup = autoStartup;
}
/**
* Return the value for the 'autoStartup' property. If "true", this client
* will automatically start and stop the contained WebSocketClient.
*/
@Override
public boolean isAutoStartup() {
return this.autoStartup;
}
@Override
public boolean isRunning() {
return this.running;
}
/**
* Specify the phase in which the WebSocket client should be started and
* subsequently closed. The startup order proceeds from lowest to highest,
* and the shutdown order is the reverse of that.
* <p>By default this is Integer.MAX_VALUE meaning that the WebSocket client
* is started as late as possible and stopped as soon as possible.
*/
public void setPhase(int phase) {
this.phase = phase;
}
/**
* Return the configured phase.
*/
@Override
public int getPhase() {
return this.phase;
}
@Override
public void start() {
if (!isRunning()) {
this.running = true;
if (getWebSocketClient() instanceof Lifecycle) {
((Lifecycle) getWebSocketClient()).start();
}
}
}
@Override
public void stop() {
if (isRunning()) {
this.running = false;
if (getWebSocketClient() instanceof Lifecycle) {
((Lifecycle) getWebSocketClient()).stop();
}
}
}
@Override
public void stop(Runnable callback) {
this.stop();
callback.run();
}
/**
* Connect to the given WebSocket URL and notify the given
* {@link org.springframework.messaging.simp.stomp.StompSessionHandler}
* when connected on the STOMP level after the CONNECTED frame is received.
* @param url the url to connect to
* @param handler the session handler
* @param uriVars URI variables to expand into the URL
* @return ListenableFuture for access to the session when ready for use
*/
public ListenableFuture<StompSession> connect(String url, StompSessionHandler handler, Object... uriVars) {
return connect(url, null, handler, uriVars);
}
/**
* An overloaded version of
* {@link #connect(String, StompSessionHandler, Object...)} that also
* accepts {@link WebSocketHttpHeaders} to use for the WebSocket handshake.
* @param url the url to connect to
* @param handshakeHeaders the headers for the WebSocket handshake
* @param handler the session handler
* @param uriVariables URI variables to expand into the URL
* @return ListenableFuture for access to the session when ready for use
*/
public ListenableFuture<StompSession> connect(String url, WebSocketHttpHeaders handshakeHeaders,
StompSessionHandler handler, Object... uriVariables) {
return connect(url, handshakeHeaders, null, handler, uriVariables);
}
/**
* An overloaded version of
* {@link #connect(String, StompSessionHandler, Object...)} that also accepts
* {@link WebSocketHttpHeaders} to use for the WebSocket handshake and
* {@link StompHeaders} for the STOMP CONNECT frame.
* @param url the url to connect to
* @param handshakeHeaders headers for the WebSocket handshake
* @param connectHeaders headers for the STOMP CONNECT frame
* @param handler the session handler
* @param uriVariables URI variables to expand into the URL
* @return ListenableFuture for access to the session when ready for use
*/
public ListenableFuture<StompSession> connect(String url, WebSocketHttpHeaders handshakeHeaders,
StompHeaders connectHeaders, StompSessionHandler handler, Object... uriVariables) {
Assert.notNull(url, "uriTemplate must not be null");
URI uri = UriComponentsBuilder.fromUriString(url).buildAndExpand(uriVariables).encode().toUri();
return connect(uri, handshakeHeaders, connectHeaders, handler);
}
/**
* An overloaded version of
* {@link #connect(String, WebSocketHttpHeaders, StompSessionHandler, Object...)}
* that accepts a fully prepared {@link java.net.URI}.
* @param url the url to connect to
* @param handshakeHeaders the headers for the WebSocket handshake
* @param connectHeaders headers for the STOMP CONNECT frame
* @param sessionHandler the STOMP session handler
* @return ListenableFuture for access to the session when ready for use
*/
public ListenableFuture<StompSession> connect(URI url, WebSocketHttpHeaders handshakeHeaders,
StompHeaders connectHeaders, StompSessionHandler sessionHandler) {
Assert.notNull(url, "'uri' must not be null");
ConnectionHandlingStompSession session = createSession(connectHeaders, sessionHandler);
WebSocketTcpConnectionHandlerAdapter adapter = new WebSocketTcpConnectionHandlerAdapter(session);
getWebSocketClient().doHandshake(adapter, handshakeHeaders, url).addCallback(adapter);
return session.getSessionFuture();
}
@Override
protected StompHeaders processConnectHeaders(StompHeaders connectHeaders) {
connectHeaders = super.processConnectHeaders(connectHeaders);
if (connectHeaders.isHeartbeatEnabled()) {
Assert.notNull(getTaskScheduler(), "TaskScheduler cannot be null if heartbeats are enabled.");
}
return connectHeaders;
}
/**
* Adapt WebSocket to the TcpConnectionHandler and TcpConnection contracts.
*/
private class WebSocketTcpConnectionHandlerAdapter implements ListenableFutureCallback<WebSocketSession>,
WebSocketHandler, TcpConnection<byte[]> {
private final TcpConnectionHandler<byte[]> connectionHandler;
private final StompWebSocketMessageCodec codec = new StompWebSocketMessageCodec(getInboundMessageSizeLimit());
private volatile WebSocketSession session;
private volatile long lastReadTime = -1;
private volatile long lastWriteTime = -1;
private final List<ScheduledFuture<?>> inactivityTasks = new ArrayList<ScheduledFuture<?>>(2);
public WebSocketTcpConnectionHandlerAdapter(TcpConnectionHandler<byte[]> connectionHandler) {
Assert.notNull(connectionHandler);
this.connectionHandler = connectionHandler;
}
// ListenableFutureCallback implementation: handshake outcome
@Override
public void onSuccess(WebSocketSession webSocketSession) {
}
@Override
public void onFailure(Throwable ex) {
this.connectionHandler.afterConnectFailure(ex);
}
// WebSocketHandler implementation
@Override
public void afterConnectionEstablished(WebSocketSession session) {
this.session = session;
this.connectionHandler.afterConnected(this);
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> webSocketMessage) {
this.lastReadTime = (this.lastReadTime != -1 ? System.currentTimeMillis() : -1);
List<Message<byte[]>> messages;
try {
messages = this.codec.decode(webSocketMessage);
}
catch (Throwable ex) {
this.connectionHandler.handleFailure(ex);
return;
}
for (Message<byte[]> message : messages) {
this.connectionHandler.handleMessage(message);
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable ex) throws Exception {
this.connectionHandler.handleFailure(ex);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
cancelInactivityTasks();
this.connectionHandler.afterConnectionClosed();
}
private void cancelInactivityTasks() {
for (ScheduledFuture<?> task : this.inactivityTasks) {
try {
task.cancel(true);
}
catch (Throwable ex) {
// Ignore
}
}
this.lastReadTime = -1;
this.lastWriteTime = -1;
this.inactivityTasks.clear();
}
@Override
public boolean supportsPartialMessages() {
return false;
}
// TcpConnection implementation
@Override
public ListenableFuture<Void> send(Message<byte[]> message) {
updateLastWriteTime();
SettableListenableFuture<Void> future = new SettableListenableFuture<Void>();
try {
this.session.sendMessage(this.codec.encode(message, this.session.getClass()));
future.set(null);
}
catch (Throwable ex) {
future.setException(ex);
}
finally {
updateLastWriteTime();
}
return future;
}
private void updateLastWriteTime() {
this.lastWriteTime = (this.lastWriteTime != -1 ? System.currentTimeMillis() : -1);
}
@Override
public void onReadInactivity(final Runnable runnable, final long duration) {
Assert.notNull(getTaskScheduler(), "No scheduler configured.");
this.lastReadTime = System.currentTimeMillis();
this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(new Runnable() {
@Override
public void run() {
if (System.currentTimeMillis() - lastReadTime > duration) {
try {
runnable.run();
}
catch (Throwable ex) {
if (logger.isDebugEnabled()) {
logger.debug("ReadInactivityTask failure", ex);
}
}
}
}
}, duration / 2));
}
@Override
public void onWriteInactivity(final Runnable runnable, final long duration) {
Assert.notNull(getTaskScheduler(), "No scheduler configured.");
this.lastWriteTime = System.currentTimeMillis();
this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(new Runnable() {
@Override
public void run() {
if (System.currentTimeMillis() - lastWriteTime > duration) {
try {
runnable.run();
}
catch (Throwable ex) {
if (logger.isDebugEnabled()) {
logger.debug("WriteInactivityTask failure", ex);
}
}
}
}
}, duration / 2));
}
@Override
public void close() {
try {
this.session.close();
}
catch (IOException ex) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to close session: " + this.session.getId(), ex);
}
}
}
}
/**
* Encode and decode STOMP WebSocket messages.
*/
private static class StompWebSocketMessageCodec {
private static final StompEncoder ENCODER = new StompEncoder();
private static final StompDecoder DECODER = new StompDecoder();
private final BufferingStompDecoder bufferingDecoder;
public StompWebSocketMessageCodec(int messageSizeLimit) {
this.bufferingDecoder = new BufferingStompDecoder(DECODER, messageSizeLimit);
}
public List<Message<byte[]>> decode(WebSocketMessage<?> webSocketMessage) {
List<Message<byte[]>> result = Collections.<Message<byte[]>>emptyList();
ByteBuffer byteBuffer;
if (webSocketMessage instanceof TextMessage) {
byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
}
else if (webSocketMessage instanceof BinaryMessage) {
byteBuffer = ((BinaryMessage) webSocketMessage).getPayload();
}
else {
return result;
}
result = this.bufferingDecoder.decode(byteBuffer);
if (result.isEmpty()) {
if (logger.isTraceEnabled()) {
logger.trace("Incomplete STOMP frame content received, bufferSize=" +
this.bufferingDecoder.getBufferSize() + ", bufferSizeLimit=" +
this.bufferingDecoder.getBufferSizeLimit() + ".");
}
}
return result;
}
public WebSocketMessage<?> encode(Message<byte[]> message, Class<? extends WebSocketSession> sessionType) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Assert.notNull(accessor);
byte[] payload = message.getPayload();
byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload);
boolean useBinary = (payload.length > 0 &&
!(SockJsSession.class.isAssignableFrom(sessionType)) &&
MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(accessor.getContentType()));
return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes));
}
}
}
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.converter.StringMessageConverter;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompFrameHandler;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.messaging.simp.stomp.StompSession;
import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.TomcatWebSocketTestServer;
import org.springframework.web.socket.WebSocketTestServer;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurationSupport;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
/**
* Integration tests for {@link WebSocketStompClient}.
* @author Rossen Stoyanchev
*/
public class WebSocketStompClientIntegrationTests {
private static final Log logger = LogFactory.getLog(WebSocketStompClientIntegrationTests.class);
@Rule
public final TestName testName = new TestName();
private WebSocketStompClient stompClient;
private WebSocketTestServer server;
private AnnotationConfigWebApplicationContext wac;
@Before
public void setUp() throws Exception {
logger.debug("Setting up before '" + this.testName.getMethodName() + "'");
this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(TestConfig.class);
this.wac.refresh();
this.server = new TomcatWebSocketTestServer();
this.server.setup();
this.server.deployConfig(this.wac);
this.server.start();
WebSocketClient webSocketClient = new StandardWebSocketClient();
this.stompClient = new WebSocketStompClient(webSocketClient);
this.stompClient.setMessageConverter(new StringMessageConverter());
}
@After
public void tearDown() throws Exception {
try {
this.server.undeployConfig();
}
catch (Throwable t) {
logger.error("Failed to undeploy application config", t);
}
try {
this.server.stop();
}
catch (Throwable t) {
logger.error("Failed to stop server", t);
}
try {
this.wac.close();
}
catch (Throwable t) {
logger.error("Failed to close WebApplicationContext", t);
}
}
@Test
public void publishSubscribe() throws Exception {
String url = "ws://127.0.0.1:" + this.server.getPort() + "/stomp";
TestHandler testHandler = new TestHandler("/topic/foo", "payload");
this.stompClient.connect(url, testHandler);
assertTrue(testHandler.awaitForMessageCount(1, 5000));
assertThat(testHandler.getReceived(), containsInAnyOrder("payload"));
}
@Configuration
static class TestConfig extends WebSocketMessageBrokerConfigurationSupport {
@Override
protected void registerStompEndpoints(StompEndpointRegistry registry) {
// Can't rely on classpath detection
RequestUpgradeStrategy upgradeStrategy = new TomcatRequestUpgradeStrategy();
registry.addEndpoint("/stomp")
.setHandshakeHandler(new DefaultHandshakeHandler(upgradeStrategy))
.setAllowedOrigins("*");
}
@Override
public void configureMessageBroker(MessageBrokerRegistry configurer) {
configurer.setApplicationDestinationPrefixes("/app");
configurer.enableSimpleBroker("/topic", "/queue");
}
}
private static class TestHandler extends StompSessionHandlerAdapter {
private final String topic;
private final Object payload;
private final List<String> received = new ArrayList<>();
public TestHandler(String topic, Object payload) {
this.topic = topic;
this.payload = payload;
}
public List<String> getReceived() {
return this.received;
}
@Override
public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
session.subscribe(this.topic, new StompFrameHandler() {
@Override
public Type getPayloadType(StompHeaders headers) {
return String.class;
}
@Override
public void handleFrame(StompHeaders headers, Object payload) {
received.add((String) payload);
}
});
session.send(this.topic, this.payload);
}
public boolean awaitForMessageCount(int expected, long millisToWait) throws InterruptedException {
if (logger.isDebugEnabled()) {
logger.debug("Awaiting for message count: " + expected);
}
long startTime = System.currentTimeMillis();
while (this.received.size() < expected) {
Thread.sleep(500);
if ((System.currentTimeMillis() - startTime) > millisToWait) {
return false;
}
}
return true;
}
@Override
public void handleException(StompSession session, StompCommand command,
StompHeaders headers, byte[] payload, Throwable ex) {
logger.error(command + " " + headers, ex);
}
@Override
public void handleFrame(StompHeaders headers, Object payload) {
logger.error("STOMP error frame " + headers + " payload=" + payload);
}
@Override
public void handleTransportError(StompSession session, Throwable exception) {
logger.error(exception);
}
}
}
\ No newline at end of file
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isNotNull;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.*;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.concurrent.ScheduledFuture;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.messaging.simp.stomp.StompSessionHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;
/**
* Unit tests for {@link \WebSocketStompClient}.
*
* @author Rossen Stoyanchev
*/
public class WebSocketStompClientTests {
private static final Charset UTF_8 = Charset.forName("UTF-8");
private TestWebSocketStompClient stompClient;
@Mock
private TaskScheduler taskScheduler;
@Mock
private ConnectionHandlingStompSession stompSession;
private ArgumentCaptor<WebSocketHandler> webSocketHandlerCaptor;
private SettableListenableFuture<WebSocketSession> handshakeFuture;
@Mock
private WebSocketSession webSocketSession;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
WebSocketClient webSocketClient = mock(WebSocketClient.class);
this.stompClient = new TestWebSocketStompClient(webSocketClient);
this.stompClient.setTaskScheduler(this.taskScheduler);
this.stompClient.setStompSession(this.stompSession);
this.webSocketHandlerCaptor = ArgumentCaptor.forClass(WebSocketHandler.class);
this.handshakeFuture = new SettableListenableFuture<>();
when(webSocketClient.doHandshake(this.webSocketHandlerCaptor.capture(), any(), any(URI.class)))
.thenReturn(this.handshakeFuture);
}
@SuppressWarnings("unchecked")
@Test
public void webSocketHandshakeFailure() throws Exception {
connect();
IllegalStateException handshakeFailure = new IllegalStateException("simulated exception");
this.handshakeFuture.setException(handshakeFailure);
verify(this.stompSession).afterConnectFailure(same(handshakeFailure));
}
@SuppressWarnings("unchecked")
@Test
public void webSocketConnectionEstablished() throws Exception {
connect().afterConnectionEstablished(this.webSocketSession);
verify(this.stompSession).afterConnected(isNotNull(TcpConnection.class));
}
@Test
public void webSocketTransportError() throws Exception {
IllegalStateException exception = new IllegalStateException("simulated exception");
connect().handleTransportError(this.webSocketSession, exception);
verify(this.stompSession).handleFailure(same(exception));
}
@Test
public void webSocketConnectionClosed() throws Exception {
connect().afterConnectionClosed(this.webSocketSession, CloseStatus.NORMAL);
verify(this.stompSession).afterConnectionClosed();
}
@Test
public void handleWebSocketMessage() throws Exception {
String text = "SEND\na:alpha\n\nMessage payload\0";
connect().handleMessage(this.webSocketSession, new TextMessage(text));
ArgumentCaptor<? extends Message<byte[]>> captor = ArgumentCaptor.forClass(Message.class);
verify(this.stompSession).handleMessage(captor.capture());
Message<byte[]> message = captor.getValue();
assertNotNull(message);
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap());
assertEquals(StompCommand.SEND, accessor.getCommand());
assertEquals("alpha", headers.getFirst("a"));
assertEquals("Message payload", new String(message.getPayload(), UTF_8));
}
@Test
public void handleWebSocketMessageSplitAcrossTwoMessage() throws Exception {
WebSocketHandler webSocketHandler = connect();
String part1 = "SEND\na:alpha\n\nMessage";
webSocketHandler.handleMessage(this.webSocketSession, new TextMessage(part1));
verifyNoMoreInteractions(this.stompSession);
String part2 = " payload\0";
webSocketHandler.handleMessage(this.webSocketSession, new TextMessage(part2));
ArgumentCaptor<? extends Message<byte[]>> captor = ArgumentCaptor.forClass(Message.class);
verify(this.stompSession).handleMessage(captor.capture());
Message<byte[]> message = captor.getValue();
assertNotNull(message);
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap());
assertEquals(StompCommand.SEND, accessor.getCommand());
assertEquals("alpha", headers.getFirst("a"));
assertEquals("Message payload", new String(message.getPayload(), UTF_8));
}
@Test
public void handleWebSocketMessageBinary() throws Exception {
String text = "SEND\na:alpha\n\nMessage payload\0";
connect().handleMessage(this.webSocketSession, new BinaryMessage(text.getBytes(UTF_8)));
ArgumentCaptor<? extends Message<byte[]>> captor = ArgumentCaptor.forClass(Message.class);
verify(this.stompSession).handleMessage(captor.capture());
Message<byte[]> message = captor.getValue();
assertNotNull(message);
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap());
assertEquals(StompCommand.SEND, accessor.getCommand());
assertEquals("alpha", headers.getFirst("a"));
assertEquals("Message payload", new String(message.getPayload(), UTF_8));
}
@Test
public void handleWebSocketMessagePong() throws Exception {
connect().handleMessage(this.webSocketSession, new PongMessage());
verifyNoMoreInteractions(this.stompSession);
}
@Test
public void sendWebSocketMessage() throws Exception {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/topic/foo");
byte[] payload = "payload".getBytes(UTF_8);
getTcpConnection().send(MessageBuilder.createMessage(payload, accessor.getMessageHeaders()));
ArgumentCaptor<TextMessage> textMessageCaptor = ArgumentCaptor.forClass(TextMessage.class);
verify(this.webSocketSession).sendMessage(textMessageCaptor.capture());
TextMessage textMessage = textMessageCaptor.getValue();
assertNotNull(textMessage);
assertEquals("SEND\ndestination:/topic/foo\ncontent-length:7\n\npayload\0", textMessage.getPayload());
}
@Test
public void sendWebSocketBinary() throws Exception {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/b");
accessor.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM);
byte[] payload = "payload".getBytes(UTF_8);
getTcpConnection().send(MessageBuilder.createMessage(payload, accessor.getMessageHeaders()));
ArgumentCaptor<BinaryMessage> binaryMessageCaptor = ArgumentCaptor.forClass(BinaryMessage.class);
verify(this.webSocketSession).sendMessage(binaryMessageCaptor.capture());
BinaryMessage binaryMessage = binaryMessageCaptor.getValue();
assertNotNull(binaryMessage);
assertEquals("SEND\ndestination:/b\ncontent-type:application/octet-stream\ncontent-length:7\n\npayload\0",
new String(binaryMessage.getPayload().array(), UTF_8));
}
@Test
public void heartbeatDefaultValue() throws Exception {
WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class));
assertArrayEquals(new long[] {0, 0}, stompClient.getDefaultHeartbeat());
StompHeaders connectHeaders = stompClient.processConnectHeaders(null);
assertArrayEquals(new long[] {0, 0}, connectHeaders.getHeartbeat());
}
@Test
public void heartbeatDefaultValueWithScheduler() throws Exception {
WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class));
stompClient.setTaskScheduler(mock(TaskScheduler.class));
assertArrayEquals(new long[] {10000, 10000}, stompClient.getDefaultHeartbeat());
StompHeaders connectHeaders = stompClient.processConnectHeaders(null);
assertArrayEquals(new long[] {10000, 10000}, connectHeaders.getHeartbeat());
}
@Test
public void heartbeatDefaultValueSetWithoutScheduler() throws Exception {
WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class));
stompClient.setDefaultHeartbeat(new long[] {5, 5});
try {
stompClient.processConnectHeaders(null);
fail("Expected exception");
}
catch (IllegalArgumentException ex) {
// Ignore
}
}
@Test
public void readInactivityAfterDelayHasElapsed() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
Runnable runnable = mock(Runnable.class);
long delay = 2;
tcpConnection.onReadInactivity(runnable, delay);
testInactivityTaskScheduling(runnable, delay, 10);
}
@Test
public void readInactivityBeforeDelayHasElapsed() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
Runnable runnable = mock(Runnable.class);
long delay = 10000;
tcpConnection.onReadInactivity(runnable, delay);
testInactivityTaskScheduling(runnable, delay, 0);
}
@Test
public void writeInactivityAfterDelayHasElapsed() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
Runnable runnable = mock(Runnable.class);
long delay = 2;
tcpConnection.onWriteInactivity(runnable, delay);
testInactivityTaskScheduling(runnable, delay, 10);
}
@Test
public void writeInactivityBeforeDelayHasElapsed() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
Runnable runnable = mock(Runnable.class);
long delay = 1000;
tcpConnection.onWriteInactivity(runnable, delay);
testInactivityTaskScheduling(runnable, delay, 0);
}
@Test
public void cancelInactivityTasks() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
ScheduledFuture future = mock(ScheduledFuture.class);
when(this.taskScheduler.scheduleWithFixedDelay(any(), eq(1L))).thenReturn(future);
tcpConnection.onReadInactivity(mock(Runnable.class), 2L);
tcpConnection.onWriteInactivity(mock(Runnable.class), 2L);
this.webSocketHandlerCaptor.getValue().afterConnectionClosed(this.webSocketSession, CloseStatus.NORMAL);
verify(future, times(2)).cancel(true);
verifyNoMoreInteractions(future);
}
private WebSocketHandler connect() {
this.stompClient.connect("/foo", mock(StompSessionHandler.class));
verify(this.stompSession).getSessionFuture();
verifyNoMoreInteractions(this.stompSession);
WebSocketHandler webSocketHandler = this.webSocketHandlerCaptor.getValue();
assertNotNull(webSocketHandler);
return webSocketHandler;
}
@SuppressWarnings("unchecked")
private TcpConnection<byte[]> getTcpConnection() throws Exception {
WebSocketHandler webSocketHandler = connect();
webSocketHandler.afterConnectionEstablished(this.webSocketSession);
return (TcpConnection<byte[]>) webSocketHandler;
}
private void testInactivityTaskScheduling(Runnable runnable, long delay, long sleepTime)
throws InterruptedException {
ArgumentCaptor<Runnable> inactivityTaskCaptor = ArgumentCaptor.forClass(Runnable.class);
verify(this.taskScheduler).scheduleWithFixedDelay(inactivityTaskCaptor.capture(), eq(delay/2));
verifyNoMoreInteractions(this.taskScheduler);
if (sleepTime > 0) {
Thread.sleep(sleepTime);
}
Runnable inactivityTask = inactivityTaskCaptor.getValue();
assertNotNull(inactivityTask);
inactivityTask.run();
if (sleepTime > 0) {
verify(runnable).run();
}
else {
verifyNoMoreInteractions(runnable);
}
}
private static class TestWebSocketStompClient extends WebSocketStompClient {
private ConnectionHandlingStompSession stompSession;
public TestWebSocketStompClient(WebSocketClient webSocketClient) {
super(webSocketClient);
}
public void setStompSession(ConnectionHandlingStompSession stompSession) {
this.stompSession = stompSession;
}
@Override
protected ConnectionHandlingStompSession createSession(StompHeaders headers, StompSessionHandler handler) {
return this.stompSession;
}
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册