From 210be9cde42efc2e958eb602fd93bbc71ac609c2 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 12 Jul 2013 15:15:37 -0400 Subject: [PATCH] Add PrincipalMessageArgumentResolver --- ...nvalidMessageMethodParameterException.java | 54 +++++++++++++++++++ .../simp/SimpMessageHeaderAccessor.java | 20 ++++--- .../PrincipalMessageArgumentResolver.java | 51 ++++++++++++++++++ .../handler/AnnotationSimpMessageHandler.java | 10 ++-- .../simp/stomp/StompMessageConverter.java | 6 +-- .../simp/stomp/StompRelayMessageHandler.java | 8 ++- .../simp/stomp/StompWebSocketHandler.java | 34 +++--------- .../stomp/StompMessageConverterTests.java | 15 ++---- 8 files changed, 146 insertions(+), 52 deletions(-) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java new file mode 100644 index 0000000000..11b9aaebab --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.handler.method; + +import org.springframework.core.MethodParameter; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessagingException; + + +/** + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class InvalidMessageMethodParameterException extends MessagingException { + + private static final long serialVersionUID = -6905878930083523161L; + + private final MethodParameter parameter; + + + public InvalidMessageMethodParameterException(Message message, String description, + MethodParameter parameter, Throwable cause) { + super(message, description, cause); + this.parameter = parameter; + } + + public InvalidMessageMethodParameterException(Message message, String description, + MethodParameter parameter) { + + super(message, description); + this.parameter = parameter; + } + + + public MethodParameter getParameter() { + return this.parameter; + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java index e18d72cc0b..a356d65a14 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java @@ -16,12 +16,14 @@ package org.springframework.messaging.simp; +import java.security.Principal; import java.util.Arrays; import java.util.List; import java.util.Map; import org.springframework.http.MediaType; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.support.NativeMessageHeaderAccessor; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -43,9 +45,6 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { public static final String DESTINATIONS = "destinations"; - // TODO - public static final String CONTENT_TYPE = "contentType"; - public static final String MESSAGE_TYPE = "messageType"; public static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType"; @@ -54,6 +53,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { public static final String SUBSCRIPTION_ID = "subscriptionId"; + public static final String USER = "user"; + /** * A constructor for creating new message headers. @@ -140,12 +141,11 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { } public MediaType getContentType() { - return (MediaType) getHeader(CONTENT_TYPE); + return (MediaType) getHeader(MessageHeaders.CONTENT_TYPE); } public void setContentType(MediaType contentType) { - Assert.notNull(contentType, "contentType is required"); - setHeader(CONTENT_TYPE, contentType); + setHeader(MessageHeaders.CONTENT_TYPE, contentType); } public String getSubscriptionId() { @@ -164,4 +164,12 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { setHeader(SESSION_ID, sessionId); } + public Principal getUser() { + return (Principal) getHeader(USER); + } + + public void setUser(Principal principal) { + setHeader(USER, principal); + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java new file mode 100644 index 0000000000..7d1e17db5e --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.annotation.support; + +import java.security.Principal; + +import org.springframework.core.MethodParameter; +import org.springframework.messaging.Message; +import org.springframework.messaging.handler.method.InvalidMessageMethodParameterException; +import org.springframework.messaging.handler.method.MessageArgumentResolver; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class PrincipalMessageArgumentResolver implements MessageArgumentResolver { + + + @Override + public boolean supportsParameter(MethodParameter parameter) { + Class paramType = parameter.getParameterType(); + return Principal.class.isAssignableFrom(paramType); + } + + @Override + public Object resolveArgument(MethodParameter parameter, Message message) throws Exception { + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + Principal user = headers.getUser(); + if (user == null) { + throw new InvalidMessageMethodParameterException(message, "User not available", parameter); + } + return user; + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java index 800bc78097..4f1bc2fb9e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java @@ -37,15 +37,16 @@ import org.springframework.messaging.MessageChannel; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.support.MessageBodyArgumentResolver; import org.springframework.messaging.handler.annotation.support.MessageExceptionHandlerMethodResolver; -import org.springframework.messaging.handler.method.MessageArgumentResolverComposite; import org.springframework.messaging.handler.method.InvocableMessageHandlerMethod; +import org.springframework.messaging.handler.method.MessageArgumentResolverComposite; import org.springframework.messaging.handler.method.MessageReturnValueHandlerComposite; -import org.springframework.messaging.simp.annotation.SubscribeEvent; -import org.springframework.messaging.simp.annotation.UnsubscribeEvent; -import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler; import org.springframework.messaging.simp.MessageHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.annotation.SubscribeEvent; +import org.springframework.messaging.simp.annotation.UnsubscribeEvent; +import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler; +import org.springframework.messaging.simp.annotation.support.PrincipalMessageArgumentResolver; import org.springframework.messaging.support.converter.MessageConverter; import org.springframework.stereotype.Controller; import org.springframework.util.Assert; @@ -113,6 +114,7 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler initHandlerMethods(); + this.argumentResolvers.addResolver(new PrincipalMessageArgumentResolver()); this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverter)); this.returnValueHandlers.addHandler( diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java index bd4c8413d8..cd79d657a1 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java @@ -46,7 +46,7 @@ public class StompMessageConverter { /** * @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String. */ - public Message toMessage(Object stompContent, String sessionId) { + public Message toMessage(Object stompContent) { byte[] byteContent = null; if (stompContent instanceof String) { @@ -91,12 +91,10 @@ public class StompMessageConverter { } } - StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers); - stompHeaders.setSessionId(sessionId); - byte[] payload = new byte[totalLength - payloadIndex]; System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers); return MessageBuilder.withPayload(payload).copyHeaders(stompHeaders.toMap()).build(); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java index 305445aba4..4fae99140d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java @@ -29,9 +29,9 @@ import java.util.concurrent.TimeUnit; import org.springframework.context.SmartLifecycle; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -350,7 +350,7 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme return; } - Message message = stompMessageConverter.toMessage(stompFrame, this.sessionId); + Message message = stompMessageConverter.toMessage(stompFrame); if (logger.isTraceEnabled()) { logger.trace("Reading message " + message); } @@ -369,6 +369,10 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme } relaySessions.remove(this.sessionId); } + + headers.setSessionId(this.sessionId); + message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build(); + sendMessageToClient(message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java index c5bff8effb..0cfb9d5ebb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java @@ -81,7 +81,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) { try { String payload = textMessage.getPayload(); - Message message = this.stompMessageConverter.toMessage(payload, session.getId()); + Message message = this.stompMessageConverter.toMessage(payload); + + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + headers.setSessionId(session.getId()); + headers.setUser(session.getPrincipal()); // TODO: validate size limits // http://stomp.github.io/stomp-specification-1.2.html#Size_Limits @@ -96,18 +100,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement if (SimpMessageType.CONNECT.equals(messageType)) { handleConnect(session, message); } - else if (SimpMessageType.MESSAGE.equals(messageType)) { - handlePublish(message); - } - else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { - handleSubscribe(message); - } - else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { - handleUnsubscribe(message); - } - else if (SimpMessageType.DISCONNECT.equals(messageType)) { - handleDisconnect(message); - } + + message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build(); this.outputChannel.send(message); } catch (Throwable t) { @@ -124,7 +118,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement } } - protected void handleConnect(final WebSocketSession session, Message message) throws IOException { + protected void handleConnect(WebSocketSession session, Message message) throws IOException { StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(message); StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED); @@ -152,18 +146,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } - protected void handlePublish(Message stompMessage) { - } - - protected void handleSubscribe(Message message) { - } - - protected void handleUnsubscribe(Message message) { - } - - protected void handleDisconnect(Message message) { - } - protected void sendErrorMessage(WebSocketSession session, Throwable error) { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java index cbff1c7e11..2f94f1cc15 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java @@ -24,9 +24,6 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.messaging.simp.stomp.StompCommand; -import org.springframework.messaging.simp.stomp.StompHeaderAccessor; -import org.springframework.messaging.simp.stomp.StompMessageConverter; import static org.junit.Assert.*; @@ -51,17 +48,16 @@ public class StompMessageConverterTests { String accept = "accept-version:1.1\n"; String host = "host:github.org\n"; String frame = "\n\n\nCONNECT\n" + accept + host + "\n"; - Message message = (Message) this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + Message message = (Message) this.converter.toMessage(frame.getBytes("UTF-8")); assertEquals(0, message.getPayload().length); MessageHeaders headers = message.getHeaders(); StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); Map map = stompHeaders.toMap(); - assertEquals(6, map.size()); + assertEquals(5, map.size()); assertNotNull(map.get(MessageHeaders.ID)); assertNotNull(map.get(MessageHeaders.TIMESTAMP)); - assertNotNull(map.get(SimpMessageHeaderAccessor.SESSION_ID)); assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS)); assertNotNull(map.get(SimpMessageHeaderAccessor.MESSAGE_TYPE)); assertNotNull(map.get(SimpMessageHeaderAccessor.PROTOCOL_MESSAGE_TYPE)); @@ -71,7 +67,6 @@ public class StompMessageConverterTests { assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType()); assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand()); - assertEquals("session-123", stompHeaders.getSessionId()); assertNotNull(headers.get(MessageHeaders.ID)); assertNotNull(headers.get(MessageHeaders.TIMESTAMP)); @@ -89,7 +84,7 @@ public class StompMessageConverterTests { String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String frame = "CONNECT\n" + accept + host + "\n"; @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + Message message = (Message) this.converter.toMessage(frame.getBytes("UTF-8")); assertEquals(0, message.getPayload().length); @@ -111,7 +106,7 @@ public class StompMessageConverterTests { String host = "host:github.org\n"; String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8")); assertEquals(0, message.getPayload().length); @@ -133,7 +128,7 @@ public class StompMessageConverterTests { String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8")); assertEquals(0, message.getPayload().length); -- GitLab