提交 582e625f 编写于 作者: A Arjen Poutsma

Allow HandlerFunction to return Mono<ServerResponse>

This commit makes it possible for handler functions to return
asynchronous status codes and headers, by making HandlerFunction.handle
return a Mono<ServerResponse> instead of a ServerResponse. As a
consequence, all other types that deal with HandlerFunctions
(RouterFunction, HandlerFilterFunction, etc.) had to change as well.

However, when combining the above change with method references (a very
typical use case), resulting signatures would have been something like:

```
public Mono<ServerResponse<Mono<Person>>> getPerson(ServerRequest request)
```

which was too ugly to consider, especially the two uses of Mono. It was
considered to merge ServerResponse with the last Mono, essentialy making
ServerResponse always contain a Publisher, but this had unfortunate
consequences in view rendering.

It was therefore decided to drop the parameterization of ServerResponse,
as the only usage of the extra type information was to manipulate the
response objects in a filter. Even before the above change this was
suggested; it just made the change even more necessary.

As a consequence, `BodyInserter` could be turned into a real
`FunctionalInterface`, which resulted in changes in ClientRequest.

We did, however, make HandlerFunction.handle return a `Mono<? extends
ServerResponse>`, adding little complexity, but allowing for
future `ServerResponse` subtypes that do expose type information, if
it's needed. For instance, a RenderingResponse could expose the view
name and model.

