提交 919f6c96 编写于 作者: R Rossen Stoyanchev

ForwardedHeaderFilter is case-insensitive

Issue: SPR-14372
上级 981a748d
...@@ -19,9 +19,8 @@ package org.springframework.web.filter; ...@@ -19,9 +19,8 @@ package org.springframework.web.filter;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
...@@ -33,6 +32,7 @@ import javax.servlet.http.HttpServletResponse; ...@@ -33,6 +32,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpRequest; import org.springframework.http.HttpRequest;
import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UrlPathHelper; import org.springframework.web.util.UrlPathHelper;
...@@ -52,10 +52,10 @@ import org.springframework.web.util.UrlPathHelper; ...@@ -52,10 +52,10 @@ import org.springframework.web.util.UrlPathHelper;
*/ */
public class ForwardedHeaderFilter extends OncePerRequestFilter { public class ForwardedHeaderFilter extends OncePerRequestFilter {
private static final Set<String> FORWARDED_HEADER_NAMES; private static final Set<String> FORWARDED_HEADER_NAMES =
Collections.newSetFromMap(new LinkedCaseInsensitiveMap<Boolean>(5, Locale.ENGLISH));
static { static {
FORWARDED_HEADER_NAMES = new HashSet<String>(5);
FORWARDED_HEADER_NAMES.add("Forwarded"); FORWARDED_HEADER_NAMES.add("Forwarded");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
...@@ -69,9 +69,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { ...@@ -69,9 +69,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@Override @Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
Enumeration<String> headerNames = request.getHeaderNames(); Enumeration<String> names = request.getHeaderNames();
while (headerNames.hasMoreElements()) { while (names.hasMoreElements()) {
String name = headerNames.nextElement(); String name = names.nextElement();
if (FORWARDED_HEADER_NAMES.contains(name)) { if (FORWARDED_HEADER_NAMES.contains(name)) {
return false; return false;
} }
...@@ -136,27 +136,33 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { ...@@ -136,27 +136,33 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
} }
private static String getForwardedPrefix(HttpServletRequest request) { private static String getForwardedPrefix(HttpServletRequest request) {
String header = request.getHeader("X-Forwarded-Prefix"); String prefix = null;
if (header != null) { Enumeration<String> names = request.getHeaderNames();
while (header.endsWith("/")) { while (names.hasMoreElements()) {
header = header.substring(0, header.length() - 1); String name = names.nextElement();
if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) {
prefix = request.getHeader(name);
} }
} }
return header; if (prefix != null) {
while (prefix.endsWith("/")) {
prefix = prefix.substring(0, prefix.length() - 1);
}
}
return prefix;
} }
/** /**
* Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}. * Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}.
*/ */
private static Map<String, List<String>> initHeaders(HttpServletRequest request) { private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
Map<String, List<String>> headers = new LinkedHashMap<String, List<String>>(); Map<String, List<String>> headers = new LinkedCaseInsensitiveMap<List<String>>(Locale.ENGLISH);
Enumeration<String> headerNames = request.getHeaderNames(); Enumeration<String> names = request.getHeaderNames();
while (headerNames.hasMoreElements()) { while (names.hasMoreElements()) {
String name = headerNames.nextElement(); String name = names.nextElement();
if (!FORWARDED_HEADER_NAMES.contains(name)) {
headers.put(name, Collections.list(request.getHeaders(name))); headers.put(name, Collections.list(request.getHeaders(name)));
} }
for (String name : FORWARDED_HEADER_NAMES) {
headers.remove(name);
} }
return headers; return headers;
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
package org.springframework.web.filter; package org.springframework.web.filter;
import java.io.IOException; import java.io.IOException;
import java.util.Enumeration;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
...@@ -39,6 +40,12 @@ import static org.junit.Assert.assertTrue; ...@@ -39,6 +40,12 @@ import static org.junit.Assert.assertTrue;
*/ */
public class ForwardedHeaderFilterTests { public class ForwardedHeaderFilterTests {
private static final String X_FORWARDED_PROTO = "x-forwarded-proto"; // SPR-14372 (case insensitive)
private static final String X_FORWARDED_HOST = "x-forwarded-host";
private static final String X_FORWARDED_PORT = "x-forwarded-port";
private static final String X_FORWARDED_PREFIX = "x-forwarded-prefix";
private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter(); private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter();
private MockHttpServletRequest request; private MockHttpServletRequest request;
...@@ -59,25 +66,25 @@ public class ForwardedHeaderFilterTests { ...@@ -59,25 +66,25 @@ public class ForwardedHeaderFilterTests {
@Test @Test
public void contextPathEmpty() throws Exception { public void contextPathEmpty() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", ""); this.request.addHeader(X_FORWARDED_PREFIX, "");
assertEquals("", filterAndGetContextPath()); assertEquals("", filterAndGetContextPath());
} }
@Test @Test
public void contextPathWithTrailingSlash() throws Exception { public void contextPathWithTrailingSlash() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/foo/bar/"); this.request.addHeader(X_FORWARDED_PREFIX, "/foo/bar/");
assertEquals("/foo/bar", filterAndGetContextPath()); assertEquals("/foo/bar", filterAndGetContextPath());
} }
@Test @Test
public void contextPathWithTrailingSlashes() throws Exception { public void contextPathWithTrailingSlashes() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/foo/bar/baz///"); this.request.addHeader(X_FORWARDED_PREFIX, "/foo/bar/baz///");
assertEquals("/foo/bar/baz", filterAndGetContextPath()); assertEquals("/foo/bar/baz", filterAndGetContextPath());
} }
@Test @Test
public void requestUri() throws Exception { public void requestUri() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/"); this.request.addHeader(X_FORWARDED_PREFIX, "/");
this.request.setContextPath("/app"); this.request.setContextPath("/app");
this.request.setRequestURI("/app/path"); this.request.setRequestURI("/app/path");
HttpServletRequest actual = filterAndGetWrappedRequest(); HttpServletRequest actual = filterAndGetWrappedRequest();
...@@ -88,7 +95,7 @@ public class ForwardedHeaderFilterTests { ...@@ -88,7 +95,7 @@ public class ForwardedHeaderFilterTests {
@Test @Test
public void requestUriWithTrailingSlash() throws Exception { public void requestUriWithTrailingSlash() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/"); this.request.addHeader(X_FORWARDED_PREFIX, "/");
this.request.setContextPath("/app"); this.request.setContextPath("/app");
this.request.setRequestURI("/app/path/"); this.request.setRequestURI("/app/path/");
HttpServletRequest actual = filterAndGetWrappedRequest(); HttpServletRequest actual = filterAndGetWrappedRequest();
...@@ -98,7 +105,7 @@ public class ForwardedHeaderFilterTests { ...@@ -98,7 +105,7 @@ public class ForwardedHeaderFilterTests {
} }
@Test @Test
public void requestUriEqualsContextPath() throws Exception { public void requestUriEqualsContextPath() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/"); this.request.addHeader(X_FORWARDED_PREFIX, "/");
this.request.setContextPath("/app"); this.request.setContextPath("/app");
this.request.setRequestURI("/app"); this.request.setRequestURI("/app");
HttpServletRequest actual = filterAndGetWrappedRequest(); HttpServletRequest actual = filterAndGetWrappedRequest();
...@@ -109,7 +116,7 @@ public class ForwardedHeaderFilterTests { ...@@ -109,7 +116,7 @@ public class ForwardedHeaderFilterTests {
@Test @Test
public void requestUriRootUrl() throws Exception { public void requestUriRootUrl() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/"); this.request.addHeader(X_FORWARDED_PREFIX, "/");
this.request.setContextPath("/app"); this.request.setContextPath("/app");
this.request.setRequestURI("/app/"); this.request.setRequestURI("/app/");
HttpServletRequest actual = filterAndGetWrappedRequest(); HttpServletRequest actual = filterAndGetWrappedRequest();
...@@ -118,12 +125,37 @@ public class ForwardedHeaderFilterTests { ...@@ -118,12 +125,37 @@ public class ForwardedHeaderFilterTests {
assertEquals("/", actual.getRequestURI()); assertEquals("/", actual.getRequestURI());
} }
@Test
public void caseInsensitiveForwardedPrefix() throws Exception {
this.request = new MockHttpServletRequest() {
// Make it case-sensitive (SPR-14372)
@Override
public String getHeader(String header) {
Enumeration<String> names = getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if (name.equals(header)) {
return super.getHeader(header);
}
}
return null;
}
};
this.request.addHeader(X_FORWARDED_PREFIX, "/prefix");
this.request.setRequestURI("/path");
HttpServletRequest actual = filterAndGetWrappedRequest();
assertEquals("/prefix/path", actual.getRequestURI());
}
@Test @Test
public void shouldFilter() throws Exception { public void shouldFilter() throws Exception {
testShouldFilter("Forwarded"); testShouldFilter("Forwarded");
testShouldFilter("X-Forwarded-Host"); testShouldFilter(X_FORWARDED_HOST);
testShouldFilter("X-Forwarded-Port"); testShouldFilter(X_FORWARDED_PORT);
testShouldFilter("X-Forwarded-Proto"); testShouldFilter(X_FORWARDED_PROTO);
} }
@Test @Test
...@@ -134,9 +166,9 @@ public class ForwardedHeaderFilterTests { ...@@ -134,9 +166,9 @@ public class ForwardedHeaderFilterTests {
@Test @Test
public void forwardedRequest() throws Exception { public void forwardedRequest() throws Exception {
this.request.setRequestURI("/mvc-showcase"); this.request.setRequestURI("/mvc-showcase");
this.request.addHeader("X-Forwarded-Proto", "https"); this.request.addHeader(X_FORWARDED_PROTO, "https");
this.request.addHeader("X-Forwarded-Host", "84.198.58.199"); this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199");
this.request.addHeader("X-Forwarded-Port", "443"); this.request.addHeader(X_FORWARDED_PORT, "443");
this.request.addHeader("foo", "bar"); this.request.addHeader("foo", "bar");
this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
...@@ -148,15 +180,15 @@ public class ForwardedHeaderFilterTests { ...@@ -148,15 +180,15 @@ public class ForwardedHeaderFilterTests {
assertEquals(443, actual.getServerPort()); assertEquals(443, actual.getServerPort());
assertTrue(actual.isSecure()); assertTrue(actual.isSecure());
assertNull(actual.getHeader("X-Forwarded-Proto")); assertNull(actual.getHeader(X_FORWARDED_PROTO));
assertNull(actual.getHeader("X-Forwarded-Host")); assertNull(actual.getHeader(X_FORWARDED_HOST));
assertNull(actual.getHeader("X-Forwarded-Port")); assertNull(actual.getHeader(X_FORWARDED_PORT));
assertEquals("bar", actual.getHeader("foo")); assertEquals("bar", actual.getHeader("foo"));
} }
@Test @Test
public void requestUriWithForwardedPrefix() throws Exception { public void requestUriWithForwardedPrefix() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/prefix"); this.request.addHeader(X_FORWARDED_PREFIX, "/prefix");
this.request.setRequestURI("/mvc-showcase"); this.request.setRequestURI("/mvc-showcase");
HttpServletRequest actual = filterAndGetWrappedRequest(); HttpServletRequest actual = filterAndGetWrappedRequest();
...@@ -165,7 +197,7 @@ public class ForwardedHeaderFilterTests { ...@@ -165,7 +197,7 @@ public class ForwardedHeaderFilterTests {
@Test @Test
public void requestUriWithForwardedPrefixTrailingSlash() throws Exception { public void requestUriWithForwardedPrefixTrailingSlash() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/prefix/"); this.request.addHeader(X_FORWARDED_PREFIX, "/prefix/");
this.request.setRequestURI("/mvc-showcase"); this.request.setRequestURI("/mvc-showcase");
HttpServletRequest actual = filterAndGetWrappedRequest(); HttpServletRequest actual = filterAndGetWrappedRequest();
...@@ -174,7 +206,7 @@ public class ForwardedHeaderFilterTests { ...@@ -174,7 +206,7 @@ public class ForwardedHeaderFilterTests {
@Test @Test
public void contextPathWithForwardedPrefix() throws Exception { public void contextPathWithForwardedPrefix() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/prefix"); this.request.addHeader(X_FORWARDED_PREFIX, "/prefix");
this.request.setContextPath("/mvc-showcase"); this.request.setContextPath("/mvc-showcase");
String actual = filterAndGetContextPath(); String actual = filterAndGetContextPath();
...@@ -183,7 +215,7 @@ public class ForwardedHeaderFilterTests { ...@@ -183,7 +215,7 @@ public class ForwardedHeaderFilterTests {
@Test @Test
public void contextPathWithForwardedPrefixTrailingSlash() throws Exception { public void contextPathWithForwardedPrefixTrailingSlash() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/prefix/"); this.request.addHeader(X_FORWARDED_PREFIX, "/prefix/");
this.request.setContextPath("/mvc-showcase"); this.request.setContextPath("/mvc-showcase");
String actual = filterAndGetContextPath(); String actual = filterAndGetContextPath();
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
package org.springframework.web.servlet.support; package org.springframework.web.servlet.support;
import java.util.Enumeration;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.http.HttpRequest; import org.springframework.http.HttpRequest;
...@@ -133,8 +134,15 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder { ...@@ -133,8 +134,15 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
} }
private static String prependForwardedPrefix(HttpServletRequest request, String path) { private static String prependForwardedPrefix(HttpServletRequest request, String path) {
String prefix = request.getHeader("X-Forwarded-Prefix"); String prefix = null;
if (StringUtils.hasText(prefix)) { Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) {
prefix = request.getHeader(name);
}
}
if (prefix != null) {
path = prefix + path; path = prefix + path;
} }
return path; return path;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册