Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
爱吃血肠
spring-framework
提交
68ecb92d
S
spring-framework
项目概览
爱吃血肠
/
spring-framework
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
S
spring-framework
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
68ecb92d
编写于
5月 03, 2015
作者:
R
Rossen Stoyanchev
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow "ws" and "wss" for isValidCorsOrigin checks
Issue: SPR-12956
上级
222f6998
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
117 addition
and
148 deletion
+117
-148
spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java
...va/org/springframework/web/util/UriComponentsBuilder.java
+23
-0
spring-web/src/main/java/org/springframework/web/util/WebUtils.java
.../src/main/java/org/springframework/web/util/WebUtils.java
+8
-17
spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java
...test/java/org/springframework/web/util/WebUtilsTests.java
+38
-53
spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java
...web/socket/server/support/OriginHandshakeInterceptor.java
+7
-10
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java
...work/web/socket/sockjs/support/AbstractSockJsService.java
+15
-18
spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java
...ocket/server/support/OriginHandshakeInterceptorTests.java
+14
-23
spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java
...t/sockjs/transport/handler/DefaultSockJsServiceTests.java
+12
-27
未找到文件。
spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java
浏览文件 @
68ecb92d
...
...
@@ -317,6 +317,29 @@ public class UriComponentsBuilder implements Cloneable {
}
/**
* Create an instance by parsing the "origin" header of an HTTP request.
*/
public
static
UriComponentsBuilder
fromOriginHeader
(
String
origin
)
{
UriComponentsBuilder
builder
=
UriComponentsBuilder
.
newInstance
();
if
(
StringUtils
.
hasText
(
origin
))
{
int
schemaIdx
=
origin
.
indexOf
(
"://"
);
String
schema
=
(
schemaIdx
!=
-
1
?
origin
.
substring
(
0
,
schemaIdx
)
:
"http"
);
builder
.
scheme
(
schema
);
String
hostString
=
(
schemaIdx
!=
-
1
?
origin
.
substring
(
schemaIdx
+
3
)
:
origin
);
if
(
hostString
.
contains
(
":"
))
{
String
[]
hostAndPort
=
StringUtils
.
split
(
hostString
,
":"
);
builder
.
host
(
hostAndPort
[
0
]);
builder
.
port
(
Integer
.
parseInt
(
hostAndPort
[
1
]));
}
else
{
builder
.
host
(
hostString
);
}
}
return
builder
;
}
// build methods
/**
...
...
spring-web/src/main/java/org/springframework/web/util/WebUtils.java
浏览文件 @
68ecb92d
...
...
@@ -23,6 +23,7 @@ import java.util.Enumeration;
import
java.util.Map
;
import
java.util.StringTokenizer
;
import
java.util.TreeMap
;
import
javax.servlet.ServletContext
;
import
javax.servlet.ServletRequest
;
import
javax.servlet.ServletRequestWrapper
;
...
...
@@ -38,6 +39,7 @@ import org.apache.commons.logging.LogFactory;
import
org.springframework.http.HttpRequest
;
import
org.springframework.util.Assert
;
import
org.springframework.util.CollectionUtils
;
import
org.springframework.util.LinkedMultiValueMap
;
import
org.springframework.util.MultiValueMap
;
import
org.springframework.util.StringUtils
;
...
...
@@ -790,21 +792,10 @@ public abstract class WebUtils {
if
(
origin
==
null
||
allowedOrigins
.
contains
(
"*"
))
{
return
true
;
}
else
if
(
allowedOrigins
.
isEmpty
())
{
UriComponents
originComponents
;
try
{
originComponents
=
UriComponentsBuilder
.
fromHttpUrl
(
origin
).
build
();
}
catch
(
IllegalArgumentException
ex
)
{
if
(
logger
.
isWarnEnabled
())
{
logger
.
warn
(
"Failed to parse Origin header value ["
+
origin
+
"]"
);
}
return
false
;
}
UriComponents
requestComponents
=
UriComponentsBuilder
.
fromHttpRequest
(
request
).
build
();
int
originPort
=
getPort
(
originComponents
);
int
requestPort
=
getPort
(
requestComponents
);
return
(
originComponents
.
getHost
().
equals
(
requestComponents
.
getHost
())
&&
originPort
==
requestPort
);
else
if
(
CollectionUtils
.
isEmpty
(
allowedOrigins
))
{
UriComponents
actualUrl
=
UriComponentsBuilder
.
fromHttpRequest
(
request
).
build
();
UriComponents
originUrl
=
UriComponentsBuilder
.
fromOriginHeader
(
origin
).
build
();
return
(
actualUrl
.
getHost
().
equals
(
originUrl
.
getHost
())
&&
getPort
(
actualUrl
)
==
getPort
(
originUrl
));
}
else
{
return
allowedOrigins
.
contains
(
origin
);
...
...
@@ -814,10 +805,10 @@ public abstract class WebUtils {
private
static
int
getPort
(
UriComponents
component
)
{
int
port
=
component
.
getPort
();
if
(
port
==
-
1
)
{
if
(
"http"
.
equals
(
component
.
getScheme
()))
{
if
(
"http"
.
equals
(
component
.
getScheme
())
||
"ws"
.
equals
(
component
.
getScheme
())
)
{
port
=
80
;
}
else
if
(
"https"
.
equals
(
component
.
getScheme
()))
{
else
if
(
"https"
.
equals
(
component
.
getScheme
())
||
"wss"
.
equals
(
component
.
getScheme
())
)
{
port
=
443
;
}
}
...
...
spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java
浏览文件 @
68ecb92d
...
...
@@ -16,8 +16,8 @@
package
org.springframework.web.util
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.HashMap
;
import
java.util.List
;
import
java.util.Map
;
...
...
@@ -106,60 +106,45 @@ public class WebUtilsTests {
}
@Test
public
void
isValidOrigin
()
{
List
<
String
>
allowedOrigins
=
new
ArrayList
<>();
public
void
isValidOriginSuccess
()
{
List
<
String
>
allowed
=
Collections
.
emptyList
();
assertTrue
(
checkOrigin
(
"mydomain1.com"
,
-
1
,
"http://mydomain1.com"
,
allowed
));
assertTrue
(
checkOrigin
(
"mydomain1.com"
,
-
1
,
"http://mydomain1.com:80"
,
allowed
));
assertTrue
(
checkOrigin
(
"mydomain1.com"
,
443
,
"https://mydomain1.com"
,
allowed
));
assertTrue
(
checkOrigin
(
"mydomain1.com"
,
443
,
"https://mydomain1.com:443"
,
allowed
));
assertTrue
(
checkOrigin
(
"mydomain1.com"
,
123
,
"http://mydomain1.com:123"
,
allowed
));
assertTrue
(
checkOrigin
(
"mydomain1.com"
,
-
1
,
"ws://mydomain1.com"
,
allowed
));
assertTrue
(
checkOrigin
(
"mydomain1.com"
,
443
,
"wss://mydomain1.com"
,
allowed
));
allowed
=
Collections
.
singletonList
(
"*"
);
assertTrue
(
checkOrigin
(
"mydomain1.com"
,
-
1
,
"http://mydomain2.com"
,
allowed
));
allowed
=
Collections
.
singletonList
(
"http://mydomain1.com"
);
assertTrue
(
checkOrigin
(
"mydomain2.com"
,
-
1
,
"http://mydomain1.com"
,
allowed
));
}
@Test
public
void
isValidOriginFailure
()
{
List
<
String
>
allowed
=
Collections
.
emptyList
();
assertFalse
(
checkOrigin
(
"mydomain1.com"
,
-
1
,
"http://mydomain2.com"
,
allowed
));
assertFalse
(
checkOrigin
(
"mydomain1.com"
,
-
1
,
"https://mydomain1.com"
,
allowed
));
assertFalse
(
checkOrigin
(
"mydomain1.com"
,
-
1
,
"invalid-origin"
,
allowed
));
allowed
=
Collections
.
singletonList
(
"http://mydomain1.com"
);
assertFalse
(
checkOrigin
(
"mydomain2.com"
,
-
1
,
"http://mydomain3.com"
,
allowed
));
}
private
boolean
checkOrigin
(
String
serverName
,
int
port
,
String
originHeader
,
List
<
String
>
allowed
)
{
MockHttpServletRequest
servletRequest
=
new
MockHttpServletRequest
();
ServerHttpRequest
request
=
new
ServletServerHttpRequest
(
servletRequest
);
servletRequest
.
setServerName
(
"mydomain1.com"
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"http://mydomain1.com"
);
assertTrue
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
servletRequest
.
setServerName
(
"mydomain1.com"
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"http://mydomain1.com:80"
);
assertTrue
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
servletRequest
.
setServerName
(
"mydomain1.com"
);
servletRequest
.
setServerPort
(
443
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"https://mydomain1.com"
);
assertTrue
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
servletRequest
.
setServerName
(
"mydomain1.com"
);
servletRequest
.
setServerPort
(
443
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"https://mydomain1.com:443"
);
assertTrue
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
servletRequest
.
setServerName
(
"mydomain1.com"
);
servletRequest
.
setServerPort
(
123
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"http://mydomain1.com:123"
);
assertTrue
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
servletRequest
.
setServerName
(
"mydomain1.com"
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"http://mydomain2.com"
);
assertFalse
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
servletRequest
.
setServerName
(
"mydomain1.com"
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"https://mydomain1.com"
);
assertFalse
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
servletRequest
.
setServerName
(
"invalid-origin"
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"invalid-origin"
);
assertFalse
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
allowedOrigins
=
Arrays
.
asList
(
"*"
);
servletRequest
.
setServerName
(
"mydomain1.com"
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"http://mydomain2.com"
);
assertTrue
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
allowedOrigins
=
Arrays
.
asList
(
"http://mydomain1.com"
);
servletRequest
.
setServerName
(
"mydomain2.com"
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"http://mydomain1.com"
);
assertTrue
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
allowedOrigins
=
Arrays
.
asList
(
"http://mydomain1.com"
);
servletRequest
.
setServerName
(
"mydomain2.com"
);
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
"http://mydomain3.com"
);
assertFalse
(
WebUtils
.
isValidOrigin
(
request
,
allowedOrigins
));
servletRequest
.
setServerName
(
serverName
);
if
(
port
!=
-
1
)
{
servletRequest
.
setServerPort
(
port
);
}
request
.
getHeaders
().
set
(
HttpHeaders
.
ORIGIN
,
originHeader
);
return
WebUtils
.
isValidOrigin
(
request
,
allowed
);
}
}
spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java
浏览文件 @
68ecb92d
...
...
@@ -65,22 +65,18 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
}
/**
* Configure allowed {@code Origin} header values. This check is mostly
designed for
*
browser clients. There is nothing preventing other types of client to modify the
* {@code Origin} header value.
* Configure allowed {@code Origin} header values. This check is mostly
*
designed for browsers. There is nothing preventing other types of client
*
to modify the
{@code Origin} header value.
*
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed).
* <p>Each provided allowed origin must have a scheme, and optionally a port
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
*
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
*/
public
void
setAllowedOrigins
(
Collection
<
String
>
allowedOrigins
)
{
Assert
.
notNull
(
allowedOrigins
,
"Allowed origin Collection must not be null"
);
for
(
String
allowedOrigin
:
allowedOrigins
)
{
Assert
.
isTrue
(
allowedOrigin
.
equals
(
"*"
)
||
allowedOrigin
.
startsWith
(
"http://"
)
||
allowedOrigin
.
startsWith
(
"https://"
),
"Invalid allowed origin provided: \""
+
allowedOrigin
+
"\". It must start with \"http://\", \"https://\" or be \"*\""
);
}
this
.
allowedOrigins
.
clear
();
this
.
allowedOrigins
.
addAll
(
allowedOrigins
);
}
...
...
@@ -93,6 +89,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
return
Collections
.
unmodifiableList
(
this
.
allowedOrigins
);
}
@Override
public
boolean
beforeHandshake
(
ServerHttpRequest
request
,
ServerHttpResponse
response
,
WebSocketHandler
wsHandler
,
Map
<
String
,
Object
>
attributes
)
throws
Exception
{
...
...
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java
浏览文件 @
68ecb92d
...
...
@@ -276,16 +276,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
}
/**
* Configure allowed {@code Origin} header values. This check is mostly
designed for
*
browser clients. There is nothing preventing other types of client to modify the
* {@code Origin} header value.
* Configure allowed {@code Origin} header values. This check is mostly
*
designed for browsers. There is nothing preventing other types of client
*
to modify the
{@code Origin} header value.
*
* <p>When SockJS is enabled and origins are restricted, transport types that do not
* allow to check request origin (JSONP and Iframe based transports) are disabled.
* As a consequence, IE 6 to 9 are not supported when origins are restricted.
* <p>When SockJS is enabled and origins are restricted, transport types
* that do not allow to check request origin (JSONP and Iframe based
* transports) are disabled. As a consequence, IE 6 to 9 are not supported
* when origins are restricted.
*
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed).
* <p>Each provided allowed origin must have a scheme, and optionally a port
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
*
* @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
...
...
@@ -293,14 +295,6 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
*/
public
void
setAllowedOrigins
(
List
<
String
>
allowedOrigins
)
{
Assert
.
notNull
(
allowedOrigins
,
"Allowed origin List must not be null"
);
for
(
String
allowedOrigin
:
allowedOrigins
)
{
Assert
.
isTrue
(
allowedOrigin
.
equals
(
"*"
)
||
allowedOrigin
.
startsWith
(
"http://"
)
||
allowedOrigin
.
startsWith
(
"https://"
),
"Invalid allowed origin provided: \""
+
allowedOrigin
+
"\". It must start with \"http://\", \"https://\" or be \"*\""
);
}
this
.
allowedOrigins
.
clear
();
this
.
allowedOrigins
.
addAll
(
allowedOrigins
);
}
...
...
@@ -451,7 +445,9 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
protected
abstract
void
handleTransportRequest
(
ServerHttpRequest
request
,
ServerHttpResponse
response
,
WebSocketHandler
webSocketHandler
,
String
sessionId
,
String
transport
)
throws
SockJsException
;
protected
boolean
checkOrigin
(
ServerHttpRequest
request
,
ServerHttpResponse
response
,
HttpMethod
...
httpMethods
)
throws
IOException
{
protected
boolean
checkOrigin
(
ServerHttpRequest
request
,
ServerHttpResponse
response
,
HttpMethod
...
httpMethods
)
throws
IOException
{
String
origin
=
request
.
getHeaders
().
getOrigin
();
if
(
origin
==
null
)
{
...
...
@@ -514,7 +510,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
addNoCacheHeaders
(
response
);
if
(
checkOrigin
(
request
,
response
))
{
response
.
getHeaders
().
setContentType
(
new
MediaType
(
"application"
,
"json"
,
UTF8_CHARSET
));
String
content
=
String
.
format
(
INFO_CONTENT
,
random
.
nextInt
(),
isSessionCookieNeeded
(),
isWebSocketEnabled
());
String
content
=
String
.
format
(
INFO_CONTENT
,
random
.
nextInt
(),
isSessionCookieNeeded
(),
isWebSocketEnabled
());
response
.
getBody
().
write
(
content
.
getBytes
());
}
...
...
spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java
浏览文件 @
68ecb92d
...
...
@@ -17,7 +17,9 @@
package
org.springframework.web.socket.server.support
;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.HashMap
;
import
java.util.List
;
import
java.util.Map
;
import
java.util.Set
;
import
java.util.concurrent.ConcurrentSkipListSet
;
...
...
@@ -39,31 +41,17 @@ import org.springframework.web.socket.WebSocketHandler;
public
class
OriginHandshakeInterceptorTests
extends
AbstractHttpRequestTests
{
@Test
(
expected
=
IllegalArgumentException
.
class
)
public
void
nullAllowedOriginLis
t
()
{
public
void
invalidInpu
t
()
{
new
OriginHandshakeInterceptor
(
null
);
}
@Test
(
expected
=
IllegalArgumentException
.
class
)
public
void
invalidAllowedOrigin
()
{
new
OriginHandshakeInterceptor
(
Arrays
.
asList
(
"domain.com"
));
}
@Test
public
void
emtpyAllowedOriginList
()
{
new
OriginHandshakeInterceptor
(
Arrays
.
asList
());
}
@Test
public
void
validAllowedOrigins
()
{
new
OriginHandshakeInterceptor
(
Arrays
.
asList
(
"http://domain.com"
,
"https://domain.com"
,
"*"
));
}
@Test
public
void
originValueMatch
()
throws
Exception
{
Map
<
String
,
Object
>
attributes
=
new
HashMap
<
String
,
Object
>();
WebSocketHandler
wsHandler
=
Mockito
.
mock
(
WebSocketHandler
.
class
);
this
.
servletRequest
.
addHeader
(
HttpHeaders
.
ORIGIN
,
"http://mydomain1.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Arrays
.
asList
(
"http://mydomain1.com"
));
List
<
String
>
allowed
=
Collections
.
singletonList
(
"http://mydomain1.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
allowed
);
assertTrue
(
interceptor
.
beforeHandshake
(
request
,
response
,
wsHandler
,
attributes
));
assertNotEquals
(
servletResponse
.
getStatus
(),
HttpStatus
.
FORBIDDEN
.
value
());
}
...
...
@@ -73,7 +61,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map
<
String
,
Object
>
attributes
=
new
HashMap
<
String
,
Object
>();
WebSocketHandler
wsHandler
=
Mockito
.
mock
(
WebSocketHandler
.
class
);
this
.
servletRequest
.
addHeader
(
HttpHeaders
.
ORIGIN
,
"http://mydomain1.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Arrays
.
asList
(
"http://mydomain2.com"
));
List
<
String
>
allowed
=
Collections
.
singletonList
(
"http://mydomain2.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
allowed
);
assertFalse
(
interceptor
.
beforeHandshake
(
request
,
response
,
wsHandler
,
attributes
));
assertEquals
(
servletResponse
.
getStatus
(),
HttpStatus
.
FORBIDDEN
.
value
());
}
...
...
@@ -83,7 +72,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map
<
String
,
Object
>
attributes
=
new
HashMap
<
String
,
Object
>();
WebSocketHandler
wsHandler
=
Mockito
.
mock
(
WebSocketHandler
.
class
);
this
.
servletRequest
.
addHeader
(
HttpHeaders
.
ORIGIN
,
"http://mydomain2.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Arrays
.
asList
(
"http://mydomain1.com"
,
"http://mydomain2.com"
,
"http://mydomain3.com"
));
List
<
String
>
allowed
=
Arrays
.
asList
(
"http://mydomain1.com"
,
"http://mydomain2.com"
,
"http://mydomain3.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
allowed
);
assertTrue
(
interceptor
.
beforeHandshake
(
request
,
response
,
wsHandler
,
attributes
));
assertNotEquals
(
servletResponse
.
getStatus
(),
HttpStatus
.
FORBIDDEN
.
value
());
}
...
...
@@ -93,7 +83,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map
<
String
,
Object
>
attributes
=
new
HashMap
<
String
,
Object
>();
WebSocketHandler
wsHandler
=
Mockito
.
mock
(
WebSocketHandler
.
class
);
this
.
servletRequest
.
addHeader
(
HttpHeaders
.
ORIGIN
,
"http://mydomain4.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Arrays
.
asList
(
"http://mydomain1.com"
,
"http://mydomain2.com"
,
"http://mydomain3.com"
));
List
<
String
>
allowed
=
Arrays
.
asList
(
"http://mydomain1.com"
,
"http://mydomain2.com"
,
"http://mydomain3.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
allowed
);
assertFalse
(
interceptor
.
beforeHandshake
(
request
,
response
,
wsHandler
,
attributes
));
assertEquals
(
servletResponse
.
getStatus
(),
HttpStatus
.
FORBIDDEN
.
value
());
}
...
...
@@ -117,7 +108,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler
wsHandler
=
Mockito
.
mock
(
WebSocketHandler
.
class
);
this
.
servletRequest
.
addHeader
(
HttpHeaders
.
ORIGIN
,
"http://mydomain1.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
();
interceptor
.
setAllowedOrigins
(
Arrays
.
as
List
(
"*"
));
interceptor
.
setAllowedOrigins
(
Collections
.
singleton
List
(
"*"
));
assertTrue
(
interceptor
.
beforeHandshake
(
request
,
response
,
wsHandler
,
attributes
));
assertNotEquals
(
servletResponse
.
getStatus
(),
HttpStatus
.
FORBIDDEN
.
value
());
}
...
...
@@ -128,7 +119,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler
wsHandler
=
Mockito
.
mock
(
WebSocketHandler
.
class
);
this
.
servletRequest
.
addHeader
(
HttpHeaders
.
ORIGIN
,
"http://mydomain2.com"
);
this
.
servletRequest
.
setServerName
(
"mydomain2.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Arrays
.
as
List
());
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Collections
.
empty
List
());
assertTrue
(
interceptor
.
beforeHandshake
(
request
,
response
,
wsHandler
,
attributes
));
assertNotEquals
(
servletResponse
.
getStatus
(),
HttpStatus
.
FORBIDDEN
.
value
());
}
...
...
@@ -139,7 +130,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler
wsHandler
=
Mockito
.
mock
(
WebSocketHandler
.
class
);
this
.
servletRequest
.
addHeader
(
HttpHeaders
.
ORIGIN
,
"http://mydomain3.com"
);
this
.
servletRequest
.
setServerName
(
"mydomain2.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Arrays
.
as
List
());
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Collections
.
empty
List
());
assertFalse
(
interceptor
.
beforeHandshake
(
request
,
response
,
wsHandler
,
attributes
));
assertEquals
(
servletResponse
.
getStatus
(),
HttpStatus
.
FORBIDDEN
.
value
());
}
...
...
spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java
浏览文件 @
68ecb92d
...
...
@@ -16,11 +16,14 @@
package
org.springframework.web.socket.sockjs.transport.handler
;
import
static
org
.
junit
.
Assert
.*;
import
static
org
.
mockito
.
BDDMockito
.*;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.List
;
import
java.util.Map
;
import
org.hamcrest.Matchers
;
import
org.junit.Before
;
import
org.junit.Test
;
import
org.mockito.Mock
;
...
...
@@ -41,9 +44,6 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
import
org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig
;
import
org.springframework.web.socket.sockjs.transport.session.TestSockJsSession
;
import
static
org
.
junit
.
Assert
.*;
import
static
org
.
mockito
.
BDDMockito
.*;
/**
* Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}.
*
...
...
@@ -125,26 +125,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
}
@Test
(
expected
=
IllegalArgumentException
.
class
)
public
void
nullAllowedOriginList
()
{
public
void
invalidAllowedOrigins
()
{
this
.
service
.
setAllowedOrigins
(
null
);
}
@Test
public
void
emptyAllowedOriginList
()
{
this
.
service
.
setAllowedOrigins
(
Arrays
.
asList
());
assertThat
(
this
.
service
.
getAllowedOrigins
(),
Matchers
.
empty
());
}
@Test
(
expected
=
IllegalArgumentException
.
class
)
public
void
invalidAllowedOrigin
()
{
this
.
service
.
setAllowedOrigins
(
Arrays
.
asList
(
"domain.com"
));
}
@Test
public
void
validAllowedOrigins
()
{
this
.
service
.
setAllowedOrigins
(
Arrays
.
asList
(
"http://domain.com"
,
"https://domain.com"
,
"*"
));
}
@Test
public
void
customizedTransportHandlerList
()
{
TransportHandlingSockJsService
service
=
new
TransportHandlingSockJsService
(
...
...
@@ -268,13 +252,13 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertEquals
(
404
,
this
.
servletResponse
.
getStatus
());
resetRequestAndResponse
();
jsonpService
.
setAllowedOrigins
(
Arrays
.
as
List
(
"http://mydomain1.com"
));
jsonpService
.
setAllowedOrigins
(
Collections
.
singleton
List
(
"http://mydomain1.com"
));
setRequest
(
"GET"
,
sockJsPrefix
+
sockJsPath
);
jsonpService
.
handleRequest
(
this
.
request
,
this
.
response
,
sockJsPath
,
this
.
wsHandler
);
assertEquals
(
404
,
this
.
servletResponse
.
getStatus
());
resetRequestAndResponse
();
jsonpService
.
setAllowedOrigins
(
Arrays
.
as
List
(
"*"
));
jsonpService
.
setAllowedOrigins
(
Collections
.
singleton
List
(
"*"
));
setRequest
(
"GET"
,
sockJsPrefix
+
sockJsPath
);
jsonpService
.
handleRequest
(
this
.
request
,
this
.
response
,
sockJsPath
,
this
.
wsHandler
);
assertNotEquals
(
404
,
this
.
servletResponse
.
getStatus
());
...
...
@@ -289,8 +273,9 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertNotEquals
(
403
,
this
.
servletResponse
.
getStatus
());
resetRequestAndResponse
();
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
Arrays
.
asList
(
"http://mydomain1.com"
));
wsService
.
setHandshakeInterceptors
(
Arrays
.
asList
(
interceptor
));
List
<
String
>
allowed
=
Collections
.
singletonList
(
"http://mydomain1.com"
);
OriginHandshakeInterceptor
interceptor
=
new
OriginHandshakeInterceptor
(
allowed
);
wsService
.
setHandshakeInterceptors
(
Collections
.
singletonList
(
interceptor
));
setRequest
(
"GET"
,
sockJsPrefix
+
sockJsPath
);
this
.
servletRequest
.
addHeader
(
HttpHeaders
.
ORIGIN
,
"http://mydomain1.com"
);
wsService
.
handleRequest
(
this
.
request
,
this
.
response
,
sockJsPath
,
this
.
wsHandler
);
...
...
@@ -313,14 +298,14 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
resetRequestAndResponse
();
setRequest
(
"GET"
,
sockJsPrefix
+
sockJsPath
);
this
.
service
.
setAllowedOrigins
(
Arrays
.
as
List
(
"http://mydomain1.com"
));
this
.
service
.
setAllowedOrigins
(
Collections
.
singleton
List
(
"http://mydomain1.com"
));
this
.
service
.
handleRequest
(
this
.
request
,
this
.
response
,
sockJsPath
,
this
.
wsHandler
);
assertEquals
(
404
,
this
.
servletResponse
.
getStatus
());
assertNull
(
this
.
servletResponse
.
getHeader
(
"X-Frame-Options"
));
resetRequestAndResponse
();
setRequest
(
"GET"
,
sockJsPrefix
+
sockJsPath
);
this
.
service
.
setAllowedOrigins
(
Arrays
.
as
List
(
"*"
));
this
.
service
.
setAllowedOrigins
(
Collections
.
singleton
List
(
"*"
));
this
.
service
.
handleRequest
(
this
.
request
,
this
.
response
,
sockJsPath
,
this
.
wsHandler
);
assertNotEquals
(
404
,
this
.
servletResponse
.
getStatus
());
assertNull
(
this
.
servletResponse
.
getHeader
(
"X-Frame-Options"
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录