提交 96cb7c06 编写于 作者: R Rossen Stoyanchev

Fix test failures

上级 0c92b85c
...@@ -140,7 +140,7 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe ...@@ -140,7 +140,7 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe
@Override @Override
public void addSubscription(String destination, String subscriptionId) { public void addSubscription(String destination, String subscriptionId) {
CachingSessionSubscriptionRegistry.this.destinationCache.mapRegistration(destination, this.delegate); destinationCache.mapRegistration(destination, this);
this.delegate.addSubscription(destination, subscriptionId); this.delegate.addSubscription(destination, subscriptionId);
} }
...@@ -148,7 +148,7 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe ...@@ -148,7 +148,7 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe
public String removeSubscription(String subscriptionId) { public String removeSubscription(String subscriptionId) {
String destination = this.delegate.removeSubscription(subscriptionId); String destination = this.delegate.removeSubscription(subscriptionId);
if (destination != null && this.delegate.getSubscriptionsByDestination(destination) == null) { if (destination != null && this.delegate.getSubscriptionsByDestination(destination) == null) {
CachingSessionSubscriptionRegistry.this.destinationCache.unmapRegistration(destination, this); destinationCache.unmapRegistration(destination, this);
} }
return destination; return destination;
} }
...@@ -163,6 +163,23 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe ...@@ -163,6 +163,23 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe
return this.delegate.getDestinations(); return this.delegate.getDestinations();
} }
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof CachingSessionSubscriptionRegistration)) {
return false;
}
CachingSessionSubscriptionRegistration otherType = (CachingSessionSubscriptionRegistration) other;
return this.delegate.equals(otherType.delegate);
}
@Override
public int hashCode() {
return this.delegate.hashCode();
}
@Override @Override
public String toString() { public String toString() {
return "CachingSessionSubscriptionRegistration [delegate=" + delegate + "]"; return "CachingSessionSubscriptionRegistration [delegate=" + delegate + "]";
......
...@@ -89,6 +89,22 @@ public class DefaultSessionSubscriptionRegistration implements SessionSubscripti ...@@ -89,6 +89,22 @@ public class DefaultSessionSubscriptionRegistration implements SessionSubscripti
return this.subscriptions.get(destination); return this.subscriptions.get(destination);
} }
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof DefaultSessionSubscriptionRegistration)) {
return false;
}
DefaultSessionSubscriptionRegistration otherType = (DefaultSessionSubscriptionRegistration) other;
return this.sessionId.equals(otherType.sessionId);
}
@Override
public int hashCode() {
return 31 + this.sessionId.hashCode();
}
@Override @Override
public String toString() { public String toString() {
......
...@@ -78,12 +78,12 @@ public class SimpleBrokerWebMessageHandlerTests { ...@@ -78,12 +78,12 @@ public class SimpleBrokerWebMessageHandlerTests {
this.messageHandler.handlePublish(createMessage("/bar", "message2")); this.messageHandler.handlePublish(createMessage("/bar", "message2"));
verify(this.clientChannel, times(6)).send(this.messageCaptor.capture()); verify(this.clientChannel, times(6)).send(this.messageCaptor.capture());
assertCapturedMessage(this.messageCaptor.getAllValues().get(0), "sess1", "sub1", "/foo"); assertCapturedMessage("sess1", "sub1", "/foo");
assertCapturedMessage(this.messageCaptor.getAllValues().get(1), "sess1", "sub2", "/foo"); assertCapturedMessage("sess1", "sub2", "/foo");
assertCapturedMessage(this.messageCaptor.getAllValues().get(2), "sess2", "sub1", "/foo"); assertCapturedMessage("sess2", "sub1", "/foo");
assertCapturedMessage(this.messageCaptor.getAllValues().get(3), "sess2", "sub2", "/foo"); assertCapturedMessage("sess2", "sub2", "/foo");
assertCapturedMessage(this.messageCaptor.getAllValues().get(4), "sess1", "sub3", "/bar"); assertCapturedMessage("sess1", "sub3", "/bar");
assertCapturedMessage(this.messageCaptor.getAllValues().get(5), "sess2", "sub3", "/bar"); assertCapturedMessage("sess2", "sub3", "/bar");
} }
@Test @Test
...@@ -105,10 +105,13 @@ public class SimpleBrokerWebMessageHandlerTests { ...@@ -105,10 +105,13 @@ public class SimpleBrokerWebMessageHandlerTests {
this.messageHandler.handlePublish(createMessage("/foo", "message1")); this.messageHandler.handlePublish(createMessage("/foo", "message1"));
this.messageHandler.handlePublish(createMessage("/bar", "message2")); this.messageHandler.handlePublish(createMessage("/bar", "message2"));
verify(this.clientChannel, times(3)).send(this.messageCaptor.capture()); verify(this.clientChannel, times(6)).send(this.messageCaptor.capture());
assertCapturedMessage(this.messageCaptor.getAllValues().get(0), "sess2", "sub1", "/foo"); assertCapturedMessage("sess1", "sub1", "/foo");
assertCapturedMessage(this.messageCaptor.getAllValues().get(1), "sess2", "sub2", "/foo"); assertCapturedMessage("sess1", "sub2", "/foo");
assertCapturedMessage(this.messageCaptor.getAllValues().get(2), "sess2", "sub3", "/bar"); assertCapturedMessage("sess2", "sub1", "/foo");
assertCapturedMessage("sess2", "sub2", "/foo");
assertCapturedMessage("sess1", "sub3", "/bar");
assertCapturedMessage("sess2", "sub3", "/bar");
} }
...@@ -130,13 +133,18 @@ public class SimpleBrokerWebMessageHandlerTests { ...@@ -130,13 +133,18 @@ public class SimpleBrokerWebMessageHandlerTests {
return MessageBuilder.withPayload(payload).copyHeaders(headers.toMap()).build(); return MessageBuilder.withPayload(payload).copyHeaders(headers.toMap()).build();
} }
protected void assertCapturedMessage(Message<?> message, String sessionId, protected boolean assertCapturedMessage(String sessionId, String subcriptionId, String destination) {
String subcriptionId, String destination) { for (Message<?> message : this.messageCaptor.getAllValues()) {
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); if (sessionId.equals(headers.getSessionId())) {
assertEquals(sessionId, headers.getSessionId()); if (subcriptionId.equals(headers.getSubscriptionId())) {
assertEquals(subcriptionId, headers.getSubscriptionId()); if (destination.equals(headers.getDestination())) {
assertEquals(destination, headers.getDestination()); return true;
}
}
}
}
return false;
} }
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
package org.springframework.web.messaging.stomp.support; package org.springframework.web.messaging.stomp.support;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
...@@ -23,6 +24,7 @@ import org.springframework.messaging.Message; ...@@ -23,6 +24,7 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessageHeaders;
import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
import static org.junit.Assert.*; import static org.junit.Assert.*;
...@@ -53,7 +55,14 @@ public class StompMessageConverterTests { ...@@ -53,7 +55,14 @@ public class StompMessageConverterTests {
MessageHeaders headers = message.getHeaders(); MessageHeaders headers = message.getHeaders();
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(7, stompHeaders.toMap().size()); Map<String, Object> map = stompHeaders.toMap();
assertEquals(6, map.size());
assertNotNull(map.get(MessageHeaders.ID));
assertNotNull(map.get(MessageHeaders.TIMESTAMP));
assertNotNull(map.get(WebMessageHeaderAccesssor.SESSION_ID));
assertNotNull(map.get(WebMessageHeaderAccesssor.NATIVE_HEADERS));
assertNotNull(map.get(WebMessageHeaderAccesssor.MESSAGE_TYPE));
assertNotNull(map.get(WebMessageHeaderAccesssor.PROTOCOL_MESSAGE_TYPE));
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost()); assertEquals("github.org", stompHeaders.getHost());
......
...@@ -47,11 +47,9 @@ public class CachingSessionSubscriptionRegistryTests { ...@@ -47,11 +47,9 @@ public class CachingSessionSubscriptionRegistryTests {
SessionSubscriptionRegistration reg1 = this.registry.getOrCreateRegistration("sess1"); SessionSubscriptionRegistration reg1 = this.registry.getOrCreateRegistration("sess1");
reg1.addSubscription("/foo", "sub1"); reg1.addSubscription("/foo", "sub1");
reg1.addSubscription("/foo", "sub1");
SessionSubscriptionRegistration reg2 = this.registry.getOrCreateRegistration("sess2"); SessionSubscriptionRegistration reg2 = this.registry.getOrCreateRegistration("sess2");
reg2.addSubscription("/foo", "sub1"); reg2.addSubscription("/foo", "sub1");
reg2.addSubscription("/foo", "sub1");
Set<SessionSubscriptionRegistration> actual = this.registry.getRegistrationsByDestination("/foo"); Set<SessionSubscriptionRegistration> actual = this.registry.getRegistrationsByDestination("/foo");
assertEquals(2, actual.size()); assertEquals(2, actual.size());
...@@ -59,14 +57,12 @@ public class CachingSessionSubscriptionRegistryTests { ...@@ -59,14 +57,12 @@ public class CachingSessionSubscriptionRegistryTests {
assertTrue(actual.contains(reg2)); assertTrue(actual.contains(reg2));
reg1.removeSubscription("sub1"); reg1.removeSubscription("sub1");
reg1.removeSubscription("sub2");
actual = this.registry.getRegistrationsByDestination("/foo"); actual = this.registry.getRegistrationsByDestination("/foo");
assertEquals("Invalid set of registrations " + actual, 1, actual.size()); assertEquals("Invalid set of registrations " + actual, 1, actual.size());
assertTrue(actual.contains(reg2)); assertTrue(actual.contains(reg2));
reg2.removeSubscription("sub1"); reg2.removeSubscription("sub1");
reg2.removeSubscription("sub2");
actual = this.registry.getRegistrationsByDestination("/foo"); actual = this.registry.getRegistrationsByDestination("/foo");
assertNull("Unexpected registrations " + actual, actual); assertNull("Unexpected registrations " + actual, actual);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册