From 3272917cf276017df63a214e18d2914999b0baeb Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 19 Jul 2013 10:57:23 -0400 Subject: [PATCH] Polish concurrency in UserSessionResolver impl --- .../handler/DefaultSubscriptionRegistry.java | 30 +++---- .../handler/SimpleUserSessionResolver.java | 13 +-- .../SimpleUserSessionResolverTests.java | 82 +++++++++++++++++++ 3 files changed, 105 insertions(+), 20 deletions(-) create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolverTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java index 310fb21f8e..17f242d9b2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArraySet; import org.springframework.messaging.Message; @@ -180,11 +181,9 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { */ private static class SessionSubscriptionRegistry { - private final Map sessions = + private final ConcurrentMap sessions = new ConcurrentHashMap(); - private final Object monitor = new Object(); - public SessionSubscriptionInfo getSubscriptions(String sessionId) { return this.sessions.get(sessionId); @@ -197,12 +196,10 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, String destination) { SessionSubscriptionInfo info = this.sessions.get(sessionId); if (info == null) { - synchronized(this.monitor) { - info = this.sessions.get(sessionId); - if (info == null) { - info = new SessionSubscriptionInfo(sessionId); - this.sessions.put(sessionId, info); - } + info = new SessionSubscriptionInfo(sessionId); + SessionSubscriptionInfo value = this.sessions.putIfAbsent(sessionId, info); + if (value != null) { + info = value; } } info.addSubscription(destination, subscriptionId); @@ -249,14 +246,17 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { } public void addSubscription(String destination, String subscriptionId) { - synchronized(this.monitor) { - Set subs = this.subscriptions.get(destination); - if (subs == null) { - subs = new HashSet(4); - this.subscriptions.put(destination, subs); + Set subs = this.subscriptions.get(destination); + if (subs == null) { + synchronized(this.monitor) { + subs = this.subscriptions.get(destination); + if (subs == null) { + subs = new HashSet(4); + this.subscriptions.put(destination, subs); + } } - subs.add(subscriptionId); } + subs.add(subscriptionId); } public String removeSubscription(String subscriptionId) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolver.java index c00770f3ea..cd897f3024 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolver.java @@ -17,9 +17,9 @@ package org.springframework.messaging.simp.handler; import java.util.Collections; -import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArraySet; @@ -30,7 +30,7 @@ import java.util.concurrent.CopyOnWriteArraySet; public class SimpleUserSessionResolver implements MutableUserSessionResolver { // userId -> sessionId's - private final Map> userSessionIds = new ConcurrentHashMap>(); + private final ConcurrentMap> userSessionIds = new ConcurrentHashMap>(); @Override @@ -38,7 +38,10 @@ public class SimpleUserSessionResolver implements MutableUserSessionResolver { Set sessionIds = this.userSessionIds.get(user); if (sessionIds == null) { sessionIds = new CopyOnWriteArraySet(); - this.userSessionIds.put(user, sessionIds); + Set value = this.userSessionIds.putIfAbsent(user, sessionIds); + if (value != null) { + sessionIds = value; + } } sessionIds.add(sessionId); } @@ -47,8 +50,8 @@ public class SimpleUserSessionResolver implements MutableUserSessionResolver { public void removeUserSessionId(String user, String sessionId) { Set sessionIds = this.userSessionIds.get(user); if (sessionIds != null) { - if (sessionIds.remove(sessionId) && sessionIds.isEmpty()) { - this.userSessionIds.remove(user); + if (sessionIds.remove(sessionId)) { + this.userSessionIds.remove(user, Collections.emptySet()); } } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolverTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolverTests.java new file mode 100644 index 0000000000..51a9bc9ab0 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolverTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.handler; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; + +import org.junit.Test; + +import static org.junit.Assert.*; + + +/** + * Test fixture for {@link SimpleUserSessionResolver} + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class SimpleUserSessionResolverTests { + + private static final String user = "joe"; + private static final List sessionIds = Arrays.asList("sess01", "sess02", "sess03"); + + + @Test + public void addOneSessionId() { + + SimpleUserSessionResolver resolver = new SimpleUserSessionResolver(); + resolver.addUserSessionId(user, sessionIds.get(0)); + + assertEquals(Collections.singleton(sessionIds.get(0)), resolver.resolveUserSessionIds(user)); + assertSame(Collections.emptySet(), resolver.resolveUserSessionIds("jane")); + } + + @Test + public void addMultipleSessionIds() { + + SimpleUserSessionResolver resolver = new SimpleUserSessionResolver(); + for (String sessionId : sessionIds) { + resolver.addUserSessionId(user, sessionId); + } + + assertEquals(new LinkedHashSet<>(sessionIds), resolver.resolveUserSessionIds(user)); + assertEquals(Collections.emptySet(), resolver.resolveUserSessionIds("jane")); + } + + + @Test + public void removeSessionIds() { + + SimpleUserSessionResolver resolver = new SimpleUserSessionResolver(); + for (String sessionId : sessionIds) { + resolver.addUserSessionId(user, sessionId); + } + + assertEquals(new LinkedHashSet<>(sessionIds), resolver.resolveUserSessionIds(user)); + + resolver.removeUserSessionId(user, sessionIds.get(1)); + resolver.removeUserSessionId(user, sessionIds.get(2)); + assertEquals(Collections.singleton(sessionIds.get(0)), resolver.resolveUserSessionIds(user)); + + resolver.removeUserSessionId(user, sessionIds.get(0)); + assertSame(Collections.emptySet(), resolver.resolveUserSessionIds(user)); + } + +} -- GitLab