diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index d05f6825c64d309dbb12694778c171f0b1b8c590..a18ad9f501300d2b205631452d4a8a0dd62d2dd9 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -158,6 +158,12 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { } } + @Override + protected void doFilterNestedErrorDispatch(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + doFilterInternal(request, response, filterChain); + } /** * Hide "Forwarded" or "X-Forwarded-*" headers. diff --git a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java index a21fc747948b1657329af840c911bd384b8db42a..a08d0df4da56b57ad3cdf598b2026c8f43d5940b 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java @@ -94,10 +94,19 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { HttpServletResponse httpResponse = (HttpServletResponse) response; String alreadyFilteredAttributeName = getAlreadyFilteredAttributeName(); - alreadyFilteredAttributeName = updateForErrorDispatch(alreadyFilteredAttributeName, request); boolean hasAlreadyFilteredAttribute = request.getAttribute(alreadyFilteredAttributeName) != null; - if (hasAlreadyFilteredAttribute || skipDispatch(httpRequest) || shouldNotFilter(httpRequest)) { + if (skipDispatch(httpRequest) || shouldNotFilter(httpRequest)) { + + // Proceed without invoking this filter... + filterChain.doFilter(request, response); + } + else if (hasAlreadyFilteredAttribute) { + + if (DispatcherType.ERROR.equals(request.getDispatcherType())) { + doFilterNestedErrorDispatch(httpRequest, httpResponse, filterChain); + return; + } // Proceed without invoking this filter... filterChain.doFilter(request, response); @@ -115,7 +124,6 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { } } - private boolean skipDispatch(HttpServletRequest request) { if (isAsyncDispatch(request) && shouldNotFilterAsyncDispatch()) { return true; @@ -167,21 +175,6 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { return name + ALREADY_FILTERED_SUFFIX; } - private String updateForErrorDispatch(String alreadyFilteredAttributeName, ServletRequest request) { - - // Jetty does ERROR dispatch within sendError, so request attribute is still present - // Use a separate attribute for ERROR dispatches - - if (DispatcherType.ERROR.equals(request.getDispatcherType()) && !shouldNotFilterErrorDispatch() && - request.getAttribute(alreadyFilteredAttributeName) != null) { - - return alreadyFilteredAttributeName + ".ERROR"; - } - - - return alreadyFilteredAttributeName; - } - /** * Can be overridden in subclasses for custom filtering control, * returning {@code true} to avoid filtering of the given request. @@ -238,4 +231,23 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException; + /** + * Typically an ERROR dispatch happens after the REQUEST dispatch completes, + * and the filter chain starts anew. On some servers however the ERROR + * dispatch may be nested within the REQUEST dispatch, e.g. as a result of + * calling {@code sendError} on the response. In that case we are still in + * the filter chain, on the same thread, but the request and response have + * been switched to the original, unwrapped ones. + *

Sub-classes may use this method to filter such nested ERROR dispatches + * and re-apply wrapping on the request or response. {@code ThreadLocal} + * context, if any, should still be active as we are still nested within + * the filter chain. + * @since 5.1.9 + */ + protected void doFilterNestedErrorDispatch(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + doFilter(request, response, filterChain); + } + } diff --git a/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..35c3476097946cc86725ecee1a7c72533194f476 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java @@ -0,0 +1,192 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.filter; + +import java.io.IOException; +import javax.servlet.DispatcherType; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.test.MockFilterChain; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.util.WebUtils; + +import static org.assertj.core.api.Assertions.assertThat; + + +/** + * Unit tests for {@link OncePerRequestFilter}. + * @author Rossen Stoyanchev + * @since 5.1.9 + */ +public class OncePerRequestFilterTests { + + private final TestOncePerRequestFilter filter = new TestOncePerRequestFilter(); + + private MockHttpServletRequest request; + + private MockFilterChain filterChain; + + + @Before + @SuppressWarnings("serial") + public void setup() throws Exception { + this.request = new MockHttpServletRequest(); + this.request.setScheme("http"); + this.request.setServerName("localhost"); + this.request.setServerPort(80); + this.filterChain = new MockFilterChain(new HttpServlet() {}); + } + + + @Test + public void filterOnce() throws ServletException, IOException { + + // Already filtered + this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertThat(this.filter.didFilter).isFalse(); + assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + + // Remove already filtered + this.request.removeAttribute(this.filter.getAlreadyFilteredAttributeName()); + this.filter.reset(); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertThat(this.filter.didFilter).isTrue(); + assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + } + + @Test + public void shouldNotFilterErrorDispatch() throws ServletException, IOException { + + initErrorDispatch(); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertThat(this.filter.didFilter).isFalse(); + assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + } + + @Test + public void shouldNotFilterNestedErrorDispatch() throws ServletException, IOException { + + initErrorDispatch(); + this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertThat(this.filter.didFilter).isFalse(); + assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + } + + @Test // gh-23196 + public void filterNestedErrorDispatch() throws ServletException, IOException { + + // Opt in for ERROR dispatch + this.filter.setShouldNotFilterErrorDispatch(false); + + this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); + initErrorDispatch(); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertThat(this.filter.didFilter).isFalse(); + assertThat(this.filter.didFilterNestedErrorDispatch).isTrue(); + } + + private void initErrorDispatch() { + this.request.setDispatcherType(DispatcherType.ERROR); + this.request.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/error"); + } + + + private static class TestOncePerRequestFilter extends OncePerRequestFilter { + + private boolean shouldNotFilter; + + private boolean shouldNotFilterAsyncDispatch = true; + + private boolean shouldNotFilterErrorDispatch = true; + + private boolean didFilter; + + private boolean didFilterNestedErrorDispatch; + + + public void setShouldNotFilter(boolean shouldNotFilter) { + this.shouldNotFilter = shouldNotFilter; + } + + public void setShouldNotFilterAsyncDispatch(boolean shouldNotFilterAsyncDispatch) { + this.shouldNotFilterAsyncDispatch = shouldNotFilterAsyncDispatch; + } + + public void setShouldNotFilterErrorDispatch(boolean shouldNotFilterErrorDispatch) { + this.shouldNotFilterErrorDispatch = shouldNotFilterErrorDispatch; + } + + + public boolean didFilter() { + return this.didFilter; + } + + public boolean didFilterNestedErrorDispatch() { + return this.didFilterNestedErrorDispatch; + } + + public void reset() { + this.didFilter = false; + this.didFilterNestedErrorDispatch = false; + } + + + @Override + protected boolean shouldNotFilter(HttpServletRequest request) { + return this.shouldNotFilter; + } + + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return this.shouldNotFilterAsyncDispatch; + } + + @Override + protected boolean shouldNotFilterErrorDispatch() { + return this.shouldNotFilterErrorDispatch; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) { + + this.didFilter = true; + } + + @Override + protected void doFilterNestedErrorDispatch(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) { + + this.didFilterNestedErrorDispatch = true; + } + } + +}