diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 636baa5618d879e8d60fb8907acb7325197859f4..7349d844db0d5e3519a5c883e3be6c244ab2119c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -97,16 +97,25 @@ public class StompSubProtocolHandler implements SubProtocolHandler { public void handleMessageFromClient(WebSocketSession session, WebSocketMessage webSocketMessage, MessageChannel outputChannel) { - Message message; + Message message = null; + Throwable decodeFailure = null; try { Assert.isInstanceOf(TextMessage.class, webSocketMessage); - String payload = ((TextMessage)webSocketMessage).getPayload(); + String payload = ((TextMessage) webSocketMessage).getPayload(); ByteBuffer byteBuffer = ByteBuffer.wrap(payload.getBytes(UTF8_CHARSET)); + message = this.stompDecoder.decode(byteBuffer); + if (message == null) { + decodeFailure = new IllegalStateException("Not a valid STOMP frame: " + payload); + } } catch (Throwable ex) { - logger.error("Failed to parse STOMP frame, WebSocket message payload", ex); - sendErrorMessage(session, ex); + decodeFailure = ex; + } + + if (decodeFailure != null) { + logger.error("Failed to parse WebSocket message as STOMP frame", decodeFailure); + sendErrorMessage(session, decodeFailure); return; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index d0b143e5aa3eec0e2d8d28ad017cb4323d8915da..84fe4605b2f12df418b2c26e96e228dbd97f1168 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -34,6 +34,7 @@ import org.springframework.messaging.simp.stomp.StompDecoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.handler.TestWebSocketSession; import static org.junit.Assert.*; @@ -119,4 +120,17 @@ public class StompSubProtocolHandlerTests { assertEquals(0, this.session.getSentMessages().size()); } + @Test + public void invalidStompCommand() { + + TextMessage textMessage = new TextMessage("FOO"); + + this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel); + + verifyZeroInteractions(this.channel); + assertEquals(1, this.session.getSentMessages().size()); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertTrue(actual.getPayload().startsWith("ERROR")); + } + }