提交 ccb2c653 编写于 作者: R Rossen Stoyanchev

Support for @ControllerAdvice in WebFlux

Issue: SPR-15132
上级 24034447
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*/ */
package org.springframework.test.web.reactive.server; package org.springframework.test.web.reactive.server;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
...@@ -44,6 +45,8 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec { ...@@ -44,6 +45,8 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec {
private final List<Object> controllers; private final List<Object> controllers;
private final List<Object> controllerAdvice = new ArrayList<>(8);
private final TestWebFluxConfigurer configurer = new TestWebFluxConfigurer(); private final TestWebFluxConfigurer configurer = new TestWebFluxConfigurer();
...@@ -53,6 +56,12 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec { ...@@ -53,6 +56,12 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec {
} }
@Override
public DefaultControllerSpec controllerAdvice(Object... controllerAdvice) {
this.controllerAdvice.addAll(Arrays.asList(controllerAdvice));
return this;
}
@Override @Override
public DefaultControllerSpec contentTypeResolver(Consumer<RequestedContentTypeResolverBuilder> consumer) { public DefaultControllerSpec contentTypeResolver(Consumer<RequestedContentTypeResolverBuilder> consumer) {
this.configurer.contentTypeResolverConsumer = consumer; this.configurer.contentTypeResolverConsumer = consumer;
...@@ -103,12 +112,17 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec { ...@@ -103,12 +112,17 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec {
@Override @Override
public WebTestClient.Builder configureClient() { public WebTestClient.Builder configureClient() {
return WebTestClient.bindToApplicationContext(createApplicationContext());
}
protected AnnotationConfigApplicationContext createApplicationContext() {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
this.controllers.forEach(controller -> registerBean(context, controller)); this.controllers.forEach(controller -> registerBean(context, controller));
this.controllerAdvice.forEach(advice -> registerBean(context, advice));
context.register(DelegatingWebFluxConfiguration.class); context.register(DelegatingWebFluxConfiguration.class);
context.registerBean(WebFluxConfigurer.class, () -> this.configurer); context.registerBean(WebFluxConfigurer.class, () -> this.configurer);
context.refresh(); context.refresh();
return WebTestClient.bindToApplicationContext(context); return context;
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
......
...@@ -179,6 +179,13 @@ public interface WebTestClient { ...@@ -179,6 +179,13 @@ public interface WebTestClient {
*/ */
interface ControllerSpec { interface ControllerSpec {
/**
* Register one or more
* {@link org.springframework.web.bind.annotation.ControllerAdvice
* ControllerAdvice} instances to be used in tests.
*/
ControllerSpec controllerAdvice(Object... controllerAdvice);
/** /**
* Customize content type resolution. * Customize content type resolution.
* @see WebFluxConfigurer#configureContentTypeResolver * @see WebFluxConfigurer#configureContentTypeResolver
......
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.test.web.reactive.server;
import org.junit.Test;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import static org.junit.Assert.assertSame;
/**
* Unit tests for {@link DefaultControllerSpec}.
* @author Rossen Stoyanchev
*/
public class DefaultControllerSpecTests {
@Test
public void controllers() throws Exception {
OneController controller1 = new OneController();
SecondController controller2 = new SecondController();
TestControllerSpec spec = new TestControllerSpec(controller1, controller2);
ApplicationContext context = spec.createApplicationContext();
assertSame(controller1, context.getBean(OneController.class));
assertSame(controller2, context.getBean(SecondController.class));
}
@Test
public void controllerAdvice() throws Exception {
OneControllerAdvice advice = new OneControllerAdvice();
TestControllerSpec spec = new TestControllerSpec(new OneController());
spec.controllerAdvice(advice);
ApplicationContext context = spec.createApplicationContext();
assertSame(advice, context.getBean(OneControllerAdvice.class));
}
private static class OneController {}
private static class SecondController {}
private static class OneControllerAdvice {}
private static class TestControllerSpec extends DefaultControllerSpec {
TestControllerSpec(Object... controllers) {
super(controllers);
}
@Override
public AnnotationConfigApplicationContext createApplicationContext() {
return super.createApplicationContext();
}
}
}
...@@ -125,6 +125,17 @@ public class ExceptionHandlerMethodResolver { ...@@ -125,6 +125,17 @@ public class ExceptionHandlerMethodResolver {
* @return a Method to handle the exception, or {@code null} if none found * @return a Method to handle the exception, or {@code null} if none found
*/ */
public Method resolveMethod(Exception exception) { public Method resolveMethod(Exception exception) {
return resolveMethod(exception);
}
/**
* Find a {@link Method} to handle the given Throwable.
* Use {@link ExceptionDepthComparator} if more than one match is found.
* @param exception the exception
* @return a Method to handle the exception, or {@code null} if none found
* @since 5.0
*/
public Method resolveMethodByThrowable(Throwable exception) {
Method method = resolveMethodByExceptionType(exception.getClass()); Method method = resolveMethodByExceptionType(exception.getClass());
if (method == null) { if (method == null) {
Throwable cause = exception.getCause(); Throwable cause = exception.getCause();
......
...@@ -18,24 +18,26 @@ package org.springframework.web.reactive.result.method.annotation; ...@@ -18,24 +18,26 @@ package org.springframework.web.reactive.result.method.annotation;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.core.MethodIntrospector; import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.codec.ByteArrayDecoder; import org.springframework.core.codec.ByteArrayDecoder;
import org.springframework.core.codec.ByteBufferDecoder; import org.springframework.core.codec.ByteBufferDecoder;
...@@ -43,11 +45,13 @@ import org.springframework.core.codec.DataBufferDecoder; ...@@ -43,11 +45,13 @@ import org.springframework.core.codec.DataBufferDecoder;
import org.springframework.core.codec.StringDecoder; import org.springframework.core.codec.StringDecoder;
import org.springframework.http.codec.DecoderHttpMessageReader; import org.springframework.http.codec.DecoderHttpMessageReader;
import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageReader;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.InitBinder;
import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.support.WebBindingInitializer; import org.springframework.web.bind.support.WebBindingInitializer;
import org.springframework.web.method.ControllerAdviceBean;
import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.annotation.ExceptionHandlerMethodResolver; import org.springframework.web.method.annotation.ExceptionHandlerMethodResolver;
import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.BindingContext;
...@@ -59,13 +63,15 @@ import org.springframework.web.reactive.result.method.SyncHandlerMethodArgumentR ...@@ -59,13 +63,15 @@ import org.springframework.web.reactive.result.method.SyncHandlerMethodArgumentR
import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod; import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import static org.springframework.core.MethodIntrospector.selectMethods;
/** /**
* Supports the invocation of {@code @RequestMapping} methods. * Supports the invocation of {@code @RequestMapping} methods.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 5.0 * @since 5.0
*/ */
public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactoryAware, InitializingBean { public class RequestMappingHandlerAdapter implements HandlerAdapter, ApplicationContextAware, InitializingBean {
private static final Log logger = LogFactory.getLog(RequestMappingHandlerAdapter.class); private static final Log logger = LogFactory.getLog(RequestMappingHandlerAdapter.class);
...@@ -84,10 +90,8 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory ...@@ -84,10 +90,8 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory
private List<SyncHandlerMethodArgumentResolver> initBinderArgumentResolvers; private List<SyncHandlerMethodArgumentResolver> initBinderArgumentResolvers;
private ConfigurableBeanFactory beanFactory; private ConfigurableApplicationContext applicationContext;
private ModelInitializer modelInitializer;
private final Map<Class<?>, Set<Method>> binderMethodCache = new ConcurrentHashMap<>(64); private final Map<Class<?>, Set<Method>> binderMethodCache = new ConcurrentHashMap<>(64);
...@@ -97,6 +101,18 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory ...@@ -97,6 +101,18 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory
new ConcurrentHashMap<>(64); new ConcurrentHashMap<>(64);
private final Map<ControllerAdviceBean, Set<Method>> binderAdviceCache = new LinkedHashMap<>(64);
private final Map<ControllerAdviceBean, Set<Method>> attributeAdviceCache = new LinkedHashMap<>(64);
private final Map<ControllerAdviceBean, ExceptionHandlerMethodResolver> exceptionHandlerAdviceCache =
new LinkedHashMap<>(64);
private ModelInitializer modelInitializer;
public RequestMappingHandlerAdapter() { public RequestMappingHandlerAdapter() {
this.messageReaders.add(new DecoderHttpMessageReader<>(new ByteArrayDecoder())); this.messageReaders.add(new DecoderHttpMessageReader<>(new ByteArrayDecoder()));
this.messageReaders.add(new DecoderHttpMessageReader<>(new ByteBufferDecoder())); this.messageReaders.add(new DecoderHttpMessageReader<>(new ByteBufferDecoder()));
...@@ -204,25 +220,30 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory ...@@ -204,25 +220,30 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory
return this.initBinderArgumentResolvers; return this.initBinderArgumentResolvers;
} }
/** /**
* A {@link ConfigurableBeanFactory} is expected for resolving expressions * A {@link ConfigurableApplicationContext} is expected for resolving
* in method argument default values. * expressions in method argument default values as well as for
* detecting {@code @ControllerAdvice} beans.
*/ */
@Override @Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException { public void setApplicationContext(ApplicationContext applicationContext) {
if (beanFactory instanceof ConfigurableBeanFactory) { if (applicationContext instanceof ConfigurableApplicationContext) {
this.beanFactory = (ConfigurableBeanFactory) beanFactory; this.applicationContext = (ConfigurableApplicationContext) applicationContext;
} }
} }
public ConfigurableApplicationContext getApplicationContext() {
return this.applicationContext;
}
public ConfigurableBeanFactory getBeanFactory() { public ConfigurableBeanFactory getBeanFactory() {
return this.beanFactory; return this.applicationContext.getBeanFactory();
} }
@Override @Override
public void afterPropertiesSet() throws Exception { public void afterPropertiesSet() throws Exception {
initControllerAdviceCache();
if (this.argumentResolvers == null) { if (this.argumentResolvers == null) {
this.argumentResolvers = getDefaultArgumentResolvers(); this.argumentResolvers = getDefaultArgumentResolvers();
} }
...@@ -232,6 +253,43 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory ...@@ -232,6 +253,43 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory
this.modelInitializer = new ModelInitializer(getReactiveAdapterRegistry()); this.modelInitializer = new ModelInitializer(getReactiveAdapterRegistry());
} }
private void initControllerAdviceCache() {
if (getApplicationContext() == null) {
return;
}
if (logger.isInfoEnabled()) {
logger.info("Looking for @ControllerAdvice: " + getApplicationContext());
}
List<ControllerAdviceBean> beans = ControllerAdviceBean.findAnnotatedBeans(getApplicationContext());
AnnotationAwareOrderComparator.sort(beans);
for (ControllerAdviceBean bean : beans) {
Class<?> beanType = bean.getBeanType();
Set<Method> attrMethods = selectMethods(beanType, ATTRIBUTE_METHODS);
if (!attrMethods.isEmpty()) {
this.attributeAdviceCache.put(bean, attrMethods);
if (logger.isInfoEnabled()) {
logger.info("Detected @ModelAttribute methods in " + bean);
}
}
Set<Method> binderMethods = selectMethods(beanType, BINDER_METHODS);
if (!binderMethods.isEmpty()) {
this.binderAdviceCache.put(bean, binderMethods);
if (logger.isInfoEnabled()) {
logger.info("Detected @InitBinder methods in " + bean);
}
}
ExceptionHandlerMethodResolver resolver = new ExceptionHandlerMethodResolver(beanType);
if (resolver.hasExceptionMappings()) {
this.exceptionHandlerAdviceCache.put(bean, resolver);
if (logger.isInfoEnabled()) {
logger.info("Detected @ExceptionHandler methods in " + bean);
}
}
}
}
protected List<HandlerMethodArgumentResolver> getDefaultArgumentResolvers() { protected List<HandlerMethodArgumentResolver> getDefaultArgumentResolvers() {
List<HandlerMethodArgumentResolver> resolvers = new ArrayList<>(); List<HandlerMethodArgumentResolver> resolvers = new ArrayList<>();
...@@ -305,80 +363,97 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory ...@@ -305,80 +363,97 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory
@Override @Override
public Mono<HandlerResult> handle(ServerWebExchange exchange, Object handler) { public Mono<HandlerResult> handle(ServerWebExchange exchange, Object handler) {
Assert.notNull(handler, "Expected handler");
HandlerMethod handlerMethod = (HandlerMethod) handler; HandlerMethod handlerMethod = (HandlerMethod) handler;
BindingContext bindingContext = new InitBinderBindingContext( BindingContext bindingContext = new InitBinderBindingContext(
getWebBindingInitializer(), getBinderMethods(handlerMethod)); getWebBindingInitializer(), getBinderMethods(handlerMethod));
Mono<Void> modelCompletion = this.modelInitializer.initModel( return this.modelInitializer
bindingContext, getAttributeMethods(handlerMethod), exchange); .initModel(bindingContext, getAttributeMethods(handlerMethod), exchange)
.then(() -> {
Function<Throwable, Mono<HandlerResult>> exceptionHandler = Function<Throwable, Mono<HandlerResult>> exceptionHandler =
ex -> handleException(ex, handlerMethod, bindingContext, exchange); ex -> handleException(exchange, handlerMethod, bindingContext, ex);
return modelCompletion.then(() -> {
InvocableHandlerMethod invocable = new InvocableHandlerMethod(handlerMethod); InvocableHandlerMethod invocable = new InvocableHandlerMethod(handlerMethod);
invocable.setArgumentResolvers(getArgumentResolvers()); invocable.setArgumentResolvers(getArgumentResolvers());
return invocable.invoke(exchange, bindingContext) return invocable.invoke(exchange, bindingContext)
.doOnNext(result -> result.setExceptionHandler(exceptionHandler)) .doOnNext(result -> result.setExceptionHandler(exceptionHandler))
.otherwise(exceptionHandler); .otherwise(exceptionHandler);
}); });
} }
private List<SyncInvocableHandlerMethod> getBinderMethods(HandlerMethod handlerMethod) { private List<SyncInvocableHandlerMethod> getBinderMethods(HandlerMethod handlerMethod) {
List<SyncInvocableHandlerMethod> result = new ArrayList<>();
Class<?> handlerType = handlerMethod.getBeanType(); Class<?> handlerType = handlerMethod.getBeanType();
Set<Method> methods = this.binderMethodCache.computeIfAbsent(handlerType, aClass -> // Global methods first
MethodIntrospector.selectMethods(handlerType, BINDER_METHODS)); this.binderAdviceCache.entrySet().forEach(entry -> {
if (entry.getKey().isApplicableToBeanType(handlerType)) {
Object bean = entry.getKey().resolveBean();
entry.getValue().forEach(method -> result.add(createBinderMethod(bean, method)));
}
});
return methods.stream() this.binderMethodCache
.map(method -> { .computeIfAbsent(handlerType, aClass -> selectMethods(handlerType, BINDER_METHODS))
.forEach(method -> {
Object bean = handlerMethod.getBean(); Object bean = handlerMethod.getBean();
SyncInvocableHandlerMethod invocable = new SyncInvocableHandlerMethod(bean, method); result.add(createBinderMethod(bean, method));
invocable.setSyncArgumentResolvers(getInitBinderArgumentResolvers()); });
return invocable;
}) return result;
.collect(Collectors.toList()); }
private SyncInvocableHandlerMethod createBinderMethod(Object bean, Method method) {
SyncInvocableHandlerMethod invocable = new SyncInvocableHandlerMethod(bean, method);
invocable.setSyncArgumentResolvers(getInitBinderArgumentResolvers());
return invocable;
} }
private List<InvocableHandlerMethod> getAttributeMethods(HandlerMethod handlerMethod) { private List<InvocableHandlerMethod> getAttributeMethods(HandlerMethod handlerMethod) {
List<InvocableHandlerMethod> result = new ArrayList<>();
Class<?> handlerType = handlerMethod.getBeanType(); Class<?> handlerType = handlerMethod.getBeanType();
Set<Method> methods = this.attributeMethodCache.computeIfAbsent(handlerType, aClass -> // Global methods first
MethodIntrospector.selectMethods(handlerType, ATTRIBUTE_METHODS)); this.attributeAdviceCache.entrySet().forEach(entry -> {
if (entry.getKey().isApplicableToBeanType(handlerType)) {
Object bean = entry.getKey().resolveBean();
entry.getValue().forEach(method -> result.add(createHandlerMethod(bean, method)));
}
});
return methods.stream() this.attributeMethodCache
.map(method -> { .computeIfAbsent(handlerType, aClass -> selectMethods(handlerType, ATTRIBUTE_METHODS))
.forEach(method -> {
Object bean = handlerMethod.getBean(); Object bean = handlerMethod.getBean();
InvocableHandlerMethod invocable = new InvocableHandlerMethod(bean, method); result.add(createHandlerMethod(bean, method));
invocable.setArgumentResolvers(getArgumentResolvers()); });
return invocable;
})
.collect(Collectors.toList());
}
private Mono<HandlerResult> handleException(Throwable ex, HandlerMethod handlerMethod, return result;
BindingContext bindingContext, ServerWebExchange exchange) { }
ExceptionHandlerMethodResolver resolver = this.exceptionHandlerCache private InvocableHandlerMethod createHandlerMethod(Object bean, Method method) {
.computeIfAbsent(handlerMethod.getBeanType(), ExceptionHandlerMethodResolver::new); InvocableHandlerMethod invocable = new InvocableHandlerMethod(bean, method);
invocable.setArgumentResolvers(getArgumentResolvers());
return invocable;
}
Method method = resolver.resolveMethodByExceptionType(ex.getClass()); private Mono<HandlerResult> handleException(ServerWebExchange exchange, HandlerMethod handlerMethod,
BindingContext bindingContext, Throwable ex) {
if (method != null) { InvocableHandlerMethod invocable = getExceptionHandlerMethod(ex, handlerMethod);
Object bean = handlerMethod.getBean(); if (invocable != null) {
InvocableHandlerMethod invocable = new InvocableHandlerMethod(bean, method);
invocable.setArgumentResolvers(getArgumentResolvers());
try { try {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Invoking @ExceptionHandler method: " + invocable.getMethod()); logger.debug("Invoking @ExceptionHandler method: " + invocable.getMethod());
} }
bindingContext.getModel().asMap().clear(); bindingContext.getModel().asMap().clear();
return invocable.invoke(exchange, bindingContext, ex); Throwable cause = ex.getCause() != null ? ex.getCause() : ex;
return invocable.invoke(exchange, bindingContext, cause, handlerMethod);
} }
catch (Throwable invocationEx) { catch (Throwable invocationEx) {
if (logger.isWarnEnabled()) { if (logger.isWarnEnabled()) {
...@@ -386,10 +461,36 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory ...@@ -386,10 +461,36 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory
} }
} }
} }
return Mono.error(ex); return Mono.error(ex);
} }
private InvocableHandlerMethod getExceptionHandlerMethod(Throwable ex, HandlerMethod handlerMethod) {
Class<?> handlerType = handlerMethod.getBeanType();
ExceptionHandlerMethodResolver resolver = this.exceptionHandlerCache
.computeIfAbsent(handlerType, ExceptionHandlerMethodResolver::new);
return Optional
.ofNullable(resolver.resolveMethodByThrowable(ex))
.map(method -> createHandlerMethod(handlerMethod.getBean(), method))
.orElseGet(() ->
this.exceptionHandlerAdviceCache.entrySet().stream()
.map(entry -> {
if (entry.getKey().isApplicableToBeanType(handlerType)) {
Method method = entry.getValue().resolveMethodByThrowable(ex);
if (method != null) {
Object bean = entry.getKey().resolveBean();
return createHandlerMethod(bean, method);
}
}
return null;
})
.filter(Objects::nonNull)
.findFirst()
.orElse(null));
}
/** /**
* MethodFilter that matches {@link InitBinder @InitBinder} methods. * MethodFilter that matches {@link InitBinder @InitBinder} methods.
......
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.result.method.annotation;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.FatalBeanException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.Order;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.util.ClassUtils;
import org.springframework.validation.Validator;
import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.InitBinder;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.support.WebExchangeDataBinder;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
/**
* {@code @ControllerAdvice} related tests for {@link RequestMappingHandlerAdapter}.
* @author Rossen Stoyanchev
*/
public class ControllerAdviceTests {
private ServerWebExchange exchange;
@Before
public void setUp() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
MockServerHttpResponse response = new MockServerHttpResponse();
this.exchange = new DefaultServerWebExchange(request, response);
}
@Test
public void resolveExceptionGlobalHandler() throws Exception {
testException(new IllegalAccessException(), "SecondControllerAdvice: IllegalAccessException");
}
@Test
public void resolveExceptionGlobalHandlerOrdered() throws Exception {
testException(new IllegalStateException(), "OneControllerAdvice: IllegalStateException");
}
@Test // SPR-12605
public void resolveExceptionWithHandlerMethodArg() throws Exception {
testException(new ArrayIndexOutOfBoundsException(), "HandlerMethod: handle");
}
@Test
public void resolveExceptionWithAssertionError() throws Exception {
AssertionError error = new AssertionError("argh");
testException(error, error.toString());
}
@Test
public void resolveExceptionWithAssertionErrorAsRootCause() throws Exception {
AssertionError cause = new AssertionError("argh");
FatalBeanException exception = new FatalBeanException("wrapped", cause);
testException(exception, cause.toString());
}
private void testException(Throwable exception, String expected) throws Exception {
ApplicationContext context = new AnnotationConfigApplicationContext(TestConfig.class);
RequestMappingHandlerAdapter adapter = createAdapter(context);
TestController controller = context.getBean(TestController.class);
controller.setException(exception);
Object actual = handle(adapter, controller, "handle").getReturnValue().orElse(null);
assertEquals(expected, actual);
}
@Test
public void modelAttributeAdvice() throws Exception {
ApplicationContext context = new AnnotationConfigApplicationContext(TestConfig.class);
RequestMappingHandlerAdapter adapter = createAdapter(context);
TestController controller = context.getBean(TestController.class);
Model model = handle(adapter, controller, "handle").getModel();
assertEquals(2, model.asMap().size());
assertEquals("lAttr1", model.asMap().get("attr1"));
assertEquals("gAttr2", model.asMap().get("attr2"));
}
@Test
public void initBinderAdvice() throws Exception {
ApplicationContext context = new AnnotationConfigApplicationContext(TestConfig.class);
RequestMappingHandlerAdapter adapter = createAdapter(context);
TestController controller = context.getBean(TestController.class);
Validator validator = mock(Validator.class);
controller.setValidator(validator);
BindingContext bindingContext = handle(adapter, controller, "handle").getBindingContext();
WebExchangeDataBinder binder = bindingContext.createDataBinder(this.exchange, "name");
assertEquals(Collections.singletonList(validator), binder.getValidators());
}
private RequestMappingHandlerAdapter createAdapter(ApplicationContext context) throws Exception {
RequestMappingHandlerAdapter adapter = new RequestMappingHandlerAdapter();
adapter.setApplicationContext(context);
adapter.afterPropertiesSet();
return adapter;
}
private HandlerResult handle(RequestMappingHandlerAdapter adapter,
Object controller, String methodName) throws Exception {
Method method = controller.getClass().getMethod(methodName);
HandlerMethod handlerMethod = new HandlerMethod(controller, method);
return adapter.handle(this.exchange, handlerMethod).block(Duration.ZERO);
}
@Configuration
static class TestConfig {
@Bean
public TestController testController() {
return new TestController();
}
@Bean
public OneControllerAdvice testExceptionResolver() {
return new OneControllerAdvice();
}
@Bean
public SecondControllerAdvice anotherTestExceptionResolver() {
return new SecondControllerAdvice();
}
}
@Controller
static class TestController {
private Validator validator;
private Throwable exception;
void setValidator(Validator validator) {
this.validator = validator;
}
void setException(Throwable exception) {
this.exception = exception;
}
@InitBinder
public void initDataBinder(WebDataBinder dataBinder) {
if (this.validator != null) {
dataBinder.addValidators(this.validator);
}
}
@ModelAttribute
public void addAttributes(Model model) {
model.addAttribute("attr1", "lAttr1");
}
@GetMapping
public void handle() throws Throwable {
if (this.exception != null) {
throw this.exception;
}
}
}
@ControllerAdvice
@Order(1)
static class OneControllerAdvice {
@ModelAttribute
public void addAttributes(Model model) {
model.addAttribute("attr1", "gAttr1");
model.addAttribute("attr2", "gAttr2");
}
@ExceptionHandler
public String handleException(IllegalStateException ex) {
return "OneControllerAdvice: " + ClassUtils.getShortName(ex.getClass());
}
@ExceptionHandler(ArrayIndexOutOfBoundsException.class)
public String handleWithHandlerMethod(HandlerMethod handlerMethod) {
return "HandlerMethod: " + handlerMethod.getMethod().getName();
}
@ExceptionHandler(AssertionError.class)
public String handleAssertionError(Error err) {
return err.toString();
}
}
@ControllerAdvice
@Order(2)
static class SecondControllerAdvice {
@ExceptionHandler({IllegalStateException.class, IllegalAccessException.class})
public String handleException(Exception ex) {
return "SecondControllerAdvice: " + ClassUtils.getShortName(ex.getClass());
}
}
}
...@@ -31,7 +31,6 @@ import org.springframework.core.ReactiveAdapterRegistry; ...@@ -31,7 +31,6 @@ import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.ui.Model; import org.springframework.ui.Model;
import org.springframework.util.ObjectUtils;
import org.springframework.validation.Validator; import org.springframework.validation.Validator;
import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.InitBinder;
...@@ -48,8 +47,8 @@ import org.springframework.web.server.adapter.DefaultServerWebExchange; ...@@ -48,8 +47,8 @@ import org.springframework.web.server.adapter.DefaultServerWebExchange;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter.BINDER_METHODS;
import static org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter.ATTRIBUTE_METHODS; import static org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter.ATTRIBUTE_METHODS;
import static org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter.BINDER_METHODS;
/** /**
* Unit tests for {@link ModelInitializer}. * Unit tests for {@link ModelInitializer}.
...@@ -76,8 +75,10 @@ public class ModelInitializerTests { ...@@ -76,8 +75,10 @@ public class ModelInitializerTests {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
public void basic() throws Exception { public void basic() throws Exception {
TestController controller = new TestController();
Validator validator = mock(Validator.class); Validator validator = mock(Validator.class);
Object controller = new TestController(validator); controller.setValidator(validator);
List<SyncInvocableHandlerMethod> binderMethods = getBinderMethods(controller); List<SyncInvocableHandlerMethod> binderMethods = getBinderMethods(controller);
List<InvocableHandlerMethod> attributeMethods = getAttributeMethods(controller); List<InvocableHandlerMethod> attributeMethods = getAttributeMethods(controller);
...@@ -131,16 +132,18 @@ public class ModelInitializerTests { ...@@ -131,16 +132,18 @@ public class ModelInitializerTests {
@SuppressWarnings("unused") @SuppressWarnings("unused")
private static class TestController { private static class TestController {
private Validator[] validators; private Validator validator;
public TestController(Validator... validators) {
this.validators = validators; void setValidator(Validator validator) {
this.validator = validator;
} }
@InitBinder @InitBinder
public void initDataBinder(WebDataBinder dataBinder) { public void initDataBinder(WebDataBinder dataBinder) {
if (!ObjectUtils.isEmpty(this.validators)) { if (this.validator != null) {
dataBinder.addValidators(this.validators); dataBinder.addValidators(this.validator);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册