提交 d73c2e26 编写于 作者: R Rossen Stoyanchev

Polish handling of STOMP message headers

上级 ba7998d0
......@@ -41,6 +41,10 @@ public abstract class AbstractMessageSendingTemplate<D> implements MessageSendin
this.defaultDestination = defaultDestination;
}
public D getDefaultDestination() {
return this.defaultDestination;
}
/**
* Set the {@link MessageConverter} that is to be used to convert
* between Messages and objects for this template.
......@@ -82,7 +86,7 @@ public abstract class AbstractMessageSendingTemplate<D> implements MessageSendin
this.doSend(destination, message);
}
protected abstract void doSend(D destination, Message<?> message) ;
protected abstract void doSend(D destination, Message<?> message);
@Override
......
......@@ -17,7 +17,6 @@
package org.springframework.messaging.simp;
import java.security.Principal;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
......@@ -26,7 +25,6 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
......@@ -43,35 +41,25 @@ import org.springframework.util.CollectionUtils;
*/
public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String DESTINATIONS = "destinations";
public static final String DESTINATION_HEADER = "destination";
public static final String MESSAGE_TYPE = "messageType";
public static final String MESSAGE_TYPE_HEADER = "messageType";
// TODO
public static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
public static final String SESSION_ID_HEADER = "sessionId";
public static final String SESSION_ID = "sessionId";
public static final String SUBSCRIPTION_ID_HEADER = "subscriptionId";
public static final String SUBSCRIPTION_ID = "subscriptionId";
public static final String USER = "user";
public static final String USER_HEADER = "user";
/**
* A constructor for creating new message headers.
* This constructor is protected. See factory methods in this and sub-classes.
*/
protected SimpMessageHeaderAccessor(SimpMessageType messageType, Object protocolMessageType,
Map<String, List<String>> externalSourceHeaders) {
protected SimpMessageHeaderAccessor(SimpMessageType messageType, Map<String, List<String>> externalSourceHeaders) {
super(externalSourceHeaders);
Assert.notNull(messageType, "messageType is required");
setHeader(MESSAGE_TYPE, messageType);
if (protocolMessageType != null) {
setHeader(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
}
setHeader(MESSAGE_TYPE_HEADER, messageType);
}
/**
......@@ -89,14 +77,14 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
* {@link SimpMessageType#MESSAGE}.
*/
public static SimpMessageHeaderAccessor create() {
return new SimpMessageHeaderAccessor(SimpMessageType.MESSAGE, null, null);
return new SimpMessageHeaderAccessor(SimpMessageType.MESSAGE, null);
}
/**
* Create {@link SimpMessageHeaderAccessor} for a new {@link Message} of a specific type.
*/
public static SimpMessageHeaderAccessor create(SimpMessageType messageType) {
return new SimpMessageHeaderAccessor(messageType, null, null);
return new SimpMessageHeaderAccessor(messageType, null);
}
/**
......@@ -106,39 +94,23 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
return new SimpMessageHeaderAccessor(message);
}
public SimpMessageType getMessageType() {
return (SimpMessageType) getHeader(MESSAGE_TYPE);
}
protected void setProtocolMessageType(Object protocolMessageType) {
setHeader(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
public void setMessageTypeIfNotSet(SimpMessageType messageType) {
if (getMessageType() == null) {
setHeader(MESSAGE_TYPE_HEADER, messageType);
}
}
protected Object getProtocolMessageType() {
return getHeader(PROTOCOL_MESSAGE_TYPE);
public SimpMessageType getMessageType() {
return (SimpMessageType) getHeader(MESSAGE_TYPE_HEADER);
}
public void setDestination(String destination) {
Assert.notNull(destination, "destination is required");
setHeader(DESTINATIONS, Arrays.asList(destination));
setHeader(DESTINATION_HEADER, destination);
}
@SuppressWarnings("unchecked")
public String getDestination() {
List<String> destinations = (List<String>) getHeader(DESTINATIONS);
return CollectionUtils.isEmpty(destinations) ? null : destinations.get(0);
}
@SuppressWarnings("unchecked")
public List<String> getDestinations() {
List<String> destinations = (List<String>) getHeader(DESTINATIONS);
return CollectionUtils.isEmpty(destinations) ? null : destinations;
}
public void setDestinations(List<String> destinations) {
Assert.notNull(destinations, "destinations are required");
setHeader(DESTINATIONS, destinations);
return (String) getHeader(DESTINATION_HEADER);
}
public MediaType getContentType() {
......@@ -150,27 +122,27 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
}
public String getSubscriptionId() {
return (String) getHeader(SUBSCRIPTION_ID);
return (String) getHeader(SUBSCRIPTION_ID_HEADER);
}
public void setSubscriptionId(String subscriptionId) {
setHeader(SUBSCRIPTION_ID, subscriptionId);
setHeader(SUBSCRIPTION_ID_HEADER, subscriptionId);
}
public String getSessionId() {
return (String) getHeader(SESSION_ID);
return (String) getHeader(SESSION_ID_HEADER);
}
public void setSessionId(String sessionId) {
setHeader(SESSION_ID, sessionId);
setHeader(SESSION_ID_HEADER, sessionId);
}
public Principal getUser() {
return (Principal) getHeader(USER);
return (Principal) getHeader(USER_HEADER);
}
public void setUser(Principal principal) {
setHeader(USER, principal);
setHeader(USER_HEADER, principal);
}
}
......@@ -22,13 +22,31 @@ import org.springframework.messaging.core.MessageSendingOperations;
/**
* A specialization of {@link MessageSendingOperations} with methods for use with
* the Spring Framework support for simple messaging protocols (like STOMP).
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface SimpMessageSendingOperations extends MessageSendingOperations<String> {
/**
* Send a message to a specific user.
*
* @param user the user that should receive the message.
* @param destination the destination to send the message to.
* @param message the message to send
*/
<T> void convertAndSendToUser(String user, String destination, T message) throws MessagingException;
/**
* Send a message to a specific user.
*
* @param user the user that should receive the message.
* @param destination the destination to send the message to.
* @param message the message to send
* @param postProcessor a postProcessor to post-process or modify the created message
*/
<T> void convertAndSendToUser(String user, String destination, T message, MessagePostProcessor postProcessor)
throws MessagingException;
......
......@@ -15,8 +15,6 @@
*/
package org.springframework.messaging.simp;
import java.util.Arrays;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageDeliveryException;
......@@ -56,6 +54,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
* @see org.springframework.messaging.simp.handler.UserDestinationMessageHandler
*/
public void setUserDestinationPrefix(String prefix) {
Assert.notNull(prefix, "userDestinationPrefix is required");
this.userDestinationPrefix = prefix;
}
......@@ -92,31 +91,32 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
@Override
public <P> void send(Message<P> message) {
// TODO: maybe look up destination of current message (via ThreadLocal)
this.send(getRequiredDefaultDestination(), message);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
String destination = headers.getDestination();
destination = (destination != null) ? destination : getRequiredDefaultDestination();
doSend(getRequiredDefaultDestination(), message);
}
@Override
protected void doSend(String destination, Message<?> message) {
Assert.notNull(destination, "destination is required");
message = updateMessageHeaders(message, destination);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setDestination(destination);
headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
long timeout = this.sendTimeout;
boolean sent = (timeout >= 0)
? this.messageChannel.send(message, timeout)
: this.messageChannel.send(message);
if (!sent) {
throw new MessageDeliveryException(message,
"failed to send message to destination '" + destination + "' within timeout: " + timeout);
}
}
protected <P> Message<P> updateMessageHeaders(Message<P> message, String destination) {
Assert.notNull(destination, "destination is required");
return MessageBuilder.fromMessage(message)
.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE, SimpMessageType.MESSAGE)
.setHeader(SimpMessageHeaderAccessor.DESTINATIONS, Arrays.asList(destination)).build();
}
@Override
public <T> void convertAndSendToUser(String user, String destination, T message) throws MessagingException {
convertAndSendToUser(user, destination, message, null);
......
......@@ -16,8 +16,6 @@
package org.springframework.messaging.simp.annotation.support;
import java.security.Principal;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
......@@ -69,7 +67,10 @@ public class ReplyToMethodReturnValueHandler implements HandlerMethodReturnValue
return;
}
MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(inputMessage);
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage);
String sessionId = inputHeaders.getSessionId();
MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId);
ReplyTo replyTo = returnType.getMethodAnnotation(ReplyTo.class);
if (replyTo != null) {
......@@ -80,37 +81,30 @@ public class ReplyToMethodReturnValueHandler implements HandlerMethodReturnValue
ReplyToUser replyToUser = returnType.getMethodAnnotation(ReplyToUser.class);
if (replyToUser != null) {
String user = getUser(inputMessage).getName();
if (inputHeaders.getUser() == null) {
throw new MissingSessionUserException(inputMessage);
}
String user = inputHeaders.getUser().getName();
for (String destination : replyToUser.value()) {
this.messagingTemplate.convertAndSendToUser(user, destination, returnValue, postProcessor);
}
}
}
private Principal getUser(Message<?> inputMessage) {
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage);
Principal user = inputHeaders.getUser();
if (user == null) {
throw new MissingSessionUserException(inputMessage);
}
return user;
}
private final class SessionHeaderPostProcessor implements MessagePostProcessor {
private final Message<?> inputMessage;
private final String sessionId;
public SessionHeaderPostProcessor(Message<?> inputMessage) {
this.inputMessage = inputMessage;
public SessionHeaderPostProcessor(String sessionId) {
this.sessionId = sessionId;
}
@Override
public Message<?> postProcessMessage(Message<?> message) {
String headerName = SimpMessageHeaderAccessor.SESSION_ID;
String sessionId = (String) this.inputMessage.getHeaders().get(headerName);
return MessageBuilder.fromMessage(message).setHeader(headerName, sessionId).build();
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(this.sessionId);
return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
}
}
}
......@@ -68,32 +68,37 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
}
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(message);
String sessionId = inputHeaders.getSessionId();
String subscriptionId = inputHeaders.getSubscriptionId();
String destination = inputHeaders.getDestination();
Assert.state(inputHeaders.getSubscriptionId() != null,
"No subsriptiondId in input message. Add @ReplyTo or @ReplyToUser to method: "
+ returnType.getMethod());
MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(inputHeaders);
MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(sessionId, subscriptionId);
this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor);
}
private final class SubscriptionHeaderPostProcessor implements MessagePostProcessor {
private final SimpMessageHeaderAccessor inputHeaders;
private final String sessionId;
private final String subscriptionId;
public SubscriptionHeaderPostProcessor(SimpMessageHeaderAccessor inputHeaders) {
this.inputHeaders = inputHeaders;
public SubscriptionHeaderPostProcessor(String sessionId, String subscriptionId) {
this.sessionId = sessionId;
this.subscriptionId = subscriptionId;
}
@Override
public Message<?> postProcessMessage(Message<?> message) {
return MessageBuilder.fromMessage(message)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID, this.inputHeaders.getSessionId())
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID, this.inputHeaders.getSubscriptionId())
.build();
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(this.sessionId);
headers.setSubscriptionId(this.subscriptionId);
return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
}
}
}
......@@ -23,8 +23,8 @@ package org.springframework.messaging.simp.handler;
*/
public interface MutableUserSessionResolver extends UserSessionResolver {
void storeUserSessionId(String user, String sessionId);
void addUserSessionId(String user, String sessionId);
void deleteUserSessionId(String user, String sessionId);
void removeUserSessionId(String user, String sessionId);
}
......@@ -94,8 +94,8 @@ public class SimpleBrokerMessageHandler implements MessageHandler {
headers.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId);
Message<?> clientMessage = MessageBuilder.withPayload(
message.getPayload()).copyHeaders(headers.toMap()).build();
Object payload = message.getPayload();
Message<?> clientMessage = MessageBuilder.withPayloadAndHeaders(payload, headers).build();
try {
this.outboundChannel.send(clientMessage);
}
......
......@@ -34,7 +34,7 @@ public class SimpleUserSessionResolver implements MutableUserSessionResolver {
@Override
public void storeUserSessionId(String user, String sessionId) {
public void addUserSessionId(String user, String sessionId) {
Set<String> sessionIds = this.userSessionIds.get(user);
if (sessionIds == null) {
sessionIds = new CopyOnWriteArraySet<String>();
......@@ -44,7 +44,7 @@ public class SimpleUserSessionResolver implements MutableUserSessionResolver {
}
@Override
public void deleteUserSessionId(String user, String sessionId) {
public void removeUserSessionId(String user, String sessionId) {
Set<String> sessionIds = this.userSessionIds.get(user);
if (sessionIds != null) {
if (sessionIds.remove(sessionId) && sessionIds.isEmpty()) {
......
......@@ -120,7 +120,7 @@ public class UserDestinationMessageHandler implements MessageHandler {
String targetDestination = destinationParser.getTargetDestination(sessionId);
headers.setDestination(targetDestination);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
if (logger.isTraceEnabled()) {
logger.trace("Sending message to resolved target destination " + targetDestination);
......
......@@ -219,7 +219,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
if (logger.isDebugEnabled()) {
logger.debug("Sending STOMP CONNECT frame to initialize \"system\" TCP connection");
}
Message<?> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build();
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
session.open(message);
}
......@@ -259,7 +259,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
String sessionId = headers.getSessionId();
String destination = headers.getDestination();
StompCommand command = headers.getStompCommand();
StompCommand command = headers.getCommand();
SimpMessageType messageType = headers.getMessageType();
if (!this.running) {
......@@ -273,11 +273,11 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
sessionId = (sessionId == null) ? STOMP_RELAY_SYSTEM_SESSION_ID : sessionId;
headers.setSessionId(sessionId);
command = (command == null) ? StompCommand.SEND : command;
headers.setStompCommandIfNotSet(command);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
headers.setCommandIfNotSet(command);
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
}
if (headers.getStompCommand() == null) {
if (headers.getCommand() == null) {
logger.error("Ignoring message, no STOMP command: " + message);
return;
}
......@@ -397,7 +397,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
}
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (StompCommand.CONNECTED == headers.getStompCommand()) {
if (StompCommand.CONNECTED == headers.getCommand()) {
synchronized(this.monitor) {
this.isConnected = true;
flushMessages(this.promise.get());
......@@ -406,7 +406,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
}
headers.setSessionId(this.sessionId);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
sendMessageToClient(message);
}
......@@ -418,7 +418,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setSessionId(sessionId);
headers.setMessage(errorText);
Message<?> errorMessage = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build();
Message<?> errorMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
sendMessageToClient(errorMessage);
}
......
......@@ -44,37 +44,44 @@ import org.springframework.util.StringUtils;
*/
public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
public static final String STOMP_ID = "id";
// STOMP header names
public static final String HOST = "host";
public static final String STOMP_ID_HEADER = "id";
public static final String ACCEPT_VERSION = "accept-version";
public static final String STOMP_HOST_HEADER = "host";
public static final String MESSAGE_ID = "message-id";
public static final String STOMP_ACCEPT_VERSION_HEADER = "accept-version";
public static final String RECEIPT_ID = "receipt-id";
public static final String STOMP_MESSAGE_ID_HEADER = "message-id";
public static final String SUBSCRIPTION = "subscription";
public static final String STOMP_RECEIPT_ID_HEADER = "receipt-id";
public static final String VERSION = "version";
public static final String STOMP_SUBSCRIPTION_HEADER = "subscription";
public static final String MESSAGE = "message";
public static final String STOMP_VERSION_HEADER = "version";
public static final String ACK = "ack";
public static final String STOMP_MESSAGE_HEADER = "message";
public static final String NACK = "nack";
public static final String STOMP_ACK_HEADER = "ack";
public static final String LOGIN = "login";
public static final String STOMP_NACK_HEADER = "nack";
public static final String PASSCODE = "passcode";
public static final String STOMP_LOGIN_HEADER = "login";
public static final String DESTINATION = "destination";
public static final String STOMP_PASSCODE_HEADER = "passcode";
public static final String CONTENT_TYPE = "content-type";
public static final String STOMP_DESTINATION_HEADER = "destination";
public static final String CONTENT_LENGTH = "content-length";
public static final String STOMP_CONTENT_TYPE_HEADER = "content-type";
public static final String HEARTBEAT = "heart-beat";
public static final String STOMP_CONTENT_LENGTH_HEADER = "content-length";
public static final String STOMP_HEARTBEAT_HEADER = "heart-beat";
// Other header names
public static final String COMMAND_HEADER = "stompCommand";
private static final AtomicLong messageIdCounter = new AtomicLong();
......@@ -84,30 +91,37 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
* A constructor for creating new STOMP message headers.
*/
private StompHeaderAccessor(StompCommand command, Map<String, List<String>> externalSourceHeaders) {
super(command.getMessageType(), command, externalSourceHeaders);
super(command.getMessageType(), externalSourceHeaders);
Assert.notNull(command, "command is required");
setHeader(COMMAND_HEADER, command);
if (externalSourceHeaders != null) {
setSimpMessageHeaders(externalSourceHeaders);
setSimpMessageHeaders(command, externalSourceHeaders);
}
}
private void setSimpMessageHeaders(Map<String, List<String>> extHeaders) {
List<String> values = extHeaders.get(StompHeaderAccessor.DESTINATION);
private void setSimpMessageHeaders(StompCommand command, Map<String, List<String>> extHeaders) {
List<String> values = extHeaders.get(StompHeaderAccessor.STOMP_DESTINATION_HEADER);
if (!CollectionUtils.isEmpty(values)) {
super.setDestination(values.get(0));
}
values = extHeaders.get(StompHeaderAccessor.CONTENT_TYPE);
values = extHeaders.get(StompHeaderAccessor.STOMP_CONTENT_TYPE_HEADER);
if (!CollectionUtils.isEmpty(values)) {
super.setContentType(MediaType.parseMediaType(values.get(0)));
}
StompCommand command = getStompCommand();
if (StompCommand.SUBSCRIBE.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) {
values = extHeaders.get(StompHeaderAccessor.STOMP_ID);
values = extHeaders.get(StompHeaderAccessor.STOMP_ID_HEADER);
if (!CollectionUtils.isEmpty(values)) {
super.setSubscriptionId(values.get(0));
}
}
else if (StompCommand.MESSAGE.equals(command)) {
values = extHeaders.get(StompHeaderAccessor.SUBSCRIPTION);
values = extHeaders.get(StompHeaderAccessor.STOMP_SUBSCRIPTION_HEADER);
if (!CollectionUtils.isEmpty(values)) {
super.setSubscriptionId(values.get(0));
}
......@@ -154,73 +168,66 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
String destination = super.getDestination();
if (destination != null) {
result.put(DESTINATION, Arrays.asList(destination));
result.put(STOMP_DESTINATION_HEADER, Arrays.asList(destination));
}
MediaType contentType = getContentType();
if (contentType != null) {
result.put(CONTENT_TYPE, Arrays.asList(contentType.toString()));
result.put(STOMP_CONTENT_TYPE_HEADER, Arrays.asList(contentType.toString()));
}
if (StompCommand.MESSAGE.equals(getStompCommand())) {
if (StompCommand.MESSAGE.equals(getCommand())) {
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
result.put(SUBSCRIPTION, Arrays.asList(subscriptionId));
result.put(STOMP_SUBSCRIPTION_HEADER, Arrays.asList(subscriptionId));
}
else {
logger.warn("STOMP MESSAGE frame should have a subscription: " + this.toString());
}
if ((getMessageId() == null)) {
String messageId = getSessionId() + "-" + messageIdCounter.getAndIncrement();
result.put(MESSAGE_ID, Arrays.asList(messageId));
result.put(STOMP_MESSAGE_ID_HEADER, Arrays.asList(messageId));
}
}
return result;
}
public void setStompCommandIfNotSet(StompCommand command) {
if (getStompCommand() == null) {
setProtocolMessageType(command);
public void setCommandIfNotSet(StompCommand command) {
if (getCommand() == null) {
setHeader(COMMAND_HEADER, command);
}
}
public StompCommand getStompCommand() {
return (StompCommand) super.getProtocolMessageType();
public StompCommand getCommand() {
return (StompCommand) getHeader(COMMAND_HEADER);
}
public Set<String> getAcceptVersion() {
String rawValue = getFirstNativeHeader(ACCEPT_VERSION);
String rawValue = getFirstNativeHeader(STOMP_ACCEPT_VERSION_HEADER);
return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.<String>emptySet();
}
public void setAcceptVersion(String acceptVersion) {
setNativeHeader(ACCEPT_VERSION, acceptVersion);
setNativeHeader(STOMP_ACCEPT_VERSION_HEADER, acceptVersion);
}
public void setHost(String host) {
setNativeHeader(HOST, host);
setNativeHeader(STOMP_HOST_HEADER, host);
}
public String getHost() {
return getFirstNativeHeader(HOST);
return getFirstNativeHeader(STOMP_HOST_HEADER);
}
@Override
public void setDestination(String destination) {
super.setDestination(destination);
setNativeHeader(DESTINATION, destination);
}
@Override
public void setDestinations(List<String> destinations) {
Assert.isTrue((destinations != null) && (destinations.size() == 1), "STOMP allows one destination per message");
super.setDestinations(destinations);
setNativeHeader(DESTINATION, destinations.get(0));
setNativeHeader(STOMP_DESTINATION_HEADER, destination);
}
public long[] getHeartbeat() {
String rawValue = getFirstNativeHeader(HEARTBEAT);
String rawValue = getFirstNativeHeader(STOMP_HEARTBEAT_HEADER);
if (!StringUtils.hasText(rawValue)) {
return null;
}
......@@ -232,91 +239,91 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
public void setContentType(MediaType mediaType) {
if (mediaType != null) {
super.setContentType(mediaType);
setNativeHeader(CONTENT_TYPE, mediaType.toString());
setNativeHeader(STOMP_CONTENT_TYPE_HEADER, mediaType.toString());
}
}
public MediaType getContentType() {
String value = getFirstNativeHeader(CONTENT_TYPE);
String value = getFirstNativeHeader(STOMP_CONTENT_TYPE_HEADER);
return (value != null) ? MediaType.parseMediaType(value) : null;
}
public Integer getContentLength() {
String contentLength = getFirstNativeHeader(CONTENT_LENGTH);
String contentLength = getFirstNativeHeader(STOMP_CONTENT_LENGTH_HEADER);
return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null;
}
public void setContentLength(int contentLength) {
setNativeHeader(CONTENT_LENGTH, String.valueOf(contentLength));
setNativeHeader(STOMP_CONTENT_LENGTH_HEADER, String.valueOf(contentLength));
}
public void setHeartbeat(long cx, long cy) {
setNativeHeader(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
setNativeHeader(STOMP_HEARTBEAT_HEADER, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
}
public void setAck(String ack) {
setNativeHeader(ACK, ack);
setNativeHeader(STOMP_ACK_HEADER, ack);
}
public String getAck() {
return getFirstNativeHeader(ACK);
return getFirstNativeHeader(STOMP_ACK_HEADER);
}
public void setNack(String nack) {
setNativeHeader(NACK, nack);
setNativeHeader(STOMP_NACK_HEADER, nack);
}
public String getNack() {
return getFirstNativeHeader(NACK);
return getFirstNativeHeader(STOMP_NACK_HEADER);
}
public void setLogin(String login) {
setNativeHeader(LOGIN, login);
setNativeHeader(STOMP_LOGIN_HEADER, login);
}
public String getLogin() {
return getFirstNativeHeader(LOGIN);
return getFirstNativeHeader(STOMP_LOGIN_HEADER);
}
public void setPasscode(String passcode) {
setNativeHeader(PASSCODE, passcode);
setNativeHeader(STOMP_PASSCODE_HEADER, passcode);
}
public String getPasscode() {
return getFirstNativeHeader(PASSCODE);
return getFirstNativeHeader(STOMP_PASSCODE_HEADER);
}
public void setReceiptId(String receiptId) {
setNativeHeader(RECEIPT_ID, receiptId);
setNativeHeader(STOMP_RECEIPT_ID_HEADER, receiptId);
}
public String getReceiptId() {
return getFirstNativeHeader(RECEIPT_ID);
return getFirstNativeHeader(STOMP_RECEIPT_ID_HEADER);
}
public String getMessage() {
return getFirstNativeHeader(MESSAGE);
return getFirstNativeHeader(STOMP_MESSAGE_HEADER);
}
public void setMessage(String content) {
setNativeHeader(MESSAGE, content);
setNativeHeader(STOMP_MESSAGE_HEADER, content);
}
public String getMessageId() {
return getFirstNativeHeader(MESSAGE_ID);
return getFirstNativeHeader(STOMP_MESSAGE_ID_HEADER);
}
public void setMessageId(String id) {
setNativeHeader(MESSAGE_ID, id);
setNativeHeader(STOMP_MESSAGE_ID_HEADER, id);
}
public String getVersion() {
return getFirstNativeHeader(VERSION);
return getFirstNativeHeader(STOMP_VERSION_HEADER);
}
public void setVersion(String version) {
setNativeHeader(VERSION, version);
setNativeHeader(STOMP_VERSION_HEADER, version);
}
}
......@@ -93,9 +93,8 @@ public class StompMessageConverter {
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
return MessageBuilder.withPayload(payload).copyHeaders(stompHeaders.toMap()).build();
return MessageBuilder.withPayloadAndHeaders(payload, stompHeaders).build();
}
private int findIndexOfPayload(byte[] bytes) {
......@@ -140,7 +139,7 @@ public class StompMessageConverter {
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
try {
out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8"));
out.write(stompHeaders.getCommand().toString().getBytes("UTF-8"));
out.write(LF);
for (Entry<String, List<String>> entry : stompHeaders.toNativeHeaderMap().entrySet()) {
String key = entry.getKey();
......
......@@ -27,8 +27,8 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.UserDestinationMessageHandler;
import org.springframework.messaging.simp.handler.MutableUserSessionResolver;
import org.springframework.messaging.simp.handler.UserDestinationMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
......@@ -57,8 +57,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
public static final String QUEUE_SUFFIX_HEADER = "queue-suffix";
private static final byte[] EMPTY_PAYLOAD = new byte[0];
private static Log logger = LogFactory.getLog(StompWebSocketHandler.class);
private MessageChannel clientInputChannel;
......@@ -107,7 +105,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
this.sessions.put(session.getId(), session);
if ((this.userSessionStore != null) && (session.getPrincipal() != null)) {
this.userSessionStore.storeUserSessionId(session.getPrincipal().getName(), session.getId());
this.userSessionStore.addUserSessionId(session.getPrincipal().getName(), session.getId());
}
}
......@@ -120,10 +118,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
String payload = textMessage.getPayload();
Message<?> message = this.stompMessageConverter.toMessage(payload);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setSessionId(session.getId());
headers.setUser(session.getPrincipal());
// TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
......@@ -132,14 +126,17 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
}
try {
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
SimpMessageType messageType = stompHeaders.getMessageType();
if (SimpMessageType.CONNECT.equals(messageType)) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setSessionId(session.getId());
headers.setUser(session.getPrincipal());
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
if (SimpMessageType.CONNECT.equals(headers.getMessageType())) {
handleConnect(session, message);
}
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
this.clientInputChannel.send(message);
}
catch (Throwable t) {
logger.error("Terminating STOMP session due to failure to send message: ", t);
......@@ -182,8 +179,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
// TODO: security
Message<?> connectedMessage = MessageBuilder.withPayload(EMPTY_PAYLOAD).copyHeaders(
connectedHeaders.toMap()).build();
Message<?> connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build();
byte[] bytes = this.stompMessageConverter.fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
......@@ -192,10 +188,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage());
Message<?> message = MessageBuilder.withPayload(EMPTY_PAYLOAD).copyHeaders(headers.toMap()).build();
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
try {
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
......@@ -211,12 +205,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
this.sessions.remove(sessionId);
if ((this.userSessionStore != null) && (session.getPrincipal() != null)) {
this.userSessionStore.deleteUserSessionId(session.getPrincipal().getName(), sessionId);
this.userSessionStore.removeUserSessionId(session.getPrincipal().getName(), sessionId);
}
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(sessionId);
Message<?> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build();
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
this.clientInputChannel.send(message);
}
......@@ -227,9 +221,9 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
public void handleMessage(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setStompCommandIfNotSet(StompCommand.MESSAGE);
headers.setCommandIfNotSet(StompCommand.MESSAGE);
if (StompCommand.CONNECTED.equals(headers.getStompCommand())) {
if (StompCommand.CONNECTED.equals(headers.getCommand())) {
// Ignore for now since we already sent it
return;
}
......@@ -248,7 +242,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
return;
}
if (StompCommand.MESSAGE.equals(headers.getStompCommand()) && (headers.getSubscriptionId() == null)) {
if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) {
// TODO: failed message delivery mechanism
logger.error("Ignoring message, no subscriptionId header: " + message);
return;
......@@ -261,7 +255,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
}
try {
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
......@@ -269,7 +263,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
sendErrorMessage(session, t);
}
finally {
if (StompCommand.ERROR.equals(headers.getStompCommand())) {
if (StompCommand.ERROR.equals(headers.getCommand())) {
try {
session.close(CloseStatus.PROTOCOL_ERROR);
}
......
......@@ -44,14 +44,14 @@ public final class MessageBuilder<T> {
/**
* Private constructor to be invoked from the static factory methods only.
*/
private MessageBuilder(T payload, Message<T> originalMessage) {
private MessageBuilder(T payload, Message<T> originalMessage, MessageHeaderAccessor headerAccessor) {
Assert.notNull(payload, "payload must not be null");
this.payload = payload;
this.originalMessage = originalMessage;
this.headerAccessor = new MessageHeaderAccessor(originalMessage);
this.headerAccessor = (headerAccessor != null) ?
headerAccessor : new MessageHeaderAccessor(originalMessage);
}
/**
* Create a builder for a new {@link Message} instance pre-populated with all of the
* headers copied from the provided message. The payload of the provided Message will
......@@ -61,7 +61,7 @@ public final class MessageBuilder<T> {
*/
public static <T> MessageBuilder<T> fromMessage(Message<T> message) {
Assert.notNull(message, "message must not be null");
MessageBuilder<T> builder = new MessageBuilder<T>(message.getPayload(), message);
MessageBuilder<T> builder = new MessageBuilder<T>(message.getPayload(), message, null);
return builder;
}
......@@ -71,7 +71,18 @@ public final class MessageBuilder<T> {
* @param payload the payload for the new message
*/
public static <T> MessageBuilder<T> withPayload(T payload) {
MessageBuilder<T> builder = new MessageBuilder<T>(payload, null);
MessageBuilder<T> builder = new MessageBuilder<T>(payload, null, null);
return builder;
}
/**
* Create a builder for a new {@link Message} instance with the provided payload and headers.
*
* @param payload the payload for the new message
* @param headerAccessor the headers for the message
*/
public static <T> MessageBuilder<T> withPayloadAndHeaders(T payload, MessageHeaderAccessor headerAccessor) {
MessageBuilder<T> builder = new MessageBuilder<T>(payload, null, headerAccessor);
return builder;
}
......
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import org.junit.Test;
import org.springframework.util.LinkedMultiValueMap;
import static org.junit.Assert.*;
/**
* Test fixture for {@link StompHeaderAccessor}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompHeaderAccessorTests {
@Test
public void testStompCommandSet() {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
assertEquals(StompCommand.CONNECTED, accessor.getCommand());
accessor = StompHeaderAccessor.create(StompCommand.CONNECTED, new LinkedMultiValueMap<String, String>());
assertEquals(StompCommand.CONNECTED, accessor.getCommand());
}
}
......@@ -56,17 +56,17 @@ public class StompMessageConverterTests {
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
Map<String, Object> map = stompHeaders.toMap();
assertEquals(5, map.size());
assertNotNull(map.get(MessageHeaders.ID));
assertNotNull(map.get(MessageHeaders.TIMESTAMP));
assertNotNull(stompHeaders.getId());
assertNotNull(stompHeaders.getTimestamp());
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getCommand());
assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS));
assertNotNull(map.get(SimpMessageHeaderAccessor.MESSAGE_TYPE));
assertNotNull(map.get(SimpMessageHeaderAccessor.PROTOCOL_MESSAGE_TYPE));
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand());
assertEquals(StompCommand.CONNECT, stompHeaders.getCommand());
assertNotNull(headers.get(MessageHeaders.ID));
assertNotNull(headers.get(MessageHeaders.TIMESTAMP));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册