提交 7de6cfa1 编写于 作者: R Rossen Stoyanchev

Refactor WebSession#getAttribute options

Issue: SPR-15718
上级 9253facf
...@@ -43,9 +43,9 @@ public class MockServerIntegrationTests { ...@@ -43,9 +43,9 @@ public class MockServerIntegrationTests {
} }
else { else {
return exchange.getSession() return exchange.getSession()
.map(session -> session.getAttribute("foo").orElse("none")) .map(session -> session.getAttributeOrDefault("foo", "none"))
.flatMap(value -> { .flatMap(value -> {
byte[] bytes = value.toString().getBytes(UTF_8); byte[] bytes = value.getBytes(UTF_8);
DataBuffer buffer = new DefaultDataBufferFactory().wrap(bytes); DataBuffer buffer = new DefaultDataBufferFactory().wrap(bytes);
return exchange.getResponse().writeWith(Mono.just(buffer)); return exchange.getResponse().writeWith(Mono.just(buffer));
}); });
......
...@@ -19,10 +19,12 @@ package org.springframework.web.server; ...@@ -19,10 +19,12 @@ package org.springframework.web.server;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/** /**
* Main contract for using a server-side session that provides access to session * Main contract for using a server-side session that provides access to session
* attributes across HTTP requests. * attributes across HTTP requests.
...@@ -48,12 +50,42 @@ public interface WebSession { ...@@ -48,12 +50,42 @@ public interface WebSession {
Map<String, Object> getAttributes(); Map<String, Object> getAttributes();
/** /**
* Return the attribute value if present. * Return the session attribute value if present.
* @param name the attribute name
* @param <T> the attribute type
* @return the attribute value
*/
@SuppressWarnings("unchecked")
@Nullable
default <T> T getAttribute(String name) {
return (T) getAttributes().get(name);
}
/**
* Return the session attribute value or if not present raise an
* {@link IllegalArgumentException}.
* @param name the attribute name
* @param <T> the attribute type
* @return the attribute value
*/
@SuppressWarnings("unchecked")
default <T> T getRequiredAttribute(String name) {
T value = getAttribute(name);
Assert.notNull(value, "Required attribute '" + name + "' is missing.");
return value;
}
/**
* Return the session attribute value, or a default, fallback value.
* @param name the attribute name * @param name the attribute name
* @param defaultValue a default value to return instead
* @param <T> the attribute type * @param <T> the attribute type
* @return the attribute value * @return the attribute value
*/ */
<T> Optional<T> getAttribute(String name); @SuppressWarnings("unchecked")
default <T> T getAttributeOrDefault(String name, T defaultValue) {
return (T) getAttributes().getOrDefault(name, defaultValue);
}
/** /**
* Force the creation of a session causing the session id to be sent when * Force the creation of a session causing the session id to be sent when
......
...@@ -20,7 +20,6 @@ import java.time.Clock; ...@@ -20,7 +20,6 @@ import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier; import java.util.function.Supplier;
...@@ -108,11 +107,6 @@ public class DefaultWebSession implements ConfigurableWebSession, Serializable { ...@@ -108,11 +107,6 @@ public class DefaultWebSession implements ConfigurableWebSession, Serializable {
return this.attributes; return this.attributes;
} }
@Override @SuppressWarnings("unchecked")
public <T> Optional<T> getAttribute(String name) {
return Optional.ofNullable((T) this.attributes.get(name));
}
@Override @Override
public Instant getCreationTime() { public Instant getCreationTime() {
return this.creationTime; return this.creationTime;
......
...@@ -58,9 +58,8 @@ public class SessionAttributeMethodArgumentResolver extends AbstractNamedValueAr ...@@ -58,9 +58,8 @@ public class SessionAttributeMethodArgumentResolver extends AbstractNamedValueAr
@Override @Override
protected Mono<Object> resolveName(String name, MethodParameter parameter, ServerWebExchange exchange) { protected Mono<Object> resolveName(String name, MethodParameter parameter, ServerWebExchange exchange) {
return exchange.getSession() return exchange.getSession()
.map(session -> session.getAttribute(name)) .filter(session -> session.getAttribute(name) != null)
.filter(Optional::isPresent) .map(session -> session.getAttribute(name));
.map(Optional::get);
} }
@Override @Override
......
...@@ -47,7 +47,6 @@ import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; ...@@ -47,7 +47,6 @@ import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
import org.springframework.web.server.session.MockWebSessionManager; import org.springframework.web.server.session.MockWebSessionManager;
import org.springframework.web.server.session.WebSessionManager; import org.springframework.web.server.session.WebSessionManager;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
...@@ -103,7 +102,7 @@ public class SessionAttributeMethodArgumentResolverTests { ...@@ -103,7 +102,7 @@ public class SessionAttributeMethodArgumentResolverTests {
StepVerifier.create(mono).expectError(ServerWebInputException.class).verify(); StepVerifier.create(mono).expectError(ServerWebInputException.class).verify();
Foo foo = new Foo(); Foo foo = new Foo();
when(this.session.getAttribute("foo")).thenReturn(Optional.of(foo)); when(this.session.getAttribute("foo")).thenReturn(foo);
mono = this.resolver.resolveArgument(param, new BindingContext(), this.exchange); mono = this.resolver.resolveArgument(param, new BindingContext(), this.exchange);
assertSame(foo, mono.block()); assertSame(foo, mono.block());
} }
...@@ -112,7 +111,7 @@ public class SessionAttributeMethodArgumentResolverTests { ...@@ -112,7 +111,7 @@ public class SessionAttributeMethodArgumentResolverTests {
public void resolveWithName() throws Exception { public void resolveWithName() throws Exception {
MethodParameter param = initMethodParameter(1); MethodParameter param = initMethodParameter(1);
Foo foo = new Foo(); Foo foo = new Foo();
when(this.session.getAttribute("specialFoo")).thenReturn(Optional.of(foo)); when(this.session.getAttribute("specialFoo")).thenReturn(foo);
Mono<Object> mono = this.resolver.resolveArgument(param, new BindingContext(), this.exchange); Mono<Object> mono = this.resolver.resolveArgument(param, new BindingContext(), this.exchange);
assertSame(foo, mono.block()); assertSame(foo, mono.block());
} }
...@@ -124,32 +123,32 @@ public class SessionAttributeMethodArgumentResolverTests { ...@@ -124,32 +123,32 @@ public class SessionAttributeMethodArgumentResolverTests {
assertNull(mono.block()); assertNull(mono.block());
Foo foo = new Foo(); Foo foo = new Foo();
when(this.session.getAttribute("foo")).thenReturn(Optional.of(foo)); when(this.session.getAttribute("foo")).thenReturn(foo);
mono = this.resolver.resolveArgument(param, new BindingContext(), this.exchange); mono = this.resolver.resolveArgument(param, new BindingContext(), this.exchange);
assertSame(foo, mono.block()); assertSame(foo, mono.block());
} }
@SuppressWarnings("unchecked")
@Test @Test
public void resolveOptional() throws Exception { public void resolveOptional() throws Exception {
MethodParameter param = initMethodParameter(3); MethodParameter param = initMethodParameter(3);
Mono<Object> mono = this.resolver.resolveArgument(param, new BindingContext(), this.exchange); Optional<Object> actual = (Optional<Object>) this.resolver
assertNotNull(mono.block()); .resolveArgument(param, new BindingContext(), this.exchange).block();
assertEquals(Optional.class, mono.block().getClass());
assertFalse(((Optional<?>) mono.block()).isPresent()); assertNotNull(actual);
assertFalse(actual.isPresent());
ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer();
initializer.setConversionService(new DefaultFormattingConversionService()); initializer.setConversionService(new DefaultFormattingConversionService());
BindingContext bindingContext = new BindingContext(initializer); BindingContext bindingContext = new BindingContext(initializer);
Foo foo = new Foo(); Foo foo = new Foo();
when(this.session.getAttribute("foo")).thenReturn(Optional.of(foo)); when(this.session.getAttribute("foo")).thenReturn(foo);
mono = this.resolver.resolveArgument(param, bindingContext, this.exchange); actual = (Optional<Object>) this.resolver.resolveArgument(param, bindingContext, this.exchange).block();
assertNotNull(mono.block()); assertNotNull(actual);
assertEquals(Optional.class, mono.block().getClass()); assertTrue(actual.isPresent());
Optional<?> optional = (Optional<?>) mono.block(); assertSame(foo, actual.get());
assertTrue(optional.isPresent());
assertSame(foo, optional.get());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册