提交 e2ee23bf 编写于 作者: R Rossen Stoyanchev

WebSession supports changeSessionId

Issue: SPR-15571
上级 70252a73
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -101,6 +101,14 @@ public interface WebSession {
*/
boolean isStarted();
/**
* Generate a new id for the session and update the underlying session
* storage to reflect the new id. After a successful call {@link #getId()}
* reflects the new session id.
* @return completion notification (success or error)
*/
Mono<Void> changeSessionId();
/**
* Save the session persisting attributes (e.g. if stored remotely) and also
* sending the session id to the client if the session is new.
......
......@@ -21,11 +21,13 @@ import java.time.Instant;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Function;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert;
import org.springframework.util.IdGenerator;
import org.springframework.web.server.WebSession;
/**
......@@ -36,12 +38,16 @@ import org.springframework.web.server.WebSession;
*/
class DefaultWebSession implements WebSession {
private final String id;
private final AtomicReference<String> id;
private final IdGenerator idGenerator;
private final Map<String, Object> attributes;
private final Clock clock;
private final BiFunction<String, WebSession, Mono<Void>> changeIdOperation;
private final Function<WebSession, Mono<Void>> saveOperation;
private final Instant creationTime;
......@@ -55,14 +61,22 @@ class DefaultWebSession implements WebSession {
/**
* Constructor for creating a brand, new session.
* @param id the session id
* @param idGenerator the session id generator
* @param clock for access to current time
*/
DefaultWebSession(String id, Clock clock, Function<WebSession, Mono<Void>> saveOperation) {
Assert.notNull(id, "'id' is required.");
DefaultWebSession(IdGenerator idGenerator, Clock clock,
BiFunction<String, WebSession, Mono<Void>> changeIdOperation,
Function<WebSession, Mono<Void>> saveOperation) {
Assert.notNull(idGenerator, "'idGenerator' is required.");
Assert.notNull(clock, "'clock' is required.");
this.id = id;
Assert.notNull(changeIdOperation, "'changeIdOperation' is required.");
Assert.notNull(saveOperation, "'saveOperation' is required.");
this.id = new AtomicReference<>(String.valueOf(idGenerator.generateId()));
this.idGenerator = idGenerator;
this.clock = clock;
this.changeIdOperation = changeIdOperation;
this.saveOperation = saveOperation;
this.attributes = new ConcurrentHashMap<>();
this.creationTime = Instant.now(clock);
......@@ -81,12 +95,14 @@ class DefaultWebSession implements WebSession {
Function<WebSession, Mono<Void>> saveOperation) {
this.id = existingSession.id;
this.idGenerator = existingSession.idGenerator;
this.attributes = existingSession.attributes;
this.clock = existingSession.clock;
this.changeIdOperation = existingSession.changeIdOperation;
this.saveOperation = saveOperation;
this.creationTime = existingSession.creationTime;
this.lastAccessTime = lastAccessTime;
this.maxIdleTime = existingSession.maxIdleTime;
this.saveOperation = saveOperation;
this.state = existingSession.state;
}
......@@ -95,19 +111,21 @@ class DefaultWebSession implements WebSession {
*/
DefaultWebSession(DefaultWebSession existingSession, Instant lastAccessTime) {
this.id = existingSession.id;
this.idGenerator = existingSession.idGenerator;
this.attributes = existingSession.attributes;
this.clock = existingSession.clock;
this.changeIdOperation = existingSession.changeIdOperation;
this.saveOperation = existingSession.saveOperation;
this.creationTime = existingSession.creationTime;
this.lastAccessTime = lastAccessTime;
this.maxIdleTime = existingSession.maxIdleTime;
this.saveOperation = existingSession.saveOperation;
this.state = existingSession.state;
}
@Override
public String getId() {
return this.id;
return this.id.get();
}
@Override
......@@ -151,6 +169,14 @@ class DefaultWebSession implements WebSession {
return (State.STARTED.equals(value) || (State.NEW.equals(value) && !getAttributes().isEmpty()));
}
@Override
public Mono<Void> changeSessionId() {
String oldId = this.id.get();
String newId = String.valueOf(this.idGenerator.generateId());
this.id.set(newId);
return this.changeIdOperation.apply(oldId, this).doOnError(ex -> this.id.set(oldId));
}
@Override
public Mono<Void> save() {
return this.saveOperation.apply(this);
......
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -19,12 +19,13 @@ import java.time.Clock;
import java.time.Instant;
import java.time.ZoneId;
import java.util.List;
import java.util.UUID;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert;
import org.springframework.util.IdGenerator;
import org.springframework.util.JdkIdGenerator;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
......@@ -39,6 +40,9 @@ import org.springframework.web.server.WebSession;
*/
public class DefaultWebSessionManager implements WebSessionManager {
private static final IdGenerator idGenerator = new JdkIdGenerator();
private WebSessionIdResolver sessionIdResolver = new CookieWebSessionIdResolver();
private WebSessionStore sessionStore = new InMemoryWebSessionStore();
......@@ -159,10 +163,10 @@ public class DefaultWebSessionManager implements WebSessionManager {
}
private Mono<DefaultWebSession> createSession(ServerWebExchange exchange) {
return Mono.fromSupplier(() -> {
String id = UUID.randomUUID().toString();
return new DefaultWebSession(id, getClock(), sess -> saveSession(exchange, sess));
});
return Mono.fromSupplier(() ->
new DefaultWebSession(idGenerator, getClock(),
(oldId, session) -> this.sessionStore.changeSessionId(oldId, session),
session -> saveSession(exchange, session)));
}
}
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -44,6 +44,13 @@ public class InMemoryWebSessionStore implements WebSessionStore {
return (this.sessions.containsKey(id) ? Mono.just(this.sessions.get(id)) : Mono.empty());
}
@Override
public Mono<Void> changeSessionId(String oldId, WebSession session) {
this.sessions.remove(oldId);
this.sessions.put(session.getId(), session);
return Mono.empty();
}
@Override
public Mono<Void> removeSession(String id) {
this.sessions.remove(id);
......
......@@ -41,6 +41,18 @@ public interface WebSessionStore {
*/
Mono<WebSession> retrieveSession(String sessionId);
/**
* Update WebSession data storage to reflect a change in session id.
* <p>Note that the same can be achieved via a combination of
* {@link #removeSession} + {@link #storeSession}. The purpose of this method
* is to allow a more efficient replacement of the session id mapping
* without replacing and storing the session with all of its data.
* @param oldId the previous session id
* @param session the session reflecting the changed session id
* @return completion notification (success or error)
*/
Mono<Void> changeSessionId(String oldId, WebSession session);
/**
* Remove the WebSession for the specified id.
* @param sessionId the id of the session to remove
......
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -18,6 +18,7 @@ package org.springframework.web.server.session;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
......@@ -31,6 +32,8 @@ import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.lang.Nullable;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.util.IdGenerator;
import org.springframework.util.JdkIdGenerator;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
......@@ -44,10 +47,16 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
/**
* Unit tests for {@link DefaultWebSessionManager}.
* @author Rossen Stoyanchev
*/
public class DefaultWebSessionManagerTests {
private static final Clock CLOCK = Clock.system(ZoneId.of("GMT"));
private static final IdGenerator idGenerator = new JdkIdGenerator();
private DefaultWebSessionManager manager;
private TestWebSessionIdResolver idResolver;
......@@ -105,9 +114,10 @@ public class DefaultWebSessionManagerTests {
@Test
public void existingSession() throws Exception {
DefaultWebSession existing = new DefaultWebSession("1", Clock.systemDefaultZone(), s -> Mono.empty());
DefaultWebSession existing = createDefaultWebSession();
String id = existing.getId();
this.manager.getSessionStore().storeSession(existing);
this.idResolver.setIdsToResolve(Collections.singletonList("1"));
this.idResolver.setIdsToResolve(Collections.singletonList(id));
WebSession actual = this.manager.getSession(this.exchange).block();
assertNotNull(actual);
......@@ -116,10 +126,9 @@ public class DefaultWebSessionManagerTests {
@Test
public void existingSessionIsExpired() throws Exception {
Clock clock = Clock.systemDefaultZone();
DefaultWebSession existing = new DefaultWebSession("1", clock, s -> Mono.empty());
DefaultWebSession existing = createDefaultWebSession();
existing.start();
Instant lastAccessTime = Instant.now(clock).minus(Duration.ofMinutes(31));
Instant lastAccessTime = Instant.now(CLOCK).minus(Duration.ofMinutes(31));
existing = new DefaultWebSession(existing, lastAccessTime, s -> Mono.empty());
this.manager.getSessionStore().storeSession(existing);
this.idResolver.setIdsToResolve(Collections.singletonList("1"));
......@@ -129,16 +138,21 @@ public class DefaultWebSessionManagerTests {
}
@Test
public void multipleSessions() throws Exception {
DefaultWebSession existing = new DefaultWebSession("3", Clock.systemDefaultZone(), s -> Mono.empty());
public void multipleSessionIds() throws Exception {
DefaultWebSession existing = createDefaultWebSession();
String id = existing.getId();
this.manager.getSessionStore().storeSession(existing);
this.idResolver.setIdsToResolve(Arrays.asList("1", "2", "3"));
this.idResolver.setIdsToResolve(Arrays.asList("neither-this", "nor-that", id));
WebSession actual = this.manager.getSession(this.exchange).block();
assertNotNull(actual);
assertEquals(existing.getId(), actual.getId());
}
private DefaultWebSession createDefaultWebSession() {
return new DefaultWebSession(idGenerator, CLOCK, (s, session) -> Mono.empty(), s -> Mono.empty());
}
private static class TestWebSessionIdResolver implements WebSessionIdResolver {
......
......@@ -22,7 +22,6 @@ import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
......@@ -34,19 +33,19 @@ import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebHandler;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
/**
* Integration tests for with a server-side session.
*
* @author Rossen Stoyanchev
*/
public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTests {
......@@ -64,12 +63,6 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe
this.restTemplate = new RestTemplate();
}
private URI createUri(String pathAndQuery) throws URISyntaxException {
boolean prefix = !StringUtils.hasText(pathAndQuery) || !pathAndQuery.startsWith("/");
pathAndQuery = (prefix ? "/" + pathAndQuery : pathAndQuery);
return new URI("http://localhost:" + port + pathAndQuery);
}
@Override
protected HttpHandler createHttpHandler() {
this.sessionManager = new DefaultWebSessionManager();
......@@ -77,45 +70,46 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe
return WebHttpHandlerBuilder.webHandler(this.handler).sessionManager(this.sessionManager).build();
}
@Test
public void createSession() throws Exception {
RequestEntity<Void> request = RequestEntity.get(createUri("/")).build();
RequestEntity<Void> request = RequestEntity.get(createUri()).build();
ResponseEntity<Void> response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
String id = extractSessionId(response.getHeaders());
assertNotNull(id);
assertEquals(1, this.handler.getCount());
assertEquals(1, this.handler.getSessionRequestCount());
request = RequestEntity.get(createUri("/")).header("Cookie", "SESSION=" + id).build();
request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build();
response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
assertNull(response.getHeaders().get("Set-Cookie"));
assertEquals(2, this.handler.getCount());
assertEquals(2, this.handler.getSessionRequestCount());
}
@Test
public void expiredSession() throws Exception {
public void expiredSessionIsRecreated() throws Exception {
// First request: no session yet, new session created
RequestEntity<Void> request = RequestEntity.get(createUri("/")).build();
RequestEntity<Void> request = RequestEntity.get(createUri()).build();
ResponseEntity<Void> response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
String id = extractSessionId(response.getHeaders());
assertNotNull(id);
assertEquals(1, this.handler.getCount());
assertEquals(1, this.handler.getSessionRequestCount());
// Second request: same session
request = RequestEntity.get(createUri("/")).header("Cookie", "SESSION=" + id).build();
request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build();
response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
assertNull(response.getHeaders().get("Set-Cookie"));
assertEquals(2, this.handler.getCount());
assertEquals(2, this.handler.getSessionRequestCount());
// Update lastAccessTime of the created session to -31 min
// Now set the clock of the session back by 31 minutes
WebSessionStore store = this.sessionManager.getSessionStore();
DefaultWebSession session = (DefaultWebSession) store.retrieveSession(id).block();
assertNotNull(session);
......@@ -124,13 +118,37 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe
store.storeSession(session);
// Third request: expired session, new session created
request = RequestEntity.get(createUri("/")).header("Cookie", "SESSION=" + id).build();
request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build();
response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
id = extractSessionId(response.getHeaders());
assertNotNull("Expected new session id", id);
assertEquals("Expected new session attribute", 1, this.handler.getCount());
assertEquals(1, this.handler.getSessionRequestCount());
}
@Test
public void changeSessionId() throws Exception {
// First request: no session yet, new session created
RequestEntity<Void> request = RequestEntity.get(createUri()).build();
ResponseEntity<Void> response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
String oldId = extractSessionId(response.getHeaders());
assertNotNull(oldId);
assertEquals(1, this.handler.getSessionRequestCount());
// Second request: session id changes
URI uri = new URI("http://localhost:" + this.port + "/?changeId");
request = RequestEntity.get(uri).header("Cookie", "SESSION=" + oldId).build();
response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
String newId = extractSessionId(response.getHeaders());
assertNotNull("Expected new session id", newId);
assertNotEquals(oldId, newId);
assertEquals(2, this.handler.getSessionRequestCount());
}
private String extractSessionId(HttpHeaders headers) {
......@@ -146,25 +164,33 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe
return null;
}
private URI createUri() throws URISyntaxException {
return new URI("http://localhost:" + this.port + "/");
}
private static class TestWebHandler implements WebHandler {
private AtomicInteger currentValue = new AtomicInteger();
public int getCount() {
public int getSessionRequestCount() {
return this.currentValue.get();
}
@Override
public Mono<Void> handle(ServerWebExchange exchange) {
return exchange.getSession().map(session -> {
Map<String, Object> map = session.getAttributes();
int value = (map.get("counter") != null ? (int) map.get("counter") : 0);
value++;
map.put("counter", value);
this.currentValue.set(value);
return session;
}).then();
if (exchange.getRequest().getQueryParams().containsKey("changeId")) {
return exchange.getSession().flatMap(session ->
session.changeSessionId().doOnSuccess(aVoid -> updateSessionAttribute(session)));
}
return exchange.getSession().doOnSuccess(this::updateSessionAttribute).then();
}
private void updateSessionAttribute(WebSession session) {
int value = session.getAttributeOrDefault("counter", 0);
session.getAttributes().put("counter", ++value);
this.currentValue.set(value);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册