提交 91e96d80 编写于 作者: A Arjen Poutsma

Improve RouterFunction builder

This commit improves the RouterFunctions.Builder based on conversations
had during the weekly team meeting.

Issue: SPR-16953
上级 22ccdb28
...@@ -26,7 +26,7 @@ import java.util.function.Supplier; ...@@ -26,7 +26,7 @@ import java.util.function.Supplier;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.http.HttpMethod; import org.springframework.core.io.Resource;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
...@@ -40,89 +40,111 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { ...@@ -40,89 +40,111 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
private List<HandlerFilterFunction<ServerResponse, ServerResponse>> filterFunctions = new ArrayList<>(); private List<HandlerFilterFunction<ServerResponse, ServerResponse>> filterFunctions = new ArrayList<>();
@Override @Override
public RouterFunctions.Builder route(RequestPredicate predicate, public RouterFunctions.Builder add(RouterFunction<ServerResponse> routerFunction) {
Assert.notNull(routerFunction, "RouterFunction must not be null");
this.routerFunctions.add(routerFunction);
return this;
}
private RouterFunctions.Builder add(RequestPredicate predicate,
HandlerFunction<ServerResponse> handlerFunction) { HandlerFunction<ServerResponse> handlerFunction) {
this.routerFunctions.add(RouterFunctions.route(predicate, handlerFunction)); this.routerFunctions.add(RouterFunctions.route(predicate, handlerFunction));
return this; return this;
} }
@Override @Override
public RouterFunctions.Builder routeGet(HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder GET(String pattern, HandlerFunction<ServerResponse> handlerFunction) {
return route(RequestPredicates.method(HttpMethod.GET), handlerFunction); return add(RequestPredicates.GET(pattern), handlerFunction);
}
@Override
public RouterFunctions.Builder GET(String pattern, RequestPredicate predicate,
HandlerFunction<ServerResponse> handlerFunction) {
return add(RequestPredicates.GET(pattern).and(predicate), handlerFunction);
}
@Override
public RouterFunctions.Builder HEAD(String pattern, HandlerFunction<ServerResponse> handlerFunction) {
return add(RequestPredicates.HEAD(pattern), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routeGet(String pattern, HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder HEAD(String pattern, RequestPredicate predicate,
return route(RequestPredicates.GET(pattern), handlerFunction); HandlerFunction<ServerResponse> handlerFunction) {
return add(RequestPredicates.HEAD(pattern).and(predicate), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routeHead(HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder POST(String pattern, HandlerFunction<ServerResponse> handlerFunction) {
return route(RequestPredicates.method(HttpMethod.HEAD), handlerFunction); return add(RequestPredicates.POST(pattern), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routeHead(String pattern, HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder POST(String pattern, RequestPredicate predicate,
return route(RequestPredicates.HEAD(pattern), handlerFunction); HandlerFunction<ServerResponse> handlerFunction) {
return add(RequestPredicates.POST(pattern).and(predicate), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routePost(HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder PUT(String pattern, HandlerFunction<ServerResponse> handlerFunction) {
return route(RequestPredicates.method(HttpMethod.POST), handlerFunction); return add(RequestPredicates.PUT(pattern), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routePost(String pattern, HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder PUT(String pattern, RequestPredicate predicate,
return route(RequestPredicates.POST(pattern), handlerFunction); HandlerFunction<ServerResponse> handlerFunction) {
return add(RequestPredicates.PUT(pattern).and(predicate), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routePut(HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder PATCH(String pattern, HandlerFunction<ServerResponse> handlerFunction) {
return route(RequestPredicates.method(HttpMethod.PUT), handlerFunction); return add(RequestPredicates.PATCH(pattern), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routePut(String pattern, HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder PATCH(String pattern, RequestPredicate predicate,
return route(RequestPredicates.PUT(pattern), handlerFunction); HandlerFunction<ServerResponse> handlerFunction) {
return add(RequestPredicates.PATCH(pattern).and(predicate), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routePatch(HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder DELETE(String pattern, HandlerFunction<ServerResponse> handlerFunction) {
return route(RequestPredicates.method(HttpMethod.PATCH), handlerFunction); return add(RequestPredicates.DELETE(pattern), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routePatch(String pattern, HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder DELETE(String pattern, RequestPredicate predicate,
return route(RequestPredicates.PATCH(pattern), handlerFunction); HandlerFunction<ServerResponse> handlerFunction) {
return add(RequestPredicates.DELETE(pattern).and(predicate), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routeDelete(HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder OPTIONS(String pattern, HandlerFunction<ServerResponse> handlerFunction) {
return route(RequestPredicates.method(HttpMethod.DELETE), handlerFunction); return add(RequestPredicates.OPTIONS(pattern), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routeDelete(String pattern, HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder OPTIONS(String pattern, RequestPredicate predicate,
return route(RequestPredicates.DELETE(pattern), handlerFunction); HandlerFunction<ServerResponse> handlerFunction) {
return add(RequestPredicates.OPTIONS(pattern).and(predicate), handlerFunction);
} }
@Override @Override
public RouterFunctions.Builder routeOptions(HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder resources(String pattern, Resource location) {
return route(RequestPredicates.method(HttpMethod.OPTIONS), handlerFunction); return add(RouterFunctions.resources(pattern, location));
} }
@Override @Override
public RouterFunctions.Builder routeOptions(String pattern, HandlerFunction<ServerResponse> handlerFunction) { public RouterFunctions.Builder resources(Function<ServerRequest, Mono<Resource>> lookupFunction) {
return route(RequestPredicates.OPTIONS(pattern), handlerFunction); return add(RouterFunctions.resources(lookupFunction));
} }
@Override @Override
public RouterFunctions.Builder nest(RequestPredicate predicate, public RouterFunctions.Builder nest(RequestPredicate predicate,
Consumer<RouterFunctions.Builder> builderConsumer) { Consumer<RouterFunctions.Builder> builderConsumer) {
Assert.notNull(builderConsumer, "'builderConsumer' must not be null"); Assert.notNull(builderConsumer, "Consumer must not be null");
RouterFunctionBuilder nestedBuilder = new RouterFunctionBuilder(); RouterFunctionBuilder nestedBuilder = new RouterFunctionBuilder();
builderConsumer.accept(nestedBuilder); builderConsumer.accept(nestedBuilder);
...@@ -136,7 +158,7 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { ...@@ -136,7 +158,7 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
public RouterFunctions.Builder nest(RequestPredicate predicate, public RouterFunctions.Builder nest(RequestPredicate predicate,
Supplier<RouterFunction<ServerResponse>> routerFunctionSupplier) { Supplier<RouterFunction<ServerResponse>> routerFunctionSupplier) {
Assert.notNull(routerFunctionSupplier, "'routerFunctionSupplier' must not be null"); Assert.notNull(routerFunctionSupplier, "RouterFunction Supplier must not be null");
RouterFunction<ServerResponse> nestedRoute = routerFunctionSupplier.get(); RouterFunction<ServerResponse> nestedRoute = routerFunctionSupplier.get();
...@@ -145,57 +167,56 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { ...@@ -145,57 +167,56 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
} }
@Override @Override
public RouterFunctions.Builder nestPath(String pattern, public RouterFunctions.Builder path(String pattern,
Consumer<RouterFunctions.Builder> builderConsumer) { Consumer<RouterFunctions.Builder> builderConsumer) {
return nest(RequestPredicates.path(pattern), builderConsumer); return nest(RequestPredicates.path(pattern), builderConsumer);
} }
@Override @Override
public RouterFunctions.Builder nestPath(String pattern, public RouterFunctions.Builder path(String pattern,
Supplier<RouterFunction<ServerResponse>> routerFunctionSupplier) { Supplier<RouterFunction<ServerResponse>> routerFunctionSupplier) {
return nest(RequestPredicates.path(pattern), routerFunctionSupplier); return nest(RequestPredicates.path(pattern), routerFunctionSupplier);
} }
@Override @Override
public RouterFunctions.Builder filter(HandlerFilterFunction<ServerResponse, ServerResponse> filterFunction) { public RouterFunctions.Builder filter(HandlerFilterFunction<ServerResponse, ServerResponse> filterFunction) {
Assert.notNull(filterFunction, "'filterFunction' must not be null"); Assert.notNull(filterFunction, "HandlerFilterFunction must not be null");
this.filterFunctions.add(filterFunction); this.filterFunctions.add(filterFunction);
return this; return this;
} }
@Override @Override
public RouterFunctions.Builder filterBefore( public RouterFunctions.Builder before(Function<ServerRequest, ServerRequest> requestProcessor) {
Function<ServerRequest, Mono<ServerRequest>> requestProcessor) { Assert.notNull(requestProcessor, "RequestProcessor must not be null");
return filter((request, next) -> next.handle(requestProcessor.apply(request)));
Assert.notNull(requestProcessor, "Function must not be null");
return filter((request, next) -> requestProcessor.apply(request).flatMap(next::handle));
} }
@Override @Override
public RouterFunctions.Builder filterAfter( public RouterFunctions.Builder after(
BiFunction<ServerRequest, ServerResponse, Mono<ServerResponse>> responseProcessor) { BiFunction<ServerRequest, ServerResponse, ServerResponse> responseProcessor) {
Assert.notNull(responseProcessor, "ResponseProcessor must not be null");
return filter((request, next) -> next.handle(request) return filter((request, next) -> next.handle(request)
.flatMap(serverResponse -> responseProcessor.apply(request, serverResponse))); .map(serverResponse -> responseProcessor.apply(request, serverResponse)));
} }
@Override @Override
public RouterFunctions.Builder filterException(Predicate<? super Throwable> predicate, public RouterFunctions.Builder onError(Predicate<? super Throwable> predicate,
BiFunction<? super Throwable, ServerRequest, Mono<ServerResponse>> responseProvider) { BiFunction<? super Throwable, ServerRequest, Mono<ServerResponse>> responseProvider) {
Assert.notNull(predicate, "'exceptionType' must not be null"); Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(responseProvider, "'fallback' must not be null"); Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter((request, next) -> next.handle(request) return filter((request, next) -> next.handle(request)
.onErrorResume(predicate, t -> responseProvider.apply(t, request))); .onErrorResume(predicate, t -> responseProvider.apply(t, request)));
} }
@Override @Override
public <T extends Throwable> RouterFunctions.Builder filterException( public <T extends Throwable> RouterFunctions.Builder onError(
Class<T> exceptionType, Class<T> exceptionType,
BiFunction<? super T, ServerRequest, Mono<ServerResponse>> responseProvider) { BiFunction<? super T, ServerRequest, Mono<ServerResponse>> responseProvider) {
Assert.notNull(exceptionType, "'exceptionType' must not be null"); Assert.notNull(exceptionType, "ExceptionType must not be null");
Assert.notNull(responseProvider, "'fallback' must not be null"); Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter((request, next) -> next.handle(request) return filter((request, next) -> next.handle(request)
.onErrorResume(exceptionType, t -> responseProvider.apply(t, request))); .onErrorResume(exceptionType, t -> responseProvider.apply(t, request)));
......
...@@ -23,8 +23,11 @@ import org.junit.Test; ...@@ -23,8 +23,11 @@ import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import static org.junit.Assert.*; import static org.junit.Assert.*;
...@@ -35,10 +38,11 @@ public class RouterFunctionBuilderTests { ...@@ -35,10 +38,11 @@ public class RouterFunctionBuilderTests {
@Test @Test
public void route() { public void route() {
RouterFunction<ServerResponse> route = RouterFunctions.builder() RouterFunction<ServerResponse> route = RouterFunctions.route()
.routeGet("/foo", request -> ServerResponse.ok().build()) .GET("/foo", request -> ServerResponse.ok().build())
.routePost(request -> ServerResponse.noContent().build()) .POST("/", RequestPredicates.contentType(MediaType.TEXT_PLAIN), request -> ServerResponse.noContent().build())
.build(); .build();
System.out.println(route);
MockServerRequest fooRequest = MockServerRequest.builder(). MockServerRequest fooRequest = MockServerRequest.builder().
method(HttpMethod.GET). method(HttpMethod.GET).
...@@ -56,7 +60,8 @@ public class RouterFunctionBuilderTests { ...@@ -56,7 +60,8 @@ public class RouterFunctionBuilderTests {
MockServerRequest barRequest = MockServerRequest.builder(). MockServerRequest barRequest = MockServerRequest.builder().
method(HttpMethod.POST). method(HttpMethod.POST).
uri(URI.create("http://localhost")) uri(URI.create("http://localhost/"))
.header("Content-Type", "text/plain")
.build(); .build();
responseMono = route.route(barRequest) responseMono = route.route(barRequest)
...@@ -68,15 +73,65 @@ public class RouterFunctionBuilderTests { ...@@ -68,15 +73,65 @@ public class RouterFunctionBuilderTests {
.expectNext(204) .expectNext(204)
.verifyComplete(); .verifyComplete();
MockServerRequest invalidRequest = MockServerRequest.builder().
method(HttpMethod.POST).
uri(URI.create("http://localhost/"))
.build();
responseMono = route.route(invalidRequest)
.flatMap(handlerFunction -> handlerFunction.handle(invalidRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
StepVerifier.create(responseMono)
.verifyComplete();
}
@Test
public void resources() {
Resource resource = new ClassPathResource("/org/springframework/web/reactive/function/server/");
assertTrue(resource.exists());
RouterFunction<ServerResponse> route = RouterFunctions.route()
.resources("/resources/**", resource)
.build();
MockServerRequest resourceRequest = MockServerRequest.builder().
method(HttpMethod.GET).
uri(URI.create("http://localhost/resources/response.txt"))
.build();
Mono<Integer> responseMono = route.route(resourceRequest)
.flatMap(handlerFunction -> handlerFunction.handle(resourceRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
StepVerifier.create(responseMono)
.expectNext(200)
.verifyComplete();
MockServerRequest invalidRequest = MockServerRequest.builder().
method(HttpMethod.POST).
uri(URI.create("http://localhost/resources/foo.txt"))
.build();
responseMono = route.route(invalidRequest)
.flatMap(handlerFunction -> handlerFunction.handle(invalidRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
StepVerifier.create(responseMono)
.verifyComplete();
} }
@Test @Test
public void nest() { public void nest() {
RouterFunction<?> route = RouterFunctions.builder() RouterFunction<?> route = RouterFunctions.route()
.nestPath("/foo", builder -> .path("/foo", builder ->
builder.nestPath("/bar", builder.path("/bar",
() -> RouterFunctions.builder() () -> RouterFunctions.route()
.routeGet("/baz", request -> ServerResponse.ok().build()) .GET("/baz", request -> ServerResponse.ok().build())
.build())) .build()))
.build(); .build();
...@@ -99,18 +154,18 @@ public class RouterFunctionBuilderTests { ...@@ -99,18 +154,18 @@ public class RouterFunctionBuilderTests {
public void filters() { public void filters() {
AtomicInteger filterCount = new AtomicInteger(); AtomicInteger filterCount = new AtomicInteger();
RouterFunction<?> route = RouterFunctions.builder() RouterFunction<?> route = RouterFunctions.route()
.routeGet("/foo", request -> ServerResponse.ok().build()) .GET("/foo", request -> ServerResponse.ok().build())
.routeGet("/bar", request -> Mono.error(new IllegalStateException())) .GET("/bar", request -> Mono.error(new IllegalStateException()))
.filterBefore(request -> { .before(request -> {
int count = filterCount.getAndIncrement(); int count = filterCount.getAndIncrement();
assertEquals(0, count); assertEquals(0, count);
return Mono.just(request); return request;
}) })
.filterAfter((request, response) -> { .after((request, response) -> {
int count = filterCount.getAndIncrement(); int count = filterCount.getAndIncrement();
assertEquals(3, count); assertEquals(3, count);
return Mono.just(response); return response;
}) })
.filter((request, next) -> { .filter((request, next) -> {
int count = filterCount.getAndIncrement(); int count = filterCount.getAndIncrement();
...@@ -120,7 +175,7 @@ public class RouterFunctionBuilderTests { ...@@ -120,7 +175,7 @@ public class RouterFunctionBuilderTests {
assertEquals(2, count); assertEquals(2, count);
return responseMono; return responseMono;
}) })
.filterException(IllegalStateException.class, (e, request) -> ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()) .onError(IllegalStateException.class, (e, request) -> ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build())
.build(); .build();
MockServerRequest fooRequest = MockServerRequest.builder(). MockServerRequest fooRequest = MockServerRequest.builder().
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册