提交 3272a3b8 编写于 作者: R Rossen Stoyanchev

Check HTTP method before raising 415

This commit moves the check whether an HTTP method supports request
body up to the base class so that all sub-classes can benefit (not just
@RequestBody).

Issue: SPR-13176
上级 244c95b0
...@@ -22,6 +22,7 @@ import java.io.PushbackInputStream; ...@@ -22,6 +22,7 @@ import java.io.PushbackInputStream;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.reflect.Type; import java.lang.reflect.Type;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
...@@ -37,6 +38,8 @@ import org.springframework.core.ResolvableType; ...@@ -37,6 +38,8 @@ import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.GenericHttpMessageConverter;
...@@ -60,6 +63,9 @@ import org.springframework.web.method.support.HandlerMethodArgumentResolver; ...@@ -60,6 +63,9 @@ import org.springframework.web.method.support.HandlerMethodArgumentResolver;
*/ */
public abstract class AbstractMessageConverterMethodArgumentResolver implements HandlerMethodArgumentResolver { public abstract class AbstractMessageConverterMethodArgumentResolver implements HandlerMethodArgumentResolver {
private static final List<HttpMethod> SUPPORTED_METHODS =
Arrays.asList(HttpMethod.POST, HttpMethod.PUT, HttpMethod.PATCH);
private static final Object NO_VALUE = new Object(); private static final Object NO_VALUE = new Object();
...@@ -170,6 +176,7 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements ...@@ -170,6 +176,7 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
targetClass = (Class<T>) resolvableType.resolve(); targetClass = (Class<T>) resolvableType.resolve();
} }
HttpMethod httpMethod = ((HttpRequest) inputMessage).getMethod();
inputMessage = new EmptyBodyCheckingHttpInputMessage(inputMessage); inputMessage = new EmptyBodyCheckingHttpInputMessage(inputMessage);
Object body = NO_VALUE; Object body = NO_VALUE;
...@@ -213,6 +220,9 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements ...@@ -213,6 +220,9 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
} }
if (body == NO_VALUE) { if (body == NO_VALUE) {
if (!SUPPORTED_METHODS.contains(httpMethod)) {
return null;
}
throw new HttpMediaTypeNotSupportedException(contentType, this.allSupportedMediaTypes); throw new HttpMediaTypeNotSupportedException(contentType, this.allSupportedMediaTypes);
} }
...@@ -273,6 +283,8 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements ...@@ -273,6 +283,8 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
private final InputStream body; private final InputStream body;
private final HttpMethod method;
public EmptyBodyCheckingHttpInputMessage(HttpInputMessage inputMessage) throws IOException { public EmptyBodyCheckingHttpInputMessage(HttpInputMessage inputMessage) throws IOException {
this.headers = inputMessage.getHeaders(); this.headers = inputMessage.getHeaders();
...@@ -296,6 +308,7 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements ...@@ -296,6 +308,7 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
pushbackInputStream.unread(b); pushbackInputStream.unread(b);
} }
} }
this.method = ((HttpRequest) inputMessage).getMethod();
} }
@Override @Override
...@@ -307,6 +320,10 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements ...@@ -307,6 +320,10 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
public InputStream getBody() throws IOException { public InputStream getBody() throws IOException {
return this.body; return this.body;
} }
public HttpMethod getMethod() {
return this.method;
}
} }
} }
...@@ -150,14 +150,11 @@ public class RequestResponseBodyMethodProcessor extends AbstractMessageConverter ...@@ -150,14 +150,11 @@ public class RequestResponseBodyMethodProcessor extends AbstractMessageConverter
ServletServerHttpRequest inputMessage = new ServletServerHttpRequest(servletRequest); ServletServerHttpRequest inputMessage = new ServletServerHttpRequest(servletRequest);
Object arg = null; Object arg = null;
if (webRequest.getHeader("Content-Type") != null || arg = readWithMessageConverters(inputMessage, methodParam, paramType);
SUPPORTED_METHODS.contains(inputMessage.getMethod())) { if (arg == null) {
arg = readWithMessageConverters(inputMessage, methodParam, paramType); if (methodParam.getParameterAnnotation(RequestBody.class).required()) {
if (arg == null) { throw new HttpMessageNotReadableException("Required request body is missing: " +
if (methodParam.getParameterAnnotation(RequestBody.class).required()) { methodParam.getMethod().toGenericString());
throw new HttpMessageNotReadableException("Required request body is missing: " +
methodParam.getMethod().toGenericString());
}
} }
} }
return arg; return arg;
......
...@@ -16,15 +16,10 @@ ...@@ -16,15 +16,10 @@
package org.springframework.web.servlet.mvc.method.annotation; package org.springframework.web.servlet.mvc.method.annotation;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
import static org.springframework.web.servlet.HandlerMapping.*;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.net.URI; import java.net.URI;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.Locale; import java.util.Locale;
...@@ -53,6 +48,10 @@ import org.springframework.web.bind.annotation.RequestMapping; ...@@ -53,6 +48,10 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.method.support.ModelAndViewContainer; import org.springframework.web.method.support.ModelAndViewContainer;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
import static org.springframework.web.servlet.HandlerMapping.*;
/** /**
* Test fixture for {@link HttpEntityMethodProcessor} delegating to a mock * Test fixture for {@link HttpEntityMethodProcessor} delegating to a mock
* {@link HttpMessageConverter}. * {@link HttpMessageConverter}.
...@@ -113,7 +112,7 @@ public class HttpEntityMethodProcessorMockTests { ...@@ -113,7 +112,7 @@ public class HttpEntityMethodProcessorMockTests {
returnTypeInt = new MethodParameter(getClass().getMethod("handle3"), -1); returnTypeInt = new MethodParameter(getClass().getMethod("handle3"), -1);
mavContainer = new ModelAndViewContainer(); mavContainer = new ModelAndViewContainer();
servletRequest = new MockHttpServletRequest("GET", "/foo"); servletRequest = new MockHttpServletRequest("POST", "/foo");
servletResponse = new MockHttpServletResponse(); servletResponse = new MockHttpServletResponse();
webRequest = new ServletWebRequest(servletRequest, servletResponse); webRequest = new ServletWebRequest(servletRequest, servletResponse);
} }
...@@ -185,7 +184,7 @@ public class HttpEntityMethodProcessorMockTests { ...@@ -185,7 +184,7 @@ public class HttpEntityMethodProcessorMockTests {
MediaType contentType = MediaType.TEXT_PLAIN; MediaType contentType = MediaType.TEXT_PLAIN;
servletRequest.addHeader("Content-Type", contentType.toString()); servletRequest.addHeader("Content-Type", contentType.toString());
given(messageConverter.getSupportedMediaTypes()).willReturn(Arrays.asList(contentType)); given(messageConverter.getSupportedMediaTypes()).willReturn(Collections.singletonList(contentType));
given(messageConverter.canRead(String.class, contentType)).willReturn(false); given(messageConverter.canRead(String.class, contentType)).willReturn(false);
processor.resolveArgument(paramHttpEntity, mavContainer, webRequest, null); processor.resolveArgument(paramHttpEntity, mavContainer, webRequest, null);
...@@ -266,7 +265,7 @@ public class HttpEntityMethodProcessorMockTests { ...@@ -266,7 +265,7 @@ public class HttpEntityMethodProcessorMockTests {
servletRequest.addHeader("Accept", accepted.toString()); servletRequest.addHeader("Accept", accepted.toString());
given(messageConverter.canWrite(String.class, null)).willReturn(true); given(messageConverter.canWrite(String.class, null)).willReturn(true);
given(messageConverter.getSupportedMediaTypes()).willReturn(Arrays.asList(MediaType.TEXT_PLAIN)); given(messageConverter.getSupportedMediaTypes()).willReturn(Collections.singletonList(MediaType.TEXT_PLAIN));
given(messageConverter.canWrite(String.class, accepted)).willReturn(false); given(messageConverter.canWrite(String.class, accepted)).willReturn(false);
processor.handleReturnValue(returnValue, returnTypeResponseEntity, mavContainer, webRequest); processor.handleReturnValue(returnValue, returnTypeResponseEntity, mavContainer, webRequest);
......
...@@ -16,16 +16,15 @@ ...@@ -16,16 +16,15 @@
package org.springframework.web.servlet.mvc.method.annotation; package org.springframework.web.servlet.mvc.method.annotation;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonTypeName;
import static org.junit.Assert.*;
import java.io.Serializable; import java.io.Serializable;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
...@@ -46,6 +45,8 @@ import org.springframework.web.context.request.ServletWebRequest; ...@@ -46,6 +45,8 @@ import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.support.ModelAndViewContainer; import org.springframework.web.method.support.ModelAndViewContainer;
import static org.junit.Assert.*;
/** /**
* Test fixture with {@link HttpEntityMethodProcessor} delegating to * Test fixture with {@link HttpEntityMethodProcessor} delegating to
* actual {@link HttpMessageConverter} instances. * actual {@link HttpMessageConverter} instances.
...@@ -54,6 +55,7 @@ import org.springframework.web.method.support.ModelAndViewContainer; ...@@ -54,6 +55,7 @@ import org.springframework.web.method.support.ModelAndViewContainer;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
@SuppressWarnings("unused")
public class HttpEntityMethodProcessorTests { public class HttpEntityMethodProcessorTests {
private MethodParameter paramList; private MethodParameter paramList;
...@@ -81,7 +83,9 @@ public class HttpEntityMethodProcessorTests { ...@@ -81,7 +83,9 @@ public class HttpEntityMethodProcessorTests {
binderFactory = new ValidatingBinderFactory(); binderFactory = new ValidatingBinderFactory();
servletRequest = new MockHttpServletRequest(); servletRequest = new MockHttpServletRequest();
servletResponse = new MockHttpServletResponse(); servletResponse = new MockHttpServletResponse();
servletRequest.setMethod("POST");
webRequest = new ServletWebRequest(servletRequest, servletResponse); webRequest = new ServletWebRequest(servletRequest, servletResponse);
} }
@Test @Test
...@@ -109,7 +113,7 @@ public class HttpEntityMethodProcessorTests { ...@@ -109,7 +113,7 @@ public class HttpEntityMethodProcessorTests {
this.servletRequest.setContent(new byte[0]); this.servletRequest.setContent(new byte[0]);
this.servletRequest.setContentType("application/json"); this.servletRequest.setContentType("application/json");
List<HttpMessageConverter<?>> converters = Arrays.asList(new MappingJackson2HttpMessageConverter()); List<HttpMessageConverter<?>> converters = Collections.singletonList(new MappingJackson2HttpMessageConverter());
HttpEntityMethodProcessor processor = new HttpEntityMethodProcessor(converters); HttpEntityMethodProcessor processor = new HttpEntityMethodProcessor(converters);
HttpEntity<?> result = (HttpEntity<?>) processor.resolveArgument(this.paramSimpleBean, HttpEntity<?> result = (HttpEntity<?>) processor.resolveArgument(this.paramSimpleBean,
...@@ -196,9 +200,9 @@ public class HttpEntityMethodProcessorTests { ...@@ -196,9 +200,9 @@ public class HttpEntityMethodProcessorTests {
private interface Identifiable extends Serializable { private interface Identifiable extends Serializable {
public Long getId(); Long getId();
public void setId(Long id); void setId(Long id);
} }
......
...@@ -80,7 +80,12 @@ import org.springframework.web.servlet.HandlerMapping; ...@@ -80,7 +80,12 @@ import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import static org.junit.Assert.*; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
/** /**
* A test fixture with a controller with all supported method signature styles * A test fixture with a controller with all supported method signature styles
...@@ -102,6 +107,7 @@ public class RequestMappingHandlerAdapterIntegrationTests { ...@@ -102,6 +107,7 @@ public class RequestMappingHandlerAdapterIntegrationTests {
private MockHttpServletResponse response; private MockHttpServletResponse response;
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
ConfigurableWebBindingInitializer bindingInitializer = new ConfigurableWebBindingInitializer(); ConfigurableWebBindingInitializer bindingInitializer = new ConfigurableWebBindingInitializer();
...@@ -123,6 +129,8 @@ public class RequestMappingHandlerAdapterIntegrationTests { ...@@ -123,6 +129,8 @@ public class RequestMappingHandlerAdapterIntegrationTests {
request = new MockHttpServletRequest(); request = new MockHttpServletRequest();
response = new MockHttpServletResponse(); response = new MockHttpServletResponse();
request.setMethod("POST");
// Expose request to the current thread (for SpEL expressions) // Expose request to the current thread (for SpEL expressions)
RequestContextHolder.setRequestAttributes(new ServletWebRequest(request)); RequestContextHolder.setRequestAttributes(new ServletWebRequest(request));
} }
...@@ -132,6 +140,7 @@ public class RequestMappingHandlerAdapterIntegrationTests { ...@@ -132,6 +140,7 @@ public class RequestMappingHandlerAdapterIntegrationTests {
RequestContextHolder.resetRequestAttributes(); RequestContextHolder.resetRequestAttributes();
} }
@Test @Test
public void handle() throws Exception { public void handle() throws Exception {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册