diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index d97b3ea1b497057b7378971c22131342309a9c35..bf3ae4428759457ea44ddf88a58f426579719e50 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,9 @@ package org.springframework.messaging.simp.broker; -import java.util.Collection; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CopyOnWriteArraySet; import org.springframework.messaging.Message; import org.springframework.util.AntPathMatcher; @@ -34,6 +30,7 @@ import org.springframework.util.MultiValueMap; * A default, simple in-memory implementation of {@link SubscriptionRegistry}. * * @author Rossen Stoyanchev + * @author Sebastien Deleuze * @since 4.0 */ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -59,18 +56,16 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @Override protected void addSubscriptionInternal(String sessionId, String subsId, String destination, Message message) { SessionSubscriptionInfo info = this.subscriptionRegistry.addSubscription(sessionId, subsId, destination); - if (!this.pathMatcher.isPattern(destination)) { - this.destinationCache.mapToDestination(destination, info); - } + this.destinationCache.mapToDestination(destination, sessionId, subsId); } @Override - protected void removeSubscriptionInternal(String sessionId, String subscriptionId, Message message) { + protected void removeSubscriptionInternal(String sessionId, String subsId, Message message) { SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId); if (info != null) { - String destination = info.removeSubscription(subscriptionId); + String destination = info.removeSubscription(subsId); if (info.getSubscriptions(destination) == null) { - this.destinationCache.unmapFromDestination(destination, info); + this.destinationCache.unmapFromDestination(destination, sessionId, subsId); } } } @@ -88,8 +83,11 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @Override protected MultiValueMap findSubscriptionsInternal(String destination, Message message) { - MultiValueMap result = this.destinationCache.getSubscriptions(destination); - if (result.isEmpty()) { + MultiValueMap result; + if (this.destinationCache.isCachedDestination(destination)) { + result = this.destinationCache.getSubscriptions(destination); + } + else { result = new LinkedMultiValueMap(); for (SessionSubscriptionInfo info : this.subscriptionRegistry.getAllSubscriptions()) { for (String destinationPattern : info.getDestinations()) { @@ -100,6 +98,9 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { } } } + if(!result.isEmpty()) { + this.destinationCache.addSubscriptions(destination, result); + } } return result; } @@ -114,60 +115,77 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { /** - * Provide direct lookup of session subscriptions by destination (for non-pattern destinations). + * Provide direct lookup of session subscriptions by destination */ private static class DestinationCache { + private AntPathMatcher pathMatcher = new AntPathMatcher(); + // destination -> .. - private final Map> subscriptionsByDestination = - new ConcurrentHashMap>(); + private final Map> subscriptionsByDestination = + new ConcurrentHashMap>(); private final Object monitor = new Object(); - public void mapToDestination(String destination, SessionSubscriptionInfo info) { + public void addSubscriptions(String destination, MultiValueMap subscriptions) { + this.subscriptionsByDestination.put(destination, subscriptions); + } + + public void mapToDestination(String destination, String sessionId, String subsId) { synchronized(this.monitor) { - Set registrations = this.subscriptionsByDestination.get(destination); - if (registrations == null) { - registrations = new CopyOnWriteArraySet(); - this.subscriptionsByDestination.put(destination, registrations); + for (String cachedDestination : this.subscriptionsByDestination.keySet()) { + if (this.pathMatcher.match(destination, cachedDestination)) { + MultiValueMap registrations = this.subscriptionsByDestination.get(cachedDestination); + if (registrations == null) { + registrations = new LinkedMultiValueMap(); + } + registrations.add(sessionId, subsId); + } } - registrations.add(info); } } - public void unmapFromDestination(String destination, SessionSubscriptionInfo info) { + public void unmapFromDestination(String destination, String sessionId, String subsId) { synchronized(this.monitor) { - Set infos = this.subscriptionsByDestination.get(destination); - if (infos != null) { - infos.remove(info); - if (infos.isEmpty()) { - this.subscriptionsByDestination.remove(destination); + for (String cachedDestination : this.subscriptionsByDestination.keySet()) { + if (this.pathMatcher.match(destination, cachedDestination)) { + MultiValueMap registrations = this.subscriptionsByDestination.get(cachedDestination); + List subscriptions = registrations.get(sessionId); + while(subscriptions.remove(subsId)); + if (subscriptions.isEmpty()) { + registrations.remove(sessionId); + } + if (registrations.isEmpty()) { + this.subscriptionsByDestination.remove(cachedDestination); + } } } } } public void removeSessionSubscriptions(SessionSubscriptionInfo info) { - for (String destination : info.getDestinations()) { - unmapFromDestination(destination, info); - } - } - - public MultiValueMap getSubscriptions(String destination) { - MultiValueMap result = new LinkedMultiValueMap(); - Set infos = this.subscriptionsByDestination.get(destination); - if (infos != null) { - for (SessionSubscriptionInfo info : infos) { - Set subscriptions = info.getSubscriptions(destination); - if (subscriptions != null) { - for (String subscription : subscriptions) { - result.add(info.getSessionId(), subscription); + synchronized(this.monitor) { + for (String destination : info.getDestinations()) { + for (String cachedDestination : this.subscriptionsByDestination.keySet()) { + if (this.pathMatcher.match(destination, cachedDestination)) { + MultiValueMap map = this.subscriptionsByDestination.get(cachedDestination); + map.remove(info.getSessionId()); + if (map.isEmpty()) { + this.subscriptionsByDestination.remove(cachedDestination); + } } } } } - return result; + } + + public MultiValueMap getSubscriptions(String destination) { + return this.subscriptionsByDestination.get(destination); + } + + public boolean isCachedDestination(String destination) { + return subscriptionsByDestination.containsKey(destination); } @Override diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index 854e00d3c022b81ab7d812afdbd817a14c871a44..5fc8c63260b2a0eb7b52a970bc8c688d9b8e313d 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,6 @@ import org.junit.Test; import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.messaging.simp.broker.DefaultSubscriptionRegistry; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.MultiValueMap; @@ -35,6 +34,7 @@ import static org.junit.Assert.*; * Test fixture for {@link org.springframework.messaging.simp.broker.DefaultSubscriptionRegistry}. * * @author Rossen Stoyanchev + * @author Sebastien Deleuze */ public class DefaultSubscriptionRegistryTests { @@ -131,6 +131,64 @@ public class DefaultSubscriptionRegistryTests { assertEquals(Arrays.asList(subsId), actual.get(sessId)); } + // SPR-11657 + + @Test + public void registerMultipleSubscriptionsWithOneUsingDestinationPattern() { + + String sessId1 = "sess01"; + String sessId2 = "sess02"; + + String destPatternIbm = "/topic/PRICE.STOCK.*.IBM"; + String destNasdaqIbm = "/topic/PRICE.STOCK.NASDAQ.IBM"; + String destNyseIdm = "/topic/PRICE.STOCK.NYSE.IBM"; + String destNasdaqGoogle = "/topic/PRICE.STOCK.NASDAQ.GOOG"; + + String sessId1ToDestPatternIbm = "subs01"; + String sessId1ToDestNasdaqIbm = "subs02"; + String sessId2TodestNasdaqIbm = "subs03"; + String sessId2ToDestNyseIdm = "subs04"; + String sessId2ToDestNasdaqGoogle = "subs05"; + + this.registry.registerSubscription(subscribeMessage(sessId1, sessId1ToDestNasdaqIbm, destNasdaqIbm)); + this.registry.registerSubscription(subscribeMessage(sessId1, sessId1ToDestPatternIbm, destPatternIbm)); + MultiValueMap actual = this.registry.findSubscriptions(message(destNasdaqIbm)); + assertEquals("Expected 1 elements " + actual, 1, actual.size()); + assertEquals(Arrays.asList(sessId1ToDestNasdaqIbm, sessId1ToDestPatternIbm), actual.get(sessId1)); + + this.registry.registerSubscription(subscribeMessage(sessId2, sessId2TodestNasdaqIbm, destNasdaqIbm)); + this.registry.registerSubscription(subscribeMessage(sessId2, sessId2ToDestNyseIdm, destNyseIdm)); + this.registry.registerSubscription(subscribeMessage(sessId2, sessId2ToDestNasdaqGoogle, destNasdaqGoogle)); + actual = this.registry.findSubscriptions(message(destNasdaqIbm)); + assertEquals("Expected 2 elements " + actual, 2, actual.size()); + assertEquals(Arrays.asList(sessId1ToDestNasdaqIbm, sessId1ToDestPatternIbm), actual.get(sessId1)); + assertEquals(Arrays.asList(sessId2TodestNasdaqIbm), actual.get(sessId2)); + + this.registry.unregisterAllSubscriptions(sessId1); + actual = this.registry.findSubscriptions(message(destNasdaqIbm)); + assertEquals("Expected 1 elements " + actual, 1, actual.size()); + assertEquals(Arrays.asList(sessId2TodestNasdaqIbm), actual.get(sessId2)); + + this.registry.registerSubscription(subscribeMessage(sessId1, sessId1ToDestPatternIbm, destPatternIbm)); + this.registry.registerSubscription(subscribeMessage(sessId1, sessId1ToDestNasdaqIbm, destNasdaqIbm)); + actual = this.registry.findSubscriptions(message(destNasdaqIbm)); + assertEquals("Expected 2 elements " + actual, 2, actual.size()); + assertEquals(Arrays.asList(sessId1ToDestPatternIbm, sessId1ToDestNasdaqIbm), actual.get(sessId1)); + assertEquals(Arrays.asList(sessId2TodestNasdaqIbm), actual.get(sessId2)); + + this.registry.unregisterSubscription(unsubscribeMessage(sessId1, sessId1ToDestNasdaqIbm)); + actual = this.registry.findSubscriptions(message(destNasdaqIbm)); + assertEquals("Expected 2 elements " + actual, 2, actual.size()); + assertEquals(Arrays.asList(sessId1ToDestPatternIbm), actual.get(sessId1)); + assertEquals(Arrays.asList(sessId2TodestNasdaqIbm), actual.get(sessId2)); + this.registry.unregisterSubscription(unsubscribeMessage(sessId1, sessId1ToDestPatternIbm)); + assertEquals("Expected 1 elements " + actual, 1, actual.size()); + assertEquals(Arrays.asList(sessId2TodestNasdaqIbm), actual.get(sessId2)); + + this.registry.unregisterSubscription(unsubscribeMessage(sessId2, sessId2TodestNasdaqIbm)); + assertEquals("Expected 0 element " + actual, 0, actual.size()); + } + @Test public void registerSubscriptionWithDestinationPatternRegex() {