Issue: SPR-14870
上级 4021d239
...@@ -21,11 +21,13 @@ import java.time.ZoneId; ...@@ -21,11 +21,13 @@ import java.time.ZoneId;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
...@@ -150,43 +152,54 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { ...@@ -150,43 +152,54 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
} }
@Override @Override
public ServerResponse<Void> build() { public Mono<ServerResponse> build() {
return body(BodyInserters.empty()); return build((exchange, handlerStrategies) -> exchange.getResponse().setComplete());
} }
@Override @Override
public <T extends Publisher<Void>> ServerResponse<T> build(T voidPublisher) { public Mono<ServerResponse> build(Publisher<Void> voidPublisher) {
Assert.notNull(voidPublisher, "'voidPublisher' must not be null"); Assert.notNull(voidPublisher, "'voidPublisher' must not be null");
return body(BodyInserter.of( return build((exchange, handlerStrategies) ->
(response, context) -> Flux.from(voidPublisher).thenEmpty(response.setComplete()), Mono.from(voidPublisher).then(exchange.getResponse().setComplete()));
() -> null));
} }
@Override @Override
public <T> ServerResponse<T> body(BodyInserter<T, ? super ServerHttpResponse> inserter) { public Mono<ServerResponse> build(
Assert.notNull(inserter, "'inserter' must not be null"); BiFunction<ServerWebExchange, HandlerStrategies, Mono<Void>> writeFunction) {
return new BodyInserterServerResponse<T>(this.statusCode, this.headers, inserter);
Assert.notNull(writeFunction, "'writeFunction' must not be null");
return Mono.just(new WriterFunctionServerResponse(this.statusCode, this.headers,
writeFunction));
} }
@Override @Override
public <S extends Publisher<T>, T> ServerResponse<S> body(S publisher, Class<T> elementClass) { public <T, P extends Publisher<T>> Mono<ServerResponse> body(P publisher,
Class<T> elementClass) {
return body(BodyInserters.fromPublisher(publisher, elementClass)); return body(BodyInserters.fromPublisher(publisher, elementClass));
} }
@Override @Override
public ServerResponse<Rendering> render(String name, Object... modelAttributes) { public <T> Mono<ServerResponse> body(BodyInserter<T, ? super ServerHttpResponse> inserter) {
Assert.notNull(inserter, "'inserter' must not be null");
return Mono
.just(new BodyInserterServerResponse<T>(this.statusCode, this.headers, inserter));
}
@Override
public Mono<ServerResponse> render(String name, Object... modelAttributes) {
Assert.hasLength(name, "'name' must not be empty"); Assert.hasLength(name, "'name' must not be empty");
return render(name, toModelMap(modelAttributes)); return render(name, toModelMap(modelAttributes));
} }
@Override @Override
public ServerResponse<Rendering> render(String name, Map<String, ?> model) { public Mono<ServerResponse> render(String name, Map<String, ?> model) {
Assert.hasLength(name, "'name' must not be empty"); Assert.hasLength(name, "'name' must not be empty");
Map<String, Object> modelMap = new LinkedHashMap<>(); Map<String, Object> modelMap = new LinkedHashMap<>();
if (model != null) { if (model != null) {
modelMap.putAll(model); modelMap.putAll(model);
} }
return new RenderingServerResponse(this.statusCode, this.headers, name, modelMap); return Mono
.just(new RenderingServerResponse(this.statusCode, this.headers, name, modelMap));
} }
private Map<String, Object> toModelMap(Object[] modelAttributes) { private Map<String, Object> toModelMap(Object[] modelAttributes) {
...@@ -199,7 +212,7 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { ...@@ -199,7 +212,7 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
} }
static abstract class AbstractServerResponse<T> implements ServerResponse<T> { private static abstract class AbstractServerResponse implements ServerResponse {
private final HttpStatus statusCode; private final HttpStatus statusCode;
...@@ -207,7 +220,13 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { ...@@ -207,7 +220,13 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
protected AbstractServerResponse(HttpStatus statusCode, HttpHeaders headers) { protected AbstractServerResponse(HttpStatus statusCode, HttpHeaders headers) {
this.statusCode = statusCode; this.statusCode = statusCode;
this.headers = HttpHeaders.readOnlyHttpHeaders(headers); this.headers = readOnlyCopy(headers);
}
private static HttpHeaders readOnlyCopy(HttpHeaders headers) {
HttpHeaders copy = new HttpHeaders();
copy.putAll(headers);
return HttpHeaders.readOnlyHttpHeaders(copy);
} }
@Override @Override
...@@ -233,8 +252,27 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { ...@@ -233,8 +252,27 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
} }
} }
private static final class WriterFunctionServerResponse extends AbstractServerResponse {
private final BiFunction<ServerWebExchange, HandlerStrategies, Mono<Void>> writeFunction;
public WriterFunctionServerResponse(HttpStatus statusCode,
HttpHeaders headers,
BiFunction<ServerWebExchange, HandlerStrategies, Mono<Void>> writeFunction) {
super(statusCode, headers);
this.writeFunction = writeFunction;
}
@Override
public Mono<Void> writeTo(ServerWebExchange exchange, HandlerStrategies strategies) {
writeStatusAndHeaders(exchange.getResponse());
return this.writeFunction.apply(exchange, strategies);
}
}
private static final class BodyInserterServerResponse<T> extends AbstractServerResponse<T> { private static final class BodyInserterServerResponse<T> extends AbstractServerResponse {
private final BodyInserter<T, ? super ServerHttpResponse> inserter; private final BodyInserter<T, ? super ServerHttpResponse> inserter;
...@@ -245,11 +283,6 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { ...@@ -245,11 +283,6 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
this.inserter = inserter; this.inserter = inserter;
} }
@Override
public T body() {
return this.inserter.t();
}
@Override @Override
public Mono<Void> writeTo(ServerWebExchange exchange, HandlerStrategies strategies) { public Mono<Void> writeTo(ServerWebExchange exchange, HandlerStrategies strategies) {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
...@@ -264,26 +297,19 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { ...@@ -264,26 +297,19 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
} }
private static final class RenderingServerResponse extends AbstractServerResponse<Rendering> { private static final class RenderingServerResponse extends AbstractServerResponse {
private final String name; private final String name;
private final Map<String, Object> model; private final Map<String, Object> model;
private final Rendering rendering;
public RenderingServerResponse(HttpStatus statusCode, HttpHeaders headers, String name, public RenderingServerResponse(HttpStatus statusCode, HttpHeaders headers, String name,
Map<String, Object> model) { Map<String, Object> model) {
super(statusCode, headers); super(statusCode, headers);
this.name = name; this.name = name;
this.model = model; this.model = Collections.unmodifiableMap(model);
this.rendering = new DefaultRendering();
}
@Override
public Rendering body() {
return this.rendering;
} }
@Override @Override
...@@ -301,18 +327,6 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { ...@@ -301,18 +327,6 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
.then(view -> view.render(this.model, contentType, exchange)); .then(view -> view.render(this.model, contentType, exchange));
} }
private class DefaultRendering implements Rendering {
@Override
public String name() {
return name;
}
@Override
public Map<String, Object> model() {
return model;
}
}
} }
} }
...@@ -18,6 +18,8 @@ package org.springframework.web.reactive.function; ...@@ -18,6 +18,8 @@ package org.springframework.web.reactive.function;
import java.util.function.Function; import java.util.function.Function;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.function.support.ServerRequestWrapper; import org.springframework.web.reactive.function.support.ServerRequestWrapper;
...@@ -31,7 +33,7 @@ import org.springframework.web.reactive.function.support.ServerRequestWrapper; ...@@ -31,7 +33,7 @@ import org.springframework.web.reactive.function.support.ServerRequestWrapper;
* @see RouterFunction#filter(HandlerFilterFunction) * @see RouterFunction#filter(HandlerFilterFunction)
*/ */
@FunctionalInterface @FunctionalInterface
public interface HandlerFilterFunction<T, R> { public interface HandlerFilterFunction<T extends ServerResponse, R extends ServerResponse> {
/** /**
* Apply this filter to the given handler function. The given * Apply this filter to the given handler function. The given
...@@ -44,7 +46,7 @@ public interface HandlerFilterFunction<T, R> { ...@@ -44,7 +46,7 @@ public interface HandlerFilterFunction<T, R> {
* @return the filtered response * @return the filtered response
* @see ServerRequestWrapper * @see ServerRequestWrapper
*/ */
ServerResponse<R> filter(ServerRequest request, HandlerFunction<T> next); Mono<R> filter(ServerRequest request, HandlerFunction<T> next);
/** /**
* Return a composed filter function that first applies this filter, and then applies the * Return a composed filter function that first applies this filter, and then applies the
...@@ -79,10 +81,10 @@ public interface HandlerFilterFunction<T, R> { ...@@ -79,10 +81,10 @@ public interface HandlerFilterFunction<T, R> {
* @return the filter adaptation of the request processor * @return the filter adaptation of the request processor
*/ */
static HandlerFilterFunction<?, ?> ofRequestProcessor(Function<ServerRequest, static HandlerFilterFunction<?, ?> ofRequestProcessor(Function<ServerRequest,
ServerRequest> requestProcessor) { Mono<ServerRequest>> requestProcessor) {
Assert.notNull(requestProcessor, "'requestProcessor' must not be null"); Assert.notNull(requestProcessor, "'requestProcessor' must not be null");
return (request, next) -> next.handle(requestProcessor.apply(request)); return (request, next) -> requestProcessor.apply(request).then(next::handle);
} }
/** /**
...@@ -91,11 +93,11 @@ public interface HandlerFilterFunction<T, R> { ...@@ -91,11 +93,11 @@ public interface HandlerFilterFunction<T, R> {
* @param responseProcessor the response processor * @param responseProcessor the response processor
* @return the filter adaptation of the request processor * @return the filter adaptation of the request processor
*/ */
static <T, R> HandlerFilterFunction<T, R> ofResponseProcessor(Function<ServerResponse<T>, static <T extends ServerResponse, R extends ServerResponse> HandlerFilterFunction<T, R> ofResponseProcessor(Function<T,
ServerResponse<R>> responseProcessor) { R> responseProcessor) {
Assert.notNull(responseProcessor, "'responseProcessor' must not be null"); Assert.notNull(responseProcessor, "'responseProcessor' must not be null");
return (request, next) -> responseProcessor.apply(next.handle(request)); return (request, next) -> next.handle(request).map(responseProcessor);
} }
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
package org.springframework.web.reactive.function; package org.springframework.web.reactive.function;
import reactor.core.publisher.Mono;
/** /**
* Represents a function that handles a {@linkplain ServerRequest request}. * Represents a function that handles a {@linkplain ServerRequest request}.
* *
...@@ -24,13 +26,13 @@ package org.springframework.web.reactive.function; ...@@ -24,13 +26,13 @@ package org.springframework.web.reactive.function;
* @since 5.0 * @since 5.0
*/ */
@FunctionalInterface @FunctionalInterface
public interface HandlerFunction<T> { public interface HandlerFunction<T extends ServerResponse> {
/** /**
* Handle the given request. * Handle the given request.
* @param request the request to handle * @param request the request to handle
* @return the response * @return the response
*/ */
ServerResponse<T> handle(ServerRequest request); Mono<T> handle(ServerRequest request);
} }
...@@ -19,9 +19,10 @@ package org.springframework.web.reactive.function; ...@@ -19,9 +19,10 @@ package org.springframework.web.reactive.function;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException; import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Optional;
import java.util.function.Function; import java.util.function.Function;
import reactor.core.publisher.Mono;
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.core.io.UrlResource; import org.springframework.core.io.UrlResource;
...@@ -37,7 +38,7 @@ import org.springframework.web.util.UriUtils; ...@@ -37,7 +38,7 @@ import org.springframework.web.util.UriUtils;
* @author Arjen Poutsma * @author Arjen Poutsma
* @since 5.0 * @since 5.0
*/ */
class PathResourceLookupFunction implements Function<ServerRequest, Optional<Resource>> { class PathResourceLookupFunction implements Function<ServerRequest, Mono<Resource>> {
private static final PathMatcher PATH_MATCHER = new AntPathMatcher(); private static final PathMatcher PATH_MATCHER = new AntPathMatcher();
...@@ -51,16 +52,16 @@ class PathResourceLookupFunction implements Function<ServerRequest, Optional<Res ...@@ -51,16 +52,16 @@ class PathResourceLookupFunction implements Function<ServerRequest, Optional<Res
} }
@Override @Override
public Optional<Resource> apply(ServerRequest request) { public Mono<Resource> apply(ServerRequest request) {
String path = processPath(request.path()); String path = processPath(request.path());
if (path.contains("%")) { if (path.contains("%")) {
path = UriUtils.decode(path, StandardCharsets.UTF_8); path = UriUtils.decode(path, StandardCharsets.UTF_8);
} }
if (!StringUtils.hasLength(path) || isInvalidPath(path)) { if (!StringUtils.hasLength(path) || isInvalidPath(path)) {
return Optional.empty(); return Mono.empty();
} }
if (!PATH_MATCHER.match(this.pattern, path)) { if (!PATH_MATCHER.match(this.pattern, path)) {
return Optional.empty(); return Mono.empty();
} }
else { else {
path = PATH_MATCHER.extractPathWithinPattern(this.pattern, path); path = PATH_MATCHER.extractPathWithinPattern(this.pattern, path);
...@@ -68,10 +69,10 @@ class PathResourceLookupFunction implements Function<ServerRequest, Optional<Res ...@@ -68,10 +69,10 @@ class PathResourceLookupFunction implements Function<ServerRequest, Optional<Res
try { try {
Resource resource = this.location.createRelative(path); Resource resource = this.location.createRelative(path);
if (resource.exists() && resource.isReadable() && isResourceUnderLocation(resource)) { if (resource.exists() && resource.isReadable() && isResourceUnderLocation(resource)) {
return Optional.of(resource); return Mono.just(resource);
} }
else { else {
return Optional.empty(); return Mono.empty();
} }
} }
catch (IOException ex) { catch (IOException ex) {
......
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.function;
import java.util.Map;
/**
* Represents a template rendering, based on a {@code String} name and a model {@code Map}.
*
* @author Arjen Poutsma
* @since 5.0
*/
public interface Rendering {
/**
* Return the name of the template to be rendered.
*/
String name();
/**
* Return the unmodifiable model map.
*/
Map<String, Object> model();
}
...@@ -25,6 +25,8 @@ import java.net.URL; ...@@ -25,6 +25,8 @@ import java.net.URL;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Set; import java.util.Set;
import reactor.core.publisher.Mono;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
...@@ -34,7 +36,7 @@ import org.springframework.http.codec.BodyInserters; ...@@ -34,7 +36,7 @@ import org.springframework.http.codec.BodyInserters;
* @author Arjen Poutsma * @author Arjen Poutsma
* @since 5.0 * @since 5.0
*/ */
class ResourceHandlerFunction implements HandlerFunction<Resource> { class ResourceHandlerFunction implements HandlerFunction<ServerResponse> {
private static final Set<HttpMethod> SUPPORTED_METHODS = private static final Set<HttpMethod> SUPPORTED_METHODS =
...@@ -48,7 +50,7 @@ class ResourceHandlerFunction implements HandlerFunction<Resource> { ...@@ -48,7 +50,7 @@ class ResourceHandlerFunction implements HandlerFunction<Resource> {
} }
@Override @Override
public ServerResponse<Resource> handle(ServerRequest request) { public Mono<ServerResponse> handle(ServerRequest request) {
switch (request.method()) { switch (request.method()) {
case GET: case GET:
return ServerResponse.ok() return ServerResponse.ok()
......
...@@ -18,6 +18,8 @@ package org.springframework.web.reactive.function; ...@@ -18,6 +18,8 @@ package org.springframework.web.reactive.function;
import java.util.Optional; import java.util.Optional;
import reactor.core.publisher.Mono;
/** /**
* Represents a function that routes to a {@linkplain HandlerFunction handler function}. * Represents a function that routes to a {@linkplain HandlerFunction handler function}.
* *
...@@ -27,30 +29,27 @@ import java.util.Optional; ...@@ -27,30 +29,27 @@ import java.util.Optional;
* @see RouterFunctions * @see RouterFunctions
*/ */
@FunctionalInterface @FunctionalInterface
public interface RouterFunction<T> { public interface RouterFunction<T extends ServerResponse> {
/** /**
* Return the {@linkplain HandlerFunction handler function} that matches the given request. * Return the {@linkplain HandlerFunction handler function} that matches the given request.
* @param request the request to route to * @param request the request to route to
* @return an {@code Optional} describing the {@code HandlerFunction} that matches this request, * @return an {@code Mono} describing the {@code HandlerFunction} that matches this request,
* or an empty {@code Optional} if there is no match * or an empty {@code Mono} if there is no match
*/ */
Optional<HandlerFunction<T>> route(ServerRequest request); Mono<HandlerFunction<T>> route(ServerRequest request);
/** /**
* Return a composed routing function that first invokes this function, * Return a composed routing function that first invokes this function,
* and then invokes the {@code other} function (of the same type {@code T}) if this route had * and then invokes the {@code other} function (of the same type {@code T}) if this route had
* {@linkplain Optional#empty() no result}. * {@linkplain Mono#empty() no result}.
* *
* @param other the function of type {@code T} to apply when this function has no result * @param other the function of type {@code T} to apply when this function has no result
* @return a composed function that first routes with this function and then the {@code other} function if this * @return a composed function that first routes with this function and then the {@code other} function if this
* function has no result * function has no result
*/ */
default RouterFunction<T> andSame(RouterFunction<T> other) { default RouterFunction<T> andSame(RouterFunction<T> other) {
return request -> { return request -> this.route(request).otherwiseIfEmpty(other.route(request));
Optional<HandlerFunction<T>> result = this.route(request);
return result.isPresent() ? result : other.route(request);
};
} }
/** /**
...@@ -63,12 +62,9 @@ public interface RouterFunction<T> { ...@@ -63,12 +62,9 @@ public interface RouterFunction<T> {
* function has no result * function has no result
*/ */
default RouterFunction<?> and(RouterFunction<?> other) { default RouterFunction<?> and(RouterFunction<?> other) {
return request -> { return request -> this.route(request)
Optional<HandlerFunction<Object>> result = this.route(request). .map(RouterFunctions::cast)
map(RouterFunctions::cast); .otherwiseIfEmpty(other.route(request).map(RouterFunctions::cast));
return result.isPresent() ? result : other.route(request)
.map(RouterFunctions::cast);
};
} }
/** /**
...@@ -83,7 +79,7 @@ public interface RouterFunction<T> { ...@@ -83,7 +79,7 @@ public interface RouterFunction<T> {
* created from {@code predicate} and {@code handlerFunction} if this * created from {@code predicate} and {@code handlerFunction} if this
* function has no result * function has no result
*/ */
default <S> RouterFunction<?> andRoute(RequestPredicate predicate, default <S extends ServerResponse> RouterFunction<?> andRoute(RequestPredicate predicate,
HandlerFunction<S> handlerFunction) { HandlerFunction<S> handlerFunction) {
return and(RouterFunctions.route(predicate, handlerFunction)); return and(RouterFunctions.route(predicate, handlerFunction));
} }
...@@ -96,7 +92,7 @@ public interface RouterFunction<T> { ...@@ -96,7 +92,7 @@ public interface RouterFunction<T> {
* @param <S> the filter return type * @param <S> the filter return type
* @return the filtered routing function * @return the filtered routing function
*/ */
default <S> RouterFunction<S> filter(HandlerFilterFunction<T, S> filterFunction) { default <S extends ServerResponse> RouterFunction<S> filter(HandlerFilterFunction<T, S> filterFunction) {
return request -> this.route(request).map(filterFunction::apply); return request -> this.route(request).map(filterFunction::apply);
} }
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package org.springframework.web.reactive.function; package org.springframework.web.reactive.function;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.function.Function; import java.util.function.Function;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
...@@ -63,7 +62,7 @@ public abstract class RouterFunctions { ...@@ -63,7 +62,7 @@ public abstract class RouterFunctions {
public static final String URI_TEMPLATE_VARIABLES_ATTRIBUTE = public static final String URI_TEMPLATE_VARIABLES_ATTRIBUTE =
RouterFunctions.class.getName() + ".uriTemplateVariables"; RouterFunctions.class.getName() + ".uriTemplateVariables";
private static final HandlerFunction<Void> NOT_FOUND_HANDLER = request -> ServerResponse.notFound().build(); private static final HandlerFunction<ServerResponse> NOT_FOUND_HANDLER = request -> ServerResponse.notFound().build();
/** /**
...@@ -75,11 +74,13 @@ public abstract class RouterFunctions { ...@@ -75,11 +74,13 @@ public abstract class RouterFunctions {
* {@code predicate} evaluates to {@code true} * {@code predicate} evaluates to {@code true}
* @see RequestPredicates * @see RequestPredicates
*/ */
public static <T> RouterFunction<T> route(RequestPredicate predicate, HandlerFunction<T> handlerFunction) { public static <T extends ServerResponse> RouterFunction<T> route(RequestPredicate predicate,
HandlerFunction<T> handlerFunction) {
Assert.notNull(predicate, "'predicate' must not be null"); Assert.notNull(predicate, "'predicate' must not be null");
Assert.notNull(handlerFunction, "'handlerFunction' must not be null"); Assert.notNull(handlerFunction, "'handlerFunction' must not be null");
return request -> predicate.test(request) ? Optional.of(handlerFunction) : Optional.empty(); return request -> predicate.test(request) ? Mono.just(handlerFunction) : Mono.empty();
} }
/** /**
...@@ -91,7 +92,9 @@ public abstract class RouterFunctions { ...@@ -91,7 +92,9 @@ public abstract class RouterFunctions {
* {@code predicate} evaluates to {@code true} * {@code predicate} evaluates to {@code true}
* @see RequestPredicates * @see RequestPredicates
*/ */
public static <T> RouterFunction<T> subroute(RequestPredicate predicate, RouterFunction<T> routerFunction) { public static <T extends ServerResponse> RouterFunction<T> subroute(RequestPredicate predicate,
RouterFunction<T> routerFunction) {
Assert.notNull(predicate, "'predicate' must not be null"); Assert.notNull(predicate, "'predicate' must not be null");
Assert.notNull(routerFunction, "'routerFunction' must not be null"); Assert.notNull(routerFunction, "'routerFunction' must not be null");
...@@ -101,7 +104,7 @@ public abstract class RouterFunctions { ...@@ -101,7 +104,7 @@ public abstract class RouterFunctions {
return routerFunction.route(subRequest); return routerFunction.route(subRequest);
} }
else { else {
return Optional.empty(); return Mono.empty();
} }
}; };
} }
...@@ -117,7 +120,7 @@ public abstract class RouterFunctions { ...@@ -117,7 +120,7 @@ public abstract class RouterFunctions {
* @param location the location directory relative to which resources should be resolved * @param location the location directory relative to which resources should be resolved
* @return a router function that routes to resources * @return a router function that routes to resources
*/ */
public static RouterFunction<Resource> resources(String pattern, Resource location) { public static RouterFunction<ServerResponse> resources(String pattern, Resource location) {
Assert.hasLength(pattern, "'pattern' must not be empty"); Assert.hasLength(pattern, "'pattern' must not be empty");
Assert.notNull(location, "'location' must not be null"); Assert.notNull(location, "'location' must not be null");
...@@ -131,12 +134,10 @@ public abstract class RouterFunctions { ...@@ -131,12 +134,10 @@ public abstract class RouterFunctions {
* @param lookupFunction the function to provide a {@link Resource} given the {@link ServerRequest} * @param lookupFunction the function to provide a {@link Resource} given the {@link ServerRequest}
* @return a router function that routes to resources * @return a router function that routes to resources
*/ */
public static RouterFunction<Resource> resources(Function<ServerRequest, Optional<Resource>> lookupFunction) { public static RouterFunction<ServerResponse> resources(Function<ServerRequest, Mono<Resource>> lookupFunction) {
Assert.notNull(lookupFunction, "'lookupFunction' must not be null"); Assert.notNull(lookupFunction, "'lookupFunction' must not be null");
// TODO: make lookupFunction return Mono<Resource> once SPR-14870 is resolved
return request -> lookupFunction.apply(request).map(ResourceHandlerFunction::new); return request -> lookupFunction.apply(request).map(ResourceHandlerFunction::new);
} }
/** /**
...@@ -190,9 +191,10 @@ public abstract class RouterFunctions { ...@@ -190,9 +191,10 @@ public abstract class RouterFunctions {
return new HttpWebHandlerAdapter(exchange -> { return new HttpWebHandlerAdapter(exchange -> {
ServerRequest request = new DefaultServerRequest(exchange, strategies); ServerRequest request = new DefaultServerRequest(exchange, strategies);
addAttributes(exchange, request); addAttributes(exchange, request);
HandlerFunction<?> handlerFunction = routerFunction.route(request).orElse(notFound()); return routerFunction.route(request)
ServerResponse<?> response = handlerFunction.handle(request); .defaultIfEmpty(notFound())
return response.writeTo(exchange, strategies); .then(handlerFunction -> handlerFunction.handle(request))
.then(response -> response.writeTo(exchange, strategies));
}); });
} }
...@@ -225,11 +227,13 @@ public abstract class RouterFunctions { ...@@ -225,11 +227,13 @@ public abstract class RouterFunctions {
Assert.notNull(routerFunction, "RouterFunction must not be null"); Assert.notNull(routerFunction, "RouterFunction must not be null");
Assert.notNull(strategies, "HandlerStrategies must not be null"); Assert.notNull(strategies, "HandlerStrategies must not be null");
return exchange -> { return new HandlerMapping() {
ServerRequest request = new DefaultServerRequest(exchange, strategies); @Override
addAttributes(exchange, request); public Mono<Object> getHandler(ServerWebExchange exchange) {
Optional<? extends HandlerFunction<?>> route = routerFunction.route(request); ServerRequest request = new DefaultServerRequest(exchange, strategies);
return Mono.justOrEmpty(route); addAttributes(exchange, request);
return routerFunction.route(request).map(handlerFunction -> (Object)handlerFunction);
}
}; };
} }
...@@ -240,12 +244,12 @@ public abstract class RouterFunctions { ...@@ -240,12 +244,12 @@ public abstract class RouterFunctions {
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static <T> HandlerFunction<T> notFound() { private static <T extends ServerResponse> HandlerFunction<T> notFound() {
return (HandlerFunction<T>) NOT_FOUND_HANDLER; return (HandlerFunction<T>) NOT_FOUND_HANDLER;
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
static <T> HandlerFunction<T> cast(HandlerFunction<?> handlerFunction) { static <T extends ServerResponse> HandlerFunction<T> cast(HandlerFunction<?> handlerFunction) {
return (HandlerFunction<T>) handlerFunction; return (HandlerFunction<T>) handlerFunction;
} }
......
...@@ -21,6 +21,7 @@ import java.time.ZonedDateTime; ...@@ -21,6 +21,7 @@ import java.time.ZonedDateTime;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.BiFunction;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
...@@ -42,9 +43,8 @@ import org.springframework.web.server.ServerWebExchange; ...@@ -42,9 +43,8 @@ import org.springframework.web.server.ServerWebExchange;
* *
* @author Arjen Poutsma * @author Arjen Poutsma
* @since 5.0 * @since 5.0
* @param <T> the type of the body that this response contains
*/ */
public interface ServerResponse<T> { public interface ServerResponse {
// Instance methods // Instance methods
...@@ -58,11 +58,6 @@ public interface ServerResponse<T> { ...@@ -58,11 +58,6 @@ public interface ServerResponse<T> {
*/ */
HttpHeaders headers(); HttpHeaders headers();
/**
* Return the body of this response.
*/
T body();
/** /**
* Writes this response to the given web exchange. * Writes this response to the given web exchange.
* *
...@@ -80,7 +75,7 @@ public interface ServerResponse<T> { ...@@ -80,7 +75,7 @@ public interface ServerResponse<T> {
* @param other the response to copy the status and headers from * @param other the response to copy the status and headers from
* @return the created builder * @return the created builder
*/ */
static BodyBuilder from(ServerResponse<?> other) { static BodyBuilder from(ServerResponse other) {
Assert.notNull(other, "'other' must not be null"); Assert.notNull(other, "'other' must not be null");
DefaultServerResponseBuilder builder = new DefaultServerResponseBuilder(other.statusCode()); DefaultServerResponseBuilder builder = new DefaultServerResponseBuilder(other.statusCode());
return builder.headers(other.headers()); return builder.headers(other.headers());
...@@ -270,7 +265,7 @@ public interface ServerResponse<T> { ...@@ -270,7 +265,7 @@ public interface ServerResponse<T> {
* *
* @return the built response * @return the built response
*/ */
ServerResponse<Void> build(); Mono<ServerResponse> build();
/** /**
* Build the response entity with no body. * Build the response entity with no body.
...@@ -279,7 +274,16 @@ public interface ServerResponse<T> { ...@@ -279,7 +274,16 @@ public interface ServerResponse<T> {
* @param voidPublisher publisher publisher to indicate when the response should be committed * @param voidPublisher publisher publisher to indicate when the response should be committed
* @return the built response * @return the built response
*/ */
<T extends Publisher<Void>> ServerResponse<T> build(T voidPublisher); Mono<ServerResponse> build(Publisher<Void> voidPublisher);
/**
* Build the response entity with a custom writer function.
*
* @param writeFunction the function used to write to the {@link ServerWebExchange}
* @return the built response
*/
Mono<ServerResponse> build(BiFunction<ServerWebExchange, HandlerStrategies,
Mono<Void>> writeFunction);
} }
...@@ -308,14 +312,6 @@ public interface ServerResponse<T> { ...@@ -308,14 +312,6 @@ public interface ServerResponse<T> {
*/ */
BodyBuilder contentType(MediaType contentType); BodyBuilder contentType(MediaType contentType);
/**
* Set the body of the response to the given {@code BodyInserter} and return it.
* @param inserter the {@code BodyInserter} that writes to the response
* @param <T> the type contained in the body
* @return the built response
*/
<T> ServerResponse<T> body(BodyInserter<T, ? super ServerHttpResponse> inserter);
/** /**
* Set the body of the response to the given {@code Publisher} and return it. This * Set the body of the response to the given {@code Publisher} and return it. This
* convenience method combines {@link #body(BodyInserter)} and * convenience method combines {@link #body(BodyInserter)} and
...@@ -323,10 +319,18 @@ public interface ServerResponse<T> { ...@@ -323,10 +319,18 @@ public interface ServerResponse<T> {
* @param publisher the {@code Publisher} to write to the response * @param publisher the {@code Publisher} to write to the response
* @param elementClass the class of elements contained in the publisher * @param elementClass the class of elements contained in the publisher
* @param <T> the type of the elements contained in the publisher * @param <T> the type of the elements contained in the publisher
* @param <S> the type of the {@code Publisher} * @param <P> the type of the {@code Publisher}
* @return the built request * @return the built request
*/ */
<S extends Publisher<T>, T> ServerResponse<S> body(S publisher, Class<T> elementClass); <T, P extends Publisher<T>> Mono<ServerResponse> body(P publisher, Class<T> elementClass);
/**
* Set the body of the response to the given {@code BodyInserter} and return it.
* @param inserter the {@code BodyInserter} that writes to the response
* @param <T> the type contained in the body
* @return the built response
*/
<T> Mono<ServerResponse> body(BodyInserter<T, ? super ServerHttpResponse> inserter);
/** /**
* Render the template with the given {@code name} using the given {@code modelAttributes}. * Render the template with the given {@code name} using the given {@code modelAttributes}.
...@@ -339,7 +343,7 @@ public interface ServerResponse<T> { ...@@ -339,7 +343,7 @@ public interface ServerResponse<T> {
* @param modelAttributes the modelAttributes used to render the template * @param modelAttributes the modelAttributes used to render the template
* @return the built response * @return the built response
*/ */
ServerResponse<Rendering> render(String name, Object... modelAttributes); Mono<ServerResponse> render(String name, Object... modelAttributes);
/** /**
* Render the template with the given {@code name} using the given {@code model}. * Render the template with the given {@code name} using the given {@code model}.
...@@ -347,7 +351,7 @@ public interface ServerResponse<T> { ...@@ -347,7 +351,7 @@ public interface ServerResponse<T> {
* @param model the model used to render the template * @param model the model used to render the template
* @return the built response * @return the built response
*/ */
ServerResponse<Rendering> render(String name, Map<String, ?> model); Mono<ServerResponse> render(String name, Map<String, ?> model);
} }
......
...@@ -26,7 +26,6 @@ import org.springframework.web.reactive.HandlerResult; ...@@ -26,7 +26,6 @@ import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.function.HandlerFunction; import org.springframework.web.reactive.function.HandlerFunction;
import org.springframework.web.reactive.function.RouterFunctions; import org.springframework.web.reactive.function.RouterFunctions;
import org.springframework.web.reactive.function.ServerRequest; import org.springframework.web.reactive.function.ServerRequest;
import org.springframework.web.reactive.function.ServerResponse;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
/** /**
...@@ -62,9 +61,7 @@ public class HandlerFunctionAdapter implements HandlerAdapter { ...@@ -62,9 +61,7 @@ public class HandlerFunctionAdapter implements HandlerAdapter {
.orElseThrow(() -> new IllegalStateException( .orElseThrow(() -> new IllegalStateException(
"Could not find ServerRequest in exchange attributes")); "Could not find ServerRequest in exchange attributes"));
ServerResponse<?> response = handlerFunction.handle(request); return handlerFunction.handle(request)
HandlerResult handlerResult = .map(response -> new HandlerResult(handlerFunction, response, HANDLER_FUNCTION_RETURN_TYPE));
new HandlerResult(handlerFunction, response, HANDLER_FUNCTION_RETURN_TYPE);
return Mono.just(handlerResult);
} }
} }
...@@ -21,8 +21,8 @@ import reactor.core.publisher.Mono; ...@@ -21,8 +21,8 @@ import reactor.core.publisher.Mono;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.HandlerResultHandler; import org.springframework.web.reactive.HandlerResultHandler;
import org.springframework.web.reactive.function.ServerResponse;
import org.springframework.web.reactive.function.HandlerStrategies; import org.springframework.web.reactive.function.HandlerStrategies;
import org.springframework.web.reactive.function.ServerResponse;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
/** /**
...@@ -59,7 +59,7 @@ public class ServerResponseResultHandler implements HandlerResultHandler { ...@@ -59,7 +59,7 @@ public class ServerResponseResultHandler implements HandlerResultHandler {
@Override @Override
public Mono<Void> handleResult(ServerWebExchange exchange, HandlerResult result) { public Mono<Void> handleResult(ServerWebExchange exchange, HandlerResult result) {
ServerResponse<?> response = (ServerResponse<?>) result.getReturnValue().orElseThrow( ServerResponse response = (ServerResponse) result.getReturnValue().orElseThrow(
IllegalStateException::new); IllegalStateException::new);
return response.writeTo(exchange, this.strategies); return response.writeTo(exchange, this.strategies);
} }
......
...@@ -17,44 +17,26 @@ ...@@ -17,44 +17,26 @@
package org.springframework.web.reactive.function; package org.springframework.web.reactive.function;
import java.net.URI; import java.net.URI;
import java.nio.ByteBuffer;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Set;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.CacheControl; import org.springframework.http.CacheControl;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.codec.BodyInserter;
import org.springframework.http.codec.EncoderHttpMessageWriter;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.web.reactive.result.view.View;
import org.springframework.web.reactive.result.view.ViewResolver;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
import org.springframework.web.server.session.MockWebSessionManager;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
...@@ -65,130 +47,203 @@ public class DefaultServerResponseBuilderTests { ...@@ -65,130 +47,203 @@ public class DefaultServerResponseBuilderTests {
@Test @Test
public void from() throws Exception { public void from() throws Exception {
ServerResponse<Void> other = ServerResponse.ok().header("foo", "bar").build(); ServerResponse other = ServerResponse.ok().header("foo", "bar").build().block();
ServerResponse<Void> result = ServerResponse.from(other).build(); Mono<ServerResponse> result = ServerResponse.from(other).build();
assertEquals(HttpStatus.OK, result.statusCode()); StepVerifier.create(result)
assertEquals("bar", result.headers().getFirst("foo")); .expectNextMatches(response -> HttpStatus.OK.equals(response.statusCode()) &&
"bar".equals(response.headers().getFirst("foo")))
.expectComplete()
.verify();
} }
@Test @Test
public void status() throws Exception { public void status() throws Exception {
ServerResponse<Void> result = ServerResponse.status(HttpStatus.CREATED).build(); Mono<ServerResponse> result = ServerResponse.status(HttpStatus.CREATED).build();
assertEquals(HttpStatus.CREATED, result.statusCode()); StepVerifier.create(result)
.expectNextMatches(response -> HttpStatus.CREATED.equals(response.statusCode()))
.expectComplete()
.verify();
} }
@Test @Test
public void ok() throws Exception { public void ok() throws Exception {
ServerResponse<Void> result = ServerResponse.ok().build(); Mono<ServerResponse> result = ServerResponse.ok().build();
assertEquals(HttpStatus.OK, result.statusCode()); StepVerifier.create(result)
.expectNextMatches(response -> HttpStatus.OK.equals(response.statusCode()))
.expectComplete()
.verify();
} }
@Test @Test
public void created() throws Exception { public void created() throws Exception {
URI location = URI.create("http://example.com"); URI location = URI.create("http://example.com");
ServerResponse<Void> result = ServerResponse.created(location).build(); Mono<ServerResponse> result = ServerResponse.created(location).build();
assertEquals(HttpStatus.CREATED, result.statusCode()); StepVerifier.create(result)
assertEquals(location, result.headers().getLocation()); .expectNextMatches(response -> HttpStatus.CREATED.equals(response.statusCode()) &&
location.equals(response.headers().getLocation()))
.expectComplete()
.verify();
} }
@Test @Test
public void accepted() throws Exception { public void accepted() throws Exception {
ServerResponse<Void> result = ServerResponse.accepted().build(); Mono<ServerResponse> result = ServerResponse.accepted().build();
assertEquals(HttpStatus.ACCEPTED, result.statusCode()); StepVerifier.create(result)
.expectNextMatches(response -> HttpStatus.ACCEPTED.equals(response.statusCode()))
.expectComplete()
.verify();
} }
@Test @Test
public void noContent() throws Exception { public void noContent() throws Exception {
ServerResponse<Void> result = ServerResponse.noContent().build(); Mono<ServerResponse> result = ServerResponse.noContent().build();
assertEquals(HttpStatus.NO_CONTENT, result.statusCode()); StepVerifier.create(result)
.expectNextMatches(response -> HttpStatus.NO_CONTENT.equals(response.statusCode()))
.expectComplete()
.verify();
} }
@Test @Test
public void badRequest() throws Exception { public void badRequest() throws Exception {
ServerResponse<Void> result = ServerResponse.badRequest().build(); Mono<ServerResponse> result = ServerResponse.badRequest().build();
assertEquals(HttpStatus.BAD_REQUEST, result.statusCode()); StepVerifier.create(result)
.expectNextMatches(response -> HttpStatus.BAD_REQUEST.equals(response.statusCode()))
.expectComplete()
.verify();
} }
@Test @Test
public void notFound() throws Exception { public void notFound() throws Exception {
ServerResponse<Void> result = ServerResponse.notFound().build(); Mono<ServerResponse> result = ServerResponse.notFound().build();
assertEquals(HttpStatus.NOT_FOUND, result.statusCode()); StepVerifier.create(result)
.expectNextMatches(response -> HttpStatus.NOT_FOUND.equals(response.statusCode()))
.expectComplete()
.verify();
} }
@Test @Test
public void unprocessableEntity() throws Exception { public void unprocessableEntity() throws Exception {
ServerResponse<Void> result = ServerResponse.unprocessableEntity().build(); Mono<ServerResponse> result = ServerResponse.unprocessableEntity().build();
assertEquals(HttpStatus.UNPROCESSABLE_ENTITY, result.statusCode()); StepVerifier.create(result)
.expectNextMatches(response -> HttpStatus.UNPROCESSABLE_ENTITY.equals(response.statusCode()))
.expectComplete()
.verify();
} }
@Test @Test
public void allow() throws Exception { public void allow() throws Exception {
ServerResponse<Void> result = ServerResponse.ok().allow(HttpMethod.GET).build(); Mono<ServerResponse> result = ServerResponse.ok().allow(HttpMethod.GET).build();
assertEquals(Collections.singleton(HttpMethod.GET), result.headers().getAllow()); Set<HttpMethod> expected = EnumSet.of(HttpMethod.GET);
StepVerifier.create(result)
.expectNextMatches(response -> expected.equals(response.headers().getAllow()))
.expectComplete()
.verify();
} }
@Test @Test
public void contentLength() throws Exception { public void contentLength() throws Exception {
ServerResponse<Void> result = ServerResponse.ok().contentLength(42).build(); Mono<ServerResponse> result = ServerResponse.ok().contentLength(42).build();
assertEquals(42, result.headers().getContentLength()); StepVerifier.create(result)
.expectNextMatches(response -> Long.valueOf(42).equals(response.headers().getContentLength()))
.expectComplete()
.verify();
} }
@Test @Test
public void contentType() throws Exception { public void contentType() throws Exception {
ServerResponse<Void> result = ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).build(); Mono<ServerResponse>
assertEquals(MediaType.APPLICATION_JSON, result.headers().getContentType()); result = ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).build();
StepVerifier.create(result)
.expectNextMatches(response -> MediaType.APPLICATION_JSON.equals(response.headers().getContentType()))
.expectComplete()
.verify();
} }
@Test @Test
public void eTag() throws Exception { public void eTag() throws Exception {
ServerResponse<Void> result = ServerResponse.ok().eTag("foo").build(); Mono<ServerResponse> result = ServerResponse.ok().eTag("foo").build();
assertEquals("\"foo\"", result.headers().getETag()); StepVerifier.create(result)
.expectNextMatches(response -> "\"foo\"".equals(response.headers().getETag()))
.expectComplete()
.verify();
} }
@Test @Test
public void lastModified() throws Exception { public void lastModified() throws Exception {
ZonedDateTime now = ZonedDateTime.now(); ZonedDateTime now = ZonedDateTime.now();
ServerResponse<Void> result = ServerResponse.ok().lastModified(now).build(); Mono<ServerResponse> result = ServerResponse.ok().lastModified(now).build();
assertEquals(now.toInstant().toEpochMilli()/1000, result.headers().getLastModified()/1000); Long expected = now.toInstant().toEpochMilli() / 1000;
StepVerifier.create(result)
.expectNextMatches(response -> expected.equals(response.headers().getLastModified() / 1000))
.expectComplete()
.verify();
} }
@Test @Test
public void cacheControlTag() throws Exception { public void cacheControlTag() throws Exception {
ServerResponse<Void> result = ServerResponse.ok().cacheControl(CacheControl.noCache()).build(); Mono<ServerResponse>
assertEquals("no-cache", result.headers().getCacheControl()); result = ServerResponse.ok().cacheControl(CacheControl.noCache()).build();
StepVerifier.create(result)
.expectNextMatches(response -> "no-cache".equals(response.headers().getCacheControl()))
.expectComplete()
.verify();
} }
@Test @Test
public void varyBy() throws Exception { public void varyBy() throws Exception {
ServerResponse<Void> result = ServerResponse.ok().varyBy("foo").build(); Mono<ServerResponse> result = ServerResponse.ok().varyBy("foo").build();
assertEquals(Collections.singletonList("foo"), result.headers().getVary()); List<String> expected = Collections.singletonList("foo");
StepVerifier.create(result)
.expectNextMatches(response -> expected.equals(response.headers().getVary()))
.expectComplete()
.verify();
} }
@Test @Test
public void statusCode() throws Exception { public void statusCode() throws Exception {
HttpStatus statusCode = HttpStatus.ACCEPTED; HttpStatus statusCode = HttpStatus.ACCEPTED;
ServerResponse<Void> result = ServerResponse.status(statusCode).build(); Mono<ServerResponse> result = ServerResponse.status(statusCode).build();
assertSame(statusCode, result.statusCode()); StepVerifier.create(result)
.expectNextMatches(response -> statusCode.equals(response.statusCode()))
.expectComplete()
.verify();
} }
@Test @Test
public void headers() throws Exception { public void headers() throws Exception {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
ServerResponse<Void> result = ServerResponse.ok().headers(headers).build(); Mono<ServerResponse> result = ServerResponse.ok().headers(headers).build();
assertEquals(headers, result.headers()); StepVerifier.create(result)
.expectNextMatches(response -> headers.equals(response.headers()))
.expectComplete()
.verify();
} }
@Test @Test
public void build() throws Exception { public void build() throws Exception {
ServerResponse<Void> result = ServerResponse.status(HttpStatus.CREATED).header("MyKey", "MyValue").build(); Mono<ServerResponse>
result = ServerResponse.status(HttpStatus.CREATED).header("MyKey", "MyValue").build();
ServerWebExchange exchange = mock(ServerWebExchange.class); ServerWebExchange exchange = mock(ServerWebExchange.class);
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse response = new MockServerHttpResponse();
when(exchange.getResponse()).thenReturn(response); when(exchange.getResponse()).thenReturn(response);
HandlerStrategies strategies = mock(HandlerStrategies.class); HandlerStrategies strategies = mock(HandlerStrategies.class);
result.writeTo(exchange, strategies).block(); result.then(res -> res.writeTo(exchange, strategies)).block();
assertEquals(201, response.getStatusCode().value());
assertEquals(HttpStatus.CREATED, response.getStatusCode());
assertEquals("MyValue", response.getHeaders().getFirst("MyKey")); assertEquals("MyValue", response.getHeaders().getFirst("MyKey"));
assertNull(response.getBody()); assertNull(response.getBody());
...@@ -197,21 +252,25 @@ public class DefaultServerResponseBuilderTests { ...@@ -197,21 +252,25 @@ public class DefaultServerResponseBuilderTests {
@Test @Test
public void buildVoidPublisher() throws Exception { public void buildVoidPublisher() throws Exception {
Mono<Void> mono = Mono.empty(); Mono<Void> mono = Mono.empty();
ServerResponse<Mono<Void>> result = ServerResponse.ok().build(mono); Mono<ServerResponse> result = ServerResponse.ok().build(mono);
ServerWebExchange exchange = mock(ServerWebExchange.class); ServerWebExchange exchange = mock(ServerWebExchange.class);
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse response = new MockServerHttpResponse();
when(exchange.getResponse()).thenReturn(response); when(exchange.getResponse()).thenReturn(response);
HandlerStrategies strategies = mock(HandlerStrategies.class); HandlerStrategies strategies = mock(HandlerStrategies.class);
result.writeTo(exchange, strategies).block(); result.then(res -> res.writeTo(exchange, strategies)).block();
assertNull(response.getBody()); assertNull(response.getBody());
} }
/*
TODO: enable when ServerEntityResponse is reintroduced
@Test @Test
public void bodyInserter() throws Exception { public void bodyInserter() throws Exception {
String body = "foo"; String body = "foo";
Supplier<String> supplier = () -> body; Publisher<String> publisher = Mono.just(body);
BiFunction<ServerHttpResponse, BodyInserter.Context, Mono<Void>> writer = BiFunction<ServerHttpResponse, BodyInserter.Context, Mono<Void>> writer =
(response, strategies) -> { (response, strategies) -> {
byte[] bodyBytes = body.getBytes(UTF_8); byte[] bodyBytes = body.getBytes(UTF_8);
...@@ -221,14 +280,13 @@ public class DefaultServerResponseBuilderTests { ...@@ -221,14 +280,13 @@ public class DefaultServerResponseBuilderTests {
return response.writeWith(Mono.just(buffer)); return response.writeWith(Mono.just(buffer));
}; };
ServerResponse<String> result = ServerResponse.ok().body(BodyInserter.of(writer, supplier)); Mono<ServerResponse> result = ServerResponse.ok().body(BodyInserter.of(writer, publisher));
assertEquals(body, result.body());
MockServerHttpRequest request = MockServerHttpRequest request =
new MockServerHttpRequest(HttpMethod.GET, "http://localhost"); new MockServerHttpRequest(HttpMethod.GET, "http://localhost");
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse mockResponse = new MockServerHttpResponse();
ServerWebExchange exchange = ServerWebExchange exchange =
new DefaultServerWebExchange(request, response, new MockWebSessionManager()); new DefaultServerWebExchange(request, mockResponse, new MockWebSessionManager());
List<HttpMessageWriter<?>> messageWriters = new ArrayList<>(); List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
messageWriters.add(new EncoderHttpMessageWriter<CharSequence>(new CharSequenceEncoder())); messageWriters.add(new EncoderHttpMessageWriter<CharSequence>(new CharSequenceEncoder()));
...@@ -236,21 +294,32 @@ public class DefaultServerResponseBuilderTests { ...@@ -236,21 +294,32 @@ public class DefaultServerResponseBuilderTests {
HandlerStrategies strategies = mock(HandlerStrategies.class); HandlerStrategies strategies = mock(HandlerStrategies.class);
when(strategies.messageWriters()).thenReturn(messageWriters::stream); when(strategies.messageWriters()).thenReturn(messageWriters::stream);
result.writeTo(exchange, strategies).block(); StepVerifier.create(result)
assertNotNull(response.getBody()); .consumeNextWith(response -> {
StepVerifier.create(response.body())
.expectNext(body)
.expectComplete()
.verify();
response.writeTo(exchange, strategies);
})
.expectComplete()
.verify();
assertNotNull(mockResponse.getBody());
} }
*/
/*
TODO: enable when ServerEntityResponse is reintroduced
@Test @Test
public void render() throws Exception { public void render() throws Exception {
Map<String, Object> model = Collections.singletonMap("foo", "bar"); Map<String, Object> model = Collections.singletonMap("foo", "bar");
ServerResponse<Rendering> result = ServerResponse.ok().render("view", model); Mono<ServerResponse> result = ServerResponse.ok().render("view", model);
assertEquals("view", result.body().name());
assertEquals(model, result.body().model());
MockServerHttpRequest request = new MockServerHttpRequest(HttpMethod.GET, URI.create("http://localhost")); MockServerHttpRequest request = new MockServerHttpRequest(HttpMethod.GET, URI.create("http://localhost"));
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse mockResponse = new MockServerHttpResponse();
ServerWebExchange exchange = new DefaultServerWebExchange(request, response, new MockWebSessionManager()); ServerWebExchange exchange = new DefaultServerWebExchange(request, mockResponse, new MockWebSessionManager());
ViewResolver viewResolver = mock(ViewResolver.class); ViewResolver viewResolver = mock(ViewResolver.class);
View view = mock(View.class); View view = mock(View.class);
when(viewResolver.resolveViewName("view", Locale.ENGLISH)).thenReturn(Mono.just(view)); when(viewResolver.resolveViewName("view", Locale.ENGLISH)).thenReturn(Mono.just(view));
...@@ -262,17 +331,37 @@ public class DefaultServerResponseBuilderTests { ...@@ -262,17 +331,37 @@ public class DefaultServerResponseBuilderTests {
HandlerStrategies mockConfig = mock(HandlerStrategies.class); HandlerStrategies mockConfig = mock(HandlerStrategies.class);
when(mockConfig.viewResolvers()).thenReturn(viewResolvers::stream); when(mockConfig.viewResolvers()).thenReturn(viewResolvers::stream);
result.writeTo(exchange, mockConfig).block(); StepVerifier.create(result)
.consumeNextWith(response -> {
StepVerifier.create(response.body())
.expectNextMatches(rendering -> "view".equals(rendering.name())
&& model.equals(rendering.model()))
.expectComplete()
.verify();
})
.expectComplete()
.verify();
} }
*/
/*
TODO: enable when ServerEntityResponse is reintroduced
@Test @Test
public void renderObjectArray() throws Exception { public void renderObjectArray() throws Exception {
ServerResponse<Rendering> result = Mono<ServerResponse> result =
ServerResponse.ok().render("name", this, Collections.emptyList(), "foo"); ServerResponse.ok().render("name", this, Collections.emptyList(), "foo");
Map<String, Object> model = result.body().model(); Flux<Rendering> map = result.flatMap(ServerResponse::body);
assertEquals(2, model.size());
assertEquals(this, model.get("defaultServerResponseBuilderTests")); Map<String, Object> expected = new HashMap<>(2);
assertEquals("foo", model.get("string")); expected.put("defaultServerResponseBuilderTests", this);
expected.put("string", "foo");
StepVerifier.create(map)
.expectNextMatches(rendering -> expected.equals(rendering.model()))
.expectComplete()
.verify();
} }
*/
} }
\ No newline at end of file
...@@ -16,14 +16,12 @@ ...@@ -16,14 +16,12 @@
package org.springframework.web.reactive.function; package org.springframework.web.reactive.function;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
...@@ -50,6 +48,7 @@ import org.springframework.web.reactive.result.view.ViewResolver; ...@@ -50,6 +48,7 @@ import org.springframework.web.reactive.result.view.ViewResolver;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.springframework.http.codec.BodyInserters.fromObject;
import static org.springframework.http.codec.BodyInserters.fromPublisher; import static org.springframework.http.codec.BodyInserters.fromPublisher;
import static org.springframework.web.reactive.function.RouterFunctions.route; import static org.springframework.web.reactive.function.RouterFunctions.route;
...@@ -84,7 +83,7 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr ...@@ -84,7 +83,7 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr
@Test @Test
public void mono() throws Exception { public void mono() throws Exception {
ResponseEntity<Person> result = ResponseEntity<Person> result =
restTemplate.getForEntity("http://localhost:" + port + "/mono", Person.class); this.restTemplate.getForEntity("http://localhost:" + this.port + "/mono", Person.class);
assertEquals(HttpStatus.OK, result.getStatusCode()); assertEquals(HttpStatus.OK, result.getStatusCode());
assertEquals("John", result.getBody().getName()); assertEquals("John", result.getBody().getName());
...@@ -94,7 +93,8 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr ...@@ -94,7 +93,8 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr
public void flux() throws Exception { public void flux() throws Exception {
ParameterizedTypeReference<List<Person>> reference = new ParameterizedTypeReference<List<Person>>() {}; ParameterizedTypeReference<List<Person>> reference = new ParameterizedTypeReference<List<Person>>() {};
ResponseEntity<List<Person>> result = ResponseEntity<List<Person>> result =
restTemplate.exchange("http://localhost:" + port + "/flux", HttpMethod.GET, null, reference); this.restTemplate
.exchange("http://localhost:" + this.port + "/flux", HttpMethod.GET, null, reference);
assertEquals(HttpStatus.OK, result.getStatusCode()); assertEquals(HttpStatus.OK, result.getStatusCode());
List<Person> body = result.getBody(); List<Person> body = result.getBody();
...@@ -134,7 +134,7 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr ...@@ -134,7 +134,7 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr
@Override @Override
public Supplier<Stream<ViewResolver>> viewResolvers() { public Supplier<Stream<ViewResolver>> viewResolvers() {
return () -> Collections.<ViewResolver>emptySet().stream(); return Stream::empty;
} }
}); });
} }
...@@ -154,18 +154,22 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr ...@@ -154,18 +154,22 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr
private static class PersonHandler { private static class PersonHandler {
public ServerResponse<Publisher<Person>> mono(ServerRequest request) { public Mono<ServerResponse> mono(ServerRequest request) {
Person person = new Person("John"); Person person = new Person("John");
return ServerResponse.ok().body(fromPublisher(Mono.just(person), Person.class)); return ServerResponse.ok().body(fromObject(person));
} }
public ServerResponse<Publisher<Person>> flux(ServerRequest request) { public Mono<ServerResponse> flux(ServerRequest request) {
Person person1 = new Person("John"); Person person1 = new Person("John");
Person person2 = new Person("Jane"); Person person2 = new Person("Jane");
return ServerResponse.ok().body( return ServerResponse.ok().body(
fromPublisher(Flux.just(person1, person2), Person.class)); fromPublisher(Flux.just(person1, person2), Person.class));
} }
public Mono<ServerResponse> view() {
return ServerResponse.ok().render("foo", "bar");
}
} }
private static class Person { private static class Person {
...@@ -181,7 +185,7 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr ...@@ -181,7 +185,7 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr
} }
public String getName() { public String getName() {
return name; return this.name;
} }
public void setName(String name) { public void setName(String name) {
...@@ -209,7 +213,7 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr ...@@ -209,7 +213,7 @@ public class DispatcherHandlerIntegrationTests extends AbstractHttpHandlerIntegr
@Override @Override
public String toString() { public String toString() {
return "Person{" + return "Person{" +
"name='" + name + '\'' + "name='" + this.name + '\'' +
'}'; '}';
} }
} }
......
...@@ -45,7 +45,7 @@ import org.springframework.util.MultiValueMap; ...@@ -45,7 +45,7 @@ import org.springframework.util.MultiValueMap;
/** /**
* @author Arjen Poutsma * @author Arjen Poutsma
*/ */
public class MockServerRequest<T> implements ServerRequest { public class MockServerRequest implements ServerRequest {
private final HttpMethod method; private final HttpMethod method;
...@@ -53,7 +53,7 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -53,7 +53,7 @@ public class MockServerRequest<T> implements ServerRequest {
private final MockHeaders headers; private final MockHeaders headers;
private final T body; private final Object body;
private final Map<String, Object> attributes; private final Map<String, Object> attributes;
...@@ -62,7 +62,7 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -62,7 +62,7 @@ public class MockServerRequest<T> implements ServerRequest {
private final Map<String, String> pathVariables; private final Map<String, String> pathVariables;
private MockServerRequest(HttpMethod method, URI uri, private MockServerRequest(HttpMethod method, URI uri,
MockHeaders headers, T body, Map<String, Object> attributes, MockHeaders headers, Object body, Map<String, Object> attributes,
MultiValueMap<String, String> queryParams, MultiValueMap<String, String> queryParams,
Map<String, String> pathVariables) { Map<String, String> pathVariables) {
this.method = method; this.method = method;
...@@ -74,8 +74,8 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -74,8 +74,8 @@ public class MockServerRequest<T> implements ServerRequest {
this.pathVariables = pathVariables; this.pathVariables = pathVariables;
} }
public static <T> Builder<T> builder() { public static Builder builder() {
return new BuilderImpl<T>(); return new BuilderImpl();
} }
@Override @Override
...@@ -127,35 +127,35 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -127,35 +127,35 @@ public class MockServerRequest<T> implements ServerRequest {
return Collections.unmodifiableMap(this.pathVariables); return Collections.unmodifiableMap(this.pathVariables);
} }
public interface Builder<T> { public interface Builder {
Builder<T> method(HttpMethod method); Builder method(HttpMethod method);
Builder<T> uri(URI uri); Builder uri(URI uri);
Builder<T> header(String key, String value); Builder header(String key, String value);
Builder<T> headers(HttpHeaders headers); Builder headers(HttpHeaders headers);
Builder<T> attribute(String name, Object value); Builder attribute(String name, Object value);
Builder<T> attributes(Map<String, Object> attributes); Builder attributes(Map<String, Object> attributes);
Builder<T> queryParam(String key, String value); Builder queryParam(String key, String value);
Builder<T> queryParams(MultiValueMap<String, String> queryParams); Builder queryParams(MultiValueMap<String, String> queryParams);
Builder<T> pathVariable(String key, String value); Builder pathVariable(String key, String value);
Builder<T> pathVariables(Map<String, String> pathVariables); Builder pathVariables(Map<String, String> pathVariables);
MockServerRequest<T> body(T body); MockServerRequest body(Object body);
MockServerRequest<Void> build(); MockServerRequest build();
} }
private static class BuilderImpl<T> implements Builder<T> { private static class BuilderImpl implements Builder {
private HttpMethod method = HttpMethod.GET; private HttpMethod method = HttpMethod.GET;
...@@ -163,7 +163,7 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -163,7 +163,7 @@ public class MockServerRequest<T> implements ServerRequest {
private MockHeaders headers = new MockHeaders(new HttpHeaders()); private MockHeaders headers = new MockHeaders(new HttpHeaders());
private T body; private Object body;
private Map<String, Object> attributes = new LinkedHashMap<>(); private Map<String, Object> attributes = new LinkedHashMap<>();
...@@ -172,21 +172,21 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -172,21 +172,21 @@ public class MockServerRequest<T> implements ServerRequest {
private Map<String, String> pathVariables = new LinkedHashMap<>(); private Map<String, String> pathVariables = new LinkedHashMap<>();
@Override @Override
public Builder<T> method(HttpMethod method) { public Builder method(HttpMethod method) {
Assert.notNull(method, "'method' must not be null"); Assert.notNull(method, "'method' must not be null");
this.method = method; this.method = method;
return this; return this;
} }
@Override @Override
public Builder<T> uri(URI uri) { public Builder uri(URI uri) {
Assert.notNull(uri, "'uri' must not be null"); Assert.notNull(uri, "'uri' must not be null");
this.uri = uri; this.uri = uri;
return this; return this;
} }
@Override @Override
public Builder<T> header(String key, String value) { public Builder header(String key, String value) {
Assert.notNull(key, "'key' must not be null"); Assert.notNull(key, "'key' must not be null");
Assert.notNull(value, "'value' must not be null"); Assert.notNull(value, "'value' must not be null");
this.headers.header(key, value); this.headers.header(key, value);
...@@ -194,14 +194,14 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -194,14 +194,14 @@ public class MockServerRequest<T> implements ServerRequest {
} }
@Override @Override
public Builder<T> headers(HttpHeaders headers) { public Builder headers(HttpHeaders headers) {
Assert.notNull(headers, "'headers' must not be null"); Assert.notNull(headers, "'headers' must not be null");
this.headers = new MockHeaders(headers); this.headers = new MockHeaders(headers);
return this; return this;
} }
@Override @Override
public Builder<T> attribute(String name, Object value) { public Builder attribute(String name, Object value) {
Assert.notNull(name, "'name' must not be null"); Assert.notNull(name, "'name' must not be null");
Assert.notNull(value, "'value' must not be null"); Assert.notNull(value, "'value' must not be null");
this.attributes.put(name, value); this.attributes.put(name, value);
...@@ -209,14 +209,14 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -209,14 +209,14 @@ public class MockServerRequest<T> implements ServerRequest {
} }
@Override @Override
public Builder<T> attributes(Map<String, Object> attributes) { public Builder attributes(Map<String, Object> attributes) {
Assert.notNull(attributes, "'attributes' must not be null"); Assert.notNull(attributes, "'attributes' must not be null");
this.attributes = attributes; this.attributes = attributes;
return this; return this;
} }
@Override @Override
public Builder<T> queryParam(String key, String value) { public Builder queryParam(String key, String value) {
Assert.notNull(key, "'key' must not be null"); Assert.notNull(key, "'key' must not be null");
Assert.notNull(value, "'value' must not be null"); Assert.notNull(value, "'value' must not be null");
this.queryParams.add(key, value); this.queryParams.add(key, value);
...@@ -224,14 +224,14 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -224,14 +224,14 @@ public class MockServerRequest<T> implements ServerRequest {
} }
@Override @Override
public Builder<T> queryParams(MultiValueMap<String, String> queryParams) { public Builder queryParams(MultiValueMap<String, String> queryParams) {
Assert.notNull(queryParams, "'queryParams' must not be null"); Assert.notNull(queryParams, "'queryParams' must not be null");
this.queryParams = queryParams; this.queryParams = queryParams;
return this; return this;
} }
@Override @Override
public Builder<T> pathVariable(String key, String value) { public Builder pathVariable(String key, String value) {
Assert.notNull(key, "'key' must not be null"); Assert.notNull(key, "'key' must not be null");
Assert.notNull(value, "'value' must not be null"); Assert.notNull(value, "'value' must not be null");
this.pathVariables.put(key, value); this.pathVariables.put(key, value);
...@@ -239,22 +239,22 @@ public class MockServerRequest<T> implements ServerRequest { ...@@ -239,22 +239,22 @@ public class MockServerRequest<T> implements ServerRequest {
} }
@Override @Override
public Builder<T> pathVariables(Map<String, String> pathVariables) { public Builder pathVariables(Map<String, String> pathVariables) {
Assert.notNull(pathVariables, "'pathVariables' must not be null"); Assert.notNull(pathVariables, "'pathVariables' must not be null");
this.pathVariables = pathVariables; this.pathVariables = pathVariables;
return this; return this;
} }
@Override @Override
public MockServerRequest<T> body(T body) { public MockServerRequest body(Object body) {
this.body = body; this.body = body;
return new MockServerRequest<T>(this.method, this.uri, this.headers, this.body, return new MockServerRequest(this.method, this.uri, this.headers, this.body,
this.attributes, this.queryParams, this.pathVariables); this.attributes, this.queryParams, this.pathVariables);
} }
@Override @Override
public MockServerRequest<Void> build() { public MockServerRequest build() {
return new MockServerRequest<Void>(this.method, this.uri, this.headers, null, return new MockServerRequest(this.method, this.uri, this.headers, null,
this.attributes, this.queryParams, this.pathVariables); this.attributes, this.queryParams, this.pathVariables);
} }
......
...@@ -16,18 +16,17 @@ ...@@ -16,18 +16,17 @@
package org.springframework.web.reactive.function; package org.springframework.web.reactive.function;
import java.io.File;
import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.Optional;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/** /**
* @author Arjen Poutsma * @author Arjen Poutsma
*/ */
...@@ -38,14 +37,23 @@ public class PathResourceLookupFunctionTests { ...@@ -38,14 +37,23 @@ public class PathResourceLookupFunctionTests {
ClassPathResource location = new ClassPathResource("org/springframework/web/reactive/function/"); ClassPathResource location = new ClassPathResource("org/springframework/web/reactive/function/");
PathResourceLookupFunction function = new PathResourceLookupFunction("/resources/**", location); PathResourceLookupFunction function = new PathResourceLookupFunction("/resources/**", location);
MockServerRequest<Void> request = MockServerRequest.builder() MockServerRequest request = MockServerRequest.builder()
.uri(new URI("http://localhost/resources/response.txt")) .uri(new URI("http://localhost/resources/response.txt"))
.build(); .build();
Optional<Resource> result = function.apply(request); Mono<Resource> result = function.apply(request);
assertTrue(result.isPresent());
ClassPathResource expected = new ClassPathResource("response.txt", getClass()); File expected = new ClassPathResource("response.txt", getClass()).getFile();
assertEquals(expected.getFile(), result.get().getFile()); StepVerifier.create(result)
.expectNextMatches(resource -> {
try {
return expected.equals(resource.getFile());
}
catch (IOException ex) {
return false;
}
})
.expectComplete()
.verify();
} }
@Test @Test
...@@ -53,14 +61,22 @@ public class PathResourceLookupFunctionTests { ...@@ -53,14 +61,22 @@ public class PathResourceLookupFunctionTests {
ClassPathResource location = new ClassPathResource("org/springframework/web/reactive/function/"); ClassPathResource location = new ClassPathResource("org/springframework/web/reactive/function/");
PathResourceLookupFunction function = new PathResourceLookupFunction("/resources/**", location); PathResourceLookupFunction function = new PathResourceLookupFunction("/resources/**", location);
MockServerRequest<Void> request = MockServerRequest.builder() MockServerRequest request = MockServerRequest.builder()
.uri(new URI("http://localhost/resources/child/response.txt")) .uri(new URI("http://localhost/resources/child/response.txt"))
.build(); .build();
Optional<Resource> result = function.apply(request); Mono<Resource> result = function.apply(request);
assertTrue(result.isPresent()); File expected = new ClassPathResource("org/springframework/web/reactive/function/child/response.txt").getFile();
StepVerifier.create(result)
ClassPathResource expected = new ClassPathResource("org/springframework/web/reactive/function/child/response.txt"); .expectNextMatches(resource -> {
assertEquals(expected.getFile(), result.get().getFile()); try {
return expected.equals(resource.getFile());
}
catch (IOException ex) {
return false;
}
})
.expectComplete()
.verify();
} }
@Test @Test
...@@ -68,11 +84,13 @@ public class PathResourceLookupFunctionTests { ...@@ -68,11 +84,13 @@ public class PathResourceLookupFunctionTests {
ClassPathResource location = new ClassPathResource("org/springframework/web/reactive/function/"); ClassPathResource location = new ClassPathResource("org/springframework/web/reactive/function/");
PathResourceLookupFunction function = new PathResourceLookupFunction("/resources/**", location); PathResourceLookupFunction function = new PathResourceLookupFunction("/resources/**", location);
MockServerRequest<Void> request = MockServerRequest.builder() MockServerRequest request = MockServerRequest.builder()
.uri(new URI("http://localhost/resources/foo")) .uri(new URI("http://localhost/resources/foo"))
.build(); .build();
Optional<Resource> result = function.apply(request); Mono<Resource> result = function.apply(request);
assertFalse(result.isPresent()); StepVerifier.create(result)
.expectComplete()
.verify();
} }
} }
\ No newline at end of file
...@@ -21,7 +21,6 @@ import java.util.List; ...@@ -21,7 +21,6 @@ import java.util.List;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
...@@ -97,17 +96,17 @@ public class PublisherHandlerFunctionIntegrationTests ...@@ -97,17 +96,17 @@ public class PublisherHandlerFunctionIntegrationTests
private static class PersonHandler { private static class PersonHandler {
public ServerResponse<Publisher<Person>> mono(ServerRequest request) { public Mono<ServerResponse> mono(ServerRequest request) {
Person person = new Person("John"); Person person = new Person("John");
return ServerResponse.ok().body(fromPublisher(Mono.just(person), Person.class)); return ServerResponse.ok().body(fromPublisher(Mono.just(person), Person.class));
} }
public ServerResponse<Publisher<Person>> postMono(ServerRequest request) { public Mono<ServerResponse> postMono(ServerRequest request) {
Mono<Person> personMono = request.body(toMono(Person.class)); Mono<Person> personMono = request.body(toMono(Person.class));
return ServerResponse.ok().body(fromPublisher(personMono, Person.class)); return ServerResponse.ok().body(fromPublisher(personMono, Person.class));
} }
public ServerResponse<Publisher<Person>> flux(ServerRequest request) { public Mono<ServerResponse> flux(ServerRequest request) {
Person person1 = new Person("John"); Person person1 = new Person("John");
Person person2 = new Person("Jane"); Person person2 = new Person("Jane");
return ServerResponse.ok().body( return ServerResponse.ok().body(
......
...@@ -57,18 +57,24 @@ public class ResourceHandlerFunctionTests { ...@@ -57,18 +57,24 @@ public class ResourceHandlerFunctionTests {
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults()); ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
ServerResponse<Resource> response = this.handlerFunction.handle(request); Mono<ServerResponse> responseMono = this.handlerFunction.handle(request);
assertEquals(HttpStatus.OK, response.statusCode());
assertEquals(this.resource, response.body());
Mono<Void> result = response.writeTo(exchange, HandlerStrategies.withDefaults()); Mono<Void> result = responseMono.then(response -> {
assertEquals(HttpStatus.OK, response.statusCode());
/*
TODO: enable when ServerEntityResponse is reintroduced
StepVerifier.create(response.body())
.expectNext(this.resource)
.expectComplete()
.verify();
*/
return response.writeTo(exchange, HandlerStrategies.withDefaults());
});
StepVerifier.create(result) StepVerifier.create(result)
.expectComplete() .expectComplete()
.verify(); .verify();
StepVerifier.create(result).expectComplete().verify();
byte[] expectedBytes = Files.readAllBytes(this.resource.getFile().toPath()); byte[] expectedBytes = Files.readAllBytes(this.resource.getFile().toPath());
StepVerifier.create(mockResponse.getBody()) StepVerifier.create(mockResponse.getBody())
...@@ -93,10 +99,12 @@ public class ResourceHandlerFunctionTests { ...@@ -93,10 +99,12 @@ public class ResourceHandlerFunctionTests {
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults()); ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
ServerResponse<Resource> response = this.handlerFunction.handle(request); Mono<ServerResponse> response = this.handlerFunction.handle(request);
assertEquals(HttpStatus.OK, response.statusCode());
Mono<Void> result = response.writeTo(exchange, HandlerStrategies.withDefaults()); Mono<Void> result = response.then(res -> {
assertEquals(HttpStatus.OK, res.statusCode());
return res.writeTo(exchange, HandlerStrategies.withDefaults());
});
StepVerifier.create(result) StepVerifier.create(result)
.expectComplete() .expectComplete()
...@@ -121,14 +129,20 @@ public class ResourceHandlerFunctionTests { ...@@ -121,14 +129,20 @@ public class ResourceHandlerFunctionTests {
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults()); ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
ServerResponse<Resource> response = this.handlerFunction.handle(request); Mono<ServerResponse> responseMono = this.handlerFunction.handle(request);
Mono<Void> result = responseMono.then(response -> {
assertEquals(HttpStatus.OK, response.statusCode()); assertEquals(HttpStatus.OK, response.statusCode());
assertEquals(EnumSet.of(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.OPTIONS), assertEquals(EnumSet.of(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.OPTIONS),
response.headers().getAllow()); response.headers().getAllow());
assertNull(response.body()); /*
TODO: enable when ServerEntityResponse is reintroduced
StepVerifier.create(response.body())
.expectComplete()
.verify();
*/
return response.writeTo(exchange, HandlerStrategies.withDefaults());
});
Mono<Void> result = response.writeTo(exchange, HandlerStrategies.withDefaults());
StepVerifier.create(result) StepVerifier.create(result)
.expectComplete() .expectComplete()
......
...@@ -16,13 +16,11 @@ ...@@ -16,13 +16,11 @@
package org.springframework.web.reactive.function; package org.springframework.web.reactive.function;
import java.util.Optional;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.springframework.http.codec.BodyInserters.fromObject; import static org.springframework.http.codec.BodyInserters.fromObject;
/** /**
...@@ -33,69 +31,96 @@ public class RouterFunctionTests { ...@@ -33,69 +31,96 @@ public class RouterFunctionTests {
@Test @Test
public void andSame() throws Exception { public void andSame() throws Exception {
HandlerFunction<Void> handlerFunction = request -> ServerResponse.ok().build(); HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();
RouterFunction<Void> routerFunction1 = request -> Optional.empty(); RouterFunction<ServerResponse> routerFunction1 = request -> Mono.empty();
RouterFunction<Void> routerFunction2 = request -> Optional.of(handlerFunction); RouterFunction<ServerResponse> routerFunction2 = request -> Mono.just(handlerFunction);
RouterFunction<Void> result = routerFunction1.andSame(routerFunction2); RouterFunction<ServerResponse> result = routerFunction1.andSame(routerFunction2);
assertNotNull(result); assertNotNull(result);
MockServerRequest request = MockServerRequest.builder().build(); MockServerRequest request = MockServerRequest.builder().build();
Optional<HandlerFunction<Void>> resultHandlerFunction = result.route(request); Mono<HandlerFunction<ServerResponse>> resultHandlerFunction = result.route(request);
assertTrue(resultHandlerFunction.isPresent());
assertEquals(handlerFunction, resultHandlerFunction.get()); StepVerifier.create(resultHandlerFunction)
.expectNext(handlerFunction)
.expectComplete()
.verify();
} }
@Test @Test
public void and() throws Exception { public void and() throws Exception {
HandlerFunction<String> handlerFunction = request -> ServerResponse.ok().body(fromObject("42")); HandlerFunction<ServerResponse> handlerFunction =
RouterFunction<Void> routerFunction1 = request -> Optional.empty(); request -> ServerResponse.ok().body(fromObject("42"));
RouterFunction<String> routerFunction2 = request -> Optional.of(handlerFunction); RouterFunction<?> routerFunction1 = request -> Mono.empty();
RouterFunction<ServerResponse> routerFunction2 =
request -> Mono.just(handlerFunction);
RouterFunction<?> result = routerFunction1.and(routerFunction2); RouterFunction<?> result = routerFunction1.and(routerFunction2);
assertNotNull(result); assertNotNull(result);
MockServerRequest request = MockServerRequest.builder().build(); MockServerRequest request = MockServerRequest.builder().build();
Optional<? extends HandlerFunction<?>> resultHandlerFunction = result.route(request); Mono<? extends HandlerFunction<?>> resultHandlerFunction = result.route(request);
assertTrue(resultHandlerFunction.isPresent());
assertEquals(handlerFunction, resultHandlerFunction.get()); StepVerifier.create(resultHandlerFunction)
.expectNextMatches(o -> o.equals(handlerFunction))
.expectComplete()
.verify();
} }
@Test @Test
public void andRoute() throws Exception { public void andRoute() throws Exception {
RouterFunction<Integer> routerFunction1 = request -> Optional.empty(); RouterFunction<?> routerFunction1 = request -> Mono.empty();
RequestPredicate requestPredicate = request -> true; RequestPredicate requestPredicate = request -> true;
RouterFunction<?> result = routerFunction1.andRoute(requestPredicate, this::handlerMethod); RouterFunction<?> result = routerFunction1.andRoute(requestPredicate, this::handlerMethod);
assertNotNull(result); assertNotNull(result);
MockServerRequest request = MockServerRequest.builder().build(); MockServerRequest request = MockServerRequest.builder().build();
Optional<? extends HandlerFunction<?>> resultHandlerFunction = result.route(request); Mono<? extends HandlerFunction<?>> resultHandlerFunction = result.route(request);
assertTrue(resultHandlerFunction.isPresent());
StepVerifier.create(resultHandlerFunction)
.expectNextCount(1)
.expectComplete()
.verify();
} }
private ServerResponse<String> handlerMethod(ServerRequest request) { private Mono<ServerResponse> handlerMethod(ServerRequest request) {
return ServerResponse.ok().body(fromObject("42")); return ServerResponse.ok().body(fromObject("42"));
} }
/*
TODO: enable when ServerEntityResponse is reintroduced
@Test @Test
public void filter() throws Exception { public void filter() throws Exception {
HandlerFunction<String> handlerFunction = request -> ServerResponse.ok().body(fromObject("42")); HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().body(fromObject("42"));
RouterFunction<String> routerFunction = request -> Optional.of(handlerFunction); RouterFunction<ServerResponse> routerFunction = request -> Mono.just(handlerFunction);
HandlerFilterFunction<String, Integer> filterFunction = (request, next) -> { HandlerFilterFunction<String, Integer> filterFunction =
ServerResponse<String> response = next.handle(request); (request, next) -> next.handle(request).then(
int i = Integer.parseInt(response.body()); response -> {
return ServerResponse.ok().body(fromObject(i)); Flux<Integer> body = Flux.from(response.body())
}; .map(Integer::parseInt);
return ServerResponse.ok().body(body, Integer.class);
});
RouterFunction<Integer> result = routerFunction.filter(filterFunction); RouterFunction<Integer> result = routerFunction.filter(filterFunction);
assertNotNull(result); assertNotNull(result);
MockServerRequest request = MockServerRequest.builder().build(); MockServerRequest request = MockServerRequest.builder().build();
Optional<? extends HandlerFunction<?>> resultHandlerFunction = result.route(request); Mono<? extends ServerResponse<Integer>> responseMono =
assertTrue(resultHandlerFunction.isPresent()); result.route(request).then(hf -> hf.handle(request));
ServerResponse<?> resultResponse = resultHandlerFunction.get().handle(request);
assertEquals(42, resultResponse.body()); StepVerifier.create(responseMono)
.consumeNextWith(
serverResponse -> {
StepVerifier.create(serverResponse.body())
.expectNext(42)
.expectComplete()
.verify();
})
.expectComplete()
.verify();
} }
*/
} }
\ No newline at end of file
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
package org.springframework.web.reactive.function; package org.springframework.web.reactive.function;
import java.util.Collections; import java.util.stream.Stream;
import java.util.Optional;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageReader;
...@@ -31,8 +31,11 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse ...@@ -31,8 +31,11 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse
import org.springframework.web.reactive.result.view.ViewResolver; import org.springframework.web.reactive.result.view.ViewResolver;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import static org.junit.Assert.*; import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/** /**
* @author Arjen Poutsma * @author Arjen Poutsma
...@@ -42,87 +45,97 @@ public class RouterFunctionsTests { ...@@ -42,87 +45,97 @@ public class RouterFunctionsTests {
@Test @Test
public void routeMatch() throws Exception { public void routeMatch() throws Exception {
HandlerFunction<Void> handlerFunction = request -> ServerResponse.ok().build(); HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();
MockServerRequest request = MockServerRequest.builder().build(); MockServerRequest request = MockServerRequest.builder().build();
RequestPredicate requestPredicate = mock(RequestPredicate.class); RequestPredicate requestPredicate = mock(RequestPredicate.class);
when(requestPredicate.test(request)).thenReturn(true); when(requestPredicate.test(request)).thenReturn(true);
RouterFunction<Void> result = RouterFunctions.route(requestPredicate, handlerFunction); RouterFunction<ServerResponse> result = RouterFunctions.route(requestPredicate, handlerFunction);
assertNotNull(result); assertNotNull(result);
Optional<HandlerFunction<Void>> resultHandlerFunction = result.route(request); Mono<HandlerFunction<ServerResponse>> resultHandlerFunction = result.route(request);
assertTrue(resultHandlerFunction.isPresent());
assertEquals(handlerFunction, resultHandlerFunction.get()); StepVerifier.create(resultHandlerFunction)
.expectNext(handlerFunction)
.expectComplete()
.verify();
} }
@Test @Test
public void routeNoMatch() throws Exception { public void routeNoMatch() throws Exception {
HandlerFunction<Void> handlerFunction = request -> ServerResponse.ok().build(); HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();
MockServerRequest request = MockServerRequest.builder().build(); MockServerRequest request = MockServerRequest.builder().build();
RequestPredicate requestPredicate = mock(RequestPredicate.class); RequestPredicate requestPredicate = mock(RequestPredicate.class);
when(requestPredicate.test(request)).thenReturn(false); when(requestPredicate.test(request)).thenReturn(false);
RouterFunction<Void> result = RouterFunctions.route(requestPredicate, handlerFunction); RouterFunction<ServerResponse> result = RouterFunctions.route(requestPredicate, handlerFunction);
assertNotNull(result); assertNotNull(result);
Optional<HandlerFunction<Void>> resultHandlerFunction = result.route(request); Mono<HandlerFunction<ServerResponse>> resultHandlerFunction = result.route(request);
assertFalse(resultHandlerFunction.isPresent()); StepVerifier.create(resultHandlerFunction)
.expectComplete()
.verify();
} }
@Test @Test
public void subrouteMatch() throws Exception { public void subrouteMatch() throws Exception {
HandlerFunction<Void> handlerFunction = request -> ServerResponse.ok().build(); HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();
RouterFunction<Void> routerFunction = request -> Optional.of(handlerFunction); RouterFunction<ServerResponse> routerFunction = request -> Mono.just(handlerFunction);
MockServerRequest request = MockServerRequest.builder().build(); MockServerRequest request = MockServerRequest.builder().build();
RequestPredicate requestPredicate = mock(RequestPredicate.class); RequestPredicate requestPredicate = mock(RequestPredicate.class);
when(requestPredicate.test(request)).thenReturn(true); when(requestPredicate.test(request)).thenReturn(true);
RouterFunction<Void> result = RouterFunctions.subroute(requestPredicate, routerFunction); RouterFunction<ServerResponse> result = RouterFunctions.subroute(requestPredicate, routerFunction);
assertNotNull(result); assertNotNull(result);
Optional<HandlerFunction<Void>> resultHandlerFunction = result.route(request); Mono<HandlerFunction<ServerResponse>> resultHandlerFunction = result.route(request);
assertTrue(resultHandlerFunction.isPresent()); StepVerifier.create(resultHandlerFunction)
assertEquals(handlerFunction, resultHandlerFunction.get()); .expectNext(handlerFunction)
.expectComplete()
.verify();
} }
@Test @Test
public void subrouteNoMatch() throws Exception { public void subrouteNoMatch() throws Exception {
HandlerFunction<Void> handlerFunction = request -> ServerResponse.ok().build(); HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();
RouterFunction<Void> routerFunction = request -> Optional.of(handlerFunction); RouterFunction<ServerResponse> routerFunction = request -> Mono.just(handlerFunction);
MockServerRequest request = MockServerRequest.builder().build(); MockServerRequest request = MockServerRequest.builder().build();
RequestPredicate requestPredicate = mock(RequestPredicate.class); RequestPredicate requestPredicate = mock(RequestPredicate.class);
when(requestPredicate.test(request)).thenReturn(false); when(requestPredicate.test(request)).thenReturn(false);
RouterFunction<Void> result = RouterFunctions.subroute(requestPredicate, routerFunction); RouterFunction<ServerResponse> result = RouterFunctions.subroute(requestPredicate, routerFunction);
assertNotNull(result); assertNotNull(result);
Optional<HandlerFunction<Void>> resultHandlerFunction = result.route(request); Mono<HandlerFunction<ServerResponse>> resultHandlerFunction = result.route(request);
assertFalse(resultHandlerFunction.isPresent()); StepVerifier.create(resultHandlerFunction)
.expectComplete()
.verify();
} }
@Test @Test
public void toHttpHandler() throws Exception { public void toHttpHandler() throws Exception {
HandlerStrategies strategies = mock(HandlerStrategies.class); HandlerStrategies strategies = mock(HandlerStrategies.class);
when(strategies.messageReaders()).thenReturn( when(strategies.messageReaders()).thenReturn(
() -> Collections.<HttpMessageReader<?>>emptyList().stream()); Stream::<HttpMessageReader<?>>empty);
when(strategies.messageWriters()).thenReturn( when(strategies.messageWriters()).thenReturn(
() -> Collections.<HttpMessageWriter<?>>emptyList().stream()); Stream::<HttpMessageWriter<?>>empty);
when(strategies.viewResolvers()).thenReturn( when(strategies.viewResolvers()).thenReturn(
() -> Collections.<ViewResolver>emptyList().stream()); Stream::<ViewResolver>empty);
ServerRequest request = mock(ServerRequest.class); ServerRequest request = mock(ServerRequest.class);
ServerResponse response = mock(ServerResponse.class); ServerResponse response = mock(ServerResponse.class);
when(response.writeTo(any(ServerWebExchange.class), eq(strategies))).thenReturn(Mono.empty()); when(response.writeTo(any(ServerWebExchange.class), eq(strategies))).thenReturn(Mono.empty());
HandlerFunction handlerFunction = mock(HandlerFunction.class); HandlerFunction<ServerResponse> handlerFunction = mock(HandlerFunction.class);
when(handlerFunction.handle(any(ServerRequest.class))).thenReturn(response); when(handlerFunction.handle(any(ServerRequest.class))).thenReturn(Mono.just(response));
RouterFunction routerFunction = mock(RouterFunction.class); RouterFunction<ServerResponse> routerFunction = mock(RouterFunction.class);
when(routerFunction.route(any(ServerRequest.class))).thenReturn(Optional.of(handlerFunction)); when(routerFunction.route(any(ServerRequest.class))).thenReturn(Mono.just(handlerFunction));
RequestPredicate requestPredicate = mock(RequestPredicate.class); RequestPredicate requestPredicate = mock(RequestPredicate.class);
when(requestPredicate.test(request)).thenReturn(false); when(requestPredicate.test(request)).thenReturn(false);
......
...@@ -20,7 +20,6 @@ import java.time.Duration; ...@@ -20,7 +20,6 @@ import java.time.Duration;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
...@@ -127,18 +126,18 @@ public class SseHandlerFunctionIntegrationTests ...@@ -127,18 +126,18 @@ public class SseHandlerFunctionIntegrationTests
private static class SseHandler { private static class SseHandler {
public ServerResponse<Publisher<String>> string(ServerRequest request) { public Mono<ServerResponse> string(ServerRequest request) {
Flux<String> flux = Flux.interval(Duration.ofMillis(100)).map(l -> "foo " + l).take(2); Flux<String> flux = Flux.interval(Duration.ofMillis(100)).map(l -> "foo " + l).take(2);
return ServerResponse.ok().body(fromServerSentEvents(flux, String.class)); return ServerResponse.ok().body(fromServerSentEvents(flux, String.class));
} }
public ServerResponse<Publisher<Person>> person(ServerRequest request) { public Mono<ServerResponse> person(ServerRequest request) {
Flux<Person> flux = Flux.interval(Duration.ofMillis(100)) Flux<Person> flux = Flux.interval(Duration.ofMillis(100))
.map(l -> new Person("foo " + l)).take(2); .map(l -> new Person("foo " + l)).take(2);
return ServerResponse.ok().body(fromServerSentEvents(flux, Person.class)); return ServerResponse.ok().body(fromServerSentEvents(flux, Person.class));
} }
public ServerResponse<Publisher<ServerSentEvent<String>>> sse(ServerRequest request) { public Mono<ServerResponse> sse(ServerRequest request) {
Flux<ServerSentEvent<String>> flux = Flux.interval(Duration.ofMillis(100)) Flux<ServerSentEvent<String>> flux = Flux.interval(Duration.ofMillis(100))
.map(l -> ServerSentEvent.<String>builder().data("foo") .map(l -> ServerSentEvent.<String>builder().data("foo")
.id(Long.toString(l)) .id(Long.toString(l))
......
...@@ -33,7 +33,7 @@ import org.springframework.http.ReactiveHttpInputMessage; ...@@ -33,7 +33,7 @@ import org.springframework.http.ReactiveHttpInputMessage;
public interface BodyExtractor<T, M extends ReactiveHttpInputMessage> { public interface BodyExtractor<T, M extends ReactiveHttpInputMessage> {
/** /**
* Extract from the given request. * Extract from the given input message.
* @param inputMessage request to extract from * @param inputMessage request to extract from
* @param context the configuration to use * @param context the configuration to use
* @return the extracted data * @return the extracted data
......
...@@ -16,14 +16,12 @@ ...@@ -16,14 +16,12 @@
package org.springframework.http.codec; package org.springframework.http.codec;
import java.util.function.BiFunction;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Stream; import java.util.stream.Stream;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.http.ReactiveHttpOutputMessage; import org.springframework.http.ReactiveHttpOutputMessage;
import org.springframework.util.Assert;
/** /**
* A combination of functions that can populate a {@link ReactiveHttpOutputMessage} body. * A combination of functions that can populate a {@link ReactiveHttpOutputMessage} body.
...@@ -32,39 +30,17 @@ import org.springframework.util.Assert; ...@@ -32,39 +30,17 @@ import org.springframework.util.Assert;
* @since 5.0 * @since 5.0
* @see BodyInserters * @see BodyInserters
*/ */
@FunctionalInterface
public interface BodyInserter<T, M extends ReactiveHttpOutputMessage> { public interface BodyInserter<T, M extends ReactiveHttpOutputMessage> {
/** /**
* Insert into the given response. * Insert into the given output message.
* @param outputMessage the response to insert into * @param outputMessage the response to insert into
* @param context the context to use * @param context the context to use
* @return a {@code Mono} that indicates completion or error * @return a {@code Mono} that indicates completion or error
*/ */
Mono<Void> insert(M outputMessage, Context context); Mono<Void> insert(M outputMessage, Context context);
/**
* Return the type contained in the body.
* @return the type contained in the body
*/
T t();
/**
* Return a new {@code BodyInserter} described by the given writer and supplier functions.
* @param writer the writer function for the new inserter
* @param supplier the supplier function for the new inserter
* @param <T> the type supplied and written by the inserter
* @return the new {@code BodyInserter}
*/
static <T, M extends ReactiveHttpOutputMessage> BodyInserter<T, M> of(
BiFunction<M, Context, Mono<Void>> writer,
Supplier<T> supplier) {
Assert.notNull(writer, "'writer' must not be null");
Assert.notNull(supplier, "'supplier' must not be null");
return new BodyInserters.DefaultBodyInserter<T, M>(writer, supplier);
}
/** /**
* Defines the context used during the insertion. * Defines the context used during the insertion.
*/ */
......
...@@ -18,7 +18,6 @@ package org.springframework.http.codec; ...@@ -18,7 +18,6 @@ package org.springframework.http.codec;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
...@@ -48,14 +47,17 @@ public abstract class BodyInserters { ...@@ -48,14 +47,17 @@ public abstract class BodyInserters {
private static final ResolvableType SERVER_SIDE_EVENT_TYPE = private static final ResolvableType SERVER_SIDE_EVENT_TYPE =
ResolvableType.forClass(ServerSentEvent.class); ResolvableType.forClass(ServerSentEvent.class);
private static final BodyInserter<Void, ReactiveHttpOutputMessage> EMPTY =
(response, context) -> response.setComplete();
/** /**
* Return an empty {@code BodyInserter} that writes nothing. * Return an empty {@code BodyInserter} that writes nothing.
* @return an empty {@code BodyInserter} * @return an empty {@code BodyInserter}
*/ */
@SuppressWarnings("unchecked")
public static <T> BodyInserter<T, ReactiveHttpOutputMessage> empty() { public static <T> BodyInserter<T, ReactiveHttpOutputMessage> empty() {
return BodyInserter.of( return (BodyInserter<T, ReactiveHttpOutputMessage>)EMPTY;
(response, context) -> response.setComplete(),
() -> null);
} }
/** /**
...@@ -65,9 +67,7 @@ public abstract class BodyInserters { ...@@ -65,9 +67,7 @@ public abstract class BodyInserters {
*/ */
public static <T> BodyInserter<T, ReactiveHttpOutputMessage> fromObject(T body) { public static <T> BodyInserter<T, ReactiveHttpOutputMessage> fromObject(T body) {
Assert.notNull(body, "'body' must not be null"); Assert.notNull(body, "'body' must not be null");
return BodyInserter.of( return bodyInserterFor(Mono.just(body), ResolvableType.forInstance(body));
writeFunctionFor(Mono.just(body), ResolvableType.forInstance(body)),
() -> body);
} }
/** /**
...@@ -75,18 +75,15 @@ public abstract class BodyInserters { ...@@ -75,18 +75,15 @@ public abstract class BodyInserters {
* @param publisher the publisher to stream to the response body * @param publisher the publisher to stream to the response body
* @param elementClass the class of elements contained in the publisher * @param elementClass the class of elements contained in the publisher
* @param <T> the type of the elements contained in the publisher * @param <T> the type of the elements contained in the publisher
* @param <S> the type of the {@code Publisher} * @param <P> the type of the {@code Publisher}
* @return a {@code BodyInserter} that writes a {@code Publisher} * @return a {@code BodyInserter} that writes a {@code Publisher}
*/ */
public static <S extends Publisher<T>, T> BodyInserter<S, ReactiveHttpOutputMessage> fromPublisher(S publisher, public static <T, P extends Publisher<T>> BodyInserter<P, ReactiveHttpOutputMessage> fromPublisher(P publisher,
Class<T> elementClass) { Class<T> elementClass) {
Assert.notNull(publisher, "'publisher' must not be null"); Assert.notNull(publisher, "'publisher' must not be null");
Assert.notNull(elementClass, "'elementClass' must not be null"); Assert.notNull(elementClass, "'elementClass' must not be null");
return BodyInserter.of( return bodyInserterFor(publisher, ResolvableType.forClass(elementClass));
writeFunctionFor(publisher, ResolvableType.forClass(elementClass)),
() -> publisher
);
} }
/** /**
...@@ -94,18 +91,15 @@ public abstract class BodyInserters { ...@@ -94,18 +91,15 @@ public abstract class BodyInserters {
* @param publisher the publisher to stream to the response body * @param publisher the publisher to stream to the response body
* @param elementType the type of elements contained in the publisher * @param elementType the type of elements contained in the publisher
* @param <T> the type of the elements contained in the publisher * @param <T> the type of the elements contained in the publisher
* @param <S> the type of the {@code Publisher} * @param <P> the type of the {@code Publisher}
* @return a {@code BodyInserter} that writes a {@code Publisher} * @return a {@code BodyInserter} that writes a {@code Publisher}
*/ */
public static <S extends Publisher<T>, T> BodyInserter<S, ReactiveHttpOutputMessage> fromPublisher(S publisher, public static <T, P extends Publisher<T>> BodyInserter<P, ReactiveHttpOutputMessage> fromPublisher(P publisher,
ResolvableType elementType) { ResolvableType elementType) {
Assert.notNull(publisher, "'publisher' must not be null"); Assert.notNull(publisher, "'publisher' must not be null");
Assert.notNull(elementType, "'elementType' must not be null"); Assert.notNull(elementType, "'elementType' must not be null");
return BodyInserter.of( return bodyInserterFor(publisher, elementType);
writeFunctionFor(publisher, elementType),
() -> publisher
);
} }
/** /**
...@@ -119,14 +113,11 @@ public abstract class BodyInserters { ...@@ -119,14 +113,11 @@ public abstract class BodyInserters {
*/ */
public static <T extends Resource> BodyInserter<T, ReactiveHttpOutputMessage> fromResource(T resource) { public static <T extends Resource> BodyInserter<T, ReactiveHttpOutputMessage> fromResource(T resource) {
Assert.notNull(resource, "'resource' must not be null"); Assert.notNull(resource, "'resource' must not be null");
return BodyInserter.of( return (response, context) -> {
(response, context) -> {
HttpMessageWriter<Resource> messageWriter = resourceHttpMessageWriter(context); HttpMessageWriter<Resource> messageWriter = resourceHttpMessageWriter(context);
return messageWriter.write(Mono.just(resource), RESOURCE_TYPE, null, return messageWriter.write(Mono.just(resource), RESOURCE_TYPE, null,
response, Collections.emptyMap()); response, Collections.emptyMap());
}, };
() -> resource
);
} }
private static HttpMessageWriter<Resource> resourceHttpMessageWriter(BodyInserter.Context context) { private static HttpMessageWriter<Resource> resourceHttpMessageWriter(BodyInserter.Context context) {
...@@ -149,14 +140,11 @@ public abstract class BodyInserters { ...@@ -149,14 +140,11 @@ public abstract class BodyInserters {
S eventsPublisher) { S eventsPublisher) {
Assert.notNull(eventsPublisher, "'eventsPublisher' must not be null"); Assert.notNull(eventsPublisher, "'eventsPublisher' must not be null");
return BodyInserter.of( return (response, context) -> {
(response, context) -> {
HttpMessageWriter<ServerSentEvent<T>> messageWriter = sseMessageWriter(context); HttpMessageWriter<ServerSentEvent<T>> messageWriter = sseMessageWriter(context);
return messageWriter.write(eventsPublisher, SERVER_SIDE_EVENT_TYPE, return messageWriter.write(eventsPublisher, SERVER_SIDE_EVENT_TYPE,
MediaType.TEXT_EVENT_STREAM, response, Collections.emptyMap()); MediaType.TEXT_EVENT_STREAM, response, Collections.emptyMap());
}, };
() -> eventsPublisher
);
} }
/** /**
...@@ -192,15 +180,12 @@ public abstract class BodyInserters { ...@@ -192,15 +180,12 @@ public abstract class BodyInserters {
Assert.notNull(eventsPublisher, "'eventsPublisher' must not be null"); Assert.notNull(eventsPublisher, "'eventsPublisher' must not be null");
Assert.notNull(eventType, "'eventType' must not be null"); Assert.notNull(eventType, "'eventType' must not be null");
return BodyInserter.of( return (outputMessage, context) -> {
(outputMessage, context) -> {
HttpMessageWriter<T> messageWriter = sseMessageWriter(context); HttpMessageWriter<T> messageWriter = sseMessageWriter(context);
return messageWriter.write(eventsPublisher, eventType, return messageWriter.write(eventsPublisher, eventType,
MediaType.TEXT_EVENT_STREAM, outputMessage, Collections.emptyMap()); MediaType.TEXT_EVENT_STREAM, outputMessage, Collections.emptyMap());
}, };
() -> eventsPublisher
);
} }
/** /**
...@@ -214,10 +199,7 @@ public abstract class BodyInserters { ...@@ -214,10 +199,7 @@ public abstract class BodyInserters {
public static <T extends Publisher<DataBuffer>> BodyInserter<T, ReactiveHttpOutputMessage> fromDataBuffers(T publisher) { public static <T extends Publisher<DataBuffer>> BodyInserter<T, ReactiveHttpOutputMessage> fromDataBuffers(T publisher) {
Assert.notNull(publisher, "'publisher' must not be null"); Assert.notNull(publisher, "'publisher' must not be null");
return BodyInserter.of( return (outputMessage, context) -> outputMessage.writeWith(publisher);
(outputMessage, context) -> outputMessage.writeWith(publisher),
() -> publisher
);
} }
private static <T> HttpMessageWriter<T> sseMessageWriter(BodyInserter.Context context) { private static <T> HttpMessageWriter<T> sseMessageWriter(BodyInserter.Context context) {
...@@ -231,8 +213,7 @@ public abstract class BodyInserters { ...@@ -231,8 +213,7 @@ public abstract class BodyInserters {
MediaType.TEXT_EVENT_STREAM_VALUE)); MediaType.TEXT_EVENT_STREAM_VALUE));
} }
private static <T, M extends ReactiveHttpOutputMessage> BiFunction<M, BodyInserter.Context, Mono<Void>> private static <T, P extends Publisher<?>, M extends ReactiveHttpOutputMessage> BodyInserter<T, M> bodyInserterFor(P body, ResolvableType bodyType) {
writeFunctionFor(Publisher<T> body, ResolvableType bodyType) {
return (m, context) -> { return (m, context) -> {
...@@ -261,31 +242,5 @@ public abstract class BodyInserters { ...@@ -261,31 +242,5 @@ public abstract class BodyInserters {
return (HttpMessageWriter<T>) messageWriter; return (HttpMessageWriter<T>) messageWriter;
} }
static class DefaultBodyInserter<T, M extends ReactiveHttpOutputMessage>
implements BodyInserter<T, M> {
private final BiFunction<M, Context, Mono<Void>> writer;
private final Supplier<T> supplier;
public DefaultBodyInserter(
BiFunction<M, Context, Mono<Void>> writer,
Supplier<T> supplier) {
this.writer = writer;
this.supplier = supplier;
}
@Override
public Mono<Void> insert(M outputMessage, Context context) {
return this.writer.apply(outputMessage, context);
}
@Override
public T t() {
return this.supplier.get();
}
}
} }
...@@ -67,11 +67,6 @@ public interface ClientRequest<T> { ...@@ -67,11 +67,6 @@ public interface ClientRequest<T> {
*/ */
MultiValueMap<String, String> cookies(); MultiValueMap<String, String> cookies();
/**
* Return the body of this request.
*/
T body();
/** /**
* Return the body inserter of this request. * Return the body inserter of this request.
*/ */
......
...@@ -121,9 +121,7 @@ class DefaultClientRequestBuilder implements ClientRequest.BodyBuilder { ...@@ -121,9 +121,7 @@ class DefaultClientRequestBuilder implements ClientRequest.BodyBuilder {
@Override @Override
public ClientRequest<Void> build() { public ClientRequest<Void> build() {
return body(BodyInserter.of( return body(BodyInserters.empty());
(response, configuration) -> response.setComplete(),
() -> null));
} }
@Override @Override
...@@ -192,11 +190,6 @@ class DefaultClientRequestBuilder implements ClientRequest.BodyBuilder { ...@@ -192,11 +190,6 @@ class DefaultClientRequestBuilder implements ClientRequest.BodyBuilder {
return this.cookies; return this.cookies;
} }
@Override
public T body() {
return this.inserter.t();
}
@Override @Override
public BodyInserter<T, ? super ClientHttpRequest> inserter() { public BodyInserter<T, ? super ClientHttpRequest> inserter() {
return this.inserter; return this.inserter;
......
...@@ -46,7 +46,6 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse ...@@ -46,7 +46,6 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/** /**
* @author Arjen Poutsma * @author Arjen Poutsma
...@@ -83,8 +82,6 @@ public class BodyInsertersTests { ...@@ -83,8 +82,6 @@ public class BodyInsertersTests {
String body = "foo"; String body = "foo";
BodyInserter<String, ReactiveHttpOutputMessage> inserter = BodyInserters.fromObject(body); BodyInserter<String, ReactiveHttpOutputMessage> inserter = BodyInserters.fromObject(body);
assertEquals(body, inserter.t());
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse response = new MockServerHttpResponse();
Mono<Void> result = inserter.insert(response, this.context); Mono<Void> result = inserter.insert(response, this.context);
StepVerifier.create(result).expectComplete().verify(); StepVerifier.create(result).expectComplete().verify();
...@@ -102,8 +99,6 @@ public class BodyInsertersTests { ...@@ -102,8 +99,6 @@ public class BodyInsertersTests {
Flux<String> body = Flux.just("foo"); Flux<String> body = Flux.just("foo");
BodyInserter<Flux<String>, ReactiveHttpOutputMessage> inserter = BodyInserters.fromPublisher(body, String.class); BodyInserter<Flux<String>, ReactiveHttpOutputMessage> inserter = BodyInserters.fromPublisher(body, String.class);
assertEquals(body, inserter.t());
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse response = new MockServerHttpResponse();
Mono<Void> result = inserter.insert(response, this.context); Mono<Void> result = inserter.insert(response, this.context);
StepVerifier.create(result).expectComplete().verify(); StepVerifier.create(result).expectComplete().verify();
...@@ -121,8 +116,6 @@ public class BodyInsertersTests { ...@@ -121,8 +116,6 @@ public class BodyInsertersTests {
Resource body = new ClassPathResource("response.txt", getClass()); Resource body = new ClassPathResource("response.txt", getClass());
BodyInserter<Resource, ReactiveHttpOutputMessage> inserter = BodyInserters.fromResource(body); BodyInserter<Resource, ReactiveHttpOutputMessage> inserter = BodyInserters.fromResource(body);
assertEquals(body, inserter.t());
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse response = new MockServerHttpResponse();
Mono<Void> result = inserter.insert(response, this.context); Mono<Void> result = inserter.insert(response, this.context);
StepVerifier.create(result).expectComplete().verify(); StepVerifier.create(result).expectComplete().verify();
...@@ -146,8 +139,6 @@ public class BodyInsertersTests { ...@@ -146,8 +139,6 @@ public class BodyInsertersTests {
BodyInserter<Flux<ServerSentEvent<String>>, ServerHttpResponse> inserter = BodyInserter<Flux<ServerSentEvent<String>>, ServerHttpResponse> inserter =
BodyInserters.fromServerSentEvents(body); BodyInserters.fromServerSentEvents(body);
assertEquals(body, inserter.t());
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse response = new MockServerHttpResponse();
Mono<Void> result = inserter.insert(response, this.context); Mono<Void> result = inserter.insert(response, this.context);
StepVerifier.create(result).expectNextCount(0).expectComplete().verify(); StepVerifier.create(result).expectNextCount(0).expectComplete().verify();
...@@ -159,8 +150,6 @@ public class BodyInsertersTests { ...@@ -159,8 +150,6 @@ public class BodyInsertersTests {
BodyInserter<Flux<String>, ServerHttpResponse> inserter = BodyInserter<Flux<String>, ServerHttpResponse> inserter =
BodyInserters.fromServerSentEvents(body, String.class); BodyInserters.fromServerSentEvents(body, String.class);
assertEquals(body, inserter.t());
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse response = new MockServerHttpResponse();
Mono<Void> result = inserter.insert(response, this.context); Mono<Void> result = inserter.insert(response, this.context);
StepVerifier.create(result).expectNextCount(0).expectComplete().verify(); StepVerifier.create(result).expectNextCount(0).expectComplete().verify();
...@@ -175,8 +164,6 @@ public class BodyInsertersTests { ...@@ -175,8 +164,6 @@ public class BodyInsertersTests {
BodyInserter<Flux<DataBuffer>, ReactiveHttpOutputMessage> inserter = BodyInserters.fromDataBuffers(body); BodyInserter<Flux<DataBuffer>, ReactiveHttpOutputMessage> inserter = BodyInserters.fromDataBuffers(body);
assertEquals(body, inserter.t());
MockServerHttpResponse response = new MockServerHttpResponse(); MockServerHttpResponse response = new MockServerHttpResponse();
Mono<Void> result = inserter.insert(response, this.context); Mono<Void> result = inserter.insert(response, this.context);
StepVerifier.create(result).expectComplete().verify(); StepVerifier.create(result).expectComplete().verify();
......
...@@ -24,8 +24,6 @@ import java.util.ArrayList; ...@@ -24,8 +24,6 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
...@@ -186,8 +184,7 @@ public class DefaultClientRequestBuilderTests { ...@@ -186,8 +184,7 @@ public class DefaultClientRequestBuilderTests {
@Test @Test
public void bodyInserter() throws Exception { public void bodyInserter() throws Exception {
String body = "foo"; String body = "foo";
Supplier<String> supplier = () -> body; BodyInserter<String, ClientHttpRequest> inserter =
BiFunction<ClientHttpRequest, BodyInserter.Context, Mono<Void>> writer =
(response, strategies) -> { (response, strategies) -> {
byte[] bodyBytes = body.getBytes(UTF_8); byte[] bodyBytes = body.getBytes(UTF_8);
ByteBuffer byteBuffer = ByteBuffer.wrap(bodyBytes); ByteBuffer byteBuffer = ByteBuffer.wrap(bodyBytes);
...@@ -197,8 +194,7 @@ public class DefaultClientRequestBuilderTests { ...@@ -197,8 +194,7 @@ public class DefaultClientRequestBuilderTests {
}; };
ClientRequest<String> result = ClientRequest.POST("http://example.com") ClientRequest<String> result = ClientRequest.POST("http://example.com")
.body(BodyInserter.of(writer, supplier)); .body(inserter);
assertEquals(body, result.body());
MockClientHttpRequest request = new MockClientHttpRequest(); MockClientHttpRequest request = new MockClientHttpRequest();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册