提交 6f029392 编写于 作者: B Brian Clozel

Move response status processing in InvocableHandlerMethod

Prior to this commit, WebFlux would look at the handler method
annotations (`@ResponseStatus`) for each handler execution, even calling
the expensive `synthesizeAnnotation`.

This commit moves this logic to the InvocableHandlerMethod so that this
executed once at instantiation time and for all result handlers.

Issue: SPR-15227
上级 ab50f7b0
......@@ -24,13 +24,10 @@ import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.accept.RequestedContentTypeResolver;
import org.springframework.web.server.ServerWebExchange;
......@@ -158,17 +155,4 @@ public abstract class AbstractHandlerResultHandler implements Ordered {
return (comparator.compare(acceptable, producible) <= 0 ? acceptable : producible);
}
/**
* Optionally set the response status using the information provided by {@code @ResponseStatus}.
* @param methodParameter the controller method return parameter
* @param exchange the server exchange being handled
*/
protected void updateResponseStatus(MethodParameter methodParameter, ServerWebExchange exchange) {
ResponseStatus annotation = methodParameter.getMethodAnnotation(ResponseStatus.class);
if (annotation != null) {
annotation = AnnotationUtils.synthesizeAnnotation(annotation, methodParameter.getMethod());
exchange.getResponse().setStatusCode(annotation.code());
}
}
}
......@@ -32,9 +32,12 @@ import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.GenericTypeResolver;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.http.HttpStatus;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult;
......@@ -54,6 +57,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
private static final Object NO_ARG_VALUE = new Object();
private HttpStatus responseStatus;
private List<HandlerMethodArgumentResolver> resolvers = new ArrayList<>();
......@@ -62,12 +67,23 @@ public class InvocableHandlerMethod extends HandlerMethod {
public InvocableHandlerMethod(HandlerMethod handlerMethod) {
super(handlerMethod);
initResponseStatus();
}
public InvocableHandlerMethod(Object bean, Method method) {
super(bean, method);
initResponseStatus();
}
private void initResponseStatus() {
ResponseStatus annotation = getMethodAnnotation(ResponseStatus.class);
if (annotation == null) {
annotation = AnnotatedElementUtils.findMergedAnnotation(getBeanType(), ResponseStatus.class);
}
if (annotation != null) {
this.responseStatus = annotation.code();
}
}
/**
* Configure the argument resolvers to use to use for resolving method
......@@ -102,6 +118,9 @@ public class InvocableHandlerMethod extends HandlerMethod {
try {
Object value = doInvoke(args);
HandlerResult result = new HandlerResult(this, value, getReturnType(), bindingContext);
if (this.responseStatus != null) {
exchange.getResponse().setStatusCode(this.responseStatus);
}
return Mono.just(result);
}
catch (InvocationTargetException ex) {
......@@ -165,7 +184,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
return resolver.resolveArgument(parameter, bindingContext, exchange)
.defaultIfEmpty(NO_ARG_VALUE)
.doOnError(cause -> {
if(logger.isDebugEnabled()) {
if (logger.isDebugEnabled()) {
logger.debug(getDetailedErrorMessage("Failed to resolve", parameter), cause);
}
});
......
......@@ -111,8 +111,7 @@ public abstract class AbstractMessageWriterResultHandler extends AbstractHandler
}
if (void.class == elementType.getRawClass() || Void.class == elementType.getRawClass()) {
return Mono.from((Publisher<Void>) publisher)
.doOnSubscribe(sub -> updateResponseStatus(bodyParameter, exchange));
return Mono.from((Publisher<Void>) publisher);
}
ServerHttpRequest request = exchange.getRequest();
......@@ -121,12 +120,11 @@ public abstract class AbstractMessageWriterResultHandler extends AbstractHandler
if (bestMediaType != null) {
for (HttpMessageWriter<?> messageWriter : getMessageWriters()) {
if (messageWriter.canWrite(elementType, bestMediaType)) {
Mono<Void> bodyWriter = (messageWriter instanceof ServerHttpMessageWriter ?
return (messageWriter instanceof ServerHttpMessageWriter ?
((ServerHttpMessageWriter<?>) messageWriter).write((Publisher) publisher,
valueType, elementType, bestMediaType, request, response, Collections.emptyMap()) :
messageWriter.write((Publisher) publisher, elementType,
bestMediaType, response, Collections.emptyMap()));
return bodyWriter.doOnSubscribe(sub -> updateResponseStatus(bodyParameter, exchange));
}
}
}
......
......@@ -208,8 +208,6 @@ public class ViewResolutionResultHandler extends AbstractHandlerResultHandler
.otherwiseIfEmpty(exchange.isNotModified() ? Mono.empty() : NO_VALUE_MONO)
.then(returnValue -> {
updateResponseStatus(result.getReturnTypeSource(), exchange);
Mono<List<View>> viewsMono;
Model model = result.getModel();
......
......@@ -24,8 +24,10 @@ import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.result.ResolvableMethod;
......@@ -147,6 +149,15 @@ public class InvocableHandlerMethodTests {
}
}
@Test
public void invokeMethodWithResponseStatus() throws Exception {
InvocableHandlerMethod hm = handlerMethod("responseStatus");
Mono<HandlerResult> mono = hm.invoke(this.exchange, new BindingContext());
assertHandlerResultValue(mono, "created");
assertThat(this.exchange.getResponse().getStatusCode(), is(HttpStatus.CREATED));
}
private InvocableHandlerMethod handlerMethod(String name) throws Exception {
TestController controller = new TestController();
......@@ -186,6 +197,11 @@ public class InvocableHandlerMethodTests {
public void exceptionMethod() {
throw new IllegalStateException("boo");
}
@ResponseStatus(HttpStatus.CREATED)
public String responseStatus() {
return "created";
}
}
}
......@@ -23,13 +23,11 @@ import java.util.List;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import rx.Completable;
import rx.Single;
import org.springframework.core.codec.ByteBufferEncoder;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.http.codec.EncoderHttpMessageWriter;
import org.springframework.http.codec.HttpMessageWriter;
......@@ -42,7 +40,6 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse
import org.springframework.stereotype.Controller;
import org.springframework.util.ObjectUtils;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.HandlerResult;
......@@ -119,24 +116,6 @@ public class ResponseBodyResultHandlerTests {
testSupports(controller, "handleToMonoResponseEntity", false);
}
@Test
public void writeResponseStatus() throws NoSuchMethodException {
Object controller = new TestRestController();
HandlerMethod hm = handlerMethod(controller, "handleToString");
HandlerResult handlerResult = new HandlerResult(hm, null, hm.getReturnType());
initExchange();
StepVerifier.create(this.resultHandler.handleResult(this.exchange, handlerResult)).expectComplete().verify();
assertEquals(HttpStatus.NO_CONTENT, this.response.getStatusCode());
hm = handlerMethod(controller, "handleToMonoVoid");
handlerResult = new HandlerResult(hm, null, hm.getReturnType());
initExchange();
StepVerifier.create(this.resultHandler.handleResult(this.exchange, handlerResult)).expectComplete().verify();
assertEquals(HttpStatus.CREATED, this.response.getStatusCode());
}
private void testSupports(Object controller, String method, boolean result) throws NoSuchMethodException {
HandlerMethod hm = handlerMethod(controller, method);
HandlerResult handlerResult = new HandlerResult(hm, null, hm.getReturnType());
......@@ -157,10 +136,8 @@ public class ResponseBodyResultHandlerTests {
@RestController @SuppressWarnings("unused")
private static class TestRestController {
@ResponseStatus(code = HttpStatus.CREATED)
public Mono<Void> handleToMonoVoid() { return null;}
@ResponseStatus(code = HttpStatus.NO_CONTENT)
public String handleToString() {
return null;
}
......
......@@ -41,7 +41,6 @@ import org.springframework.core.Ordered;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
......@@ -49,7 +48,6 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse
import org.springframework.ui.ConcurrentModel;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.accept.HeaderContentTypeResolver;
......@@ -131,23 +129,19 @@ public class ViewResolutionResultHandlerTests {
returnType = forClass(View.class);
returnValue = new TestView("account");
ServerWebExchange exchange = testHandle("/path", returnType, returnValue, "account: {id=123}");
assertEquals(HttpStatus.NO_CONTENT, exchange.getResponse().getStatusCode());
testHandle("/path", returnType, returnValue, "account: {id=123}");
returnType = forClassWithGenerics(Mono.class, View.class);
returnValue = Mono.just(new TestView("account"));
exchange = testHandle("/path", returnType, returnValue, "account: {id=123}");
assertEquals(HttpStatus.SEE_OTHER, exchange.getResponse().getStatusCode());
testHandle("/path", returnType, returnValue, "account: {id=123}");
returnType = forClass(String.class);
returnValue = "account";
exchange = testHandle("/path", returnType, returnValue, "account: {id=123}", resolver);
assertEquals(HttpStatus.CREATED, exchange.getResponse().getStatusCode());
testHandle("/path", returnType, returnValue, "account: {id=123}", resolver);
returnType = forClassWithGenerics(Mono.class, String.class);
returnValue = Mono.just("account");
exchange = testHandle("/path", returnType, returnValue, "account: {id=123}", resolver);
assertEquals(HttpStatus.PARTIAL_CONTENT, exchange.getResponse().getStatusCode());
testHandle("/path", returnType, returnValue, "account: {id=123}", resolver);
returnType = forClass(Model.class);
returnValue = new ConcurrentModel().addAttribute("name", "Joe");
......@@ -438,38 +432,61 @@ public class ViewResolutionResultHandlerTests {
@SuppressWarnings("unused")
private static class TestController {
@ResponseStatus(code = HttpStatus.CREATED)
String string() { return null; }
String string() {
return null;
}
@ResponseStatus(HttpStatus.NO_CONTENT)
View view() { return null; }
View view() {
return null;
}
@ResponseStatus(HttpStatus.PARTIAL_CONTENT)
Mono<String> monoString() { return null; }
Mono<String> monoString() {
return null;
}
@ResponseStatus(code = HttpStatus.SEE_OTHER)
Mono<View> monoView() { return null; }
Mono<View> monoView() {
return null;
}
Mono<Void> monoVoid() { return null; }
Mono<Void> monoVoid() {
return null;
}
void voidMethod() {}
void voidMethod() {
}
Single<String> singleString() { return null; }
Single<String> singleString() {
return null;
}
Single<View> singleView() { return null; }
Single<View> singleView() {
return null;
}
Completable completable() { return null; }
Completable completable() {
return null;
}
Model model() { return null; }
Model model() {
return null;
}
Map map() { return null; }
Map map() {
return null;
}
TestBean testBean() { return null; }
TestBean testBean() {
return null;
}
Integer integer() { return null; }
Integer integer() {
return null;
}
@ModelAttribute("num")
Long longAttribute() { return null; }
Long longAttribute() {
return null;
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册