提交 34eb6d54 编写于 作者: R Rossen Stoyanchev

Add support for @ExceptionHandler methods

上级 8c89b478
......@@ -25,7 +25,6 @@ import reactor.io.net.http.model.Status;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
/**
......@@ -67,9 +66,11 @@ public class ReactorServerHttpResponse implements ServerHttpResponse {
}
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> contentPublisher) {
applyHeaders();
return this.channel.writeWith(Publishers.map(contentPublisher, Buffer::new));
public Publisher<Void> setBody(Publisher<ByteBuffer> publisher) {
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher -> {
applyHeaders();
return this.channel.writeWith(Publishers.map(writePublisher, Buffer::new));
}));
}
private void applyHeaders() {
......
......@@ -28,7 +28,6 @@ import rx.Observable;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.Assert;
/**
......
......@@ -28,7 +28,6 @@ import rx.Observable;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
/**
......@@ -72,10 +71,12 @@ public class RxNettyServerHttpResponse implements ServerHttpResponse {
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> publisher) {
applyHeaders();
Observable<byte[]> observable = RxJava1Converter.from(publisher).map(
content -> new Buffer(content).asBytes());
return RxJava1Converter.from(this.response.writeBytes(observable));
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher -> {
applyHeaders();
Observable<byte[]> observable = RxJava1Converter.from(writePublisher)
.map(buffer -> new Buffer(buffer).asBytes());
return RxJava1Converter.from(this.response.writeBytes(observable));
}));
}
private void applyHeaders() {
......
......@@ -83,9 +83,11 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
}
@Override
public Publisher<Void> setBody(final Publisher<ByteBuffer> contentPublisher) {
applyHeaders();
return (s -> contentPublisher.subscribe(subscriber));
public Publisher<Void> setBody(final Publisher<ByteBuffer> publisher) {
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher -> {
applyHeaders();
return (s -> writePublisher.subscribe(subscriber));
}));
}
private void applyHeaders() {
......
......@@ -38,6 +38,7 @@ import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.xnio.ChannelListener;
import org.xnio.channels.StreamSinkChannel;
import reactor.Publishers;
import reactor.core.subscriber.BaseSubscriber;
import static org.xnio.ChannelListeners.closingChannelExceptionHandler;
......@@ -74,13 +75,6 @@ public class UndertowServerHttpResponse implements ServerHttpResponse {
this.exchange.setStatusCode(status.value());
}
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> bodyPublisher) {
applyHeaders();
return (subscriber -> bodyPublisher.subscribe(bodySubscriber));
}
@Override
public HttpHeaders getHeaders() {
return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
......@@ -112,6 +106,14 @@ public class UndertowServerHttpResponse implements ServerHttpResponse {
}
}
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> publisher) {
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher -> {
applyHeaders();
return (subscriber -> writePublisher.subscribe(bodySubscriber));
}));
}
private class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer>
implements ChannelListener<StreamSinkChannel> {
......
......@@ -97,7 +97,7 @@ public class WriteWithOperator<T> implements Function<Subscriber<? super Void>,
else if (this.beforeFirstEmission) {
this.item = item;
this.beforeFirstEmission = false;
writeFunction.apply(this).subscribe(downstream());
writeFunction.apply(this).subscribe(new DownstreamBridge(downstream()));
}
else {
subscription.cancel();
......@@ -139,7 +139,7 @@ public class WriteWithOperator<T> implements Function<Subscriber<? super Void>,
else if (this.beforeFirstEmission) {
this.completed = true;
this.beforeFirstEmission = false;
writeFunction.apply(this).subscribe(downstream());
writeFunction.apply(this).subscribe(new DownstreamBridge(downstream()));
}
else {
this.completed = true;
......@@ -148,10 +148,10 @@ public class WriteWithOperator<T> implements Function<Subscriber<? super Void>,
}
@Override
public void subscribe(Subscriber<? super T> subscriber) {
public void subscribe(Subscriber<? super T> writeSubscriber) {
synchronized (this) {
Assert.isNull(this.writeSubscriber, "Only one writeSubscriber supported.");
this.writeSubscriber = subscriber;
this.writeSubscriber = writeSubscriber;
if (this.error != null || this.completed) {
this.writeSubscriber.onSubscribe(NO_OP_SUBSCRIPTION);
......@@ -184,7 +184,7 @@ public class WriteWithOperator<T> implements Function<Subscriber<? super Void>,
@Override
protected void doRequest(long n) {
if (this.readyToWrite) {
if (readyToWrite) {
super.doRequest(n);
return;
}
......@@ -204,6 +204,34 @@ public class WriteWithOperator<T> implements Function<Subscriber<? super Void>,
}
}
private class DownstreamBridge implements Subscriber<Void> {
private final Subscriber<? super Void> downstream;
public DownstreamBridge(Subscriber<? super Void> downstream) {
this.downstream = downstream;
}
@Override
public void onSubscribe(Subscription subscription) {
subscription.request(Long.MAX_VALUE);
}
@Override
public void onNext(Void aVoid) {
}
@Override
public void onError(Throwable ex) {
this.downstream.onError(ex);
}
@Override
public void onComplete() {
this.downstream.onComplete();
}
}
private final static Subscription NO_OP_SUBSCRIPTION = new Subscription() {
@Override
......
......@@ -129,8 +129,24 @@ public class DispatcherHandler implements HttpHandler, ApplicationContextAware {
});
Publisher<Void> completionPublisher = Publishers.concatMap(resultPublisher, result -> {
HandlerResultHandler handler = getResultHandler(result);
return handler.handleResult(request, response, result);
Publisher<Void> publisher;
if (result.hasError()) {
publisher = Publishers.error(result.getError());
}
else {
HandlerResultHandler handler = getResultHandler(result);
publisher = handler.handleResult(request, response, result);
}
if (result.hasExceptionMapper()) {
return Publishers.onErrorResumeNext(publisher, ex -> {
return Publishers.concatMap(result.getExceptionMapper().apply(ex),
errorResult -> {
HandlerResultHandler handler = getResultHandler(errorResult);
return handler.handleResult(request, response, errorResult);
});
});
}
return publisher;
});
return mapError(completionPublisher, this.errorMapper);
......
......@@ -16,11 +16,17 @@
package org.springframework.web.reactive;
import java.util.function.Function;
import java.util.logging.Handler;
import org.reactivestreams.Publisher;
import reactor.Publishers;
import org.springframework.core.ResolvableType;
import org.springframework.util.Assert;
/**
* Represent the result of the invocation of an handler.
* Represent the result of the invocation of a handler.
*
* @author Rossen Stoyanchev
*/
......@@ -32,6 +38,10 @@ public class HandlerResult {
private final ResolvableType resultType;
private final Throwable error;
private Function<Throwable, Publisher<HandlerResult>> exceptionMapper;
public HandlerResult(Object handler, Object result, ResolvableType resultType) {
Assert.notNull(handler, "'handler' is required");
......@@ -39,6 +49,16 @@ public class HandlerResult {
this.handler = handler;
this.result = result;
this.resultType = resultType;
this.error = null;
}
public HandlerResult(Object handler, Throwable error) {
Assert.notNull(handler, "'handler' is required");
Assert.notNull(error, "'error' is required");
this.handler = handler;
this.result = null;
this.resultType = null;
this.error = error;
}
......@@ -54,4 +74,38 @@ public class HandlerResult {
return this.resultType;
}
public Throwable getError() {
return this.error;
}
/**
* Whether handler invocation produced a result or failed with an error.
* <p>If {@code true} the {@link #getError()} returns the error while
* {@link #getResult()} and {@link #getResultType()} return {@code null}
* and vice versa.
* @return whether this instance contains a result or an error.
*/
public boolean hasError() {
return (this.error != null);
}
/**
* Configure a function for selecting an alternate {@code HandlerResult} in
* case of an {@link #hasError() error result} or in case of an async result
* that results in an error.
* @param function the exception resolving function
*/
public HandlerResult setExceptionMapper(Function<Throwable, Publisher<HandlerResult>> function) {
this.exceptionMapper = function;
return this;
}
public Function<Throwable, Publisher<HandlerResult>> getExceptionMapper() {
return this.exceptionMapper;
}
public boolean hasExceptionMapper() {
return (this.exceptionMapper != null);
}
}
......@@ -62,6 +62,10 @@ public class InvocableHandlerMethod extends HandlerMethod {
super(handlerMethod);
}
public InvocableHandlerMethod(Object bean, Method method) {
super(bean, method);
}
public void setHandlerMethodArgumentResolvers(List<HandlerMethodArgumentResolver> resolvers) {
this.resolvers.clear();
......@@ -75,9 +79,10 @@ public class InvocableHandlerMethod extends HandlerMethod {
/**
*
* @param request
* @param providedArgs
* Invoke the method and return a Publisher for the return value.
* @param request the current request
* @param providedArgs optional list of argument values to check by type
* (via {@code instanceof}) for resolving method arguments.
* @return Publisher that produces a single HandlerResult or an error signal;
* never throws an exception.
*/
......@@ -98,11 +103,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
return Publishers.concatMap(argsPublisher, args -> {
try {
Object value = doInvoke(args);
HandlerMethod handlerMethod = InvocableHandlerMethod.this;
ResolvableType type = ResolvableType.forMethodParameter(handlerMethod.getReturnType());
HandlerResult handlerResult = new HandlerResult(handlerMethod, value, type);
ResolvableType type = ResolvableType.forMethodParameter(getReturnType());
HandlerResult handlerResult = new HandlerResult(this, value, type);
return Publishers.just(handlerResult);
}
catch (InvocationTargetException ex) {
......@@ -187,9 +189,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
}
private static <E> Publisher<E> mapError(Publisher<E> source, Function<Throwable, Throwable> function) {
return Publishers.lift(source, null, (throwable, subscriber) -> {
subscriber.onError(function.apply(throwable));
}, null);
return Publishers.lift(source, null,
(throwable, subscriber) -> subscriber.onError(function.apply(throwable)), null);
}
}
......@@ -16,30 +16,35 @@
package org.springframework.web.reactive.method.annotation;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.Publishers;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.ResolvableType;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.core.codec.support.ByteBufferDecoder;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.support.ByteBufferDecoder;
import org.springframework.core.codec.support.JacksonJsonDecoder;
import org.springframework.core.codec.support.JsonObjectDecoder;
import org.springframework.core.codec.support.StringDecoder;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.ObjectUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.annotation.ExceptionHandlerMethodResolver;
import org.springframework.web.reactive.HandlerAdapter;
import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.method.HandlerMethodArgumentResolver;
import org.springframework.web.reactive.method.InvocableHandlerMethod;
import org.springframework.web.method.HandlerMethod;
/**
......@@ -47,16 +52,29 @@ import org.springframework.web.method.HandlerMethod;
*/
public class RequestMappingHandlerAdapter implements HandlerAdapter, InitializingBean {
private static Log logger = LogFactory.getLog(RequestMappingHandlerAdapter.class);
private final List<HandlerMethodArgumentResolver> argumentResolvers = new ArrayList<>();
private ConversionService conversionService = new DefaultConversionService();
private final Map<Class<?>, ExceptionHandlerMethodResolver> exceptionHandlerCache =
new ConcurrentHashMap<Class<?>, ExceptionHandlerMethodResolver>(64);
/**
* Configure the complete list of supported argument types thus overriding
* the resolvers that would otherwise be configured by default.
*/
public void setArgumentResolvers(List<HandlerMethodArgumentResolver> resolvers) {
this.argumentResolvers.clear();
this.argumentResolvers.addAll(resolvers);
}
/**
* Return the configured argument resolvers.
*/
public List<HandlerMethodArgumentResolver> getArgumentResolvers() {
return this.argumentResolvers;
}
......@@ -91,9 +109,59 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, Initializin
public Publisher<HandlerResult> handle(ServerHttpRequest request,
ServerHttpResponse response, Object handler) {
InvocableHandlerMethod handlerMethod = new InvocableHandlerMethod((HandlerMethod) handler);
handlerMethod.setHandlerMethodArgumentResolvers(this.argumentResolvers);
return handlerMethod.invokeForRequest(request);
HandlerMethod handlerMethod = (HandlerMethod) handler;
InvocableHandlerMethod invocable = new InvocableHandlerMethod(handlerMethod);
invocable.setHandlerMethodArgumentResolvers(this.argumentResolvers);
Publisher<HandlerResult> publisher = invocable.invokeForRequest(request);
publisher = Publishers.onErrorResumeNext(publisher, ex -> {
return Publishers.just(new HandlerResult(handler, ex));
});
publisher = Publishers.map(publisher,
result -> result.setExceptionMapper(
ex -> mapException((Exception) ex, handlerMethod, request, response)));
return publisher;
}
private Publisher<HandlerResult> mapException(Throwable ex, HandlerMethod handlerMethod,
ServerHttpRequest request, ServerHttpResponse response) {
if (ex instanceof Exception) {
InvocableHandlerMethod invocable = findExceptionHandler(handlerMethod, (Exception) ex);
if (invocable != null) {
try {
if (logger.isDebugEnabled()) {
logger.debug("Invoking @ExceptionHandler method: " + invocable);
}
invocable.setHandlerMethodArgumentResolvers(getArgumentResolvers());
return invocable.invokeForRequest(request, response, ex);
}
catch (Exception invocationEx) {
if (logger.isErrorEnabled()) {
logger.error("Failed to invoke @ExceptionHandler method: " + invocable, invocationEx);
}
}
}
}
return Publishers.error(ex);
}
protected InvocableHandlerMethod findExceptionHandler(HandlerMethod handlerMethod, Exception exception) {
if (handlerMethod == null) {
return null;
}
Class<?> handlerType = handlerMethod.getBeanType();
ExceptionHandlerMethodResolver resolver = this.exceptionHandlerCache.get(handlerType);
if (resolver == null) {
resolver = new ExceptionHandlerMethodResolver(handlerType);
this.exceptionHandlerCache.put(handlerType, resolver);
}
Method method = resolver.resolveMethod(exception);
return (method != null ? new InvocableHandlerMethod(handlerMethod.getBean(), method) : null);
}
}
\ No newline at end of file
......@@ -21,7 +21,6 @@ import org.junit.Before;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.boot.HttpServer;
import org.springframework.http.server.reactive.boot.JettyHttpServer;
import org.springframework.http.server.reactive.boot.ReactorHttpServer;
......
......@@ -31,7 +31,7 @@ import reactor.Publishers;
import reactor.core.publisher.PublisherFactory;
import reactor.core.subscriber.SubscriberBarrier;
import reactor.rx.Streams;
import reactor.rx.action.Signal;
import reactor.rx.stream.Signal;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
......
......@@ -55,6 +55,7 @@ import org.springframework.http.ResponseEntity;
import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
......@@ -122,6 +123,30 @@ public class RequestMappingIntegrationTests extends AbstractHttpHandlerIntegrati
assertEquals("Hello!", response.getBody());
}
@Test
public void handleWithThrownException() throws Exception {
RestTemplate restTemplate = new RestTemplate();
URI url = new URI("http://localhost:" + port + "/thrown-exception");
RequestEntity<Void> request = RequestEntity.get(url).build();
ResponseEntity<String> response = restTemplate.exchange(request, String.class);
assertEquals("Recovered from error: Boo", response.getBody());
}
@Test
public void handleWithErrorSignal() throws Exception {
RestTemplate restTemplate = new RestTemplate();
URI url = new URI("http://localhost:" + port + "/error-signal");
RequestEntity<Void> request = RequestEntity.get(url).build();
ResponseEntity<String> response = restTemplate.exchange(request, String.class);
assertEquals("Recovered from error: Boo", response.getBody());
}
@Test
public void serializeAsPojo() throws Exception {
serializeAsPojo("http://localhost:" + port + "/person");
......@@ -478,6 +503,24 @@ public class RequestMappingIntegrationTests extends AbstractHttpHandlerIntegrati
return personStream.toList().doOnNext(persons::addAll).flatMap(document -> Observable.empty());
}
@RequestMapping("/thrown-exception")
@ResponseBody
public Publisher<String> handleAndThrowException() {
throw new IllegalStateException("Boo");
}
@RequestMapping("/error-signal")
@ResponseBody
public Publisher<String> handleWithError() {
return Publishers.error(new IllegalStateException("Boo"));
}
@ExceptionHandler
@ResponseBody
public Publisher<String> handleException(IllegalStateException ex) {
return Streams.just("Recovered from error: " + ex.getMessage());
}
//TODO add mixed and T request mappings tests
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册