提交 e82df99a 编写于 作者: R Rossen Stoyanchev

Add SockJS client

This change adds a new implementation of WebSocketClient that can
connect to a SockJS server using one of the SockJS transports
"websocket", "xhr_streaming", or "xhr". From a client perspective
there is no implementation difference between "xhr_streaming" and
"xhr". Just keep receiving and when the response is complete,
start over. Other SockJS transports are browser specific
and therefore not relevant in Java ("eventsource", "htmlfile" or
iframe based variations).

The client loosely mimics the behavior of the JavaScript SockJS client.
First it sends an info request to find the server capabilities,
then it tries to connect with each configured transport, falling
back, or forcing a timeout and then falling back, until one of the
configured transports succeeds.

The WebSocketTransport can be configured with any Spring Framework
WebSocketClient implementation (currently JSR-356 or Jetty 9).

The XhrTransport currently has a RestTemplate-based and a Jetty
HttpClient-based implementations. To use those to simulate a large
number of users be sure to configure Jetty's HttpClient executor
and maxConnectionsPerDestination to high numbers. The same is true
for whichever underlying HTTP library is used with the RestTemplate
(e.g. maxConnPerRoute and maxConnTotal in Apache HttpComponents).

Issue: SPR-10797
上级 dc1d85d0
......@@ -670,6 +670,7 @@ project("spring-websocket") {
exclude group: "javax.servlet", module: "javax.servlet"
}
optional("org.eclipse.jetty.websocket:websocket-client:${jettyVersion}")
optional("org.eclipse.jetty:jetty-client:${jettyVersion}")
optional("io.undertow:undertow-core:1.0.15.Final")
optional("io.undertow:undertow-servlet:1.0.15.Final") {
exclude group: "org.jboss.spec.javax.servlet", module: "jboss-servlet-api_3.1_spec"
......
......@@ -217,7 +217,7 @@ public final class CloseStatus {
@Override
public String toString() {
return "CloseStatus [code=" + this.code + ", reason=" + this.reason + "]";
return "CloseStatus[code=" + this.code + ", reason=" + this.reason + "]";
}
}
......@@ -73,10 +73,7 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
WebSocketHttpHeaders headers, URI uri) {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
Assert.notNull(uri, "uri must not be null");
String scheme = uri.getScheme();
Assert.isTrue(((scheme != null) && ("ws".equals(scheme) || "wss".equals(scheme))), "Invalid scheme: " + scheme);
assertUri(uri);
if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + uri);
......@@ -101,6 +98,12 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
Collections.<String, Object>emptyMap());
}
protected void assertUri(URI uri) {
Assert.notNull(uri, "uri must not be null");
String scheme = uri.getScheme();
Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme)), "Invalid scheme: " + scheme);
}
/**
* Perform the actual handshake to establish a connection to the server.
*
......
......@@ -194,7 +194,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
}
catch (Throwable ex) {
logger.error("Failed to parse WebSocket message to STOMP frame(s)", ex);
logger.error("Failed to parse WebSocket message to STOMP." +
"Sending STOMP ERROR to client, sessionId=" + session.getId(), ex);
sendErrorMessage(session, ex);
return;
}
......@@ -232,7 +233,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
}
catch (Throwable ex) {
logger.error("Terminating STOMP session due to failure to send message", ex);
logger.error("Parsed STOMP message but could not send it to to message channel. " +
"Sending STOMP ERROR to client, sessionId=" + session.getId(), ex);
sendErrorMessage(session, ex);
}
}
......@@ -248,7 +250,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
headerAccessor.setMessage(error.getMessage());
byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD);
......@@ -331,7 +332,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
throw ex;
}
catch (Throwable ex) {
sendErrorMessage(session, ex);
logger.error("Failed to send WebSocket message to client, sessionId=" + session.getId(), ex);
command = StompCommand.ERROR;
}
finally {
if (StompCommand.ERROR.equals(command)) {
......
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import org.springframework.web.socket.sockjs.frame.SockJsFrameType;
import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
import java.io.IOException;
import java.net.URI;
import java.security.Principal;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Base class for SockJS client implementations of {@link WebSocketSession}.
* Provides processing of incoming SockJS message frames and delegates lifecycle
* events and messages to the (application) {@link WebSocketHandler}.
* Sub-classes implement actual send as well as disconnect logic.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public abstract class AbstractClientSockJsSession implements WebSocketSession {
protected final Log logger = LogFactory.getLog(getClass());
private final TransportRequest request;
private final WebSocketHandler webSocketHandler;
private final SettableListenableFuture<WebSocketSession> connectFuture;
private final Map<String, Object> attributes = new ConcurrentHashMap<String, Object>();
private volatile State state = State.NEW;
private volatile CloseStatus closeStatus;
protected AbstractClientSockJsSession(TransportRequest request, WebSocketHandler handler,
SettableListenableFuture<WebSocketSession> connectFuture) {
Assert.notNull(request, "'request' is required");
Assert.notNull(handler, "'handler' is required");
Assert.notNull(connectFuture, "'connectFuture' is required");
this.request = request;
this.webSocketHandler = handler;
this.connectFuture = connectFuture;
}
@Override
public String getId() {
return this.request.getSockJsUrlInfo().getSessionId();
}
@Override
public URI getUri() {
return this.request.getSockJsUrlInfo().getSockJsUrl();
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.request.getHandshakeHeaders();
}
@Override
public Map<String, Object> getAttributes() {
return this.attributes;
}
@Override
public Principal getPrincipal() {
return this.request.getUser();
}
public SockJsMessageCodec getMessageCodec() {
return this.request.getMessageCodec();
}
public WebSocketHandler getWebSocketHandler() {
return this.webSocketHandler;
}
/**
* Return a timeout cleanup task to invoke if the SockJS sessions is not
* fully established within the retransmission timeout period calculated in
* {@code SockJsRequest} based on the duration of the initial SockJS "Info"
* request.
*/
Runnable getTimeoutTask() {
return new Runnable() {
@Override
public void run() {
closeInternal(new CloseStatus(2007, "Transport timed out"));
}
};
}
@Override
public boolean isOpen() {
return State.OPEN.equals(this.state);
}
public boolean isDisconnected() {
return (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state));
}
@Override
public final void sendMessage(WebSocketMessage<?> message) throws IOException {
Assert.state(State.OPEN.equals(this.state), this + " is not open, current state=" + this.state);
Assert.isInstanceOf(TextMessage.class, message, this + " supports text messages only.");
String payload = ((TextMessage) message).getPayload();
payload = getMessageCodec().encode(new String[] { payload });
payload = payload.substring(1); // the client-side doesn't need message framing (letter "a")
message = new TextMessage(payload);
if (logger.isTraceEnabled()) {
logger.trace("Sending message " + message + " in " + this);
}
sendInternal((TextMessage) message);
}
protected abstract void sendInternal(TextMessage textMessage) throws IOException;
@Override
public final void close() throws IOException {
close(CloseStatus.NORMAL);
}
@Override
public final void close(CloseStatus status) {
Assert.isTrue(status != null && isUserSetStatus(status), "Invalid close status: " + status);
if (logger.isInfoEnabled()) {
logger.info("Closing session with " + status + " in " + this);
}
closeInternal(status);
}
private boolean isUserSetStatus(CloseStatus status) {
return (status.getCode() == 1000 || (status.getCode() >= 3000 && status.getCode() <= 4999));
}
protected void closeInternal(CloseStatus status) {
if (this.state == null) {
logger.warn("Ignoring close since connect() was never invoked");
return;
}
if (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state)) {
logger.debug("Ignoring close (already closing or closed), current state=" + this.state);
return;
}
this.state = State.CLOSING;
this.closeStatus = status;
try {
disconnect(status);
}
catch (Throwable ex) {
if (logger.isErrorEnabled()) {
logger.error("Failed to close " + this, ex);
}
}
}
protected abstract void disconnect(CloseStatus status) throws IOException;
public void handleFrame(String payload) {
SockJsFrame frame = new SockJsFrame(payload);
if (SockJsFrameType.OPEN.equals(frame.getType())) {
handleOpenFrame();
}
else if (SockJsFrameType.MESSAGE.equals(frame.getType())) {
handleMessageFrame(frame);
}
else if (SockJsFrameType.CLOSE.equals(frame.getType())) {
handleCloseFrame(frame);
}
else if (SockJsFrameType.HEARTBEAT.equals(frame.getType())) {
if (logger.isTraceEnabled()) {
logger.trace("Received heartbeat in " + this);
}
}
else {
// should never happen
throw new IllegalStateException("Unknown SockJS frame type " + frame + " in " + this);
}
}
private void handleOpenFrame() {
if (logger.isInfoEnabled()) {
logger.info("Processing SockJS open frame in " + this);
}
if (State.NEW.equals(state)) {
this.state = State.OPEN;
try {
this.webSocketHandler.afterConnectionEstablished(this);
this.connectFuture.set(this);
}
catch (Throwable ex) {
if (logger.isErrorEnabled()) {
Class<?> type = this.webSocketHandler.getClass();
logger.error(type + ".afterConnectionEstablished threw exception in " + this, ex);
}
}
}
else {
if (logger.isDebugEnabled()) {
logger.debug("Open frame received in " + getId() + " but we're not" +
"connecting (current state=" + this.state + "). The server might " +
"have been restarted and lost track of the session.");
}
closeInternal(new CloseStatus(1006, "Server lost session"));
}
}
private void handleMessageFrame(SockJsFrame frame) {
if (!isOpen()) {
if (logger.isWarnEnabled()) {
logger.warn("Ignoring received message due to state=" + this.state + " in " + this);
}
return;
}
String[] messages;
try {
messages = getMessageCodec().decode(frame.getFrameData());
}
catch (IOException ex) {
if (logger.isErrorEnabled()) {
logger.error("Failed to decode data for SockJS \"message\" frame: " + frame + " in " + this, ex);
}
closeInternal(CloseStatus.BAD_DATA);
return;
}
if (logger.isTraceEnabled()) {
logger.trace("Processing SockJS message frame " + frame.getContent() + " in " + this);
}
for (String message : messages) {
try {
if (isOpen()) {
this.webSocketHandler.handleMessage(this, new TextMessage(message));
}
}
catch (Throwable ex) {
Class<?> type = this.webSocketHandler.getClass();
logger.error(type + ".handleMessage threw an exception on " + frame + " in " + this, ex);
}
}
}
private void handleCloseFrame(SockJsFrame frame) {
CloseStatus closeStatus = CloseStatus.NO_STATUS_CODE;
try {
String[] data = getMessageCodec().decode(frame.getFrameData());
if (data.length == 2) {
closeStatus = new CloseStatus(Integer.valueOf(data[0]), data[1]);
}
if (logger.isInfoEnabled()) {
logger.info("Processing SockJS close frame with " + closeStatus + " in " + this);
}
}
catch (IOException ex) {
if (logger.isErrorEnabled()) {
logger.error("Failed to decode data for " + frame + " in " + this, ex);
}
}
closeInternal(closeStatus);
}
public void handleTransportError(Throwable error) {
try {
if (logger.isErrorEnabled()) {
logger.error("Transport error in " + this, error);
}
this.webSocketHandler.handleTransportError(this, error);
}
catch (Exception ex) {
Class<?> type = this.webSocketHandler.getClass();
if (logger.isErrorEnabled()) {
logger.error(type + ".handleTransportError threw an exception", ex);
}
}
}
public void afterTransportClosed(CloseStatus closeStatus) {
this.closeStatus = (this.closeStatus != null ? this.closeStatus : closeStatus);
Assert.state(this.closeStatus != null, "CloseStatus not available");
if (logger.isInfoEnabled()) {
logger.info("Transport closed with " + this.closeStatus + " in " + this);
}
this.state = State.CLOSED;
try {
this.webSocketHandler.afterConnectionClosed(this, this.closeStatus);
}
catch (Exception ex) {
if (logger.isErrorEnabled()) {
Class<?> type = this.webSocketHandler.getClass();
logger.error(type + ".afterConnectionClosed threw an exception", ex);
}
}
}
@Override
public String toString() {
return getClass().getSimpleName() + "[id='" + getId() + ", url=" + getUri() + "]";
}
private enum State { NEW, OPEN, CLOSING, CLOSED }
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import java.net.URI;
/**
* Abstract base class for XHR transport implementations to extend.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public abstract class AbstractXhrTransport implements XhrTransport {
protected static final String PRELUDE;
static {
byte[] bytes = new byte[2048];
for (int i = 0; i < bytes.length; i++) {
bytes[i] = 'h';
}
PRELUDE = new String(bytes, SockJsFrame.CHARSET);
}
protected Log logger = LogFactory.getLog(getClass());
private boolean xhrStreamingDisabled;
private HttpHeaders requestHeaders = new HttpHeaders();
private HttpHeaders xhrSendRequestHeaders = new HttpHeaders();
/**
* Whether to attempt to connect with "xhr_streaming" first before trying
* with "xhr" next, see {@link XhrTransport#isXhrStreamingDisabled()}.
*
* <p>By default this property is set to {@code false} which means both
* "xhr_streaming" and "xhr" will be tried.
*/
public void setXhrStreamingDisabled(boolean disabled) {
this.xhrStreamingDisabled = disabled;
}
public boolean isXhrStreamingDisabled() {
return this.xhrStreamingDisabled;
}
/**
* Configure headers to be added to every executed HTTP request.
* @param requestHeaders the headers to add to requests
*/
public void setRequestHeaders(HttpHeaders requestHeaders) {
this.requestHeaders.clear();
this.xhrSendRequestHeaders.clear();
if (requestHeaders != null) {
this.requestHeaders.putAll(requestHeaders);
this.xhrSendRequestHeaders.putAll(requestHeaders);
this.xhrSendRequestHeaders.setContentType(MediaType.APPLICATION_JSON);
}
}
public HttpHeaders getRequestHeaders() {
return this.requestHeaders;
}
@Override
public String executeInfoRequest(URI infoUrl) {
if (logger.isDebugEnabled()) {
logger.debug("Executing SockJS Info request, url=" + infoUrl);
}
ResponseEntity<String> response = executeInfoRequestInternal(infoUrl);
if (response.getStatusCode() != HttpStatus.OK) {
if (logger.isErrorEnabled()) {
logger.error("SockJS Info request (url=" + infoUrl + ") failed: " + response);
}
throw new HttpServerErrorException(response.getStatusCode());
}
if (logger.isDebugEnabled()) {
logger.debug("SockJS Info request (url=" + infoUrl + ") response: " + response);
}
return response.getBody();
}
protected abstract ResponseEntity<String> executeInfoRequestInternal(URI infoUrl);
@Override
public void executeSendRequest(URI url, TextMessage message) {
if (logger.isDebugEnabled()) {
logger.debug("Starting XHR send, url=" + url);
}
ResponseEntity<String> response = executeSendRequestInternal(url, this.xhrSendRequestHeaders, message);
if (response.getStatusCode() != HttpStatus.NO_CONTENT) {
if (logger.isErrorEnabled()) {
logger.error("XHR send request (url=" + url + ") failed: " + response);
}
throw new HttpServerErrorException(response.getStatusCode());
}
if (logger.isDebugEnabled()) {
logger.debug("XHR send request (url=" + url + ") response: " + response);
}
}
protected abstract ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message);
@Override
public ListenableFuture<WebSocketSession> connect(TransportRequest request, WebSocketHandler handler) {
SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<WebSocketSession>();
XhrClientSockJsSession session = new XhrClientSockJsSession(request, handler, this, connectFuture);
request.addTimeoutTask(session.getTimeoutTask());
URI receiveUrl = request.getTransportUrl();
if (logger.isDebugEnabled()) {
logger.debug("Opening XHR session, receive url=" + receiveUrl);
}
HttpHeaders handshakeHeaders = new HttpHeaders();
handshakeHeaders.putAll(request.getHandshakeHeaders());
handshakeHeaders.putAll(getRequestHeaders());
connectInternal(request, handler, receiveUrl, handshakeHeaders, session, connectFuture);
return connectFuture;
}
protected abstract void connectInternal(TransportRequest request, WebSocketHandler handler,
URI receiveUrl, HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
SettableListenableFuture<WebSocketSession> connectFuture);
@Override
public String toString() {
return getClass().getSimpleName();
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.net.URI;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* A default implementation of
* {@link org.springframework.web.socket.sockjs.client.TransportRequest
* TransportRequest}.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
class DefaultTransportRequest implements TransportRequest {
private static Log logger = LogFactory.getLog(DefaultTransportRequest.class);
private final SockJsUrlInfo sockJsUrlInfo;
private final HttpHeaders handshakeHeaders;
private final Transport transport;
private final TransportType serverTransportType;
private SockJsMessageCodec codec;
private Principal user;
private long timeoutValue;
private TaskScheduler timeoutScheduler;
private final List<Runnable> timeoutTasks = new ArrayList<Runnable>();
private DefaultTransportRequest fallbackRequest;
public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshakeHeaders,
Transport transport, TransportType serverTransportType, SockJsMessageCodec codec) {
Assert.notNull(sockJsUrlInfo, "'sockJsUrlInfo' is required");
Assert.notNull(transport, "'transport' is required");
Assert.notNull(serverTransportType, "'transportType' is required");
Assert.notNull(codec, "'codec' is required");
this.sockJsUrlInfo = sockJsUrlInfo;
this.handshakeHeaders = (handshakeHeaders != null ? handshakeHeaders : new HttpHeaders());
this.transport = transport;
this.serverTransportType = serverTransportType;
this.codec = codec;
}
@Override
public SockJsUrlInfo getSockJsUrlInfo() {
return this.sockJsUrlInfo;
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.handshakeHeaders;
}
@Override
public URI getTransportUrl() {
return this.sockJsUrlInfo.getTransportUrl(this.serverTransportType);
}
public void setUser(Principal user) {
this.user = user;
}
@Override
public Principal getUser() {
return this.user;
}
@Override
public SockJsMessageCodec getMessageCodec() {
return this.codec;
}
public void setTimeoutValue(long timeoutValue) {
this.timeoutValue = timeoutValue;
}
public void setTimeoutScheduler(TaskScheduler scheduler) {
this.timeoutScheduler = scheduler;
}
@Override
public void addTimeoutTask(Runnable runnable) {
this.timeoutTasks.add(runnable);
}
public void setFallbackRequest(DefaultTransportRequest fallbackRequest) {
this.fallbackRequest = fallbackRequest;
}
public void connect(WebSocketHandler handler, SettableListenableFuture<WebSocketSession> future) {
if (logger.isDebugEnabled()) {
logger.debug("Starting " + this);
}
ConnectCallback connectCallback = new ConnectCallback(handler, future);
scheduleConnectTimeoutTask(connectCallback);
this.transport.connect(this, handler).addCallback(connectCallback);
}
private void scheduleConnectTimeoutTask(ConnectCallback connectHandler) {
if (this.timeoutScheduler != null) {
if (logger.isDebugEnabled()) {
logger.debug("Scheduling connect to time out after " + this.timeoutValue + " milliseconds");
}
Date timeoutDate = new Date(System.currentTimeMillis() + this.timeoutValue);
this.timeoutScheduler.schedule(connectHandler, timeoutDate);
}
else if (logger.isDebugEnabled()) {
logger.debug("Connect timeout task not scheduled. Is SockJsClient configured with a TaskScheduler?");
}
}
@Override
public String toString() {
return "TransportRequest[url=" + getTransportUrl() + "]";
}
/**
* Updates the given (global) future based success or failure to connect for
* the entire SockJS request regardless of which transport actually managed
* to connect. Also implements {@code Runnable} to handle a scheduled timeout
* callback.
*/
private class ConnectCallback implements ListenableFutureCallback<WebSocketSession>, Runnable {
private final WebSocketHandler handler;
private final SettableListenableFuture<WebSocketSession> future;
private final AtomicBoolean handled = new AtomicBoolean(false);
public ConnectCallback(WebSocketHandler handler, SettableListenableFuture<WebSocketSession> future) {
this.handler = handler;
this.future = future;
}
@Override
public void onSuccess(WebSocketSession session) {
if (this.handled.compareAndSet(false, true)) {
this.future.set(session);
}
else {
logger.error("Connect success/failure already handled for " + DefaultTransportRequest.this);
}
}
@Override
public void onFailure(Throwable failure) {
handleFailure(failure, false);
}
@Override
public void run() {
handleFailure(null, true);
}
private void handleFailure(Throwable failure, boolean isTimeoutFailure) {
if (this.handled.compareAndSet(false, true)) {
if (isTimeoutFailure) {
String message = "Connect timed out for " + DefaultTransportRequest.this;
logger.error(message);
failure = new SockJsTransportFailureException(message, getSockJsUrlInfo().getSessionId(), null);
}
if (fallbackRequest != null) {
logger.error(DefaultTransportRequest.this + " failed. Falling back on next transport.", failure);
fallbackRequest.connect(this.handler, this.future);
}
else {
logger.error("No more fallback transports after " + DefaultTransportRequest.this, failure);
this.future.setException(failure);
}
if (isTimeoutFailure) {
try {
for (Runnable runnable : timeoutTasks) {
runnable.run();
}
}
catch (Throwable ex) {
logger.error("Transport failed to run timeout tasks for " + DefaultTransportRequest.this, ex);
}
}
}
else {
logger.error("Connect success/failure events already took place for " +
DefaultTransportRequest.this + ". Ignoring this additional failure event.", failure);
}
}
}
}
package org.springframework.web.socket.sockjs.client;
import java.net.URI;
/**
* A simple contract for executing the SockJS "Info" request before the SockJS
* session starts. The request is used to check server capabilities such as
* whether it permits use of the WebSocket transport.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public interface InfoReceiver {
/**
* Perform an HTTP request to the SockJS "Info" URL.
* and return the resulting JSON response content, or raise an exception.
*
* @param infoUrl the URL to obtain SockJS server information from
* @return the body of the response
*/
String executeInfoRequest(URI infoUrl);
}
\ No newline at end of file
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.client.api.ContentResponse;
import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.client.api.Response;
import org.eclipse.jetty.client.util.StringContentProvider;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpMethod;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import java.io.ByteArrayOutputStream;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.Enumeration;
/**
* An XHR transport based on Jetty's {@link org.eclipse.jetty.client.HttpClient}.
*
* <p>When used for testing purposes (e.g. load testing) the {@code HttpClient}
* properties must be set to allow a larger than usual number of connections and
* threads. For example:
*
* <pre class="code">
* HttpClient httpClient = new HttpClient();
* httpClient.setMaxConnectionsPerDestination(1000);
* httpClient.setExecutor(new QueuedThreadPool(500));
* </pre>
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransport {
private final HttpClient httpClient;
public JettyXhrTransport(HttpClient httpClient) {
Assert.notNull(httpClient, "'httpClient' is required");
this.httpClient = httpClient;
}
public HttpClient getHttpClient() {
return this.httpClient;
}
@Override
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
return executeRequest(infoUrl, HttpMethod.GET, getRequestHeaders(), null);
}
@Override
public ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
return executeRequest(url, HttpMethod.POST, headers, message.getPayload());
}
protected ResponseEntity<String> executeRequest(URI url, HttpMethod method, HttpHeaders headers, String body) {
Request httpRequest = this.httpClient.newRequest(url).method(method);
addHttpHeaders(httpRequest, headers);
if (body != null) {
httpRequest.content(new StringContentProvider(body));
}
ContentResponse response;
try {
response = httpRequest.send();
}
catch (Exception ex) {
throw new SockJsTransportFailureException("Failed to execute request to " + url, null, ex);
}
HttpStatus status = HttpStatus.valueOf(response.getStatus());
HttpHeaders responseHeaders = toHttpHeaders(response.getHeaders());
return (response.getContent() != null ?
new ResponseEntity<String>(response.getContentAsString(), responseHeaders, status) :
new ResponseEntity<String>(responseHeaders, status));
}
private static void addHttpHeaders(Request request, HttpHeaders headers) {
for (String name : headers.keySet()) {
for (String value : headers.get(name)) {
request.header(name, value);
}
}
}
private static HttpHeaders toHttpHeaders(HttpFields httpFields) {
HttpHeaders responseHeaders = new HttpHeaders();
Enumeration<String> names = httpFields.getFieldNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
Enumeration<String> values = httpFields.getValues(name);
while (values.hasMoreElements()) {
String value = values.nextElement();
responseHeaders.add(name, value);
}
}
return responseHeaders;
}
@Override
protected void connectInternal(TransportRequest request, WebSocketHandler handler,
URI url, HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
SettableListenableFuture<WebSocketSession> connectFuture) {
SockJsResponseListener listener = new SockJsResponseListener(url, getRequestHeaders(), session, connectFuture);
executeReceiveRequest(url, handshakeHeaders, listener);
}
private void executeReceiveRequest(URI url, HttpHeaders headers, SockJsResponseListener listener) {
if (logger.isDebugEnabled()) {
logger.debug("Starting XHR receive request, url=" + url);
}
Request httpRequest = this.httpClient.newRequest(url).method(HttpMethod.POST);
addHttpHeaders(httpRequest, headers);
httpRequest.send(listener);
}
/**
* Splits the body of an HTTP response into SockJS frames and delegates those
* to an {@link XhrClientSockJsSession}.
*/
private class SockJsResponseListener extends Response.Listener.Adapter {
private final URI transportUrl;
private final HttpHeaders receiveHeaders;
private final XhrClientSockJsSession sockJsSession;
private final SettableListenableFuture<WebSocketSession> connectFuture;
private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
public SockJsResponseListener(URI url, HttpHeaders headers, XhrClientSockJsSession sockJsSession,
SettableListenableFuture<WebSocketSession> connectFuture) {
this.transportUrl = url;
this.receiveHeaders = headers;
this.connectFuture = connectFuture;
this.sockJsSession = sockJsSession;
}
@Override
public void onBegin(Response response) {
if (response.getStatus() != 200) {
HttpStatus status = HttpStatus.valueOf(response.getStatus());
response.abort(new HttpServerErrorException(status, "Unexpected XHR receive status"));
}
}
@Override
public void onHeaders(Response response) {
if (logger.isDebugEnabled()) {
// Convert to HttpHeaders to avoid "\n"
logger.debug("XHR receive headers: " + toHttpHeaders(response.getHeaders()));
}
}
@Override
public void onContent(Response response, ByteBuffer buffer) {
while (true) {
if (this.sockJsSession.isDisconnected()) {
if (logger.isDebugEnabled()) {
logger.debug("SockJS sockJsSession closed. Closing ClientHttpResponse.");
}
response.abort(new SockJsException("Session closed.", this.sockJsSession.getId(), null));
return;
}
if (buffer.remaining() == 0) {
break;
}
int b = buffer.get();
if (b == '\n') {
handleFrame();
}
else {
this.outputStream.write(b);
}
}
}
private void handleFrame() {
byte[] bytes = this.outputStream.toByteArray();
this.outputStream.reset();
String content = new String(bytes, SockJsFrame.CHARSET);
if (logger.isTraceEnabled()) {
logger.trace("XHR content received: " + content);
}
if (!PRELUDE.equals(content)) {
this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET));
}
}
@Override
public void onSuccess(Response response) {
if (this.outputStream.size() > 0) {
handleFrame();
}
if (logger.isDebugEnabled()) {
logger.debug("XHR receive request completed.");
}
executeReceiveRequest(this.transportUrl, this.receiveHeaders, this);
}
@Override
public void onFailure(Response response, Throwable failure) {
if (connectFuture.setException(failure)) {
return;
}
if (this.sockJsSession.isDisconnected()) {
this.sockJsSession.afterTransportClosed(null);
}
else {
this.sockJsSession.handleTransportError(failure);
this.sockJsSession.afterTransportClosed(new CloseStatus(1006, failure.getMessage()));
}
}
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseExtractor;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
/**
* An {@code XhrTransport} implementation that uses a
* {@link org.springframework.web.client.RestTemplate RestTemplate}.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public class RestTemplateXhrTransport extends AbstractXhrTransport implements XhrTransport {
private final RestOperations restTemplate;
private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();
public RestTemplateXhrTransport() {
this(new RestTemplate());
}
public RestTemplateXhrTransport(RestOperations restTemplate) {
Assert.notNull(restTemplate, "'restTemplate' is required");
this.restTemplate = restTemplate;
}
/**
* Return the configured {@code RestTemplate}.
*/
public RestOperations getRestTemplate() {
return this.restTemplate;
}
/**
* Configure the {@code TaskExecutor} to use to execute XHR receive requests.
*
* <p>By default {@link org.springframework.core.task.SimpleAsyncTaskExecutor
* SimpleAsyncTaskExecutor} is configured which creates a new thread every
* time the transports connects.
*
* @param taskExecutor the task executor, cannot be {@code null}
*/
public void setTaskExecutor(TaskExecutor taskExecutor) {
Assert.notNull(this.taskExecutor);
this.taskExecutor = taskExecutor;
}
/**
* Return the configured {@code TaskExecutor}.
*/
public TaskExecutor getTaskExecutor() {
return this.taskExecutor;
}
@Override
public ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
RequestCallback requestCallback = new XhrRequestCallback(getRequestHeaders());
return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textExtractor);
}
@Override
public ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
RequestCallback requestCallback = new XhrRequestCallback(headers, message.getPayload());
return this.restTemplate.execute(url, HttpMethod.POST, requestCallback, textExtractor);
}
@Override
protected void connectInternal(final TransportRequest request, final WebSocketHandler handler,
final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session,
final SettableListenableFuture<WebSocketSession> connectFuture) {
getTaskExecutor().execute(new Runnable() {
@Override
public void run() {
XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders);
XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(getRequestHeaders());
XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session);
while (true) {
if (session.isDisconnected()) {
session.afterTransportClosed(null);
break;
}
try {
if (logger.isDebugEnabled()) {
logger.debug("Starting XHR receive request, url=" + receiveUrl);
}
getRestTemplate().execute(receiveUrl, HttpMethod.POST, requestCallback, responseExtractor);
requestCallback = requestCallbackAfterHandshake;
}
catch (Throwable ex) {
if (!connectFuture.isDone()) {
connectFuture.setException(ex);
}
else {
session.handleTransportError(ex);
session.afterTransportClosed(new CloseStatus(1006, ex.getMessage()));
}
break;
}
}
}
});
}
/**
* A RequestCallback to add the headers and (optionally) String content.
*/
private static class XhrRequestCallback implements RequestCallback {
private final HttpHeaders headers;
private final String body;
public XhrRequestCallback(HttpHeaders headers) {
this(headers, null);
}
public XhrRequestCallback(HttpHeaders headers, String body) {
this.headers = headers;
this.body = body;
}
@Override
public void doWithRequest(ClientHttpRequest request) throws IOException {
if (this.headers != null) {
request.getHeaders().putAll(this.headers);
}
if (this.body != null) {
StreamUtils.copy(this.body, SockJsFrame.CHARSET, request.getBody());
}
}
}
/**
* A simple ResponseExtractor that reads the body into a String.
*/
private final static ResponseExtractor<ResponseEntity<String>> textExtractor =
new ResponseExtractor<ResponseEntity<String>>() {
@Override
public ResponseEntity<String> extractData(ClientHttpResponse response) throws IOException {
if (response.getBody() == null) {
return new ResponseEntity<String>(response.getHeaders(), response.getStatusCode());
}
else {
String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET);
return new ResponseEntity<String>(body, response.getHeaders(), response.getStatusCode());
}
}
};
/**
* Splits the body of an HTTP response into SockJS frames and delegates those
* to an {@link XhrClientSockJsSession}.
*/
private class XhrReceiveExtractor implements ResponseExtractor<Object> {
private final XhrClientSockJsSession sockJsSession;
public XhrReceiveExtractor(XhrClientSockJsSession sockJsSession) {
this.sockJsSession = sockJsSession;
}
@Override
public Object extractData(ClientHttpResponse response) throws IOException {
if (!HttpStatus.OK.equals(response.getStatusCode())) {
throw new HttpServerErrorException(response.getStatusCode());
}
if (logger.isDebugEnabled()) {
logger.debug("XHR receive headers: " + response.getHeaders());
}
InputStream is = response.getBody();
ByteArrayOutputStream os = new ByteArrayOutputStream();
while (true) {
if (this.sockJsSession.isDisconnected()) {
if (logger.isDebugEnabled()) {
logger.debug("SockJS sockJsSession closed. Closing ClientHttpResponse.");
}
response.close();
break;
}
int b = is.read();
if (b == -1) {
if (os.size() > 0) {
handleFrame(os);
}
if (logger.isDebugEnabled()) {
logger.debug("XHR receive completed");
}
break;
}
if (b == '\n') {
handleFrame(os);
}
else {
os.write(b);
}
}
return null;
}
private void handleFrame(ByteArrayOutputStream os) {
byte[] bytes = os.toByteArray();
os.reset();
String content = new String(bytes, SockJsFrame.CHARSET);
if (logger.isTraceEnabled()) {
logger.trace("XHR receive content: " + content);
}
if (!PRELUDE.equals(content)) {
this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET));
}
}
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.net.URI;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* A SockJS implementation of
* {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}
* with HTTP-based fallback alternative simulating a WebSocket interaction.
*
* @author Rossen Stoyanchev
* @since 4.1
*
* @see <a href="http://sockjs.org">http://sockjs.org</a>
* @see org.springframework.web.socket.sockjs.client.Transport
*/
public class SockJsClient extends AbstractWebSocketClient {
private static final boolean jackson2Present = ClassUtils.isPresent(
"com.fasterxml.jackson.databind.ObjectMapper", SockJsClient.class.getClassLoader());
private static final Log logger = LogFactory.getLog(SockJsClient.class);
private final List<Transport> transports;
private InfoReceiver infoReceiver;
private SockJsMessageCodec messageCodec;
private TaskScheduler taskScheduler;
private final Map<URI, ServerInfo> infoCache = new ConcurrentHashMap<URI, ServerInfo>();
/**
* Create a {@code SockJsClient} with the given transports.
* @param transports the transports to use
*/
public SockJsClient(List<Transport> transports) {
Assert.notEmpty(transports, "No transports provided");
this.transports = new ArrayList<Transport>(transports);
this.infoReceiver = initInfoReceiver(transports);
if (jackson2Present) {
this.messageCodec = new Jackson2SockJsMessageCodec();
}
}
private static InfoReceiver initInfoReceiver(List<Transport> transports) {
for (Transport transport : transports) {
if (transport instanceof InfoReceiver) {
return ((InfoReceiver) transport);
}
}
return new RestTemplateXhrTransport();
}
/**
* Configure the {@code InfoReceiver} to use to perform the SockJS "Info"
* request before the SockJS session starts.
*
* <p>By default this is initialized either by looking through the configured
* transports to find the first {@code XhrTransport} or by creating an
* instance of {@code RestTemplateXhrTransport}.
*
* @param infoReceiver the transport to use for the SockJS "Info" request
*/
public void setInfoReceiver(InfoReceiver infoReceiver) {
this.infoReceiver = infoReceiver;
}
public InfoReceiver getInfoReceiver() {
return this.infoReceiver;
}
/**
* Set the SockJsMessageCodec to use.
*
* <p>By default {@link org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec
* Jackson2SockJsMessageCodec} is used if Jackson is on the classpath.
*
* @param messageCodec the message messageCodec to use
*/
public void setMessageCodec(SockJsMessageCodec messageCodec) {
Assert.notNull(messageCodec, "'messageCodec' is required");
this.messageCodec = messageCodec;
}
public SockJsMessageCodec getMessageCodec() {
return this.messageCodec;
}
/**
* Configure a {@code TaskScheduler} for scheduling a connect timeout task
* where the timeout value is calculated based on the duration of the initial
* SockJS info request. Having a connect timeout task is optional but can
* improve the speed with which the client falls back to alternative
* transport options.
*
* <p>By default no task scheduler is configured in which case it may take
* longer before a fallback transport can be used.
*
* @param taskScheduler the scheduler to use
*/
public void setTaskScheduler(TaskScheduler taskScheduler) {
this.taskScheduler = taskScheduler;
}
public void clearServerInfoCache() {
this.infoCache.clear();
}
@Override
protected void assertUri(URI uri) {
Assert.notNull(uri, "uri must not be null");
String scheme = uri.getScheme();
Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme)
|| "http".equals(scheme) || "https".equals(scheme)), "Invalid scheme: " + scheme);
}
@Override
protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler handler,
HttpHeaders handshakeHeaders, URI url, List<String> protocols,
List<WebSocketExtension> extensions, Map<String, Object> attributes) {
SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<WebSocketSession>();
try {
SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url);
ServerInfo serverInfo = getServerInfo(sockJsUrlInfo);
createFallbackChain(sockJsUrlInfo, handshakeHeaders, serverInfo).connect(handler, connectFuture);
}
catch (Throwable exception) {
if (logger.isErrorEnabled()) {
logger.error("Initial SockJS \"Info\" request to server failed, url=" + url, exception);
}
connectFuture.setException(exception);
}
return connectFuture;
}
private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo) {
URI infoUrl = sockJsUrlInfo.getInfoUrl();
ServerInfo info = this.infoCache.get(infoUrl);
if (info == null) {
long start = System.currentTimeMillis();
String response = this.infoReceiver.executeInfoRequest(infoUrl);
long infoRequestTime = System.currentTimeMillis() - start;
info = new ServerInfo(response, infoRequestTime);
this.infoCache.put(infoUrl, info);
}
return info;
}
private DefaultTransportRequest createFallbackChain(SockJsUrlInfo urlInfo, HttpHeaders headers, ServerInfo serverInfo) {
List<DefaultTransportRequest> requests = new ArrayList<DefaultTransportRequest>(this.transports.size());
for (Transport transport : this.transports) {
if (transport instanceof XhrTransport) {
XhrTransport xhrTransport = (XhrTransport) transport;
if (!xhrTransport.isXhrStreamingDisabled()) {
addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.XHR_STREAMING);
}
addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.XHR);
}
else if (serverInfo.isWebSocketEnabled()) {
addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.WEBSOCKET);
}
}
Assert.notEmpty(requests,
"0 transports for request to " + urlInfo + " . Configured transports: " +
this.transports + ". SockJS server webSocketEnabled=" + serverInfo.isWebSocketEnabled());
for (int i = 0; i < requests.size() - 1; i++) {
requests.get(i).setFallbackRequest(requests.get(i + 1));
}
return requests.get(0);
}
private void addRequest(List<DefaultTransportRequest> requests, SockJsUrlInfo info, HttpHeaders headers,
ServerInfo serverInfo, Transport transport, TransportType type) {
DefaultTransportRequest request = new DefaultTransportRequest(info, headers, transport, type, getMessageCodec());
request.setUser(getUser());
if (this.taskScheduler != null) {
request.setTimeoutValue(serverInfo.getRetransmissionTimeout());
request.setTimeoutScheduler(this.taskScheduler);
}
requests.add(request);
}
/**
* Return the user to associate with the SockJS session and make available via
* {@link org.springframework.web.socket.WebSocketSession#getPrincipal()
* WebSocketSession#getPrincipal()}.
* <p>By default this method returns {@code null}.
* @return the user to associate with the session, possibly {@code null}
*/
protected Principal getUser() {
return null;
}
private static class ServerInfo {
private final boolean webSocketEnabled;
private final long responseTime;
private ServerInfo(String response, long responseTime) {
this.responseTime = responseTime;
this.webSocketEnabled = !response.matches(".*[\"']websocket[\"']\\s*:\\s*false.*");
}
public boolean isWebSocketEnabled() {
return this.webSocketEnabled;
}
public long getRetransmissionTimeout() {
return (this.responseTime > 100 ? 4 * this.responseTime : this.responseTime + 300);
}
}
}
\ No newline at end of file
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.springframework.util.AlternativeJdkIdGenerator;
import org.springframework.util.IdGenerator;
import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.UUID;
/**
* Given the base URL to a SockJS server endpoint, also provides methods to
* generate and obtain session and a server id used for construct a transport URL.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public class SockJsUrlInfo {
private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator();
private final URI sockJsUrl;
private String serverId;
private String sessionId;
private UUID uuid;
public SockJsUrlInfo(URI sockJsUrl) {
this.sockJsUrl = sockJsUrl;
}
public URI getSockJsUrl() {
return this.sockJsUrl;
}
public String getServerId() {
if (this.serverId == null) {
this.serverId = String.valueOf(Math.abs(getUuid().getMostSignificantBits()) % 1000);
}
return this.serverId;
}
public String getSessionId() {
if (this.sessionId == null) {
this.sessionId = getUuid().toString().replace("-","");
}
return this.sessionId;
}
protected UUID getUuid() {
if (this.uuid == null) {
this.uuid = idGenerator.generateId();
}
return this.uuid;
}
public URI getInfoUrl() {
return UriComponentsBuilder.fromUri(this.sockJsUrl)
.scheme(getScheme(TransportType.XHR))
.pathSegment("info")
.build(true).toUri();
}
public URI getTransportUrl(TransportType transportType) {
return UriComponentsBuilder.fromUri(this.sockJsUrl)
.scheme(getScheme(transportType))
.pathSegment(getServerId())
.pathSegment(getSessionId())
.pathSegment(transportType.toString())
.build(true).toUri();
}
private String getScheme(TransportType transportType) {
String scheme = this.sockJsUrl.getScheme();
if (TransportType.WEBSOCKET.equals(transportType)) {
if (!scheme.startsWith("ws")) {
scheme = ("https".equals(scheme) ? "wss" : "ws");
}
}
else {
if (!scheme.startsWith("http")) {
scheme = ("wss".equals(scheme) ? "https" : "http");
}
}
return scheme;
}
@Override
public String toString() {
return "SockJsUrlInfo[url=" + this.sockJsUrl + "]";
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
/**
* A client-side implementation for a SockJS transport.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public interface Transport {
/**
* Connect the transport.
*
* @param request the transport request.
* @param webSocketHandler the application handler to delegate lifecycle events to.
* @return a future to indicate success or failure to connect.
*/
ListenableFuture<WebSocketSession> connect(TransportRequest request, WebSocketHandler webSocketHandler);
}
package org.springframework.web.socket.sockjs.client;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.net.URI;
import java.security.Principal;
/**
* Represents a request to connect to a SockJS service using a specific
* Transport. A single SockJS request however may require falling back
* and therefore multiple TransportRequest instances.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public interface TransportRequest {
/**
* Return information about the SockJS URL including server and session id..
*/
SockJsUrlInfo getSockJsUrlInfo();
/**
* Return the headers to send with the connect request.
*/
HttpHeaders getHandshakeHeaders();
/**
* Return the transport URL for the given transport.
* For an {@link XhrTransport} this is the URL for receiving messages.
*/
URI getTransportUrl();
/**
* Return the user associated with the request, if any.
*/
Principal getUser();
/**
* Return the message codec to use for encoding SockJS messages.
*/
SockJsMessageCodec getMessageCodec();
/**
* Register a timeout cleanup task to invoke if the SockJS session is not
* fully established within the calculated retransmission timeout period.
*/
void addTimeoutTask(Runnable runnable);
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.NativeWebSocketSession;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.List;
/**
* An extension of {@link AbstractClientSockJsSession} wrapping and delegating
* to an actual WebSocket session.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public class WebSocketClientSockJsSession extends AbstractClientSockJsSession implements NativeWebSocketSession {
private WebSocketSession webSocketSession;
public WebSocketClientSockJsSession(TransportRequest request, WebSocketHandler handler,
SettableListenableFuture<WebSocketSession> connectFuture) {
super(request, handler, connectFuture);
}
@Override
public Object getNativeSession() {
return this.webSocketSession;
}
@SuppressWarnings("unchecked")
@Override
public <T> T getNativeSession(Class<T> requiredType) {
if (requiredType != null) {
if (requiredType.isInstance(this.webSocketSession)) {
return (T) this.webSocketSession;
}
}
return null;
}
@Override
public InetSocketAddress getLocalAddress() {
checkDelegateSessionInitialized();
return this.webSocketSession.getLocalAddress();
}
@Override
public InetSocketAddress getRemoteAddress() {
checkDelegateSessionInitialized();
return this.webSocketSession.getRemoteAddress();
}
@Override
public String getAcceptedProtocol() {
checkDelegateSessionInitialized();
return this.webSocketSession.getAcceptedProtocol();
}
@Override
public void setTextMessageSizeLimit(int messageSizeLimit) {
checkDelegateSessionInitialized();
this.webSocketSession.setTextMessageSizeLimit(messageSizeLimit);
}
@Override
public int getTextMessageSizeLimit() {
checkDelegateSessionInitialized();
return this.webSocketSession.getTextMessageSizeLimit();
}
@Override
public void setBinaryMessageSizeLimit(int messageSizeLimit) {
checkDelegateSessionInitialized();
this.webSocketSession.setBinaryMessageSizeLimit(messageSizeLimit);
}
@Override
public int getBinaryMessageSizeLimit() {
checkDelegateSessionInitialized();
return this.webSocketSession.getBinaryMessageSizeLimit();
}
@Override
public List<WebSocketExtension> getExtensions() {
checkDelegateSessionInitialized();
return this.webSocketSession.getExtensions();
}
private void checkDelegateSessionInitialized() {
Assert.state(this.webSocketSession != null, "WebSocketSession not yet initialized");
}
public void initializeDelegateSession(WebSocketSession session) {
this.webSocketSession = session;
}
@Override
protected void sendInternal(TextMessage textMessage) throws IOException {
this.webSocketSession.sendMessage(textMessage);
}
@Override
protected void disconnect(CloseStatus status) throws IOException {
if (this.webSocketSession != null) {
this.webSocketSession.close(status);
}
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.net.URI;
import java.util.concurrent.atomic.AtomicInteger;
/**
* A SockJS {@link Transport} that uses a
* {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public class WebSocketTransport implements Transport {
private static Log logger = LogFactory.getLog(WebSocketTransport.class);
private final WebSocketClient webSocketClient;
public WebSocketTransport(WebSocketClient webSocketClient) {
Assert.notNull(webSocketClient, "'webSocketClient' is required");
this.webSocketClient = webSocketClient;
}
/**
* Return the configured {@code WebSocketClient}.
*/
public WebSocketClient getWebSocketClient() {
return this.webSocketClient;
}
@Override
public ListenableFuture<WebSocketSession> connect(TransportRequest request, WebSocketHandler handler) {
final SettableListenableFuture<WebSocketSession> future = new SettableListenableFuture<WebSocketSession>();
WebSocketClientSockJsSession session = new WebSocketClientSockJsSession(request, handler, future);
handler = new ClientSockJsWebSocketHandler(session);
request.addTimeoutTask(session.getTimeoutTask());
URI url = request.getTransportUrl();
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHandshakeHeaders());
if (logger.isDebugEnabled()) {
logger.debug("Opening WebSocket connection, url=" + url);
}
this.webSocketClient.doHandshake(handler, headers, url).addCallback(
new ListenableFutureCallback<WebSocketSession>() {
@Override
public void onSuccess(WebSocketSession webSocketSession) {
// WebSocket session ready, SockJS Session not yet
}
@Override
public void onFailure(Throwable t) {
future.setException(t);
}
});
return future;
}
@Override
public String toString() {
return "WebSocketTransport[client=" + this.webSocketClient + "]";
}
private static class ClientSockJsWebSocketHandler extends TextWebSocketHandler {
private final WebSocketClientSockJsSession sockJsSession;
private final AtomicInteger connectCount = new AtomicInteger(0);
private ClientSockJsWebSocketHandler(WebSocketClientSockJsSession session) {
Assert.notNull(session);
this.sockJsSession = session;
}
@Override
public void afterConnectionEstablished(WebSocketSession webSocketSession) throws Exception {
Assert.isTrue(this.connectCount.compareAndSet(0, 1));
this.sockJsSession.initializeDelegateSession(webSocketSession);
}
@Override
public void handleTextMessage(WebSocketSession webSocketSession, TextMessage message) throws Exception {
this.sockJsSession.handleFrame(message.getPayload());
}
@Override
public void handleTransportError(WebSocketSession webSocketSession, Throwable ex) throws Exception {
this.sockJsSession.handleTransportError(ex);
}
@Override
public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus status) throws Exception {
this.sockJsSession.afterTransportClosed(status);
}
}
}
\ No newline at end of file
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.List;
/**
* An extension of {@link AbstractClientSockJsSession} for use with HTTP
* transports simulating a WebSocket session.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public class XhrClientSockJsSession extends AbstractClientSockJsSession {
private final URI sendUrl;
private final XhrTransport transport;
private int textMessageSizeLimit = -1;
private int binaryMessageSizeLimit = -1;
public XhrClientSockJsSession(TransportRequest request, WebSocketHandler handler,
XhrTransport transport, SettableListenableFuture<WebSocketSession> connectFuture) {
super(request, handler, connectFuture);
Assert.notNull(transport, "'restTemplate' is required");
this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND);
this.transport = transport;
}
@Override
public InetSocketAddress getLocalAddress() {
return null;
}
@Override
public InetSocketAddress getRemoteAddress() {
return new InetSocketAddress(getUri().getHost(), getUri().getPort());
}
@Override
public String getAcceptedProtocol() {
return null;
}
@Override
public void setTextMessageSizeLimit(int messageSizeLimit) {
this.textMessageSizeLimit = messageSizeLimit;
}
@Override
public int getTextMessageSizeLimit() {
return this.textMessageSizeLimit;
}
@Override
public void setBinaryMessageSizeLimit(int messageSizeLimit) {
this.binaryMessageSizeLimit = -1;
}
@Override
public int getBinaryMessageSizeLimit() {
return this.binaryMessageSizeLimit;
}
@Override
public List<WebSocketExtension> getExtensions() {
return null;
}
@Override
protected void sendInternal(TextMessage message) {
this.transport.executeSendRequest(this.sendUrl, message);
}
@Override
protected void disconnect(CloseStatus status) {
// Nothing to do, XHR transports check if session is disconnected
}
}
\ No newline at end of file
package org.springframework.web.socket.sockjs.client;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import java.net.URI;
/**
* A SockJS {@link Transport} that uses HTTP requests to simulate a WebSocket
* interaction. The {@code connect} method of the base {@code Transport} interface
* is used to receive messages from the server while the
* {@link #executeSendRequest(java.net.URI, org.springframework.web.socket.TextMessage)
* executeSendRequest(URI, TextMessage)} method here is used to send messages.
*
* @author Rossen Stoyanchev
* @since 4.1
*/
public interface XhrTransport extends Transport, InfoReceiver {
/**
* An {@code XhrTransport} supports both the "xhr_streaming" and "xhr" SockJS
* server transports. From a client perspective there is no implementation
* difference.
*
* <p>By default an {@code XhrTransport} will be used with "xhr_streaming"
* first and then with "xhr", if the streaming fails to connect. In some
* cases it may be useful to suppress streaming so that only "xhr" is used.
*/
boolean isXhrStreamingDisabled();
/**
* Execute a request to send the message to the server.
* @param transportUrl the URL for sending messages.
* @param message the message to send
*/
void executeSendRequest(URI transportUrl, TextMessage message);
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* SockJS client implementation of
* {@link org.springframework.web.socket.client.WebSocketClient}.
*/
package org.springframework.web.socket.sockjs.client;
......@@ -145,7 +145,7 @@ public class SockJsFrame {
if (!(other instanceof SockJsFrame)) {
return false;
}
return this.content.equals(((SockJsFrame) other).content);
return (this.type.equals(((SockJsFrame) other).type) && this.content.equals(((SockJsFrame) other).content));
}
@Override
......
......@@ -238,7 +238,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem
}
else {
response.setStatusCode(HttpStatus.NOT_FOUND);
logger.warn("Session not found");
logger.warn("Session not found, sessionId=" + sessionId);
return;
}
}
......
......@@ -78,6 +78,7 @@ public class JettyWebSocketTestServer implements WebSocketTestServer {
@Override
public void stop() throws Exception {
if (this.jettyServer.isRunning()) {
this.jettyServer.setStopTimeout(0);
this.jettyServer.stop();
}
}
......
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.WebSocketTestServer;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.hamcrest.Matchers.*;
/**
* Integration tests using the
* {@link org.springframework.web.socket.sockjs.client.SockJsClient}.
* against actual SockJS server endpoints.
*
* @author Rossen Stoyanchev
*/
public abstract class AbstractSockJsIntegrationTests {
protected Log logger = LogFactory.getLog(getClass());
private WebSocketTestServer server;
private AnnotationConfigWebApplicationContext wac;
private ErrorFilter errorFilter;
private String baseUrl;
@Before
public void setup() throws Exception {
this.errorFilter = new ErrorFilter();
this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(TestConfig.class, upgradeStrategyConfigClass());
this.wac.refresh();
this.server = createWebSocketTestServer();
this.server.deployConfig(this.wac, this.errorFilter);
this.server.start();
this.baseUrl = "http://localhost:" + this.server.getPort();
}
@After
public void teardown() throws Exception {
try {
this.server.undeployConfig();
}
catch (Throwable t) {
logger.error("Failed to undeploy application config", t);
}
try {
this.server.stop();
}
catch (Throwable t) {
logger.error("Failed to stop server", t);
}
}
protected abstract WebSocketTestServer createWebSocketTestServer();
protected abstract Class<?> upgradeStrategyConfigClass();
protected abstract Transport getWebSocketTransport();
protected abstract AbstractXhrTransport getXhrTransport();
protected SockJsClient createSockJsClient(Transport... transports) {
return new SockJsClient(Arrays.<Transport>asList(transports));
}
@Test
public void echoWebSocket() throws Exception {
testEcho(100, getWebSocketTransport());
}
@Test
public void echoXhrStreaming() throws Exception {
testEcho(100, getXhrTransport());
}
@Test
public void echoXhr() throws Exception {
AbstractXhrTransport xhrTransport = getXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
testEcho(100, xhrTransport);
}
@Test
public void closeAfterOneMessageWebSocket() throws Exception {
testCloseAfterOneMessage(getWebSocketTransport());
}
@Test
public void closeAfterOneMessageXhrStreaming() throws Exception {
testCloseAfterOneMessage(getXhrTransport());
}
@Test
public void closeAfterOneMessageXhr() throws Exception {
AbstractXhrTransport xhrTransport = getXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
testCloseAfterOneMessage(xhrTransport);
}
@Test
public void infoRequestFailure() throws Exception {
TestClientHandler handler = new TestClientHandler();
this.errorFilter.responseStatusMap.put("/info", 500);
CountDownLatch latch = new CountDownLatch(1);
createSockJsClient(getWebSocketTransport()).doHandshake(handler, this.baseUrl + "/echo").addCallback(
new ListenableFutureCallback<WebSocketSession>() {
@Override
public void onSuccess(WebSocketSession result) {
}
@Override
public void onFailure(Throwable t) {
latch.countDown();
}
}
);
assertTrue(latch.await(5000, TimeUnit.MILLISECONDS));
}
@Test
public void fallbackAfterTransportFailure() throws Exception {
this.errorFilter.responseStatusMap.put("/websocket", 200);
this.errorFilter.responseStatusMap.put("/xhr_streaming", 500);
TestClientHandler handler = new TestClientHandler();
Transport[] transports = { getWebSocketTransport(), getXhrTransport() };
WebSocketSession session = createSockJsClient(transports).doHandshake(handler, this.baseUrl + "/echo").get();
assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, session.getClass());
TextMessage message = new TextMessage("message1");
session.sendMessage(message);
handler.awaitMessage(message, 5000);
}
@Test(timeout = 5000)
public void fallbackAfterConnectTimeout() throws Exception {
TestClientHandler clientHandler = new TestClientHandler();
this.errorFilter.sleepDelayMap.put("/xhr_streaming", 10000L);
this.errorFilter.responseStatusMap.put("/xhr_streaming", 503);
SockJsClient sockJsClient = createSockJsClient(getXhrTransport());
sockJsClient.setTaskScheduler(this.wac.getBean(ThreadPoolTaskScheduler.class));
WebSocketSession clientSession = sockJsClient.doHandshake(clientHandler, this.baseUrl + "/echo").get();
assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, clientSession.getClass());
TextMessage message = new TextMessage("message1");
clientSession.sendMessage(message);
clientHandler.awaitMessage(message, 5000);
clientSession.close();
}
private void testEcho(int messageCount, Transport transport) throws Exception {
List<TextMessage> messages = new ArrayList<>();
for (int i = 0; i < messageCount; i++) {
messages.add(new TextMessage("m" + i));
}
TestClientHandler handler = new TestClientHandler();
WebSocketSession session = createSockJsClient(transport).doHandshake(handler, this.baseUrl + "/echo").get();
for (TextMessage message : messages) {
session.sendMessage(message);
}
handler.awaitMessageCount(messageCount, 5000);
for (TextMessage message : messages) {
assertTrue("Message not received: " + message, handler.receivedMessages.remove(message));
}
assertEquals("Remaining messages: " + handler.receivedMessages, 0, handler.receivedMessages.size());
session.close();
}
private void testCloseAfterOneMessage(Transport transport) throws Exception {
TestClientHandler clientHandler = new TestClientHandler();
createSockJsClient(transport).doHandshake(clientHandler, this.baseUrl + "/test").get();
TestServerHandler serverHandler = this.wac.getBean(TestServerHandler.class);
assertNotNull("afterConnectionEstablished should have been called", clientHandler.session);
serverHandler.awaitSession(5000);
TextMessage message = new TextMessage("message1");
serverHandler.session.sendMessage(message);
clientHandler.awaitMessage(message, 5000);
CloseStatus expected = new CloseStatus(3500, "Oops");
serverHandler.session.close(expected);
CloseStatus actual = clientHandler.awaitCloseStatus(5000);
if (transport instanceof XhrTransport) {
assertThat(actual, Matchers.anyOf(equalTo(expected), equalTo(new CloseStatus(3000, "Go away!"))));
}
else {
assertEquals(expected, actual);
}
}
@Configuration
@EnableWebSocket
static class TestConfig implements WebSocketConfigurer {
@Autowired
private RequestUpgradeStrategy upgradeStrategy;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy);
registry.addHandler(new EchoHandler(), "/echo").setHandshakeHandler(handshakeHandler).withSockJS();
registry.addHandler(testServerHandler(), "/test").setHandshakeHandler(handshakeHandler).withSockJS();
}
@Bean
public TestServerHandler testServerHandler() {
return new TestServerHandler();
}
}
private static interface Condition {
boolean match();
}
private static void awaitEvent(Condition condition, long timeToWait, String description) {
long timeToSleep = 200;
for (int i = 0 ; i < Math.floor(timeToWait / timeToSleep); i++) {
if (condition.match()) {
return;
}
try {
Thread.sleep(timeToSleep);
}
catch (InterruptedException e) {
throw new IllegalStateException("Interrupted while waiting for " + description, e);
}
}
throw new IllegalStateException("Timed out waiting for " + description);
}
private static class TestClientHandler extends TextWebSocketHandler {
private final BlockingQueue<TextMessage> receivedMessages = new LinkedBlockingQueue<>();
private volatile WebSocketSession session;
private volatile CloseStatus closeStatus;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.session = session;
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
this.receivedMessages.add(message);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
this.closeStatus = status;
}
public void awaitMessageCount(final int count, long timeToWait) throws Exception {
awaitEvent(() -> receivedMessages.size() >= count, timeToWait,
count + " number of messages. Received so far: " + this.receivedMessages);
}
public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException {
TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS);
assertNotNull("Timed out waiting for [" + expected + "]", actual);
assertEquals(expected, actual);
}
public CloseStatus awaitCloseStatus(long timeToWait) throws InterruptedException {
awaitEvent(() -> this.closeStatus != null, timeToWait, " CloseStatus");
return this.closeStatus;
}
}
private static class TestServerHandler extends TextWebSocketHandler {
private WebSocketSession session;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.session = session;
}
public WebSocketSession awaitSession(long timeToWait) throws InterruptedException {
awaitEvent(() -> this.session != null, timeToWait, " session");
return this.session;
}
}
private static class EchoHandler extends TextWebSocketHandler {
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
session.sendMessage(message);
}
}
private static class ErrorFilter implements Filter {
private final Map<String, Integer> responseStatusMap = new HashMap<>();
private final Map<String, Long> sleepDelayMap = new HashMap<>();
@Override
public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain) throws IOException, ServletException {
for (String suffix : this.sleepDelayMap.keySet()) {
if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) {
try {
Thread.sleep(this.sleepDelayMap.get(suffix));
break;
}
catch (InterruptedException e) {
e.printStackTrace();
}
}
}
for (String suffix : this.responseStatusMap.keySet()) {
if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) {
((HttpServletResponse) resp).sendError(this.responseStatusMap.get(suffix));
return;
}
}
chain.doFilter(req, resp);
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void destroy() {
}
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.List;
import static org.junit.Assert.assertThat;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
/**
* Unit tests for
* {@link org.springframework.web.socket.sockjs.client.AbstractClientSockJsSession}.
*
* @author Rossen Stoyanchev
*/
public class ClientSockJsSessionTests {
private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec();
private TestClientSockJsSession session;
private WebSocketHandler handler;
private SettableListenableFuture<WebSocketSession> connectFuture;
@Rule
public final ExpectedException thrown = ExpectedException.none();
@Before
public void setup() throws Exception {
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
Transport transport = mock(Transport.class);
TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC);
this.handler = mock(WebSocketHandler.class);
this.connectFuture = new SettableListenableFuture<>();
this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture);
}
@Test
public void handleFrameOpen() throws Exception {
assertThat(this.session.isOpen(), is(false));
this.session.handleFrame(SockJsFrame.openFrame().getContent());
assertThat(this.session.isOpen(), is(true));
assertTrue(this.connectFuture.isDone());
assertThat(this.connectFuture.get(), sameInstance(this.session));
verify(this.handler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameOpenWhenStatusNotNew() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
assertThat(this.session.isOpen(), is(true));
this.session.handleFrame(SockJsFrame.openFrame().getContent());
assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1006, "Server lost session")));
}
@Test
public void handleFrameOpenWithWebSocketHandlerException() throws Exception {
doThrow(new IllegalStateException("Fake error")).when(this.handler).afterConnectionEstablished(this.session);
this.session.handleFrame(SockJsFrame.openFrame().getContent());
assertThat(this.session.isOpen(), is(true));
}
@Test
public void handleFrameMessage() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent());
verify(this.handler).afterConnectionEstablished(this.session);
verify(this.handler).handleMessage(this.session, new TextMessage("foo"));
verify(this.handler).handleMessage(this.session, new TextMessage("bar"));
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameMessageWhenNotOpen() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.close();
reset(this.handler);
this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent());
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameMessageWithBadData() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.handleFrame("a['bad data");
assertThat(this.session.isOpen(), equalTo(false));
assertThat(this.session.disconnectStatus, equalTo(CloseStatus.BAD_DATA));
verify(this.handler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameMessageWithWebSocketHandlerException() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("foo"));
doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("bar"));
this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent());
assertThat(this.session.isOpen(), equalTo(true));
verify(this.handler).afterConnectionEstablished(this.session);
verify(this.handler).handleMessage(this.session, new TextMessage("foo"));
verify(this.handler).handleMessage(this.session, new TextMessage("bar"));
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameClose() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.handleFrame(SockJsFrame.closeFrame(1007, "").getContent());
assertThat(this.session.isOpen(), equalTo(false));
assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1007, "")));
verify(this.handler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleTransportError() throws Exception {
final IllegalStateException ex = new IllegalStateException("Fake error");
this.session.handleTransportError(ex);
verify(this.handler).handleTransportError(this.session, ex);
verifyNoMoreInteractions(this.handler);
}
@Test
public void afterTransportClosed() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.afterTransportClosed(CloseStatus.SERVER_ERROR);
assertThat(this.session.isOpen(), equalTo(false));
verify(this.handler).afterConnectionEstablished(this.session);
verify(this.handler).afterConnectionClosed(this.session, CloseStatus.SERVER_ERROR);
verifyNoMoreInteractions(this.handler);
}
@Test
public void close() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.close();
assertThat(this.session.isOpen(), equalTo(false));
assertThat(this.session.disconnectStatus, equalTo(CloseStatus.NORMAL));
verify(this.handler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.handler);
}
@Test
public void closeWithStatus() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.close(new CloseStatus(3000, "reason"));
assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(3000, "reason")));
}
@Test
public void closeWithNullStatus() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Invalid close status");
this.session.close(null);
}
@Test
public void closeWithStatusOutOfRange() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Invalid close status");
this.session.close(new CloseStatus(2999, "reason"));
}
@Test
public void timeoutTask() {
this.session.getTimeoutTask().run();
assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(2007, "Transport timed out")));
}
@Test
public void send() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.sendMessage(new TextMessage("foo"));
assertThat(this.session.sentMessage, equalTo(new TextMessage("[\"foo\"]")));
}
private static class TestClientSockJsSession extends AbstractClientSockJsSession {
private TextMessage sentMessage;
private CloseStatus disconnectStatus;
protected TestClientSockJsSession(TransportRequest request, WebSocketHandler handler,
SettableListenableFuture<WebSocketSession> connectFuture) {
super(request, handler, connectFuture);
}
@Override
protected void sendInternal(TextMessage textMessage) throws IOException {
this.sentMessage = textMessage;
}
@Override
protected void disconnect(CloseStatus status) throws IOException {
this.disconnectStatus = status;
}
@Override
public InetSocketAddress getLocalAddress() {
return null;
}
@Override
public InetSocketAddress getRemoteAddress() {
return null;
}
@Override
public String getAcceptedProtocol() {
return null;
}
@Override
public void setTextMessageSizeLimit(int messageSizeLimit) {
}
@Override
public int getTextMessageSizeLimit() {
return 0;
}
@Override
public void setBinaryMessageSizeLimit(int messageSizeLimit) {
}
@Override
public int getBinaryMessageSizeLimit() {
return 0;
}
@Override
public List<WebSocketExtension> getExtensions() {
return null;
}
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.io.IOException;
import java.net.URI;
import java.util.Date;
import java.util.concurrent.ExecutionException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
/**
* Unit tests for {@link DefaultTransportRequest}.
*
* @author Rossen Stoyanchev
*/
public class DefaultTransportRequestTests {
private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec();
private SettableListenableFuture<WebSocketSession> connectFuture;
private ListenableFutureCallback<WebSocketSession> connectCallback;
private TestTransport webSocketTransport;
private TestTransport xhrTransport;
@Rule
public final ExpectedException thrown = ExpectedException.none();
@SuppressWarnings("unchecked")
@Before
public void setup() throws Exception {
this.connectCallback = mock(ListenableFutureCallback.class);
this.connectFuture = new SettableListenableFuture<>();
this.connectFuture.addCallback(this.connectCallback);
this.webSocketTransport = new TestTransport("WebSocketTestTransport");
this.xhrTransport = new TestTransport("XhrTestTransport");
}
@Test
@SuppressWarnings("unchecked")
public void connect() throws Exception {
DefaultTransportRequest request = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET);
request.connect(null, this.connectFuture);
WebSocketSession session = mock(WebSocketSession.class);
this.webSocketTransport.getConnectCallback().onSuccess(session);
assertSame(session, this.connectFuture.get());
}
@Test
public void fallbackAfterTransportError() throws Exception {
DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET);
DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING);
request1.setFallbackRequest(request2);
request1.connect(null, this.connectFuture);
// Transport error => fallback
this.webSocketTransport.getConnectCallback().onFailure(new IOException("Fake exception 1"));
assertFalse(this.connectFuture.isDone());
assertTrue(this.xhrTransport.invoked());
// Transport error => no more fallback
this.xhrTransport.getConnectCallback().onFailure(new IOException("Fake exception 2"));
assertTrue(this.connectFuture.isDone());
this.thrown.expect(ExecutionException.class);
this.thrown.expectMessage("Fake exception 2");
this.connectFuture.get();
}
@Test
public void fallbackAfterTimeout() throws Exception {
TaskScheduler scheduler = mock(TaskScheduler.class);
Runnable sessionCleanupTask = mock(Runnable.class);
DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET);
DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING);
request1.setFallbackRequest(request2);
request1.setTimeoutScheduler(scheduler);
request1.addTimeoutTask(sessionCleanupTask);
request1.connect(null, this.connectFuture);
assertTrue(this.webSocketTransport.invoked());
assertFalse(this.xhrTransport.invoked());
// Get and invoke the scheduled timeout task
ArgumentCaptor<Runnable> taskCaptor = ArgumentCaptor.forClass(Runnable.class);
verify(scheduler).schedule(taskCaptor.capture(), any(Date.class));
verifyNoMoreInteractions(scheduler);
taskCaptor.getValue().run();
assertTrue(this.xhrTransport.invoked());
verify(sessionCleanupTask).run();
}
protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception {
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC);
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.After;
import org.junit.Before;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.JettyWebSocketTestServer;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy;
import java.util.ArrayList;
import java.util.List;
/**
* SockJS integration tests using Jetty for client and server.
*
* @author Rossen Stoyanchev
*/
public class JettySockJsIntegrationTests extends AbstractSockJsIntegrationTests {
private WebSocketClient webSocketClient;
private HttpClient httpClient;
@Before
public void setup() throws Exception {
super.setup();
this.webSocketClient = new WebSocketClient();
this.webSocketClient.start();
this.httpClient = new HttpClient();
this.httpClient.start();
}
@After
public void teardown() throws Exception {
super.teardown();
try {
this.webSocketClient.stop();
}
catch (Throwable ex) {
logger.error("Failed to stop Jetty WebSocketClient", ex);
}
try {
this.httpClient.stop();
}
catch (Throwable ex) {
logger.error("Failed to stop Jetty HttpClient", ex);
}
}
@Override
protected JettyWebSocketTestServer createWebSocketTestServer() {
return new JettyWebSocketTestServer();
}
@Override
protected Class<?> upgradeStrategyConfigClass() {
return JettyTestConfig.class;
}
@Override
protected Transport getWebSocketTransport() {
return new WebSocketTransport(new JettyWebSocketClient(this.webSocketClient));
}
@Override
protected AbstractXhrTransport getXhrTransport() {
return new JettyXhrTransport(this.httpClient);
}
@Configuration
static class JettyTestConfig {
@Bean
public RequestUpgradeStrategy upgradeStrategy() {
return new JettyRequestUpgradeStrategy();
}
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Before;
import org.junit.Test;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseExtractor;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingDeque;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.verifyNoMoreInteractions;
/**
* Unit tests for {@link RestTemplateXhrTransport}.
*
* @author Rossen Stoyanchev
*/
public class RestTemplateXhrTransportTests {
private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec();
private WebSocketHandler webSocketHandler;
@Before
public void setup() throws Exception {
this.webSocketHandler = mock(WebSocketHandler.class);
}
@Test
public void connectReceiveAndClose() throws Exception {
String body = "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]";
ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response);
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo")));
verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!")));
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void connectReceiveAndCloseWithPrelude() throws Exception {
StringBuilder sb = new StringBuilder(2048);
for (int i=0; i < 2048; i++) {
sb.append('h');
}
String body = sb.toString() + "\n" + "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]";
ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response);
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo")));
verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!")));
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void connectReceiveAndCloseWithStompFrame() throws Exception {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/destination");
MessageHeaders headers = accessor.getMessageHeaders();
Message<byte[]> message = MessageBuilder.createMessage("body".getBytes(Charset.forName("UTF-8")), headers);
byte[] bytes = new StompEncoder().encode(message);
TextMessage textMessage = new TextMessage(bytes);
SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), textMessage.getPayload());
String body = "o\n" + frame.getContent() + "\n" + "c[3000,\"Go away!\"]";
ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response);
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).handleMessage(any(), eq(textMessage));
verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!")));
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void connectFailure() throws Exception {
final HttpServerErrorException expected = new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR);
RestOperations restTemplate = mock(RestOperations.class);
when(restTemplate.execute(any(), eq(HttpMethod.POST), any(), any())).thenThrow(expected);
final CountDownLatch latch = new CountDownLatch(1);
connect(restTemplate).addCallback(
new ListenableFutureCallback<WebSocketSession>() {
@Override
public void onSuccess(WebSocketSession result) {
}
@Override
public void onFailure(Throwable actual) {
if (actual == expected) {
latch.countDown();
}
}
}
);
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void errorResponseStatus() throws Exception {
connect(response(HttpStatus.OK, "o\n"), response(HttpStatus.INTERNAL_SERVER_ERROR, "Oops"));
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).handleTransportError(any(), any());
verify(this.webSocketHandler).afterConnectionClosed(any(), any());
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void responseClosedAfterDisconnected() throws Exception {
String body = "o\n" + "c[3000,\"Go away!\"]\n" + "a[\"foo\"]\n";
ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response);
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).afterConnectionClosed(any(), any());
verifyNoMoreInteractions(this.webSocketHandler);
verify(response).close();
}
private ListenableFuture<WebSocketSession> connect(ClientHttpResponse... responses) throws Exception {
return connect(new TestRestTemplate(responses));
}
private ListenableFuture<WebSocketSession> connect(RestOperations restTemplate, ClientHttpResponse... responses)
throws Exception {
RestTemplateXhrTransport transport = new RestTemplateXhrTransport(restTemplate);
transport.setTaskExecutor(new SyncTaskExecutor());
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
HttpHeaders headers = new HttpHeaders();
headers.add("h-foo", "h-bar");
TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC);
return transport.connect(request, this.webSocketHandler);
}
private ClientHttpResponse response(HttpStatus status, String body) throws IOException {
ClientHttpResponse response = mock(ClientHttpResponse.class);
InputStream inputStream = getInputStream(body);
when(response.getStatusCode()).thenReturn(status);
when(response.getBody()).thenReturn(inputStream);
return response;
}
private InputStream getInputStream(String content) {
byte[] bytes = content.getBytes(Charset.forName("UTF-8"));
return new ByteArrayInputStream(bytes);
}
private static class TestRestTemplate extends RestTemplate {
private Queue<ClientHttpResponse> responses = new LinkedBlockingDeque<>();
private TestRestTemplate(ClientHttpResponse... responses) {
this.responses.addAll(Arrays.asList(responses));
}
@Override
public <T> T execute(URI url, HttpMethod method, RequestCallback callback, ResponseExtractor<T> extractor) throws RestClientException {
try {
extractor.extractData(this.responses.remove());
}
catch (Throwable t) {
throw new RestClientException("Failed to invoke extractor", t);
}
return null;
}
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpStatus;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
/**
* Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}.
*
* @author Rossen Stoyanchev
*/
public class SockJsClientTests {
private static final String URL = "http://example.com";
private static final WebSocketHandler handler = mock(WebSocketHandler.class);
private SockJsClient sockJsClient;
private InfoReceiver infoReceiver;
private TestTransport webSocketTransport;
private XhrTestTransport xhrTransport;
private ListenableFutureCallback<WebSocketSession> connectCallback;
@Before
@SuppressWarnings("unchecked")
public void setup() {
this.infoReceiver = mock(InfoReceiver.class);
this.webSocketTransport = new TestTransport("WebSocketTestTransport");
this.xhrTransport = new XhrTestTransport("XhrTestTransport");
List<Transport> transports = new ArrayList<>();
transports.add(this.webSocketTransport);
transports.add(this.xhrTransport);
this.sockJsClient = new SockJsClient(transports);
this.sockJsClient.setInfoReceiver(this.infoReceiver);
this.connectCallback = mock(ListenableFutureCallback.class);
}
@Test
public void connectWebSocket() throws Exception {
setupInfoRequest(true);
this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback);
assertTrue(this.webSocketTransport.invoked());
WebSocketSession session = mock(WebSocketSession.class);
this.webSocketTransport.getConnectCallback().onSuccess(session);
verify(this.connectCallback).onSuccess(session);
verifyNoMoreInteractions(this.connectCallback);
}
@Test
public void connectWebSocketDisabled() throws URISyntaxException {
setupInfoRequest(false);
this.sockJsClient.doHandshake(handler, URL);
assertFalse(this.webSocketTransport.invoked());
assertTrue(this.xhrTransport.invoked());
assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr_streaming"));
}
@Test
public void connectXhrStreamingDisabled() throws Exception {
setupInfoRequest(false);
this.xhrTransport.setStreamingDisabled(true);
this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback);
assertFalse(this.webSocketTransport.invoked());
assertTrue(this.xhrTransport.invoked());
assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr"));
}
@Test
public void connectSockJsInfo() throws Exception {
setupInfoRequest(true);
this.sockJsClient.doHandshake(handler, URL);
verify(this.infoReceiver, times(1)).executeInfoRequest(any());
}
@Test
public void connectSockJsInfoCached() throws Exception {
setupInfoRequest(true);
this.sockJsClient.doHandshake(handler, URL);
this.sockJsClient.doHandshake(handler, URL);
this.sockJsClient.doHandshake(handler, URL);
verify(this.infoReceiver, times(1)).executeInfoRequest(any());
}
@Test
@SuppressWarnings("unchecked")
public void connectInfoRequestFailure() throws URISyntaxException {
HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE);
when(this.infoReceiver.executeInfoRequest(any())).thenThrow(exception);
this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback);
verify(this.connectCallback).onFailure(exception);
assertFalse(this.webSocketTransport.invoked());
assertFalse(this.xhrTransport.invoked());
}
private void setupInfoRequest(boolean webSocketEnabled) {
when(this.infoReceiver.executeInfoRequest(any())).thenReturn("{\"entropy\":123," +
"\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}");
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.net.URI;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Unit tests for {@code SockJsUrlInfo}.
* @author Rossen Stoyanchev
*/
public class SockJsUrlInfoTests {
@Test
public void serverId() throws Exception {
SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com"));
int serverId = Integer.valueOf(info.getServerId());
assertTrue("Invalid serverId: " + serverId, serverId > 0 && serverId < 1000);
}
@Test
public void sessionId() throws Exception {
SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com"));
assertEquals("Invalid sessionId: " + info.getSessionId(), 32, info.getSessionId().length());
}
@Test
public void infoUrl() throws Exception {
testInfoUrl("http", "http");
testInfoUrl("http", "http");
testInfoUrl("https", "https");
testInfoUrl("https", "https");
testInfoUrl("ws", "http");
testInfoUrl("ws", "http");
testInfoUrl("wss", "https");
testInfoUrl("wss", "https");
}
private void testInfoUrl(String scheme, String expectedScheme) throws Exception {
SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com"));
Assert.assertThat(info.getInfoUrl(), is(equalTo(new URI(expectedScheme + "://example.com/info"))));
}
@Test
public void transportUrl() throws Exception {
testTransportUrl("http", "http", TransportType.XHR_STREAMING);
testTransportUrl("http", "ws", TransportType.WEBSOCKET);
testTransportUrl("https", "https", TransportType.XHR_STREAMING);
testTransportUrl("https", "wss", TransportType.WEBSOCKET);
testTransportUrl("ws", "http", TransportType.XHR_STREAMING);
testTransportUrl("ws", "ws", TransportType.WEBSOCKET);
testTransportUrl("wss", "https", TransportType.XHR_STREAMING);
testTransportUrl("wss", "wss", TransportType.WEBSOCKET);
}
private void testTransportUrl(String scheme, String expectedScheme, TransportType transportType) throws Exception {
SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com"));
String serverId = info.getServerId();
String sessionId = info.getSessionId();
String transport = transportType.toString().toLowerCase();
URI expected = new URI(expectedScheme + "://example.com/" + serverId + "/" + sessionId + "/" + transport);
assertThat(info.getTransportUrl(transportType), equalTo(expected));
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.mockito.ArgumentCaptor;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import java.net.URI;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Test SockJS Transport.
*
* @author Rossen Stoyanchev
*/
class TestTransport implements Transport {
private final String name;
private TransportRequest request;
private ListenableFuture future;
public TestTransport(String name) {
this.name = name;
}
public TransportRequest getRequest() {
return this.request;
}
public boolean invoked() {
return this.future != null;
}
@SuppressWarnings("unchecked")
public ListenableFutureCallback<WebSocketSession> getConnectCallback() {
ArgumentCaptor<ListenableFutureCallback> captor = ArgumentCaptor.forClass(ListenableFutureCallback.class);
verify(this.future).addCallback(captor.capture());
return captor.getValue();
}
@SuppressWarnings("unchecked")
@Override
public ListenableFuture<WebSocketSession> connect(TransportRequest request, WebSocketHandler handler) {
this.request = request;
this.future = mock(ListenableFuture.class);
return this.future;
}
@Override
public String toString() {
return "TestTransport[" + name + "]";
}
static class XhrTestTransport extends TestTransport implements XhrTransport {
private boolean streamingDisabled;
XhrTestTransport(String name) {
super(name);
}
public void setStreamingDisabled(boolean streamingDisabled) {
this.streamingDisabled = streamingDisabled;
}
@Override
public boolean isXhrStreamingDisabled() {
return this.streamingDisabled;
}
@Override
public void executeSendRequest(URI transportUrl, TextMessage message) {
}
@Override
public String executeInfoRequest(URI infoUrl) {
return null;
}
}
}
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import java.net.URI;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.notNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit tests for
* {@link org.springframework.web.socket.sockjs.client.AbstractXhrTransport}.
*
* @author Rossen Stoyanchev
*/
public class XhrTransportTests {
@Test
public void infoResponse() throws Exception {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK);
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info")));
}
@Test(expected = HttpServerErrorException.class)
public void infoResponseError() throws Exception {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST);
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info")));
}
@Test
public void sendMessage() throws Exception {
HttpHeaders requestHeaders = new HttpHeaders();
requestHeaders.set("foo", "bar");
TestXhrTransport transport = new TestXhrTransport();
transport.setRequestHeaders(requestHeaders);
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT);
URI url = new URI("http://example.com");
transport.executeSendRequest(url, new TextMessage("payload"));
assertEquals(2, transport.actualSendRequestHeaders.size());
assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo"));
assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType());
}
@Test(expected = HttpServerErrorException.class)
public void sendMessageError() throws Exception {
TestXhrTransport transport = new TestXhrTransport();
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST);
URI url = new URI("http://example.com");
transport.executeSendRequest(url, new TextMessage("payload"));
}
@Test
public void connect() throws Exception {
HttpHeaders handshakeHeaders = new HttpHeaders();
handshakeHeaders.setOrigin("foo");
TransportRequest request = mock(TransportRequest.class);
when(request.getSockJsUrlInfo()).thenReturn(new SockJsUrlInfo(new URI("http://example.com")));
when(request.getHandshakeHeaders()).thenReturn(handshakeHeaders);
HttpHeaders requestHeaders = new HttpHeaders();
requestHeaders.set("foo", "bar");
TestXhrTransport transport = new TestXhrTransport();
transport.setRequestHeaders(requestHeaders);
WebSocketHandler handler = mock(WebSocketHandler.class);
transport.connect(request, handler);
ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class);
verify(request).getSockJsUrlInfo();
verify(request).addTimeoutTask(captor.capture());
verify(request).getTransportUrl();
verify(request).getHandshakeHeaders();
verifyNoMoreInteractions(request);
assertEquals(2, transport.actualHandshakeHeaders.size());
assertEquals("foo", transport.actualHandshakeHeaders.getOrigin());
assertEquals("bar", transport.actualHandshakeHeaders.getFirst("foo"));
assertFalse(transport.actualSession.isDisconnected());
captor.getValue().run();
assertTrue(transport.actualSession.isDisconnected());
}
private static class TestXhrTransport extends AbstractXhrTransport {
private ResponseEntity<String> infoResponseToReturn;
private ResponseEntity<String> sendMessageResponseToReturn;
private HttpHeaders actualSendRequestHeaders;
private HttpHeaders actualHandshakeHeaders;
private XhrClientSockJsSession actualSession;
@Override
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
return this.infoResponseToReturn;
}
@Override
protected ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
this.actualSendRequestHeaders = headers;
return this.sendMessageResponseToReturn;
}
@Override
protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl,
HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
SettableListenableFuture<WebSocketSession> connectFuture) {
this.actualHandshakeHeaders = handshakeHeaders;
this.actualSession = session;
}
}
}
log4j.appender.console=org.apache.log4j.ConsoleAppender
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c] - %m%n
log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c][%t] - %m%n
log4j.rootCategory=WARN, console
log4j.logger.org.springframework.web=DEBUG
log4j.logger.org.springframework.web.socket=DEBUG
log4j.logger.org.springframework.web.socket=TRACE
log4j.logger.org.springframework.messaging=DEBUG
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册