diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index 940e27f650fc1e819d390730d4b2b6e531ec5b23..44b514cd00ec92f87c6968f9a3028c6502cbba44 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -23,11 +23,13 @@ import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; import org.springframework.http.HttpRequest; import org.springframework.http.server.ServletServerHttpRequest; @@ -38,16 +40,19 @@ import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UrlPathHelper; /** - * Filter that wraps the request in order to override its + * Filter that wraps the request and response in order to override its * {@link HttpServletRequest#getServerName() getServerName()}, * {@link HttpServletRequest#getServerPort() getServerPort()}, - * {@link HttpServletRequest#getScheme() getScheme()}, and - * {@link HttpServletRequest#isSecure() isSecure()} methods with values derived - * from "Forwarded" or "X-Forwarded-*" headers. In effect the wrapped request - * reflects the client-originated protocol and address. + * {@link HttpServletRequest#getScheme() getScheme()}, + * {@link HttpServletRequest#isSecure() isSecure()}, + * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}, + * methods with values derived from "Forwarded" or "X-Forwarded-*" + * headers. In effect the wrapped request and response reflects the + * client-originated protocol and address. * * @author Rossen Stoyanchev * @author Eddú Meléndez + * @author Rob Winch * @since 4.3 */ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -93,7 +98,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - filterChain.doFilter(new ForwardedHeaderRequestWrapper(request, this.pathHelper), response); + ForwardedHeaderRequestWrapper wrappedRequest = new ForwardedHeaderRequestWrapper(request, this.pathHelper); + ForwardedHeaderResponseWrapper wrappedResponse = new ForwardedHeaderResponseWrapper(response, wrappedRequest); + filterChain.doFilter(wrappedRequest, wrappedResponse); } @@ -222,4 +229,59 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { } } + private static class ForwardedHeaderResponseWrapper extends HttpServletResponseWrapper { + private static final String FOLDER_SEPARATOR = "/"; + + private final HttpServletRequest request; + + public ForwardedHeaderResponseWrapper(HttpServletResponse response, HttpServletRequest request) { + super(response); + this.request = request; + } + + @Override + public void sendRedirect(String location) throws IOException { + String forwardedLocation = forwardedLocation(location); + + super.sendRedirect(forwardedLocation); + } + + private String forwardedLocation(String location) { + if(hasScheme(location)) { + return location; + } + + return createForwardedLocation(location); + } + + private String createForwardedLocation(String location) { + boolean isNetworkPathReference = location.startsWith("//"); + if(isNetworkPathReference) { + UriComponentsBuilder schemeForwardedLocation = UriComponentsBuilder.fromUriString(location).scheme(request.getScheme()); + return schemeForwardedLocation.toUriString(); + } + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponentsBuilder forwardedLocation = UriComponentsBuilder.fromHttpRequest(httpRequest); + boolean isRelativeToContextPath = location.startsWith(FOLDER_SEPARATOR); + if(isRelativeToContextPath) { + forwardedLocation.replacePath(request.getContextPath()); + } else if(endsWithFileSpecificPart(forwardedLocation)) { + // remove a file specific part from existing request + forwardedLocation.path("/../"); + } + forwardedLocation.path(location); + return forwardedLocation.build().normalize().toUriString(); + } + + private boolean endsWithFileSpecificPart(UriComponentsBuilder forwardedLocation) { + return !forwardedLocation.build().getPath().endsWith(FOLDER_SEPARATOR); + } + + private boolean hasScheme(String location) { + String locationScheme = UriComponentsBuilder.fromUriString(location).build().getScheme(); + return locationScheme != null; + } + } + } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index 56d93f8c40b2584d29b3eab3814ff7f28e9bf24e..496d41c36257c976f93cd89538f5cf9e369c7b71 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -17,9 +17,13 @@ package org.springframework.web.filter; import java.io.IOException; import java.util.Enumeration; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.junit.Before; import org.junit.Test; @@ -37,6 +41,7 @@ import static org.junit.Assert.assertTrue; * Unit tests for {@link ForwardedHeaderFilter}. * @author Rossen Stoyanchev * @author Eddú Meléndez + * @author Rob Winch */ public class ForwardedHeaderFilterTests { @@ -222,6 +227,157 @@ public class ForwardedHeaderFilterTests { assertEquals("/prefix", actual); } + @Test + public void sendRedirectWithAbsolutePath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String redirectedUrl = sendRedirect("/foo/bar"); + assertEquals("https://example.com/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithContextPath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setContextPath("/context"); + + String redirectedUrl = sendRedirect("/foo/bar"); + assertEquals("https://example.com/context/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithXForwardedPrefix() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.addHeader(X_FORWARDED_PREFIX, "/prefix"); + + String redirectedUrl = sendRedirect("/foo/bar"); + assertEquals("https://example.com/prefix/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithXForwardedPrefixAndContextPath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.addHeader(X_FORWARDED_PREFIX, "/prefix"); + this.request.setContextPath("/context"); + + String redirectedUrl = sendRedirect("/foo/bar"); + assertEquals("https://example.com/prefix/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithRelativePath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setRequestURI("/parent/"); + + String redirectedUrl = sendRedirect("foo/bar"); + assertEquals("https://example.com/parent/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithFileInPathAndRelativeRedirect() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setRequestURI("/context/a"); + + String redirectedUrl = sendRedirect("foo/bar"); + assertEquals("https://example.com/context/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithRelativePathIgnoresFile() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setRequestURI("/parent"); + + String redirectedUrl = sendRedirect("foo/bar"); + assertEquals("https://example.com/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithLocationDotDotPath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String redirectedUrl = sendRedirect("parent/../foo/bar"); + assertEquals("https://example.com/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithLocationHasScheme() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String location = "http://other.info/foo/bar"; + String redirectedUrl = sendRedirect(location); + assertEquals(location, redirectedUrl); + } + + @Test + public void sendRedirectWithLocationSlashSlash() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String location = "//other.info/foo/bar"; + String redirectedUrl = sendRedirect(location); + assertEquals("https:" + location, redirectedUrl); + } + + @Test + public void sendRedirectWithLocationSlashSlashParentDotDot() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String location = "//other.info/parent/../foo/bar"; + String redirectedUrl = sendRedirect(location); + assertEquals("https:" + location, redirectedUrl); + } + + @Test + public void sendRedirectWithNoXForwardedAndAbsolutePath() throws Exception { + String redirectedUrl = sendRedirect("/foo/bar"); + assertEquals("/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithNoXForwardedAndDotDotPath() throws Exception { + String redirectedUrl = sendRedirect("../foo/bar"); + assertEquals("../foo/bar", redirectedUrl); + } + + private String sendRedirect(final String location) throws ServletException, IOException { + MockHttpServletResponse response = doWithFiltersAndGetResponse(this.filter, new OncePerRequestFilter() { + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + response.sendRedirect(location); + } + }); + + return response.getRedirectedUrl(); + } + + private MockHttpServletResponse doWithFiltersAndGetResponse(Filter... filters) throws ServletException, IOException { + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = new MockFilterChain(new HttpServlet() {}, filters); + filterChain.doFilter(request, response); + return response; + } + private String filterAndGetContextPath() throws ServletException, IOException { return filterAndGetWrappedRequest().getContextPath(); }