diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java index 7e3d7f12b69b38f0bd38e0f40b77e0a08eb1a63c..5bec0672bb1ee1574adea155346667b63cd5b6b6 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java @@ -65,6 +65,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { public static final String CONNECT_MESSAGE_HEADER = "simpConnectMessage"; + public static final String DISCONNECT_MESSAGE_HEADER = "simpDisconnectMessage"; + public static final String HEART_BEAT_HEADER = "simpHeartbeat"; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java index 35722bd2fb6928e1f08e7e4dbccbdc15516473b2..2c9c8f52e6bf6e3b26189076dd507b1b09ee067a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java @@ -275,7 +275,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { } else if (SimpMessageType.DISCONNECT.equals(messageType)) { logMessage(message); - handleDisconnect(sessionId, user); + handleDisconnect(sessionId, user, message); } else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { logMessage(message); @@ -310,12 +310,15 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { } } - private void handleDisconnect(String sessionId, Principal user) { + private void handleDisconnect(String sessionId, Principal user, Message origMessage) { this.sessions.remove(sessionId); this.subscriptionRegistry.unregisterAllSubscriptions(sessionId); SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); accessor.setSessionId(sessionId); accessor.setUser(user); + if (origMessage != null) { + accessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, origMessage); + } initHeaders(accessor); Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); getClientOutboundChannel().send(message); @@ -432,7 +435,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { long now = System.currentTimeMillis(); for (SessionInfo info : sessions.values()) { if (info.getReadInterval() > 0 && (now - info.getLastReadTime()) > info.getReadInterval()) { - handleDisconnect(info.getSessiondId(), info.getUser()); + handleDisconnect(info.getSessiondId(), info.getUser(), null); } if (info.getWriteInterval() > 0 && (now - info.getLastWriteTime()) > info.getWriteInterval()) { SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java index cc969cc5c7c359f663b0a3b6305e73d248fcd38e..86d8b4f0fe76e245fd76df491313a594be176be7 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.messaging.simp.broker; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - import java.security.Principal; import java.util.Collections; import java.util.List; @@ -41,6 +38,21 @@ import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.support.MessageBuilder; import org.springframework.scheduling.TaskScheduler; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + /** * Unit tests for SimpleBrokerMessageHandler. * @@ -72,7 +84,7 @@ public class SimpleBrokerMessageHandlerTests { public void setup() { MockitoAnnotations.initMocks(this); this.messageHandler = new SimpleBrokerMessageHandler(this.clientInboundChannel, - this.clientOutboundChannel, this.brokerChannel, Collections.emptyList()); + this.clientOutboundChannel, this.brokerChannel, Collections.emptyList()); } @@ -130,6 +142,7 @@ public class SimpleBrokerMessageHandlerTests { Message captured = this.messageCaptor.getAllValues().get(0); assertEquals(SimpMessageType.DISCONNECT_ACK, SimpMessageHeaderAccessor.getMessageType(captured.getHeaders())); + assertSame(message, captured.getHeaders().get(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER)); assertEquals(sess1, SimpMessageHeaderAccessor.getSessionId(captured.getHeaders())); assertEquals("joe", SimpMessageHeaderAccessor.getUser(captured.getHeaders()).getName()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 2d516a8e083575d61781321a1b7948efb89574ee..07bbbf991fce8cedaf41248bb0a5f7c46c57d022 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -34,6 +34,7 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.simp.SimpAttributes; import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -449,8 +450,15 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE stompAccessor = convertConnectAcktoStompConnected(stompAccessor); } else if (SimpMessageType.DISCONNECT_ACK.equals(messageType)) { - stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR); - stompAccessor.setMessage("Session closed."); + String receipt = getDisconnectReceipt(stompAccessor); + if (receipt != null) { + stompAccessor = StompHeaderAccessor.create(StompCommand.RECEIPT); + stompAccessor.setReceiptId(receipt); + } + else { + stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR); + stompAccessor.setMessage("Session closed."); + } } else if (SimpMessageType.HEARTBEAT.equals(messageType)) { stompAccessor = StompHeaderAccessor.createForHeartbeat(); @@ -503,6 +511,16 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE return connectedHeaders; } + private String getDisconnectReceipt(SimpMessageHeaderAccessor simpHeaders) { + String name = StompHeaderAccessor.DISCONNECT_MESSAGE_HEADER; + Message message = (Message) simpHeaders.getHeader(name); + if (message != null) { + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + return accessor.getReceipt(); + } + return null; + } + protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message message) { return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message)); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 73e6f1d28472b98090da7b47321ac2f68763a107..445cdfa1c451a8b0f7a73c80fd08a24bcdd81224 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -169,6 +169,40 @@ public class StompSubProtocolHandlerTests { "user-name:joe\n" + "\n" + "\u0000", actual.getPayload()); } + @Test + public void handleMessageToClientWithSimpDisconnectAck() { + + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT); + Message connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); + + SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); + ackAccessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, connectMessage); + Message ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, ackAccessor.getMessageHeaders()); + this.protocolHandler.handleMessageToClient(this.session, ackMessage); + + assertEquals(1, this.session.getSentMessages().size()); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertEquals("ERROR\n" + "message:Session closed.\n" + "content-length:0\n" + + "\n\u0000", actual.getPayload()); + } + + @Test + public void handleMessageToClientWithSimpDisconnectAckAndReceipt() { + + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT); + accessor.setReceipt("message-123"); + Message connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); + + SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); + ackAccessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, connectMessage); + Message ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, ackAccessor.getMessageHeaders()); + this.protocolHandler.handleMessageToClient(this.session, ackMessage); + + assertEquals(1, this.session.getSentMessages().size()); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertEquals("RECEIPT\n" + "receipt-id:message-123\n" + "\n\u0000", actual.getPayload()); + } + @Test public void handleMessageToClientWithSimpHeartbeat() {