diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java index f28a3008e4762458ded372941c2465d36f420846..510ad4926b346a11475202a96188972213127323 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -107,6 +107,8 @@ public class MockHttpServletRequest implements HttpServletRequest { private static final String CONTENT_TYPE_HEADER = "Content-Type"; + private static final String HOST_HEADER = "Host"; + private static final String CHARSET_PREFIX = "charset="; private static final ServletInputStream EMPTY_SERVLET_INPUT_STREAM = @@ -544,6 +546,19 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public String getServerName() { + String host = getHeader(HOST_HEADER); + if (host != null) { + host = host.trim(); + if (host.startsWith("[")) { + host = host.substring(1, host.indexOf(']')); + } + else if (host.contains(":")) { + host = host.substring(0, host.indexOf(':')); + } + return host; + } + + // else return this.serverName; } @@ -553,6 +568,22 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public int getServerPort() { + String host = getHeader(HOST_HEADER); + if (host != null) { + host = host.trim(); + int idx; + if (host.startsWith("[")) { + idx = host.indexOf(':', host.indexOf(']')); + } + else { + idx = host.indexOf(':'); + } + if (idx != -1) { + return Integer.parseInt(host.substring(idx + 1)); + } + } + + // else return this.serverPort; } diff --git a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java index c6cfe46c16edc27b39f0032111a589ec3d2d272e..633a26795003d6672776ca3386b0a5a934034103 100644 --- a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java +++ b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java @@ -42,6 +42,8 @@ import static org.junit.Assert.*; */ public class MockHttpServletRequestTests { + private static final String HOST = "Host"; + private MockHttpServletRequest request = new MockHttpServletRequest(); @@ -212,6 +214,86 @@ public class MockHttpServletRequestTests { assertEqualEnumerations(Collections.enumeration(preferredLocales), request.getLocales()); } + @Test + public void getServerNameWithDefaultName() { + assertEquals("localhost", request.getServerName()); + } + + @Test + public void getServerNameWithCustomName() { + request.setServerName("example.com"); + assertEquals("example.com", request.getServerName()); + } + + @Test + public void getServerNameViaHostHeaderWithoutPort() { + String testServer = "test.server"; + request.addHeader(HOST, testServer); + assertEquals(testServer, request.getServerName()); + } + + @Test + public void getServerNameViaHostHeaderWithPort() { + String testServer = "test.server"; + request.addHeader(HOST, testServer + ":8080"); + assertEquals(testServer, request.getServerName()); + } + + @Test + public void getServerNameViaHostHeaderAsIpv6AddressWithoutPort() { + String ipv6Address = "[2001:db8:0:1]"; + request.addHeader(HOST, ipv6Address); + assertEquals("2001:db8:0:1", request.getServerName()); + } + + @Test + public void getServerNameViaHostHeaderAsIpv6AddressWithPort() { + String ipv6Address = "[2001:db8:0:1]:8081"; + request.addHeader(HOST, ipv6Address); + assertEquals("2001:db8:0:1", request.getServerName()); + } + + @Test + public void getServerPortWithDefaultPort() { + assertEquals(80, request.getServerPort()); + } + + @Test + public void getServerPortWithCustomPort() { + request.setServerPort(8080); + assertEquals(8080, request.getServerPort()); + } + + @Test + public void getServerPortViaHostHeaderAsIpv6AddressWithoutPort() { + String testServer = "[2001:db8:0:1]"; + request.addHeader(HOST, testServer); + assertEquals(80, request.getServerPort()); + } + + @Test + public void getServerPortViaHostHeaderAsIpv6AddressWithPort() { + String testServer = "[2001:db8:0:1]"; + int testPort = 9999; + request.addHeader(HOST, testServer + ":" + testPort); + assertEquals(testPort, request.getServerPort()); + } + + @Test + public void getServerPortViaHostHeaderWithoutPort() { + String testServer = "test.server"; + request.addHeader(HOST, testServer); + assertEquals(80, request.getServerPort()); + } + + @Test + public void getServerPortViaHostHeaderWithPort() { + String testServer = "test.server"; + int testPort = 9999; + request.addHeader(HOST, testServer + ":" + testPort); + assertEquals(testPort, request.getServerPort()); + } + @Test public void getRequestURL() { request.setServerPort(8080);