提交 210be9cd 编写于 作者: R Rossen Stoyanchev

Add PrincipalMessageArgumentResolver

上级 d3cecfc6
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.handler.method;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessagingException;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class InvalidMessageMethodParameterException extends MessagingException {
private static final long serialVersionUID = -6905878930083523161L;
private final MethodParameter parameter;
public InvalidMessageMethodParameterException(Message<?> message, String description,
MethodParameter parameter, Throwable cause) {
super(message, description, cause);
this.parameter = parameter;
}
public InvalidMessageMethodParameterException(Message<?> message, String description,
MethodParameter parameter) {
super(message, description);
this.parameter = parameter;
}
public MethodParameter getParameter() {
return this.parameter;
}
}
......@@ -16,12 +16,14 @@
package org.springframework.messaging.simp;
import java.security.Principal;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
......@@ -43,9 +45,6 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String DESTINATIONS = "destinations";
// TODO
public static final String CONTENT_TYPE = "contentType";
public static final String MESSAGE_TYPE = "messageType";
public static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
......@@ -54,6 +53,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String SUBSCRIPTION_ID = "subscriptionId";
public static final String USER = "user";
/**
* A constructor for creating new message headers.
......@@ -140,12 +141,11 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
}
public MediaType getContentType() {
return (MediaType) getHeader(CONTENT_TYPE);
return (MediaType) getHeader(MessageHeaders.CONTENT_TYPE);
}
public void setContentType(MediaType contentType) {
Assert.notNull(contentType, "contentType is required");
setHeader(CONTENT_TYPE, contentType);
setHeader(MessageHeaders.CONTENT_TYPE, contentType);
}
public String getSubscriptionId() {
......@@ -164,4 +164,12 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
setHeader(SESSION_ID, sessionId);
}
public Principal getUser() {
return (Principal) getHeader(USER);
}
public void setUser(Principal principal) {
setHeader(USER, principal);
}
}
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.annotation.support;
import java.security.Principal;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.method.InvalidMessageMethodParameterException;
import org.springframework.messaging.handler.method.MessageArgumentResolver;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PrincipalMessageArgumentResolver implements MessageArgumentResolver {
@Override
public boolean supportsParameter(MethodParameter parameter) {
Class<?> paramType = parameter.getParameterType();
return Principal.class.isAssignableFrom(paramType);
}
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
Principal user = headers.getUser();
if (user == null) {
throw new InvalidMessageMethodParameterException(message, "User not available", parameter);
}
return user;
}
}
......@@ -37,15 +37,16 @@ import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.support.MessageBodyArgumentResolver;
import org.springframework.messaging.handler.annotation.support.MessageExceptionHandlerMethodResolver;
import org.springframework.messaging.handler.method.MessageArgumentResolverComposite;
import org.springframework.messaging.handler.method.InvocableMessageHandlerMethod;
import org.springframework.messaging.handler.method.MessageArgumentResolverComposite;
import org.springframework.messaging.handler.method.MessageReturnValueHandlerComposite;
import org.springframework.messaging.simp.annotation.SubscribeEvent;
import org.springframework.messaging.simp.annotation.UnsubscribeEvent;
import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler;
import org.springframework.messaging.simp.MessageHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.SubscribeEvent;
import org.springframework.messaging.simp.annotation.UnsubscribeEvent;
import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler;
import org.springframework.messaging.simp.annotation.support.PrincipalMessageArgumentResolver;
import org.springframework.messaging.support.converter.MessageConverter;
import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
......@@ -113,6 +114,7 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler
initHandlerMethods();
this.argumentResolvers.addResolver(new PrincipalMessageArgumentResolver());
this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverter));
this.returnValueHandlers.addHandler(
......
......@@ -46,7 +46,7 @@ public class StompMessageConverter {
/**
* @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
public Message<?> toMessage(Object stompContent, String sessionId) {
public Message<?> toMessage(Object stompContent) {
byte[] byteContent = null;
if (stompContent instanceof String) {
......@@ -91,12 +91,10 @@ public class StompMessageConverter {
}
}
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
stompHeaders.setSessionId(sessionId);
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
return MessageBuilder.withPayload(payload).copyHeaders(stompHeaders.toMap()).build();
}
......
......@@ -29,9 +29,9 @@ import java.util.concurrent.TimeUnit;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
......@@ -350,7 +350,7 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
return;
}
Message<?> message = stompMessageConverter.toMessage(stompFrame, this.sessionId);
Message<?> message = stompMessageConverter.toMessage(stompFrame);
if (logger.isTraceEnabled()) {
logger.trace("Reading message " + message);
}
......@@ -369,6 +369,10 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
}
relaySessions.remove(this.sessionId);
}
headers.setSessionId(this.sessionId);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
sendMessageToClient(message);
}
......
......@@ -81,7 +81,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) {
try {
String payload = textMessage.getPayload();
Message<?> message = this.stompMessageConverter.toMessage(payload, session.getId());
Message<?> message = this.stompMessageConverter.toMessage(payload);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setSessionId(session.getId());
headers.setUser(session.getPrincipal());
// TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
......@@ -96,18 +100,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
if (SimpMessageType.CONNECT.equals(messageType)) {
handleConnect(session, message);
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
handlePublish(message);
}
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
handleSubscribe(message);
}
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
handleUnsubscribe(message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
handleDisconnect(message);
}
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
this.outputChannel.send(message);
}
catch (Throwable t) {
......@@ -124,7 +118,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
}
}
protected void handleConnect(final WebSocketSession session, Message<?> message) throws IOException {
protected void handleConnect(WebSocketSession session, Message<?> message) throws IOException {
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(message);
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
......@@ -152,18 +146,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
protected void handlePublish(Message<?> stompMessage) {
}
protected void handleSubscribe(Message<?> message) {
}
protected void handleUnsubscribe(Message<?> message) {
}
protected void handleDisconnect(Message<?> message) {
}
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
......
......@@ -24,9 +24,6 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompMessageConverter;
import static org.junit.Assert.*;
......@@ -51,17 +48,16 @@ public class StompMessageConverterTests {
String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
String frame = "\n\n\nCONNECT\n" + accept + host + "\n";
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
MessageHeaders headers = message.getHeaders();
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
Map<String, Object> map = stompHeaders.toMap();
assertEquals(6, map.size());
assertEquals(5, map.size());
assertNotNull(map.get(MessageHeaders.ID));
assertNotNull(map.get(MessageHeaders.TIMESTAMP));
assertNotNull(map.get(SimpMessageHeaderAccessor.SESSION_ID));
assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS));
assertNotNull(map.get(SimpMessageHeaderAccessor.MESSAGE_TYPE));
assertNotNull(map.get(SimpMessageHeaderAccessor.PROTOCOL_MESSAGE_TYPE));
......@@ -71,7 +67,6 @@ public class StompMessageConverterTests {
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand());
assertEquals("session-123", stompHeaders.getSessionId());
assertNotNull(headers.get(MessageHeaders.ID));
assertNotNull(headers.get(MessageHeaders.TIMESTAMP));
......@@ -89,7 +84,7 @@ public class StompMessageConverterTests {
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String frame = "CONNECT\n" + accept + host + "\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
......@@ -111,7 +106,7 @@ public class StompMessageConverterTests {
String host = "host:github.org\n";
String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
......@@ -133,7 +128,7 @@ public class StompMessageConverterTests {
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册