提交 d02e4fb5 编写于 作者: S sdeleuze

Add Vary:Access-Control-Request-Method/Headers CORS headers

This commit adds these 2 Vary headers in addition to the existing
Origin one to avoid caching of Access-Control-Request-Method and
Access-Control-Request-Headers headers which can be an issue
when allowed methods or headers are unbounded and only the
requested method or headers are returned in the response.

Issue: SPR-16413
上级 857a5b03
/* /*
* Copyright 2002-201/ the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -19,6 +19,7 @@ package org.springframework.web.cors; ...@@ -19,6 +19,7 @@ package org.springframework.web.cors;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
...@@ -121,7 +122,8 @@ public class DefaultCorsProcessor implements CorsProcessor { ...@@ -121,7 +122,8 @@ public class DefaultCorsProcessor implements CorsProcessor {
String allowOrigin = checkOrigin(config, requestOrigin); String allowOrigin = checkOrigin(config, requestOrigin);
HttpHeaders responseHeaders = response.getHeaders(); HttpHeaders responseHeaders = response.getHeaders();
responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN); responseHeaders.addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
if (allowOrigin == null) { if (allowOrigin == null) {
logger.debug("Rejecting CORS request because '" + requestOrigin + "' origin is not allowed"); logger.debug("Rejecting CORS request because '" + requestOrigin + "' origin is not allowed");
......
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
package org.springframework.web.cors.reactive; package org.springframework.web.cors.reactive;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
...@@ -107,7 +108,8 @@ public class DefaultCorsProcessor implements CorsProcessor { ...@@ -107,7 +108,8 @@ public class DefaultCorsProcessor implements CorsProcessor {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
HttpHeaders responseHeaders = response.getHeaders(); HttpHeaders responseHeaders = response.getHeaders();
response.getHeaders().add(HttpHeaders.VARY, HttpHeaders.ORIGIN); response.getHeaders().addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
String requestOrigin = request.getHeaders().getOrigin(); String requestOrigin = request.getHeaders().getOrigin();
String allowOrigin = checkOrigin(config, requestOrigin); String allowOrigin = checkOrigin(config, requestOrigin);
......
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -26,6 +26,7 @@ import org.springframework.http.HttpMethod; ...@@ -26,6 +26,7 @@ import org.springframework.http.HttpMethod;
import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.mock.web.test.MockHttpServletResponse;
import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.*; import static org.junit.Assert.*;
/** /**
...@@ -65,7 +66,8 @@ public class DefaultCorsProcessorTests { ...@@ -65,7 +66,8 @@ public class DefaultCorsProcessorTests {
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus());
} }
...@@ -90,7 +92,8 @@ public class DefaultCorsProcessorTests { ...@@ -90,7 +92,8 @@ public class DefaultCorsProcessorTests {
assertEquals("*", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("*", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -108,7 +111,8 @@ public class DefaultCorsProcessorTests { ...@@ -108,7 +111,8 @@ public class DefaultCorsProcessorTests {
assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -124,7 +128,8 @@ public class DefaultCorsProcessorTests { ...@@ -124,7 +128,8 @@ public class DefaultCorsProcessorTests {
assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -136,7 +141,8 @@ public class DefaultCorsProcessorTests { ...@@ -136,7 +141,8 @@ public class DefaultCorsProcessorTests {
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -154,7 +160,8 @@ public class DefaultCorsProcessorTests { ...@@ -154,7 +160,8 @@ public class DefaultCorsProcessorTests {
assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1"));
assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2"));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -166,7 +173,8 @@ public class DefaultCorsProcessorTests { ...@@ -166,7 +173,8 @@ public class DefaultCorsProcessorTests {
this.conf.addAllowedOrigin("*"); this.conf.addAllowedOrigin("*");
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -178,7 +186,8 @@ public class DefaultCorsProcessorTests { ...@@ -178,7 +186,8 @@ public class DefaultCorsProcessorTests {
this.conf.addAllowedOrigin("*"); this.conf.addAllowedOrigin("*");
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus());
} }
...@@ -192,7 +201,8 @@ public class DefaultCorsProcessorTests { ...@@ -192,7 +201,8 @@ public class DefaultCorsProcessorTests {
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
assertEquals("GET,HEAD", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertEquals("GET,HEAD", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
} }
@Test @Test
...@@ -202,7 +212,8 @@ public class DefaultCorsProcessorTests { ...@@ -202,7 +212,8 @@ public class DefaultCorsProcessorTests {
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus());
} }
...@@ -214,7 +225,8 @@ public class DefaultCorsProcessorTests { ...@@ -214,7 +225,8 @@ public class DefaultCorsProcessorTests {
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus());
} }
...@@ -227,7 +239,8 @@ public class DefaultCorsProcessorTests { ...@@ -227,7 +239,8 @@ public class DefaultCorsProcessorTests {
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus());
} }
...@@ -249,7 +262,8 @@ public class DefaultCorsProcessorTests { ...@@ -249,7 +262,8 @@ public class DefaultCorsProcessorTests {
assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
assertEquals("GET,PUT", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertEquals("GET,PUT", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -270,7 +284,8 @@ public class DefaultCorsProcessorTests { ...@@ -270,7 +284,8 @@ public class DefaultCorsProcessorTests {
assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -289,7 +304,8 @@ public class DefaultCorsProcessorTests { ...@@ -289,7 +304,8 @@ public class DefaultCorsProcessorTests {
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -310,7 +326,8 @@ public class DefaultCorsProcessorTests { ...@@ -310,7 +326,8 @@ public class DefaultCorsProcessorTests {
assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3"));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -329,7 +346,8 @@ public class DefaultCorsProcessorTests { ...@@ -329,7 +346,8 @@ public class DefaultCorsProcessorTests {
assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*"));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
...@@ -345,7 +363,8 @@ public class DefaultCorsProcessorTests { ...@@ -345,7 +363,8 @@ public class DefaultCorsProcessorTests {
this.processor.processRequest(this.conf, this.request, this.response); this.processor.processRequest(this.conf, this.request, this.response);
assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS));
assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus());
} }
......
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -28,14 +28,18 @@ import org.springframework.mock.web.test.server.MockServerWebExchange; ...@@ -28,14 +28,18 @@ import org.springframework.mock.web.test.server.MockServerWebExchange;
import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS; import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS;
import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN; import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN;
import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS; import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS;
import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD; import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD;
import static org.springframework.http.HttpHeaders.ORIGIN;
import static org.springframework.http.HttpHeaders.VARY;
/** /**
* {@link DefaultCorsProcessor} tests with simple or pre-flight CORS request. * {@link DefaultCorsProcessor} tests with simple or pre-flight CORS request.
...@@ -63,7 +67,8 @@ public class DefaultCorsProcessorTests { ...@@ -63,7 +67,8 @@ public class DefaultCorsProcessorTests {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode());
} }
...@@ -88,7 +93,8 @@ public class DefaultCorsProcessorTests { ...@@ -88,7 +93,8 @@ public class DefaultCorsProcessorTests {
assertEquals("*", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("*", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -106,7 +112,8 @@ public class DefaultCorsProcessorTests { ...@@ -106,7 +112,8 @@ public class DefaultCorsProcessorTests {
assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -122,7 +129,8 @@ public class DefaultCorsProcessorTests { ...@@ -122,7 +129,8 @@ public class DefaultCorsProcessorTests {
assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -134,7 +142,8 @@ public class DefaultCorsProcessorTests { ...@@ -134,7 +142,8 @@ public class DefaultCorsProcessorTests {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -152,7 +161,8 @@ public class DefaultCorsProcessorTests { ...@@ -152,7 +161,8 @@ public class DefaultCorsProcessorTests {
assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
assertTrue(response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); assertTrue(response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1"));
assertTrue(response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); assertTrue(response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2"));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -164,7 +174,8 @@ public class DefaultCorsProcessorTests { ...@@ -164,7 +174,8 @@ public class DefaultCorsProcessorTests {
this.processor.process(this.conf, exchange); this.processor.process(this.conf, exchange);
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -177,7 +188,8 @@ public class DefaultCorsProcessorTests { ...@@ -177,7 +188,8 @@ public class DefaultCorsProcessorTests {
this.processor.process(this.conf, exchange); this.processor.process(this.conf, exchange);
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode());
} }
...@@ -190,7 +202,8 @@ public class DefaultCorsProcessorTests { ...@@ -190,7 +202,8 @@ public class DefaultCorsProcessorTests {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals("GET,HEAD", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertEquals("GET,HEAD", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
} }
...@@ -201,7 +214,8 @@ public class DefaultCorsProcessorTests { ...@@ -201,7 +214,8 @@ public class DefaultCorsProcessorTests {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode());
} }
...@@ -213,7 +227,8 @@ public class DefaultCorsProcessorTests { ...@@ -213,7 +227,8 @@ public class DefaultCorsProcessorTests {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode());
} }
...@@ -227,7 +242,8 @@ public class DefaultCorsProcessorTests { ...@@ -227,7 +242,8 @@ public class DefaultCorsProcessorTests {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode());
} }
...@@ -251,7 +267,8 @@ public class DefaultCorsProcessorTests { ...@@ -251,7 +267,8 @@ public class DefaultCorsProcessorTests {
assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
assertEquals("GET,PUT", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertEquals("GET,PUT", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -274,7 +291,8 @@ public class DefaultCorsProcessorTests { ...@@ -274,7 +291,8 @@ public class DefaultCorsProcessorTests {
assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -295,7 +313,8 @@ public class DefaultCorsProcessorTests { ...@@ -295,7 +313,8 @@ public class DefaultCorsProcessorTests {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -318,7 +337,8 @@ public class DefaultCorsProcessorTests { ...@@ -318,7 +337,8 @@ public class DefaultCorsProcessorTests {
assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
assertFalse(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); assertFalse(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3"));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -339,7 +359,8 @@ public class DefaultCorsProcessorTests { ...@@ -339,7 +359,8 @@ public class DefaultCorsProcessorTests {
assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
assertFalse(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); assertFalse(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("*"));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
...@@ -357,7 +378,8 @@ public class DefaultCorsProcessorTests { ...@@ -357,7 +378,8 @@ public class DefaultCorsProcessorTests {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN));
assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_HEADERS)); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_HEADERS));
assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); assertThat(response.getHeaders().get(VARY), contains(ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
assertNull(response.getStatusCode()); assertNull(response.getStatusCode());
} }
......
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -24,7 +24,9 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext ...@@ -24,7 +24,9 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
...@@ -34,8 +36,10 @@ import org.springframework.web.client.RestTemplate; ...@@ -34,8 +36,10 @@ import org.springframework.web.client.RestTemplate;
import org.springframework.web.reactive.config.CorsRegistry; import org.springframework.web.reactive.config.CorsRegistry;
import org.springframework.web.reactive.config.WebFluxConfigurationSupport; import org.springframework.web.reactive.config.WebFluxConfigurationSupport;
import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
/** /**
...@@ -101,12 +105,22 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte ...@@ -101,12 +105,22 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte
assertEquals("welcome", entity.getBody()); assertEquals("welcome", entity.getBody());
} }
@Test
public void actualRequestWithAmbiguousMapping() throws Exception {
this.headers.add(HttpHeaders.ACCEPT, MediaType.TEXT_HTML_VALUE);
ResponseEntity<String> entity = performGet("/ambiguous", this.headers, String.class);
assertEquals(HttpStatus.OK, entity.getStatusCode());
assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin());
}
@Test @Test
public void preFlightRequestWithCorsEnabled() throws Exception { public void preFlightRequestWithCorsEnabled() throws Exception {
this.headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
ResponseEntity<String> entity = performOptions("/cors", this.headers, String.class); ResponseEntity<String> entity = performOptions("/cors", this.headers, String.class);
assertEquals(HttpStatus.OK, entity.getStatusCode()); assertEquals(HttpStatus.OK, entity.getStatusCode());
assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin()); assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin());
assertThat(entity.getHeaders().getAccessControlAllowMethods(),
contains(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.POST));
} }
@Test @Test
...@@ -133,6 +147,28 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte ...@@ -133,6 +147,28 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte
} }
} }
@Test
public void preFlightRequestWithCorsRestricted() throws Exception {
this.headers.set(HttpHeaders.ORIGIN, "http://foo");
this.headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
ResponseEntity<String> entity = performOptions("/cors-restricted", this.headers, String.class);
assertEquals(HttpStatus.OK, entity.getStatusCode());
assertEquals("http://foo", entity.getHeaders().getAccessControlAllowOrigin());
assertThat(entity.getHeaders().getAccessControlAllowMethods(), contains(HttpMethod.GET, HttpMethod.POST));
}
@Test
public void preFlightRequestWithAmbiguousMapping() throws Exception {
this.headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
ResponseEntity<String> entity = performOptions("/ambiguous", this.headers, String.class);
assertEquals(HttpStatus.OK, entity.getStatusCode());
assertEquals("http://localhost:9000", entity.getHeaders().getAccessControlAllowOrigin());
assertThat(entity.getHeaders().getAccessControlAllowMethods(), contains(HttpMethod.GET));
assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials());
assertThat(entity.getHeaders().get(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS));
}
@Configuration @Configuration
@ComponentScan(resourcePattern = "**/GlobalCorsConfigIntegrationTests*.class") @ComponentScan(resourcePattern = "**/GlobalCorsConfigIntegrationTests*.class")
...@@ -141,8 +177,12 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte ...@@ -141,8 +177,12 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte
@Override @Override
protected void addCorsMappings(CorsRegistry registry) { protected void addCorsMappings(CorsRegistry registry) {
registry.addMapping("/cors-restricted").allowedOrigins("http://foo"); registry.addMapping("/cors-restricted")
.allowedOrigins("http://foo")
.allowedMethods("GET", "POST");
registry.addMapping("/cors"); registry.addMapping("/cors");
registry.addMapping("/ambiguous")
.allowedMethods("GET", "POST");
} }
} }
...@@ -163,6 +203,16 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte ...@@ -163,6 +203,16 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte
public String corsRestricted() { public String corsRestricted() {
return "corsRestricted"; return "corsRestricted";
} }
@GetMapping(value = "/ambiguous", produces = MediaType.TEXT_PLAIN_VALUE)
public String ambiguous1() {
return "ambiguous";
}
@GetMapping(value = "/ambiguous", produces = MediaType.TEXT_HTML_VALUE)
public String ambiguous2() {
return "<p>ambiguous</p>";
}
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册