From e2ee23bfc597c667d3f5792eb319666723cf3014 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 17 Jul 2017 11:09:38 +0200 Subject: [PATCH] WebSession supports changeSessionId Issue: SPR-15571 --- .../web/server/WebSession.java | 10 ++- .../web/server/session/DefaultWebSession.java | 42 +++++++-- .../session/DefaultWebSessionManager.java | 16 ++-- .../session/InMemoryWebSessionStore.java | 9 +- .../web/server/session/WebSessionStore.java | 12 +++ .../DefaultWebSessionManagerTests.java | 32 +++++-- .../session/WebSessionIntegrationTests.java | 86 ++++++++++++------- 7 files changed, 152 insertions(+), 55 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/server/WebSession.java b/spring-web/src/main/java/org/springframework/web/server/WebSession.java index b85ae5a558..8cbc0fdea9 100644 --- a/spring-web/src/main/java/org/springframework/web/server/WebSession.java +++ b/spring-web/src/main/java/org/springframework/web/server/WebSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -101,6 +101,14 @@ public interface WebSession { */ boolean isStarted(); + /** + * Generate a new id for the session and update the underlying session + * storage to reflect the new id. After a successful call {@link #getId()} + * reflects the new session id. + * @return completion notification (success or error) + */ + Mono changeSessionId(); + /** * Save the session persisting attributes (e.g. if stored remotely) and also * sending the session id to the client if the session is new. diff --git a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSession.java b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSession.java index 77d3f142f3..e6856a0699 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSession.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSession.java @@ -21,11 +21,13 @@ import java.time.Instant; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Function; import reactor.core.publisher.Mono; import org.springframework.util.Assert; +import org.springframework.util.IdGenerator; import org.springframework.web.server.WebSession; /** @@ -36,12 +38,16 @@ import org.springframework.web.server.WebSession; */ class DefaultWebSession implements WebSession { - private final String id; + private final AtomicReference id; + + private final IdGenerator idGenerator; private final Map attributes; private final Clock clock; + private final BiFunction> changeIdOperation; + private final Function> saveOperation; private final Instant creationTime; @@ -55,14 +61,22 @@ class DefaultWebSession implements WebSession { /** * Constructor for creating a brand, new session. - * @param id the session id + * @param idGenerator the session id generator * @param clock for access to current time */ - DefaultWebSession(String id, Clock clock, Function> saveOperation) { - Assert.notNull(id, "'id' is required."); + DefaultWebSession(IdGenerator idGenerator, Clock clock, + BiFunction> changeIdOperation, + Function> saveOperation) { + + Assert.notNull(idGenerator, "'idGenerator' is required."); Assert.notNull(clock, "'clock' is required."); - this.id = id; + Assert.notNull(changeIdOperation, "'changeIdOperation' is required."); + Assert.notNull(saveOperation, "'saveOperation' is required."); + + this.id = new AtomicReference<>(String.valueOf(idGenerator.generateId())); + this.idGenerator = idGenerator; this.clock = clock; + this.changeIdOperation = changeIdOperation; this.saveOperation = saveOperation; this.attributes = new ConcurrentHashMap<>(); this.creationTime = Instant.now(clock); @@ -81,12 +95,14 @@ class DefaultWebSession implements WebSession { Function> saveOperation) { this.id = existingSession.id; + this.idGenerator = existingSession.idGenerator; this.attributes = existingSession.attributes; this.clock = existingSession.clock; + this.changeIdOperation = existingSession.changeIdOperation; + this.saveOperation = saveOperation; this.creationTime = existingSession.creationTime; this.lastAccessTime = lastAccessTime; this.maxIdleTime = existingSession.maxIdleTime; - this.saveOperation = saveOperation; this.state = existingSession.state; } @@ -95,19 +111,21 @@ class DefaultWebSession implements WebSession { */ DefaultWebSession(DefaultWebSession existingSession, Instant lastAccessTime) { this.id = existingSession.id; + this.idGenerator = existingSession.idGenerator; this.attributes = existingSession.attributes; this.clock = existingSession.clock; + this.changeIdOperation = existingSession.changeIdOperation; + this.saveOperation = existingSession.saveOperation; this.creationTime = existingSession.creationTime; this.lastAccessTime = lastAccessTime; this.maxIdleTime = existingSession.maxIdleTime; - this.saveOperation = existingSession.saveOperation; this.state = existingSession.state; } @Override public String getId() { - return this.id; + return this.id.get(); } @Override @@ -151,6 +169,14 @@ class DefaultWebSession implements WebSession { return (State.STARTED.equals(value) || (State.NEW.equals(value) && !getAttributes().isEmpty())); } + @Override + public Mono changeSessionId() { + String oldId = this.id.get(); + String newId = String.valueOf(this.idGenerator.generateId()); + this.id.set(newId); + return this.changeIdOperation.apply(oldId, this).doOnError(ex -> this.id.set(oldId)); + } + @Override public Mono save() { return this.saveOperation.apply(this); diff --git a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java index 75190de0d0..e950209070 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,12 +19,13 @@ import java.time.Clock; import java.time.Instant; import java.time.ZoneId; import java.util.List; -import java.util.UUID; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.util.Assert; +import org.springframework.util.IdGenerator; +import org.springframework.util.JdkIdGenerator; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; @@ -39,6 +40,9 @@ import org.springframework.web.server.WebSession; */ public class DefaultWebSessionManager implements WebSessionManager { + private static final IdGenerator idGenerator = new JdkIdGenerator(); + + private WebSessionIdResolver sessionIdResolver = new CookieWebSessionIdResolver(); private WebSessionStore sessionStore = new InMemoryWebSessionStore(); @@ -159,10 +163,10 @@ public class DefaultWebSessionManager implements WebSessionManager { } private Mono createSession(ServerWebExchange exchange) { - return Mono.fromSupplier(() -> { - String id = UUID.randomUUID().toString(); - return new DefaultWebSession(id, getClock(), sess -> saveSession(exchange, sess)); - }); + return Mono.fromSupplier(() -> + new DefaultWebSession(idGenerator, getClock(), + (oldId, session) -> this.sessionStore.changeSessionId(oldId, session), + session -> saveSession(exchange, session))); } } diff --git a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java index 60fbed7ed2..dd193faf5e 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,13 @@ public class InMemoryWebSessionStore implements WebSessionStore { return (this.sessions.containsKey(id) ? Mono.just(this.sessions.get(id)) : Mono.empty()); } + @Override + public Mono changeSessionId(String oldId, WebSession session) { + this.sessions.remove(oldId); + this.sessions.put(session.getId(), session); + return Mono.empty(); + } + @Override public Mono removeSession(String id) { this.sessions.remove(id); diff --git a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java index e6461c311d..b52cafe546 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java @@ -41,6 +41,18 @@ public interface WebSessionStore { */ Mono retrieveSession(String sessionId); + /** + * Update WebSession data storage to reflect a change in session id. + *

Note that the same can be achieved via a combination of + * {@link #removeSession} + {@link #storeSession}. The purpose of this method + * is to allow a more efficient replacement of the session id mapping + * without replacing and storing the session with all of its data. + * @param oldId the previous session id + * @param session the session reflecting the changed session id + * @return completion notification (success or error) + */ + Mono changeSessionId(String oldId, WebSession session); + /** * Remove the WebSession for the specified id. * @param sessionId the id of the session to remove diff --git a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java index 82d47f0ea3..d05cd7fd04 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.web.server.session; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.time.ZoneId; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -31,6 +32,8 @@ import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.lang.Nullable; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.util.IdGenerator; +import org.springframework.util.JdkIdGenerator; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; import org.springframework.web.server.adapter.DefaultServerWebExchange; @@ -44,10 +47,16 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; /** + * Unit tests for {@link DefaultWebSessionManager}. * @author Rossen Stoyanchev */ public class DefaultWebSessionManagerTests { + private static final Clock CLOCK = Clock.system(ZoneId.of("GMT")); + + private static final IdGenerator idGenerator = new JdkIdGenerator(); + + private DefaultWebSessionManager manager; private TestWebSessionIdResolver idResolver; @@ -105,9 +114,10 @@ public class DefaultWebSessionManagerTests { @Test public void existingSession() throws Exception { - DefaultWebSession existing = new DefaultWebSession("1", Clock.systemDefaultZone(), s -> Mono.empty()); + DefaultWebSession existing = createDefaultWebSession(); + String id = existing.getId(); this.manager.getSessionStore().storeSession(existing); - this.idResolver.setIdsToResolve(Collections.singletonList("1")); + this.idResolver.setIdsToResolve(Collections.singletonList(id)); WebSession actual = this.manager.getSession(this.exchange).block(); assertNotNull(actual); @@ -116,10 +126,9 @@ public class DefaultWebSessionManagerTests { @Test public void existingSessionIsExpired() throws Exception { - Clock clock = Clock.systemDefaultZone(); - DefaultWebSession existing = new DefaultWebSession("1", clock, s -> Mono.empty()); + DefaultWebSession existing = createDefaultWebSession(); existing.start(); - Instant lastAccessTime = Instant.now(clock).minus(Duration.ofMinutes(31)); + Instant lastAccessTime = Instant.now(CLOCK).minus(Duration.ofMinutes(31)); existing = new DefaultWebSession(existing, lastAccessTime, s -> Mono.empty()); this.manager.getSessionStore().storeSession(existing); this.idResolver.setIdsToResolve(Collections.singletonList("1")); @@ -129,16 +138,21 @@ public class DefaultWebSessionManagerTests { } @Test - public void multipleSessions() throws Exception { - DefaultWebSession existing = new DefaultWebSession("3", Clock.systemDefaultZone(), s -> Mono.empty()); + public void multipleSessionIds() throws Exception { + DefaultWebSession existing = createDefaultWebSession(); + String id = existing.getId(); this.manager.getSessionStore().storeSession(existing); - this.idResolver.setIdsToResolve(Arrays.asList("1", "2", "3")); + this.idResolver.setIdsToResolve(Arrays.asList("neither-this", "nor-that", id)); WebSession actual = this.manager.getSession(this.exchange).block(); assertNotNull(actual); assertEquals(existing.getId(), actual.getId()); } + private DefaultWebSession createDefaultWebSession() { + return new DefaultWebSession(idGenerator, CLOCK, (s, session) -> Mono.empty(), s -> Mono.empty()); + } + private static class TestWebSessionIdResolver implements WebSessionIdResolver { diff --git a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java index 26b1adc2a5..705044447f 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java @@ -22,7 +22,6 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.List; -import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; @@ -34,19 +33,19 @@ import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests; import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.util.StringUtils; import org.springframework.web.client.RestTemplate; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebHandler; +import org.springframework.web.server.WebSession; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; /** * Integration tests for with a server-side session. - * * @author Rossen Stoyanchev */ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTests { @@ -64,12 +63,6 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe this.restTemplate = new RestTemplate(); } - private URI createUri(String pathAndQuery) throws URISyntaxException { - boolean prefix = !StringUtils.hasText(pathAndQuery) || !pathAndQuery.startsWith("/"); - pathAndQuery = (prefix ? "/" + pathAndQuery : pathAndQuery); - return new URI("http://localhost:" + port + pathAndQuery); - } - @Override protected HttpHandler createHttpHandler() { this.sessionManager = new DefaultWebSessionManager(); @@ -77,45 +70,46 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe return WebHttpHandlerBuilder.webHandler(this.handler).sessionManager(this.sessionManager).build(); } + @Test public void createSession() throws Exception { - RequestEntity request = RequestEntity.get(createUri("/")).build(); + RequestEntity request = RequestEntity.get(createUri()).build(); ResponseEntity response = this.restTemplate.exchange(request, Void.class); assertEquals(HttpStatus.OK, response.getStatusCode()); String id = extractSessionId(response.getHeaders()); assertNotNull(id); - assertEquals(1, this.handler.getCount()); + assertEquals(1, this.handler.getSessionRequestCount()); - request = RequestEntity.get(createUri("/")).header("Cookie", "SESSION=" + id).build(); + request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build(); response = this.restTemplate.exchange(request, Void.class); assertEquals(HttpStatus.OK, response.getStatusCode()); assertNull(response.getHeaders().get("Set-Cookie")); - assertEquals(2, this.handler.getCount()); + assertEquals(2, this.handler.getSessionRequestCount()); } @Test - public void expiredSession() throws Exception { + public void expiredSessionIsRecreated() throws Exception { // First request: no session yet, new session created - RequestEntity request = RequestEntity.get(createUri("/")).build(); + RequestEntity request = RequestEntity.get(createUri()).build(); ResponseEntity response = this.restTemplate.exchange(request, Void.class); assertEquals(HttpStatus.OK, response.getStatusCode()); String id = extractSessionId(response.getHeaders()); assertNotNull(id); - assertEquals(1, this.handler.getCount()); + assertEquals(1, this.handler.getSessionRequestCount()); // Second request: same session - request = RequestEntity.get(createUri("/")).header("Cookie", "SESSION=" + id).build(); + request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build(); response = this.restTemplate.exchange(request, Void.class); assertEquals(HttpStatus.OK, response.getStatusCode()); assertNull(response.getHeaders().get("Set-Cookie")); - assertEquals(2, this.handler.getCount()); + assertEquals(2, this.handler.getSessionRequestCount()); - // Update lastAccessTime of the created session to -31 min + // Now set the clock of the session back by 31 minutes WebSessionStore store = this.sessionManager.getSessionStore(); DefaultWebSession session = (DefaultWebSession) store.retrieveSession(id).block(); assertNotNull(session); @@ -124,13 +118,37 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe store.storeSession(session); // Third request: expired session, new session created - request = RequestEntity.get(createUri("/")).header("Cookie", "SESSION=" + id).build(); + request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build(); response = this.restTemplate.exchange(request, Void.class); assertEquals(HttpStatus.OK, response.getStatusCode()); id = extractSessionId(response.getHeaders()); assertNotNull("Expected new session id", id); - assertEquals("Expected new session attribute", 1, this.handler.getCount()); + assertEquals(1, this.handler.getSessionRequestCount()); + } + + @Test + public void changeSessionId() throws Exception { + + // First request: no session yet, new session created + RequestEntity request = RequestEntity.get(createUri()).build(); + ResponseEntity response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String oldId = extractSessionId(response.getHeaders()); + assertNotNull(oldId); + assertEquals(1, this.handler.getSessionRequestCount()); + + // Second request: session id changes + URI uri = new URI("http://localhost:" + this.port + "/?changeId"); + request = RequestEntity.get(uri).header("Cookie", "SESSION=" + oldId).build(); + response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String newId = extractSessionId(response.getHeaders()); + assertNotNull("Expected new session id", newId); + assertNotEquals(oldId, newId); + assertEquals(2, this.handler.getSessionRequestCount()); } private String extractSessionId(HttpHeaders headers) { @@ -146,25 +164,33 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe return null; } + private URI createUri() throws URISyntaxException { + return new URI("http://localhost:" + this.port + "/"); + } + + private static class TestWebHandler implements WebHandler { private AtomicInteger currentValue = new AtomicInteger(); - public int getCount() { + public int getSessionRequestCount() { return this.currentValue.get(); } @Override public Mono handle(ServerWebExchange exchange) { - return exchange.getSession().map(session -> { - Map map = session.getAttributes(); - int value = (map.get("counter") != null ? (int) map.get("counter") : 0); - value++; - map.put("counter", value); - this.currentValue.set(value); - return session; - }).then(); + if (exchange.getRequest().getQueryParams().containsKey("changeId")) { + return exchange.getSession().flatMap(session -> + session.changeSessionId().doOnSuccess(aVoid -> updateSessionAttribute(session))); + } + return exchange.getSession().doOnSuccess(this::updateSessionAttribute).then(); + } + + private void updateSessionAttribute(WebSession session) { + int value = session.getAttributeOrDefault("counter", 0); + session.getAttributes().put("counter", ++value); + this.currentValue.set(value); } } -- GitLab