提交 d73c2e26 编写于 作者: R Rossen Stoyanchev

Polish handling of STOMP message headers

上级 ba7998d0
...@@ -41,6 +41,10 @@ public abstract class AbstractMessageSendingTemplate<D> implements MessageSendin ...@@ -41,6 +41,10 @@ public abstract class AbstractMessageSendingTemplate<D> implements MessageSendin
this.defaultDestination = defaultDestination; this.defaultDestination = defaultDestination;
} }
public D getDefaultDestination() {
return this.defaultDestination;
}
/** /**
* Set the {@link MessageConverter} that is to be used to convert * Set the {@link MessageConverter} that is to be used to convert
* between Messages and objects for this template. * between Messages and objects for this template.
...@@ -82,7 +86,7 @@ public abstract class AbstractMessageSendingTemplate<D> implements MessageSendin ...@@ -82,7 +86,7 @@ public abstract class AbstractMessageSendingTemplate<D> implements MessageSendin
this.doSend(destination, message); this.doSend(destination, message);
} }
protected abstract void doSend(D destination, Message<?> message) ; protected abstract void doSend(D destination, Message<?> message);
@Override @Override
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package org.springframework.messaging.simp; package org.springframework.messaging.simp;
import java.security.Principal; import java.security.Principal;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
...@@ -26,7 +25,6 @@ import org.springframework.messaging.Message; ...@@ -26,7 +25,6 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.NativeMessageHeaderAccessor; import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/** /**
...@@ -43,35 +41,25 @@ import org.springframework.util.CollectionUtils; ...@@ -43,35 +41,25 @@ import org.springframework.util.CollectionUtils;
*/ */
public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String DESTINATIONS = "destinations"; public static final String DESTINATION_HEADER = "destination";
public static final String MESSAGE_TYPE = "messageType"; public static final String MESSAGE_TYPE_HEADER = "messageType";
// TODO public static final String SESSION_ID_HEADER = "sessionId";
public static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
public static final String SESSION_ID = "sessionId"; public static final String SUBSCRIPTION_ID_HEADER = "subscriptionId";
public static final String SUBSCRIPTION_ID = "subscriptionId"; public static final String USER_HEADER = "user";
public static final String USER = "user";
/** /**
* A constructor for creating new message headers. * A constructor for creating new message headers.
* This constructor is protected. See factory methods in this and sub-classes. * This constructor is protected. See factory methods in this and sub-classes.
*/ */
protected SimpMessageHeaderAccessor(SimpMessageType messageType, Object protocolMessageType, protected SimpMessageHeaderAccessor(SimpMessageType messageType, Map<String, List<String>> externalSourceHeaders) {
Map<String, List<String>> externalSourceHeaders) {
super(externalSourceHeaders); super(externalSourceHeaders);
Assert.notNull(messageType, "messageType is required"); Assert.notNull(messageType, "messageType is required");
setHeader(MESSAGE_TYPE, messageType); setHeader(MESSAGE_TYPE_HEADER, messageType);
if (protocolMessageType != null) {
setHeader(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
}
} }
/** /**
...@@ -89,14 +77,14 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { ...@@ -89,14 +77,14 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
* {@link SimpMessageType#MESSAGE}. * {@link SimpMessageType#MESSAGE}.
*/ */
public static SimpMessageHeaderAccessor create() { public static SimpMessageHeaderAccessor create() {
return new SimpMessageHeaderAccessor(SimpMessageType.MESSAGE, null, null); return new SimpMessageHeaderAccessor(SimpMessageType.MESSAGE, null);
} }
/** /**
* Create {@link SimpMessageHeaderAccessor} for a new {@link Message} of a specific type. * Create {@link SimpMessageHeaderAccessor} for a new {@link Message} of a specific type.
*/ */
public static SimpMessageHeaderAccessor create(SimpMessageType messageType) { public static SimpMessageHeaderAccessor create(SimpMessageType messageType) {
return new SimpMessageHeaderAccessor(messageType, null, null); return new SimpMessageHeaderAccessor(messageType, null);
} }
/** /**
...@@ -106,39 +94,23 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { ...@@ -106,39 +94,23 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
return new SimpMessageHeaderAccessor(message); return new SimpMessageHeaderAccessor(message);
} }
public void setMessageTypeIfNotSet(SimpMessageType messageType) {
public SimpMessageType getMessageType() { if (getMessageType() == null) {
return (SimpMessageType) getHeader(MESSAGE_TYPE); setHeader(MESSAGE_TYPE_HEADER, messageType);
} }
protected void setProtocolMessageType(Object protocolMessageType) {
setHeader(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
} }
protected Object getProtocolMessageType() { public SimpMessageType getMessageType() {
return getHeader(PROTOCOL_MESSAGE_TYPE); return (SimpMessageType) getHeader(MESSAGE_TYPE_HEADER);
} }
public void setDestination(String destination) { public void setDestination(String destination) {
Assert.notNull(destination, "destination is required"); Assert.notNull(destination, "destination is required");
setHeader(DESTINATIONS, Arrays.asList(destination)); setHeader(DESTINATION_HEADER, destination);
} }
@SuppressWarnings("unchecked")
public String getDestination() { public String getDestination() {
List<String> destinations = (List<String>) getHeader(DESTINATIONS); return (String) getHeader(DESTINATION_HEADER);
return CollectionUtils.isEmpty(destinations) ? null : destinations.get(0);
}
@SuppressWarnings("unchecked")
public List<String> getDestinations() {
List<String> destinations = (List<String>) getHeader(DESTINATIONS);
return CollectionUtils.isEmpty(destinations) ? null : destinations;
}
public void setDestinations(List<String> destinations) {
Assert.notNull(destinations, "destinations are required");
setHeader(DESTINATIONS, destinations);
} }
public MediaType getContentType() { public MediaType getContentType() {
...@@ -150,27 +122,27 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { ...@@ -150,27 +122,27 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
} }
public String getSubscriptionId() { public String getSubscriptionId() {
return (String) getHeader(SUBSCRIPTION_ID); return (String) getHeader(SUBSCRIPTION_ID_HEADER);
} }
public void setSubscriptionId(String subscriptionId) { public void setSubscriptionId(String subscriptionId) {
setHeader(SUBSCRIPTION_ID, subscriptionId); setHeader(SUBSCRIPTION_ID_HEADER, subscriptionId);
} }
public String getSessionId() { public String getSessionId() {
return (String) getHeader(SESSION_ID); return (String) getHeader(SESSION_ID_HEADER);
} }
public void setSessionId(String sessionId) { public void setSessionId(String sessionId) {
setHeader(SESSION_ID, sessionId); setHeader(SESSION_ID_HEADER, sessionId);
} }
public Principal getUser() { public Principal getUser() {
return (Principal) getHeader(USER); return (Principal) getHeader(USER_HEADER);
} }
public void setUser(Principal principal) { public void setUser(Principal principal) {
setHeader(USER, principal); setHeader(USER_HEADER, principal);
} }
} }
...@@ -22,13 +22,31 @@ import org.springframework.messaging.core.MessageSendingOperations; ...@@ -22,13 +22,31 @@ import org.springframework.messaging.core.MessageSendingOperations;
/** /**
* A specialization of {@link MessageSendingOperations} with methods for use with
* the Spring Framework support for simple messaging protocols (like STOMP).
*
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public interface SimpMessageSendingOperations extends MessageSendingOperations<String> { public interface SimpMessageSendingOperations extends MessageSendingOperations<String> {
/**
* Send a message to a specific user.
*
* @param user the user that should receive the message.
* @param destination the destination to send the message to.
* @param message the message to send
*/
<T> void convertAndSendToUser(String user, String destination, T message) throws MessagingException; <T> void convertAndSendToUser(String user, String destination, T message) throws MessagingException;
/**
* Send a message to a specific user.
*
* @param user the user that should receive the message.
* @param destination the destination to send the message to.
* @param message the message to send
* @param postProcessor a postProcessor to post-process or modify the created message
*/
<T> void convertAndSendToUser(String user, String destination, T message, MessagePostProcessor postProcessor) <T> void convertAndSendToUser(String user, String destination, T message, MessagePostProcessor postProcessor)
throws MessagingException; throws MessagingException;
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
*/ */
package org.springframework.messaging.simp; package org.springframework.messaging.simp;
import java.util.Arrays;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.MessageDeliveryException;
...@@ -56,6 +54,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String ...@@ -56,6 +54,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
* @see org.springframework.messaging.simp.handler.UserDestinationMessageHandler * @see org.springframework.messaging.simp.handler.UserDestinationMessageHandler
*/ */
public void setUserDestinationPrefix(String prefix) { public void setUserDestinationPrefix(String prefix) {
Assert.notNull(prefix, "userDestinationPrefix is required");
this.userDestinationPrefix = prefix; this.userDestinationPrefix = prefix;
} }
...@@ -92,31 +91,32 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String ...@@ -92,31 +91,32 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
@Override @Override
public <P> void send(Message<P> message) { public <P> void send(Message<P> message) {
// TODO: maybe look up destination of current message (via ThreadLocal) SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
this.send(getRequiredDefaultDestination(), message); String destination = headers.getDestination();
destination = (destination != null) ? destination : getRequiredDefaultDestination();
doSend(getRequiredDefaultDestination(), message);
} }
@Override @Override
protected void doSend(String destination, Message<?> message) { protected void doSend(String destination, Message<?> message) {
Assert.notNull(destination, "destination is required"); Assert.notNull(destination, "destination is required");
message = updateMessageHeaders(message, destination);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setDestination(destination);
headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
long timeout = this.sendTimeout; long timeout = this.sendTimeout;
boolean sent = (timeout >= 0) boolean sent = (timeout >= 0)
? this.messageChannel.send(message, timeout) ? this.messageChannel.send(message, timeout)
: this.messageChannel.send(message); : this.messageChannel.send(message);
if (!sent) { if (!sent) {
throw new MessageDeliveryException(message, throw new MessageDeliveryException(message,
"failed to send message to destination '" + destination + "' within timeout: " + timeout); "failed to send message to destination '" + destination + "' within timeout: " + timeout);
} }
} }
protected <P> Message<P> updateMessageHeaders(Message<P> message, String destination) {
Assert.notNull(destination, "destination is required");
return MessageBuilder.fromMessage(message)
.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE, SimpMessageType.MESSAGE)
.setHeader(SimpMessageHeaderAccessor.DESTINATIONS, Arrays.asList(destination)).build();
}
@Override @Override
public <T> void convertAndSendToUser(String user, String destination, T message) throws MessagingException { public <T> void convertAndSendToUser(String user, String destination, T message) throws MessagingException {
convertAndSendToUser(user, destination, message, null); convertAndSendToUser(user, destination, message, null);
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
package org.springframework.messaging.simp.annotation.support; package org.springframework.messaging.simp.annotation.support;
import java.security.Principal;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
...@@ -69,7 +67,10 @@ public class ReplyToMethodReturnValueHandler implements HandlerMethodReturnValue ...@@ -69,7 +67,10 @@ public class ReplyToMethodReturnValueHandler implements HandlerMethodReturnValue
return; return;
} }
MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(inputMessage); SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage);
String sessionId = inputHeaders.getSessionId();
MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId);
ReplyTo replyTo = returnType.getMethodAnnotation(ReplyTo.class); ReplyTo replyTo = returnType.getMethodAnnotation(ReplyTo.class);
if (replyTo != null) { if (replyTo != null) {
...@@ -80,37 +81,30 @@ public class ReplyToMethodReturnValueHandler implements HandlerMethodReturnValue ...@@ -80,37 +81,30 @@ public class ReplyToMethodReturnValueHandler implements HandlerMethodReturnValue
ReplyToUser replyToUser = returnType.getMethodAnnotation(ReplyToUser.class); ReplyToUser replyToUser = returnType.getMethodAnnotation(ReplyToUser.class);
if (replyToUser != null) { if (replyToUser != null) {
String user = getUser(inputMessage).getName(); if (inputHeaders.getUser() == null) {
throw new MissingSessionUserException(inputMessage);
}
String user = inputHeaders.getUser().getName();
for (String destination : replyToUser.value()) { for (String destination : replyToUser.value()) {
this.messagingTemplate.convertAndSendToUser(user, destination, returnValue, postProcessor); this.messagingTemplate.convertAndSendToUser(user, destination, returnValue, postProcessor);
} }
} }
} }
private Principal getUser(Message<?> inputMessage) {
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage);
Principal user = inputHeaders.getUser();
if (user == null) {
throw new MissingSessionUserException(inputMessage);
}
return user;
}
private final class SessionHeaderPostProcessor implements MessagePostProcessor { private final class SessionHeaderPostProcessor implements MessagePostProcessor {
private final Message<?> inputMessage; private final String sessionId;
public SessionHeaderPostProcessor(Message<?> inputMessage) { public SessionHeaderPostProcessor(String sessionId) {
this.inputMessage = inputMessage; this.sessionId = sessionId;
} }
@Override @Override
public Message<?> postProcessMessage(Message<?> message) { public Message<?> postProcessMessage(Message<?> message) {
String headerName = SimpMessageHeaderAccessor.SESSION_ID; SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
String sessionId = (String) this.inputMessage.getHeaders().get(headerName); headers.setSessionId(this.sessionId);
return MessageBuilder.fromMessage(message).setHeader(headerName, sessionId).build(); return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
} }
} }
} }
...@@ -68,32 +68,37 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn ...@@ -68,32 +68,37 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
} }
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(message); SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(message);
String sessionId = inputHeaders.getSessionId();
String subscriptionId = inputHeaders.getSubscriptionId();
String destination = inputHeaders.getDestination(); String destination = inputHeaders.getDestination();
Assert.state(inputHeaders.getSubscriptionId() != null, Assert.state(inputHeaders.getSubscriptionId() != null,
"No subsriptiondId in input message. Add @ReplyTo or @ReplyToUser to method: " "No subsriptiondId in input message. Add @ReplyTo or @ReplyToUser to method: "
+ returnType.getMethod()); + returnType.getMethod());
MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(inputHeaders); MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(sessionId, subscriptionId);
this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor); this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor);
} }
private final class SubscriptionHeaderPostProcessor implements MessagePostProcessor { private final class SubscriptionHeaderPostProcessor implements MessagePostProcessor {
private final SimpMessageHeaderAccessor inputHeaders; private final String sessionId;
private final String subscriptionId;
public SubscriptionHeaderPostProcessor(SimpMessageHeaderAccessor inputHeaders) {
this.inputHeaders = inputHeaders; public SubscriptionHeaderPostProcessor(String sessionId, String subscriptionId) {
this.sessionId = sessionId;
this.subscriptionId = subscriptionId;
} }
@Override @Override
public Message<?> postProcessMessage(Message<?> message) { public Message<?> postProcessMessage(Message<?> message) {
return MessageBuilder.fromMessage(message) SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
.setHeader(SimpMessageHeaderAccessor.SESSION_ID, this.inputHeaders.getSessionId()) headers.setSessionId(this.sessionId);
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID, this.inputHeaders.getSubscriptionId()) headers.setSubscriptionId(this.subscriptionId);
.build(); return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
} }
} }
} }
...@@ -23,8 +23,8 @@ package org.springframework.messaging.simp.handler; ...@@ -23,8 +23,8 @@ package org.springframework.messaging.simp.handler;
*/ */
public interface MutableUserSessionResolver extends UserSessionResolver { public interface MutableUserSessionResolver extends UserSessionResolver {
void storeUserSessionId(String user, String sessionId); void addUserSessionId(String user, String sessionId);
void deleteUserSessionId(String user, String sessionId); void removeUserSessionId(String user, String sessionId);
} }
...@@ -94,8 +94,8 @@ public class SimpleBrokerMessageHandler implements MessageHandler { ...@@ -94,8 +94,8 @@ public class SimpleBrokerMessageHandler implements MessageHandler {
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId); headers.setSubscriptionId(subscriptionId);
Message<?> clientMessage = MessageBuilder.withPayload( Object payload = message.getPayload();
message.getPayload()).copyHeaders(headers.toMap()).build(); Message<?> clientMessage = MessageBuilder.withPayloadAndHeaders(payload, headers).build();
try { try {
this.outboundChannel.send(clientMessage); this.outboundChannel.send(clientMessage);
} }
......
...@@ -34,7 +34,7 @@ public class SimpleUserSessionResolver implements MutableUserSessionResolver { ...@@ -34,7 +34,7 @@ public class SimpleUserSessionResolver implements MutableUserSessionResolver {
@Override @Override
public void storeUserSessionId(String user, String sessionId) { public void addUserSessionId(String user, String sessionId) {
Set<String> sessionIds = this.userSessionIds.get(user); Set<String> sessionIds = this.userSessionIds.get(user);
if (sessionIds == null) { if (sessionIds == null) {
sessionIds = new CopyOnWriteArraySet<String>(); sessionIds = new CopyOnWriteArraySet<String>();
...@@ -44,7 +44,7 @@ public class SimpleUserSessionResolver implements MutableUserSessionResolver { ...@@ -44,7 +44,7 @@ public class SimpleUserSessionResolver implements MutableUserSessionResolver {
} }
@Override @Override
public void deleteUserSessionId(String user, String sessionId) { public void removeUserSessionId(String user, String sessionId) {
Set<String> sessionIds = this.userSessionIds.get(user); Set<String> sessionIds = this.userSessionIds.get(user);
if (sessionIds != null) { if (sessionIds != null) {
if (sessionIds.remove(sessionId) && sessionIds.isEmpty()) { if (sessionIds.remove(sessionId) && sessionIds.isEmpty()) {
......
...@@ -120,7 +120,7 @@ public class UserDestinationMessageHandler implements MessageHandler { ...@@ -120,7 +120,7 @@ public class UserDestinationMessageHandler implements MessageHandler {
String targetDestination = destinationParser.getTargetDestination(sessionId); String targetDestination = destinationParser.getTargetDestination(sessionId);
headers.setDestination(targetDestination); headers.setDestination(targetDestination);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build(); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Sending message to resolved target destination " + targetDestination); logger.trace("Sending message to resolved target destination " + targetDestination);
......
...@@ -219,7 +219,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife ...@@ -219,7 +219,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Sending STOMP CONNECT frame to initialize \"system\" TCP connection"); logger.debug("Sending STOMP CONNECT frame to initialize \"system\" TCP connection");
} }
Message<?> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
session.open(message); session.open(message);
} }
...@@ -259,7 +259,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife ...@@ -259,7 +259,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
String sessionId = headers.getSessionId(); String sessionId = headers.getSessionId();
String destination = headers.getDestination(); String destination = headers.getDestination();
StompCommand command = headers.getStompCommand(); StompCommand command = headers.getCommand();
SimpMessageType messageType = headers.getMessageType(); SimpMessageType messageType = headers.getMessageType();
if (!this.running) { if (!this.running) {
...@@ -273,11 +273,11 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife ...@@ -273,11 +273,11 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
sessionId = (sessionId == null) ? STOMP_RELAY_SYSTEM_SESSION_ID : sessionId; sessionId = (sessionId == null) ? STOMP_RELAY_SYSTEM_SESSION_ID : sessionId;
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
command = (command == null) ? StompCommand.SEND : command; command = (command == null) ? StompCommand.SEND : command;
headers.setStompCommandIfNotSet(command); headers.setCommandIfNotSet(command);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build(); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
} }
if (headers.getStompCommand() == null) { if (headers.getCommand() == null) {
logger.error("Ignoring message, no STOMP command: " + message); logger.error("Ignoring message, no STOMP command: " + message);
return; return;
} }
...@@ -397,7 +397,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife ...@@ -397,7 +397,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
} }
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (StompCommand.CONNECTED == headers.getStompCommand()) { if (StompCommand.CONNECTED == headers.getCommand()) {
synchronized(this.monitor) { synchronized(this.monitor) {
this.isConnected = true; this.isConnected = true;
flushMessages(this.promise.get()); flushMessages(this.promise.get());
...@@ -406,7 +406,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife ...@@ -406,7 +406,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
} }
headers.setSessionId(this.sessionId); headers.setSessionId(this.sessionId);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build(); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
sendMessageToClient(message); sendMessageToClient(message);
} }
...@@ -418,7 +418,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife ...@@ -418,7 +418,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
headers.setMessage(errorText); headers.setMessage(errorText);
Message<?> errorMessage = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); Message<?> errorMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
sendMessageToClient(errorMessage); sendMessageToClient(errorMessage);
} }
......
...@@ -44,37 +44,44 @@ import org.springframework.util.StringUtils; ...@@ -44,37 +44,44 @@ import org.springframework.util.StringUtils;
*/ */
public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
public static final String STOMP_ID = "id"; // STOMP header names
public static final String HOST = "host"; public static final String STOMP_ID_HEADER = "id";
public static final String ACCEPT_VERSION = "accept-version"; public static final String STOMP_HOST_HEADER = "host";
public static final String MESSAGE_ID = "message-id"; public static final String STOMP_ACCEPT_VERSION_HEADER = "accept-version";
public static final String RECEIPT_ID = "receipt-id"; public static final String STOMP_MESSAGE_ID_HEADER = "message-id";
public static final String SUBSCRIPTION = "subscription"; public static final String STOMP_RECEIPT_ID_HEADER = "receipt-id";
public static final String VERSION = "version"; public static final String STOMP_SUBSCRIPTION_HEADER = "subscription";
public static final String MESSAGE = "message"; public static final String STOMP_VERSION_HEADER = "version";
public static final String ACK = "ack"; public static final String STOMP_MESSAGE_HEADER = "message";
public static final String NACK = "nack"; public static final String STOMP_ACK_HEADER = "ack";
public static final String LOGIN = "login"; public static final String STOMP_NACK_HEADER = "nack";
public static final String PASSCODE = "passcode"; public static final String STOMP_LOGIN_HEADER = "login";
public static final String DESTINATION = "destination"; public static final String STOMP_PASSCODE_HEADER = "passcode";
public static final String CONTENT_TYPE = "content-type"; public static final String STOMP_DESTINATION_HEADER = "destination";
public static final String CONTENT_LENGTH = "content-length"; public static final String STOMP_CONTENT_TYPE_HEADER = "content-type";
public static final String HEARTBEAT = "heart-beat"; public static final String STOMP_CONTENT_LENGTH_HEADER = "content-length";
public static final String STOMP_HEARTBEAT_HEADER = "heart-beat";
// Other header names
public static final String COMMAND_HEADER = "stompCommand";
private static final AtomicLong messageIdCounter = new AtomicLong(); private static final AtomicLong messageIdCounter = new AtomicLong();
...@@ -84,30 +91,37 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { ...@@ -84,30 +91,37 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
* A constructor for creating new STOMP message headers. * A constructor for creating new STOMP message headers.
*/ */
private StompHeaderAccessor(StompCommand command, Map<String, List<String>> externalSourceHeaders) { private StompHeaderAccessor(StompCommand command, Map<String, List<String>> externalSourceHeaders) {
super(command.getMessageType(), command, externalSourceHeaders);
super(command.getMessageType(), externalSourceHeaders);
Assert.notNull(command, "command is required");
setHeader(COMMAND_HEADER, command);
if (externalSourceHeaders != null) { if (externalSourceHeaders != null) {
setSimpMessageHeaders(externalSourceHeaders); setSimpMessageHeaders(command, externalSourceHeaders);
} }
} }
private void setSimpMessageHeaders(Map<String, List<String>> extHeaders) { private void setSimpMessageHeaders(StompCommand command, Map<String, List<String>> extHeaders) {
List<String> values = extHeaders.get(StompHeaderAccessor.DESTINATION);
List<String> values = extHeaders.get(StompHeaderAccessor.STOMP_DESTINATION_HEADER);
if (!CollectionUtils.isEmpty(values)) { if (!CollectionUtils.isEmpty(values)) {
super.setDestination(values.get(0)); super.setDestination(values.get(0));
} }
values = extHeaders.get(StompHeaderAccessor.CONTENT_TYPE);
values = extHeaders.get(StompHeaderAccessor.STOMP_CONTENT_TYPE_HEADER);
if (!CollectionUtils.isEmpty(values)) { if (!CollectionUtils.isEmpty(values)) {
super.setContentType(MediaType.parseMediaType(values.get(0))); super.setContentType(MediaType.parseMediaType(values.get(0)));
} }
StompCommand command = getStompCommand();
if (StompCommand.SUBSCRIBE.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) { if (StompCommand.SUBSCRIBE.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) {
values = extHeaders.get(StompHeaderAccessor.STOMP_ID); values = extHeaders.get(StompHeaderAccessor.STOMP_ID_HEADER);
if (!CollectionUtils.isEmpty(values)) { if (!CollectionUtils.isEmpty(values)) {
super.setSubscriptionId(values.get(0)); super.setSubscriptionId(values.get(0));
} }
} }
else if (StompCommand.MESSAGE.equals(command)) { else if (StompCommand.MESSAGE.equals(command)) {
values = extHeaders.get(StompHeaderAccessor.SUBSCRIPTION); values = extHeaders.get(StompHeaderAccessor.STOMP_SUBSCRIPTION_HEADER);
if (!CollectionUtils.isEmpty(values)) { if (!CollectionUtils.isEmpty(values)) {
super.setSubscriptionId(values.get(0)); super.setSubscriptionId(values.get(0));
} }
...@@ -154,73 +168,66 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { ...@@ -154,73 +168,66 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
String destination = super.getDestination(); String destination = super.getDestination();
if (destination != null) { if (destination != null) {
result.put(DESTINATION, Arrays.asList(destination)); result.put(STOMP_DESTINATION_HEADER, Arrays.asList(destination));
} }
MediaType contentType = getContentType(); MediaType contentType = getContentType();
if (contentType != null) { if (contentType != null) {
result.put(CONTENT_TYPE, Arrays.asList(contentType.toString())); result.put(STOMP_CONTENT_TYPE_HEADER, Arrays.asList(contentType.toString()));
} }
if (StompCommand.MESSAGE.equals(getStompCommand())) { if (StompCommand.MESSAGE.equals(getCommand())) {
String subscriptionId = getSubscriptionId(); String subscriptionId = getSubscriptionId();
if (subscriptionId != null) { if (subscriptionId != null) {
result.put(SUBSCRIPTION, Arrays.asList(subscriptionId)); result.put(STOMP_SUBSCRIPTION_HEADER, Arrays.asList(subscriptionId));
} }
else { else {
logger.warn("STOMP MESSAGE frame should have a subscription: " + this.toString()); logger.warn("STOMP MESSAGE frame should have a subscription: " + this.toString());
} }
if ((getMessageId() == null)) { if ((getMessageId() == null)) {
String messageId = getSessionId() + "-" + messageIdCounter.getAndIncrement(); String messageId = getSessionId() + "-" + messageIdCounter.getAndIncrement();
result.put(MESSAGE_ID, Arrays.asList(messageId)); result.put(STOMP_MESSAGE_ID_HEADER, Arrays.asList(messageId));
} }
} }
return result; return result;
} }
public void setStompCommandIfNotSet(StompCommand command) { public void setCommandIfNotSet(StompCommand command) {
if (getStompCommand() == null) { if (getCommand() == null) {
setProtocolMessageType(command); setHeader(COMMAND_HEADER, command);
} }
} }
public StompCommand getStompCommand() { public StompCommand getCommand() {
return (StompCommand) super.getProtocolMessageType(); return (StompCommand) getHeader(COMMAND_HEADER);
} }
public Set<String> getAcceptVersion() { public Set<String> getAcceptVersion() {
String rawValue = getFirstNativeHeader(ACCEPT_VERSION); String rawValue = getFirstNativeHeader(STOMP_ACCEPT_VERSION_HEADER);
return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.<String>emptySet(); return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.<String>emptySet();
} }
public void setAcceptVersion(String acceptVersion) { public void setAcceptVersion(String acceptVersion) {
setNativeHeader(ACCEPT_VERSION, acceptVersion); setNativeHeader(STOMP_ACCEPT_VERSION_HEADER, acceptVersion);
} }
public void setHost(String host) { public void setHost(String host) {
setNativeHeader(HOST, host); setNativeHeader(STOMP_HOST_HEADER, host);
} }
public String getHost() { public String getHost() {
return getFirstNativeHeader(HOST); return getFirstNativeHeader(STOMP_HOST_HEADER);
} }
@Override @Override
public void setDestination(String destination) { public void setDestination(String destination) {
super.setDestination(destination); super.setDestination(destination);
setNativeHeader(DESTINATION, destination); setNativeHeader(STOMP_DESTINATION_HEADER, destination);
}
@Override
public void setDestinations(List<String> destinations) {
Assert.isTrue((destinations != null) && (destinations.size() == 1), "STOMP allows one destination per message");
super.setDestinations(destinations);
setNativeHeader(DESTINATION, destinations.get(0));
} }
public long[] getHeartbeat() { public long[] getHeartbeat() {
String rawValue = getFirstNativeHeader(HEARTBEAT); String rawValue = getFirstNativeHeader(STOMP_HEARTBEAT_HEADER);
if (!StringUtils.hasText(rawValue)) { if (!StringUtils.hasText(rawValue)) {
return null; return null;
} }
...@@ -232,91 +239,91 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { ...@@ -232,91 +239,91 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
public void setContentType(MediaType mediaType) { public void setContentType(MediaType mediaType) {
if (mediaType != null) { if (mediaType != null) {
super.setContentType(mediaType); super.setContentType(mediaType);
setNativeHeader(CONTENT_TYPE, mediaType.toString()); setNativeHeader(STOMP_CONTENT_TYPE_HEADER, mediaType.toString());
} }
} }
public MediaType getContentType() { public MediaType getContentType() {
String value = getFirstNativeHeader(CONTENT_TYPE); String value = getFirstNativeHeader(STOMP_CONTENT_TYPE_HEADER);
return (value != null) ? MediaType.parseMediaType(value) : null; return (value != null) ? MediaType.parseMediaType(value) : null;
} }
public Integer getContentLength() { public Integer getContentLength() {
String contentLength = getFirstNativeHeader(CONTENT_LENGTH); String contentLength = getFirstNativeHeader(STOMP_CONTENT_LENGTH_HEADER);
return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null; return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null;
} }
public void setContentLength(int contentLength) { public void setContentLength(int contentLength) {
setNativeHeader(CONTENT_LENGTH, String.valueOf(contentLength)); setNativeHeader(STOMP_CONTENT_LENGTH_HEADER, String.valueOf(contentLength));
} }
public void setHeartbeat(long cx, long cy) { public void setHeartbeat(long cx, long cy) {
setNativeHeader(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy})); setNativeHeader(STOMP_HEARTBEAT_HEADER, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
} }
public void setAck(String ack) { public void setAck(String ack) {
setNativeHeader(ACK, ack); setNativeHeader(STOMP_ACK_HEADER, ack);
} }
public String getAck() { public String getAck() {
return getFirstNativeHeader(ACK); return getFirstNativeHeader(STOMP_ACK_HEADER);
} }
public void setNack(String nack) { public void setNack(String nack) {
setNativeHeader(NACK, nack); setNativeHeader(STOMP_NACK_HEADER, nack);
} }
public String getNack() { public String getNack() {
return getFirstNativeHeader(NACK); return getFirstNativeHeader(STOMP_NACK_HEADER);
} }
public void setLogin(String login) { public void setLogin(String login) {
setNativeHeader(LOGIN, login); setNativeHeader(STOMP_LOGIN_HEADER, login);
} }
public String getLogin() { public String getLogin() {
return getFirstNativeHeader(LOGIN); return getFirstNativeHeader(STOMP_LOGIN_HEADER);
} }
public void setPasscode(String passcode) { public void setPasscode(String passcode) {
setNativeHeader(PASSCODE, passcode); setNativeHeader(STOMP_PASSCODE_HEADER, passcode);
} }
public String getPasscode() { public String getPasscode() {
return getFirstNativeHeader(PASSCODE); return getFirstNativeHeader(STOMP_PASSCODE_HEADER);
} }
public void setReceiptId(String receiptId) { public void setReceiptId(String receiptId) {
setNativeHeader(RECEIPT_ID, receiptId); setNativeHeader(STOMP_RECEIPT_ID_HEADER, receiptId);
} }
public String getReceiptId() { public String getReceiptId() {
return getFirstNativeHeader(RECEIPT_ID); return getFirstNativeHeader(STOMP_RECEIPT_ID_HEADER);
} }
public String getMessage() { public String getMessage() {
return getFirstNativeHeader(MESSAGE); return getFirstNativeHeader(STOMP_MESSAGE_HEADER);
} }
public void setMessage(String content) { public void setMessage(String content) {
setNativeHeader(MESSAGE, content); setNativeHeader(STOMP_MESSAGE_HEADER, content);
} }
public String getMessageId() { public String getMessageId() {
return getFirstNativeHeader(MESSAGE_ID); return getFirstNativeHeader(STOMP_MESSAGE_ID_HEADER);
} }
public void setMessageId(String id) { public void setMessageId(String id) {
setNativeHeader(MESSAGE_ID, id); setNativeHeader(STOMP_MESSAGE_ID_HEADER, id);
} }
public String getVersion() { public String getVersion() {
return getFirstNativeHeader(VERSION); return getFirstNativeHeader(STOMP_VERSION_HEADER);
} }
public void setVersion(String version) { public void setVersion(String version) {
setNativeHeader(VERSION, version); setNativeHeader(STOMP_VERSION_HEADER, version);
} }
} }
...@@ -93,9 +93,8 @@ public class StompMessageConverter { ...@@ -93,9 +93,8 @@ public class StompMessageConverter {
byte[] payload = new byte[totalLength - payloadIndex]; byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers); StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
return MessageBuilder.withPayload(payload).copyHeaders(stompHeaders.toMap()).build(); return MessageBuilder.withPayloadAndHeaders(payload, stompHeaders).build();
} }
private int findIndexOfPayload(byte[] bytes) { private int findIndexOfPayload(byte[] bytes) {
...@@ -140,7 +139,7 @@ public class StompMessageConverter { ...@@ -140,7 +139,7 @@ public class StompMessageConverter {
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
try { try {
out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8")); out.write(stompHeaders.getCommand().toString().getBytes("UTF-8"));
out.write(LF); out.write(LF);
for (Entry<String, List<String>> entry : stompHeaders.toNativeHeaderMap().entrySet()) { for (Entry<String, List<String>> entry : stompHeaders.toNativeHeaderMap().entrySet()) {
String key = entry.getKey(); String key = entry.getKey();
......
...@@ -27,8 +27,8 @@ import org.springframework.messaging.Message; ...@@ -27,8 +27,8 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.UserDestinationMessageHandler;
import org.springframework.messaging.simp.handler.MutableUserSessionResolver; import org.springframework.messaging.simp.handler.MutableUserSessionResolver;
import org.springframework.messaging.simp.handler.UserDestinationMessageHandler;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
...@@ -57,8 +57,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -57,8 +57,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
public static final String QUEUE_SUFFIX_HEADER = "queue-suffix"; public static final String QUEUE_SUFFIX_HEADER = "queue-suffix";
private static final byte[] EMPTY_PAYLOAD = new byte[0];
private static Log logger = LogFactory.getLog(StompWebSocketHandler.class); private static Log logger = LogFactory.getLog(StompWebSocketHandler.class);
private MessageChannel clientInputChannel; private MessageChannel clientInputChannel;
...@@ -107,7 +105,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -107,7 +105,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
this.sessions.put(session.getId(), session); this.sessions.put(session.getId(), session);
if ((this.userSessionStore != null) && (session.getPrincipal() != null)) { if ((this.userSessionStore != null) && (session.getPrincipal() != null)) {
this.userSessionStore.storeUserSessionId(session.getPrincipal().getName(), session.getId()); this.userSessionStore.addUserSessionId(session.getPrincipal().getName(), session.getId());
} }
} }
...@@ -120,10 +118,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -120,10 +118,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
String payload = textMessage.getPayload(); String payload = textMessage.getPayload();
Message<?> message = this.stompMessageConverter.toMessage(payload); Message<?> message = this.stompMessageConverter.toMessage(payload);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setSessionId(session.getId());
headers.setUser(session.getPrincipal());
// TODO: validate size limits // TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits // http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
...@@ -132,14 +126,17 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -132,14 +126,17 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
} }
try { try {
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
SimpMessageType messageType = stompHeaders.getMessageType(); headers.setSessionId(session.getId());
if (SimpMessageType.CONNECT.equals(messageType)) { headers.setUser(session.getPrincipal());
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
if (SimpMessageType.CONNECT.equals(headers.getMessageType())) {
handleConnect(session, message); handleConnect(session, message);
} }
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
this.clientInputChannel.send(message); this.clientInputChannel.send(message);
} }
catch (Throwable t) { catch (Throwable t) {
logger.error("Terminating STOMP session due to failure to send message: ", t); logger.error("Terminating STOMP session due to failure to send message: ", t);
...@@ -182,8 +179,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -182,8 +179,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
// TODO: security // TODO: security
Message<?> connectedMessage = MessageBuilder.withPayload(EMPTY_PAYLOAD).copyHeaders( Message<?> connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build();
connectedHeaders.toMap()).build();
byte[] bytes = this.stompMessageConverter.fromMessage(connectedMessage); byte[] bytes = this.stompMessageConverter.fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
} }
...@@ -192,10 +188,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -192,10 +188,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage()); headers.setMessage(error.getMessage());
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
Message<?> message = MessageBuilder.withPayload(EMPTY_PAYLOAD).copyHeaders(headers.toMap()).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message); byte[] bytes = this.stompMessageConverter.fromMessage(message);
try { try {
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
} }
...@@ -211,12 +205,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -211,12 +205,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
this.sessions.remove(sessionId); this.sessions.remove(sessionId);
if ((this.userSessionStore != null) && (session.getPrincipal() != null)) { if ((this.userSessionStore != null) && (session.getPrincipal() != null)) {
this.userSessionStore.deleteUserSessionId(session.getPrincipal().getName(), sessionId); this.userSessionStore.removeUserSessionId(session.getPrincipal().getName(), sessionId);
} }
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
Message<?> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
this.clientInputChannel.send(message); this.clientInputChannel.send(message);
} }
...@@ -227,9 +221,9 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -227,9 +221,9 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
public void handleMessage(Message<?> message) { public void handleMessage(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setStompCommandIfNotSet(StompCommand.MESSAGE); headers.setCommandIfNotSet(StompCommand.MESSAGE);
if (StompCommand.CONNECTED.equals(headers.getStompCommand())) { if (StompCommand.CONNECTED.equals(headers.getCommand())) {
// Ignore for now since we already sent it // Ignore for now since we already sent it
return; return;
} }
...@@ -248,7 +242,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -248,7 +242,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
return; return;
} }
if (StompCommand.MESSAGE.equals(headers.getStompCommand()) && (headers.getSubscriptionId() == null)) { if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) {
// TODO: failed message delivery mechanism // TODO: failed message delivery mechanism
logger.error("Ignoring message, no subscriptionId header: " + message); logger.error("Ignoring message, no subscriptionId header: " + message);
return; return;
...@@ -261,7 +255,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -261,7 +255,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
} }
try { try {
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build(); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message); byte[] bytes = this.stompMessageConverter.fromMessage(message);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
} }
...@@ -269,7 +263,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement ...@@ -269,7 +263,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
sendErrorMessage(session, t); sendErrorMessage(session, t);
} }
finally { finally {
if (StompCommand.ERROR.equals(headers.getStompCommand())) { if (StompCommand.ERROR.equals(headers.getCommand())) {
try { try {
session.close(CloseStatus.PROTOCOL_ERROR); session.close(CloseStatus.PROTOCOL_ERROR);
} }
......
...@@ -44,14 +44,14 @@ public final class MessageBuilder<T> { ...@@ -44,14 +44,14 @@ public final class MessageBuilder<T> {
/** /**
* Private constructor to be invoked from the static factory methods only. * Private constructor to be invoked from the static factory methods only.
*/ */
private MessageBuilder(T payload, Message<T> originalMessage) { private MessageBuilder(T payload, Message<T> originalMessage, MessageHeaderAccessor headerAccessor) {
Assert.notNull(payload, "payload must not be null"); Assert.notNull(payload, "payload must not be null");
this.payload = payload; this.payload = payload;
this.originalMessage = originalMessage; this.originalMessage = originalMessage;
this.headerAccessor = new MessageHeaderAccessor(originalMessage); this.headerAccessor = (headerAccessor != null) ?
headerAccessor : new MessageHeaderAccessor(originalMessage);
} }
/** /**
* Create a builder for a new {@link Message} instance pre-populated with all of the * Create a builder for a new {@link Message} instance pre-populated with all of the
* headers copied from the provided message. The payload of the provided Message will * headers copied from the provided message. The payload of the provided Message will
...@@ -61,7 +61,7 @@ public final class MessageBuilder<T> { ...@@ -61,7 +61,7 @@ public final class MessageBuilder<T> {
*/ */
public static <T> MessageBuilder<T> fromMessage(Message<T> message) { public static <T> MessageBuilder<T> fromMessage(Message<T> message) {
Assert.notNull(message, "message must not be null"); Assert.notNull(message, "message must not be null");
MessageBuilder<T> builder = new MessageBuilder<T>(message.getPayload(), message); MessageBuilder<T> builder = new MessageBuilder<T>(message.getPayload(), message, null);
return builder; return builder;
} }
...@@ -71,7 +71,18 @@ public final class MessageBuilder<T> { ...@@ -71,7 +71,18 @@ public final class MessageBuilder<T> {
* @param payload the payload for the new message * @param payload the payload for the new message
*/ */
public static <T> MessageBuilder<T> withPayload(T payload) { public static <T> MessageBuilder<T> withPayload(T payload) {
MessageBuilder<T> builder = new MessageBuilder<T>(payload, null); MessageBuilder<T> builder = new MessageBuilder<T>(payload, null, null);
return builder;
}
/**
* Create a builder for a new {@link Message} instance with the provided payload and headers.
*
* @param payload the payload for the new message
* @param headerAccessor the headers for the message
*/
public static <T> MessageBuilder<T> withPayloadAndHeaders(T payload, MessageHeaderAccessor headerAccessor) {
MessageBuilder<T> builder = new MessageBuilder<T>(payload, null, headerAccessor);
return builder; return builder;
} }
......
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import org.junit.Test;
import org.springframework.util.LinkedMultiValueMap;
import static org.junit.Assert.*;
/**
* Test fixture for {@link StompHeaderAccessor}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompHeaderAccessorTests {
@Test
public void testStompCommandSet() {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
assertEquals(StompCommand.CONNECTED, accessor.getCommand());
accessor = StompHeaderAccessor.create(StompCommand.CONNECTED, new LinkedMultiValueMap<String, String>());
assertEquals(StompCommand.CONNECTED, accessor.getCommand());
}
}
...@@ -56,17 +56,17 @@ public class StompMessageConverterTests { ...@@ -56,17 +56,17 @@ public class StompMessageConverterTests {
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
Map<String, Object> map = stompHeaders.toMap(); Map<String, Object> map = stompHeaders.toMap();
assertEquals(5, map.size()); assertEquals(5, map.size());
assertNotNull(map.get(MessageHeaders.ID)); assertNotNull(stompHeaders.getId());
assertNotNull(map.get(MessageHeaders.TIMESTAMP)); assertNotNull(stompHeaders.getTimestamp());
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getCommand());
assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS)); assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS));
assertNotNull(map.get(SimpMessageHeaderAccessor.MESSAGE_TYPE));
assertNotNull(map.get(SimpMessageHeaderAccessor.PROTOCOL_MESSAGE_TYPE));
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost()); assertEquals("github.org", stompHeaders.getHost());
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType()); assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand()); assertEquals(StompCommand.CONNECT, stompHeaders.getCommand());
assertNotNull(headers.get(MessageHeaders.ID)); assertNotNull(headers.get(MessageHeaders.ID));
assertNotNull(headers.get(MessageHeaders.TIMESTAMP)); assertNotNull(headers.get(MessageHeaders.TIMESTAMP));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册