diff --git a/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java index 52c3eb60e81eac0039f8613a85fdc28613647b14..394a83076354e186ffdd63a48a5e23144a6258fd 100644 --- a/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java @@ -91,6 +91,14 @@ final class SimpleBufferingClientHttpRequest extends AbstractBufferingClientHttp * @param headers the headers to add */ static void addHeaders(HttpURLConnection connection, HttpHeaders headers) { + String method = connection.getRequestMethod(); + if (method.equals("PUT") || method.equals("DELETE")) { + if (!StringUtils.hasText(headers.getFirst(HttpHeaders.ACCEPT))) { + // Avoid "text/html, image/gif, image/jpeg, *; q=.2, */*; q=.2" + // from HttpUrlConnection which prevents JSON error response details. + headers.set(HttpHeaders.ACCEPT, "*/*"); + } + } headers.forEach((headerName, headerValues) -> { if (HttpHeaders.COOKIE.equalsIgnoreCase(headerName)) { // RFC 6265 String headerValue = StringUtils.collectionToDelimitedString(headerValues, "; "); diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 4e21a9f9e06174dbb8cf63974e156411b83e8eb7..9ccd5e4dd7317c691ac7377ab304bd8aa704bebc 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -739,9 +739,6 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat if (requestCallback != null) { requestCallback.doWithRequest(request); } - if ((method == HttpMethod.DELETE || method == HttpMethod.PUT) && request.getHeaders().getAccept().isEmpty()) { - request.getHeaders().add("Accept", "*/*"); - } response = request.execute(); handleResponse(url, method, response); return (responseExtractor != null ? responseExtractor.extractData(response) : null); diff --git a/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java index f698b01728b0b76c8e0009913d7a841ecb0f8d6e..df8f9a58893ff06412309670e0f84c331cf8c2f7 100644 --- a/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.springframework.http.HttpHeaders; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -31,9 +32,11 @@ import static org.mockito.Mockito.verify; */ public class SimpleClientHttpRequestFactoryTests { - @Test // SPR-13225 + + @Test // SPR-13225 public void headerWithNullValue() { HttpURLConnection urlConnection = mock(HttpURLConnection.class); + given(urlConnection.getRequestMethod()).willReturn("GET"); HttpHeaders headers = new HttpHeaders(); headers.set("foo", null); SimpleBufferingClientHttpRequest.addHeaders(urlConnection, headers); diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index f899e6802ae4ba31231b970636eb7aa5b57f8366..3877572704d447f337daa500047a23978c6af0b5 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -492,45 +492,47 @@ public class RestTemplateTests { verify(response).close(); } - @Test + @Test // gh-23740 public void headerAcceptAllOnPut() throws Exception { MockWebServer server = new MockWebServer(); server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); server.start(); - template.setRequestFactory(new SimpleClientHttpRequestFactory()); - - template.put(server.url("/internal/server/error").uri(), null); - - RecordedRequest request = server.takeRequest(); - assertThat(request.getHeader("Accept")).isEqualTo("*/*"); - - server.shutdown(); + try { + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + template.put(server.url("/internal/server/error").uri(), null); + assertThat(server.takeRequest().getHeader("Accept")).isEqualTo("*/*"); + } + finally { + server.shutdown(); + } } - @Test + @Test // gh-23740 public void keepGivenAcceptHeaderOnPut() throws Exception { MockWebServer server = new MockWebServer(); server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); server.start(); + try { + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); + HttpEntity entity = new HttpEntity<>(null, headers); + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + template.exchange(server.url("/internal/server/error").uri(), PUT, entity, Void.class); - template.setRequestFactory(new SimpleClientHttpRequestFactory()); - HttpHeaders headers = new HttpHeaders(); - headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); - HttpEntity entity = new HttpEntity<>(null, headers); - template.exchange(server.url("/internal/server/error").uri(), PUT, entity, Void.class); - - RecordedRequest request = server.takeRequest(); + RecordedRequest request = server.takeRequest(); final List> accepts = request.getHeaders().toMultimap().entrySet().stream() - .filter(entry -> entry.getKey().equalsIgnoreCase("accept")) - .map(Entry::getValue) - .collect(Collectors.toList()); + .filter(entry -> entry.getKey().equalsIgnoreCase("accept")) + .map(Entry::getValue) + .collect(Collectors.toList()); - assertThat(accepts).hasSize(1); - assertThat(accepts.get(0)).hasSize(1); - assertThat(accepts.get(0).get(0)).isEqualTo("application/json"); - - server.shutdown(); + assertThat(accepts).hasSize(1); + assertThat(accepts.get(0)).hasSize(1); + assertThat(accepts.get(0).get(0)).isEqualTo("application/json"); + } + finally { + server.shutdown(); + } } @Test @@ -579,19 +581,19 @@ public class RestTemplateTests { verify(response).close(); } - @Test + @Test // gh-23740 public void headerAcceptAllOnDelete() throws Exception { MockWebServer server = new MockWebServer(); server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); server.start(); - template.setRequestFactory(new SimpleClientHttpRequestFactory()); - - template.delete(server.url("/internal/server/error").uri()); - - RecordedRequest request = server.takeRequest(); - assertThat(request.getHeader("Accept")).isEqualTo("*/*"); - - server.shutdown(); + try { + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + template.delete(server.url("/internal/server/error").uri()); + assertThat(server.takeRequest().getHeader("Accept")).isEqualTo("*/*"); + } + finally { + server.shutdown(); + } } @Test