diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/AbstractHandlerResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/AbstractHandlerResultHandler.java index 485d80d25f2560916a76a2a631e56d15a2b44056..56eb1b406236c38d4691fafcfc05c32a80027a7e 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/AbstractHandlerResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/AbstractHandlerResultHandler.java @@ -24,13 +24,10 @@ import java.util.Optional; import java.util.Set; import java.util.function.Supplier; -import org.springframework.core.MethodParameter; import org.springframework.core.Ordered; import org.springframework.core.ReactiveAdapterRegistry; -import org.springframework.core.annotation.AnnotationUtils; import org.springframework.http.MediaType; import org.springframework.util.Assert; -import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.accept.RequestedContentTypeResolver; import org.springframework.web.server.ServerWebExchange; @@ -158,17 +155,4 @@ public abstract class AbstractHandlerResultHandler implements Ordered { return (comparator.compare(acceptable, producible) <= 0 ? acceptable : producible); } - /** - * Optionally set the response status using the information provided by {@code @ResponseStatus}. - * @param methodParameter the controller method return parameter - * @param exchange the server exchange being handled - */ - protected void updateResponseStatus(MethodParameter methodParameter, ServerWebExchange exchange) { - ResponseStatus annotation = methodParameter.getMethodAnnotation(ResponseStatus.class); - if (annotation != null) { - annotation = AnnotationUtils.synthesizeAnnotation(annotation, methodParameter.getMethod()); - exchange.getResponse().setStatusCode(annotation.code()); - } - } - } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index 9fb496951e9c13522d0b1a74f057721681873de8..66dd274815a1259bab468af904066ae328397831 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -32,9 +32,12 @@ import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.GenericTypeResolver; import org.springframework.core.MethodParameter; import org.springframework.core.ParameterNameDiscoverer; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.http.HttpStatus; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; +import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.method.HandlerMethod; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerResult; @@ -54,6 +57,8 @@ public class InvocableHandlerMethod extends HandlerMethod { private static final Object NO_ARG_VALUE = new Object(); + private HttpStatus responseStatus; + private List resolvers = new ArrayList<>(); @@ -62,12 +67,23 @@ public class InvocableHandlerMethod extends HandlerMethod { public InvocableHandlerMethod(HandlerMethod handlerMethod) { super(handlerMethod); + initResponseStatus(); } public InvocableHandlerMethod(Object bean, Method method) { super(bean, method); + initResponseStatus(); } + private void initResponseStatus() { + ResponseStatus annotation = getMethodAnnotation(ResponseStatus.class); + if (annotation == null) { + annotation = AnnotatedElementUtils.findMergedAnnotation(getBeanType(), ResponseStatus.class); + } + if (annotation != null) { + this.responseStatus = annotation.code(); + } + } /** * Configure the argument resolvers to use to use for resolving method @@ -102,6 +118,9 @@ public class InvocableHandlerMethod extends HandlerMethod { try { Object value = doInvoke(args); HandlerResult result = new HandlerResult(this, value, getReturnType(), bindingContext); + if (this.responseStatus != null) { + exchange.getResponse().setStatusCode(this.responseStatus); + } return Mono.just(result); } catch (InvocationTargetException ex) { @@ -165,7 +184,7 @@ public class InvocableHandlerMethod extends HandlerMethod { return resolver.resolveArgument(parameter, bindingContext, exchange) .defaultIfEmpty(NO_ARG_VALUE) .doOnError(cause -> { - if(logger.isDebugEnabled()) { + if (logger.isDebugEnabled()) { logger.debug(getDetailedErrorMessage("Failed to resolve", parameter), cause); } }); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java index 5aaa381a07b8261a01c114ab8d81341c2f1e2676..57784837e06020b8a09ef4aef83d66eb8f12ae79 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java @@ -111,8 +111,7 @@ public abstract class AbstractMessageWriterResultHandler extends AbstractHandler } if (void.class == elementType.getRawClass() || Void.class == elementType.getRawClass()) { - return Mono.from((Publisher) publisher) - .doOnSubscribe(sub -> updateResponseStatus(bodyParameter, exchange)); + return Mono.from((Publisher) publisher); } ServerHttpRequest request = exchange.getRequest(); @@ -121,12 +120,11 @@ public abstract class AbstractMessageWriterResultHandler extends AbstractHandler if (bestMediaType != null) { for (HttpMessageWriter messageWriter : getMessageWriters()) { if (messageWriter.canWrite(elementType, bestMediaType)) { - Mono bodyWriter = (messageWriter instanceof ServerHttpMessageWriter ? + return (messageWriter instanceof ServerHttpMessageWriter ? ((ServerHttpMessageWriter) messageWriter).write((Publisher) publisher, valueType, elementType, bestMediaType, request, response, Collections.emptyMap()) : messageWriter.write((Publisher) publisher, elementType, bestMediaType, response, Collections.emptyMap())); - return bodyWriter.doOnSubscribe(sub -> updateResponseStatus(bodyParameter, exchange)); } } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java index b3d8e1857a1902ede337f5ff60854e8fd4ed842f..b1a4bac2af87b9b645caecd7d927c15647e34654 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java @@ -208,8 +208,6 @@ public class ViewResolutionResultHandler extends AbstractHandlerResultHandler .otherwiseIfEmpty(exchange.isNotModified() ? Mono.empty() : NO_VALUE_MONO) .then(returnValue -> { - updateResponseStatus(result.getReturnTypeSource(), exchange); - Mono> viewsMono; Model model = result.getModel(); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/InvocableHandlerMethodTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/InvocableHandlerMethodTests.java index a8a3c0c729b23b13f99370260aa7049ee82c0900..21828b3c4061f27051d6fb2bc91dfbf7ccbd617c 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/InvocableHandlerMethodTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/InvocableHandlerMethodTests.java @@ -24,8 +24,10 @@ import org.junit.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import org.springframework.http.HttpStatus; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.result.ResolvableMethod; @@ -147,6 +149,15 @@ public class InvocableHandlerMethodTests { } } + @Test + public void invokeMethodWithResponseStatus() throws Exception { + InvocableHandlerMethod hm = handlerMethod("responseStatus"); + Mono mono = hm.invoke(this.exchange, new BindingContext()); + + assertHandlerResultValue(mono, "created"); + assertThat(this.exchange.getResponse().getStatusCode(), is(HttpStatus.CREATED)); + } + private InvocableHandlerMethod handlerMethod(String name) throws Exception { TestController controller = new TestController(); @@ -186,6 +197,11 @@ public class InvocableHandlerMethodTests { public void exceptionMethod() { throw new IllegalStateException("boo"); } + + @ResponseStatus(HttpStatus.CREATED) + public String responseStatus() { + return "created"; + } } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseBodyResultHandlerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseBodyResultHandlerTests.java index 3581f8a5bc83748212a671f12b20df0eb30f4992..51a80bd1d51a9c579cd2d3742d8a5cc0854dcf05 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseBodyResultHandlerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseBodyResultHandlerTests.java @@ -23,13 +23,11 @@ import java.util.List; import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; import rx.Completable; import rx.Single; import org.springframework.core.codec.ByteBufferEncoder; import org.springframework.core.codec.CharSequenceEncoder; -import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.HttpMessageWriter; @@ -42,7 +40,6 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse import org.springframework.stereotype.Controller; import org.springframework.util.ObjectUtils; import org.springframework.web.bind.annotation.ResponseBody; -import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.method.HandlerMethod; import org.springframework.web.reactive.HandlerResult; @@ -119,24 +116,6 @@ public class ResponseBodyResultHandlerTests { testSupports(controller, "handleToMonoResponseEntity", false); } - @Test - public void writeResponseStatus() throws NoSuchMethodException { - Object controller = new TestRestController(); - HandlerMethod hm = handlerMethod(controller, "handleToString"); - HandlerResult handlerResult = new HandlerResult(hm, null, hm.getReturnType()); - - initExchange(); - StepVerifier.create(this.resultHandler.handleResult(this.exchange, handlerResult)).expectComplete().verify(); - assertEquals(HttpStatus.NO_CONTENT, this.response.getStatusCode()); - - hm = handlerMethod(controller, "handleToMonoVoid"); - handlerResult = new HandlerResult(hm, null, hm.getReturnType()); - - initExchange(); - StepVerifier.create(this.resultHandler.handleResult(this.exchange, handlerResult)).expectComplete().verify(); - assertEquals(HttpStatus.CREATED, this.response.getStatusCode()); - } - private void testSupports(Object controller, String method, boolean result) throws NoSuchMethodException { HandlerMethod hm = handlerMethod(controller, method); HandlerResult handlerResult = new HandlerResult(hm, null, hm.getReturnType()); @@ -157,10 +136,8 @@ public class ResponseBodyResultHandlerTests { @RestController @SuppressWarnings("unused") private static class TestRestController { - @ResponseStatus(code = HttpStatus.CREATED) public Mono handleToMonoVoid() { return null;} - @ResponseStatus(code = HttpStatus.NO_CONTENT) public String handleToString() { return null; } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java index 60354364e00efa8a5762acbf5ddf94564f2b4a60..918c25885c265fbf7e875116f70c0e8926c19013 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java @@ -41,7 +41,6 @@ import org.springframework.core.Ordered; import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; -import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; @@ -49,7 +48,6 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse import org.springframework.ui.ConcurrentModel; import org.springframework.ui.Model; import org.springframework.web.bind.annotation.ModelAttribute; -import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.accept.HeaderContentTypeResolver; @@ -131,23 +129,19 @@ public class ViewResolutionResultHandlerTests { returnType = forClass(View.class); returnValue = new TestView("account"); - ServerWebExchange exchange = testHandle("/path", returnType, returnValue, "account: {id=123}"); - assertEquals(HttpStatus.NO_CONTENT, exchange.getResponse().getStatusCode()); + testHandle("/path", returnType, returnValue, "account: {id=123}"); returnType = forClassWithGenerics(Mono.class, View.class); returnValue = Mono.just(new TestView("account")); - exchange = testHandle("/path", returnType, returnValue, "account: {id=123}"); - assertEquals(HttpStatus.SEE_OTHER, exchange.getResponse().getStatusCode()); + testHandle("/path", returnType, returnValue, "account: {id=123}"); returnType = forClass(String.class); returnValue = "account"; - exchange = testHandle("/path", returnType, returnValue, "account: {id=123}", resolver); - assertEquals(HttpStatus.CREATED, exchange.getResponse().getStatusCode()); + testHandle("/path", returnType, returnValue, "account: {id=123}", resolver); returnType = forClassWithGenerics(Mono.class, String.class); returnValue = Mono.just("account"); - exchange = testHandle("/path", returnType, returnValue, "account: {id=123}", resolver); - assertEquals(HttpStatus.PARTIAL_CONTENT, exchange.getResponse().getStatusCode()); + testHandle("/path", returnType, returnValue, "account: {id=123}", resolver); returnType = forClass(Model.class); returnValue = new ConcurrentModel().addAttribute("name", "Joe"); @@ -438,38 +432,61 @@ public class ViewResolutionResultHandlerTests { @SuppressWarnings("unused") private static class TestController { - @ResponseStatus(code = HttpStatus.CREATED) - String string() { return null; } + String string() { + return null; + } - @ResponseStatus(HttpStatus.NO_CONTENT) - View view() { return null; } + View view() { + return null; + } - @ResponseStatus(HttpStatus.PARTIAL_CONTENT) - Mono monoString() { return null; } + Mono monoString() { + return null; + } - @ResponseStatus(code = HttpStatus.SEE_OTHER) - Mono monoView() { return null; } + Mono monoView() { + return null; + } - Mono monoVoid() { return null; } + Mono monoVoid() { + return null; + } - void voidMethod() {} + void voidMethod() { + } - Single singleString() { return null; } + Single singleString() { + return null; + } - Single singleView() { return null; } + Single singleView() { + return null; + } - Completable completable() { return null; } + Completable completable() { + return null; + } - Model model() { return null; } + Model model() { + return null; + } - Map map() { return null; } + Map map() { + return null; + } - TestBean testBean() { return null; } + TestBean testBean() { + return null; + } - Integer integer() { return null; } + Integer integer() { + return null; + } @ModelAttribute("num") - Long longAttribute() { return null; } + Long longAttribute() { + return null; + } } }