提交 e556aacc 编写于 作者: A Arjen Poutsma

Use PathPattern.getPathRemaining in WebFlux fn

This commit uses the newly introduced
`PathPattern.getPathRemaining(String)` in the functional web framework.
With this change, all path predicates can be used for nested router
functions, so the `pathPrefix` predicate is no longer required and has
been removed.

Issue: SPR-15336
上级 b2459185
......@@ -16,6 +16,8 @@
package org.springframework.web.reactive.function.server;
import java.util.Optional;
import org.springframework.util.Assert;
/**
......@@ -47,20 +49,7 @@ public interface RequestPredicate {
*/
default RequestPredicate and(RequestPredicate other) {
Assert.notNull(other, "'other' must not be null");
return new RequestPredicate() {
@Override
public boolean test(ServerRequest t) {
return RequestPredicate.this.test(t) && other.test(t);
}
@Override
public ServerRequest nestRequest(ServerRequest request) {
return other.nestRequest(RequestPredicate.this.nestRequest(request));
}
@Override
public String toString() {
return String.format("(%s && %s)", RequestPredicate.this, other);
}
};
return new RequestPredicates.AndRequestPredicate(this, other);
}
/**
......@@ -80,41 +69,22 @@ public interface RequestPredicate {
*/
default RequestPredicate or(RequestPredicate other) {
Assert.notNull(other, "'other' must not be null");
return new RequestPredicate() {
@Override
public boolean test(ServerRequest t) {
return RequestPredicate.this.test(t) || other.test(t);
}
@Override
public ServerRequest nestRequest(ServerRequest request) {
if (RequestPredicate.this.test(request)) {
return RequestPredicate.this.nestRequest(request);
}
else if (other.test(request)) {
return other.nestRequest(request);
}
else {
throw new IllegalStateException("Neither " + RequestPredicate.this.toString() +
" nor " + other + "matches");
}
}
@Override
public String toString() {
return String.format("(%s || %s)", RequestPredicate.this, other);
}
};
return new RequestPredicates.OrRequestPredicate(this, other);
}
/**
* Transform the given request into a request used for a nested route. For instance,
* a path-based predicate can return a {@code ServerRequest} with a nested path.
* <p>The default implementation returns the given path.
* a path-based predicate can return a {@code ServerRequest} with a the path remaining after a
* match.
* <p>The default implementation returns an {@code Optional} wrapping the given path if
* {@link #test(ServerRequest)} evaluates to {@code true}; or {@link Optional#empty()} if it
* evaluates to {@code false}.
* @param request the request to be nested
* @return the nested request
* @see RouterFunctions#nest(RequestPredicate, RouterFunction)
*/
default ServerRequest nestRequest(ServerRequest request) {
return request;
default Optional<ServerRequest> nest(ServerRequest request) {
return test(request) ? Optional.of(request) : Optional.empty();
}
}
......@@ -20,6 +20,7 @@ import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
......@@ -279,22 +280,6 @@ public abstract class RequestPredicates {
};
}
/**
* Return a {@code RequestPredicate} that tests the beginning of the request path against the
* given path pattern. This predicate is effectively identical to a
* {@linkplain #path(String) standard path predicate} with path {@code pathPrefixPattern + "/**"}.
* @param pathPrefixPattern the pattern to match against the start of the request path
* @return a predicate that matches if the given predicate matches against the beginning of
* the request's path
*/
public static RequestPredicate pathPrefix(String pathPrefixPattern) {
Assert.notNull(pathPrefixPattern, "'pathPrefixPattern' must not be null");
if (!pathPrefixPattern.endsWith("/**")) {
pathPrefixPattern += "/**";
}
return path(pathPrefixPattern);
}
/**
* Return a {@code RequestPredicate} that tests the request's query parameter of the given name
* against the given predicate.
......@@ -358,8 +343,7 @@ public abstract class RequestPredicates {
boolean match = this.pattern.matches(path);
traceMatch("Pattern", this.pattern.getPatternString(), path, match);
if (match) {
Map<String, String> uriTemplateVariables = this.pattern.matchAndExtract(path);
request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, uriTemplateVariables);
mergeTemplateVariables(request);
return true;
}
else {
......@@ -368,16 +352,26 @@ public abstract class RequestPredicates {
}
@Override
public ServerRequest nestRequest(ServerRequest request) {
String requestPath = request.path();
String subPath = this.pattern.extractPathWithinPattern(requestPath);
if (!subPath.startsWith("/")) {
subPath = "/" + subPath;
}
if (requestPath.endsWith("/") && !subPath.endsWith("/")) {
subPath += "/";
public Optional<ServerRequest> nest(ServerRequest request) {
String remainingPath = this.pattern.getPathRemaining(request.path());
return Optional.ofNullable(remainingPath)
.map(path -> !path.startsWith("/") ? "/" + path : path)
.map(path -> {
// TODO: re-enable when SPR-15419 has been fixed.
// mergeTemplateVariables(request);
return new SubPathServerRequestWrapper(request, path);
});
}
private void mergeTemplateVariables(ServerRequest request) {
Map<String, String> newVariables = this.pattern.matchAndExtract(request.path());
if (!newVariables.isEmpty()) {
Map<String, String> oldVariables = request.pathVariables();
Map<String, String> variables = new LinkedHashMap<>(oldVariables);
variables.putAll(newVariables);
request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Collections.unmodifiableMap(variables));
}
return new SubPathServerRequestWrapper(request, subPath);
}
@Override
......@@ -407,6 +401,64 @@ public abstract class RequestPredicates {
}
}
static class AndRequestPredicate implements RequestPredicate {
private final RequestPredicate left;
private final RequestPredicate right;
public AndRequestPredicate(RequestPredicate left, RequestPredicate right) {
this.left = left;
this.right = right;
}
@Override
public boolean test(ServerRequest t) {
return this.left.test(t) && this.right.test(t);
}
@Override
public Optional<ServerRequest> nest(ServerRequest request) {
return this.left.nest(request).flatMap(this.right::nest);
}
@Override
public String toString() {
return String.format("(%s && %s)", this.left, this.right);
}
}
static class OrRequestPredicate implements RequestPredicate {
private final RequestPredicate left;
private final RequestPredicate right;
public OrRequestPredicate(RequestPredicate left, RequestPredicate right) {
this.left = left;
this.right = right;
}
@Override
public boolean test(ServerRequest t) {
return this.left.test(t) || this.right.test(t);
}
@Override
public Optional<ServerRequest> nest(ServerRequest request) {
Optional<ServerRequest> leftResult = this.left.nest(request);
if (leftResult.isPresent()) {
return leftResult;
}
else {
return this.right.nest(request);
}
}
@Override
public String toString() {
return String.format("(%s || %s)", this.left, this.right);
}
}
private static class SubPathServerRequestWrapper implements ServerRequest {
......@@ -414,6 +466,7 @@ public abstract class RequestPredicates {
private final String subPath;
public SubPathServerRequestWrapper(ServerRequest request, String subPath) {
this.request = request;
this.subPath = subPath;
......
......@@ -28,7 +28,8 @@ import org.springframework.core.io.Resource;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.util.Assert;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.function.server.support.*;
import org.springframework.web.reactive.function.server.support.HandlerFunctionAdapter;
import org.springframework.web.reactive.function.server.support.ServerResponseResultHandler;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebHandler;
......@@ -136,19 +137,17 @@ public abstract class RouterFunctions {
Assert.notNull(predicate, "'predicate' must not be null");
Assert.notNull(routerFunction, "'routerFunction' must not be null");
return request -> {
if (predicate.test(request)) {
if (logger.isDebugEnabled()) {
logger.debug(String.format("Nested predicate \"%s\" matches against \"%s\"",
predicate, request));
}
ServerRequest subRequest = predicate.nestRequest(request);
return routerFunction.route(subRequest);
}
else {
return Mono.empty();
}
};
return request -> predicate.nest(request)
.map(nestedRequest -> {
if (logger.isDebugEnabled()) {
logger.debug(
String.format("Nested predicate \"%s\" matches against \"%s\"",
predicate, request));
}
return routerFunction.route(nestedRequest);
}
)
.orElseGet(Mono::empty);
}
/**
......
......@@ -19,7 +19,6 @@ package org.springframework.web.reactive.function.server
import org.springframework.core.io.Resource
import org.springframework.http.HttpMethod
import org.springframework.http.MediaType
import org.springframework.web.reactive.function.server.RequestPredicates.*
import reactor.core.publisher.Mono
/**
......@@ -66,13 +65,13 @@ class RouterDsl {
val routes = mutableListOf<RouterFunction<ServerResponse>>()
infix fun RequestPredicate.and(other: String): RequestPredicate = this.and(pathPrefix(other))
infix fun RequestPredicate.and(other: String): RequestPredicate = this.and(path(other))
infix fun RequestPredicate.or(other: String): RequestPredicate = this.or(pathPrefix(other))
infix fun RequestPredicate.or(other: String): RequestPredicate = this.or(path(other))
infix fun String.and(other: RequestPredicate): RequestPredicate = pathPrefix(this).and(other)
infix fun String.and(other: RequestPredicate): RequestPredicate = path(this).and(other)
infix fun String.or(other: RequestPredicate): RequestPredicate = pathPrefix(this).or(other)
infix fun String.or(other: RequestPredicate): RequestPredicate = path(this).or(other)
infix fun RequestPredicate.and(other: RequestPredicate): RequestPredicate = this.and(other)
......@@ -85,7 +84,7 @@ class RouterDsl {
}
fun String.nest(r: Routes) {
routes += RouterFunctions.nest(pathPrefix(this), RouterDsl().apply(r).router())
routes += RouterFunctions.nest(path(this), RouterDsl().apply(r).router())
}
operator fun RequestPredicate.invoke(f: (ServerRequest) -> Mono<ServerResponse>) {
......
......@@ -18,6 +18,7 @@ package org.springframework.web.reactive.function.server;
import org.junit.Ignore;
import org.junit.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpStatus;
......@@ -25,9 +26,11 @@ import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import static org.junit.Assert.*;
import static org.springframework.web.reactive.function.BodyInserters.*;
import static org.springframework.web.reactive.function.server.RequestPredicates.*;
import static org.springframework.web.reactive.function.server.RouterFunctions.*;
import static org.springframework.web.reactive.function.BodyInserters.fromObject;
import static org.springframework.web.reactive.function.server.RequestPredicates.GET;
import static org.springframework.web.reactive.function.server.RequestPredicates.path;
import static org.springframework.web.reactive.function.server.RouterFunctions.nest;
import static org.springframework.web.reactive.function.server.RouterFunctions.route;
/**
* @author Arjen Poutsma
......@@ -40,9 +43,12 @@ public class NestedRouteIntegrationTests extends AbstractRouterFunctionIntegrati
@Override
protected RouterFunction<?> routerFunction() {
NestedHandler nestedHandler = new NestedHandler();
return nest(pathPrefix("/foo"),
return nest(path("/foo/"),
route(GET("/bar"), nestedHandler::bar)
.andRoute(GET("/baz"), nestedHandler::baz));
.andRoute(GET("/baz"), nestedHandler::baz))
.andNest(GET("/{foo}"),
nest(GET("/{bar}"),
route(GET("/{baz}"), nestedHandler::variables)));
}
......@@ -56,7 +62,6 @@ public class NestedRouteIntegrationTests extends AbstractRouterFunctionIntegrati
}
@Test
@Ignore
public void baz() throws Exception {
ResponseEntity<String> result =
restTemplate.getForEntity("http://localhost:" + port + "/foo/baz", String.class);
......@@ -65,6 +70,16 @@ public class NestedRouteIntegrationTests extends AbstractRouterFunctionIntegrati
assertEquals("baz", result.getBody());
}
@Test
@Ignore("SPR-15419")
public void variables() throws Exception {
ResponseEntity<String> result =
restTemplate.getForEntity("http://localhost:" + port + "/1/2/3", String.class);
assertEquals(HttpStatus.OK, result.getStatusCode());
assertEquals("1-2-3", result.getBody());
}
private static class NestedHandler {
......@@ -75,6 +90,13 @@ public class NestedRouteIntegrationTests extends AbstractRouterFunctionIntegrati
public Mono<ServerResponse> baz(ServerRequest request) {
return ServerResponse.ok().body(fromObject("baz"));
}
public Mono<ServerResponse> variables(ServerRequest request) {
Flux<String> responseBody =
Flux.just(request.pathVariable("foo"), "-", request.pathVariable("bar"), "-",
request.pathVariable("baz"));
return ServerResponse.ok().body(responseBody, String.class);
}
}
}
......@@ -163,23 +163,6 @@ public class RequestPredicatesTests {
assertFalse(predicate.test(request));
}
@Test
public void pathPrefix() throws Exception {
RequestPredicate predicate = RequestPredicates.pathPrefix("/foo");
URI uri = URI.create("http://localhost/foo/bar");
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
assertTrue(predicate.test(request));
uri = URI.create("http://localhost/foo");
request = MockServerRequest.builder().uri(uri).build();
assertTrue(predicate.test(request));
uri = URI.create("http://localhost/bar");
request = MockServerRequest.builder().uri(uri).build();
assertFalse(predicate.test(request));
}
@Test
public void queryParam() throws Exception {
MockServerRequest request = MockServerRequest.builder().queryParam("foo", "bar").build();
......
......@@ -16,6 +16,8 @@
package org.springframework.web.reactive.function.server;
import java.util.Optional;
import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
......@@ -82,7 +84,7 @@ public class RouterFunctionsTests {
MockServerRequest request = MockServerRequest.builder().build();
RequestPredicate requestPredicate = mock(RequestPredicate.class);
when(requestPredicate.test(request)).thenReturn(true);
when(requestPredicate.nest(request)).thenReturn(Optional.of(request));
RouterFunction<ServerResponse> result = RouterFunctions.nest(requestPredicate, routerFunction);
assertNotNull(result);
......@@ -101,7 +103,7 @@ public class RouterFunctionsTests {
MockServerRequest request = MockServerRequest.builder().build();
RequestPredicate requestPredicate = mock(RequestPredicate.class);
when(requestPredicate.test(request)).thenReturn(false);
when(requestPredicate.nest(request)).thenReturn(Optional.empty());
RouterFunction<ServerResponse> result = RouterFunctions.nest(requestPredicate, routerFunction);
assertNotNull(result);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册