提交 6eb0e9e4 编写于 作者: R Rossen Stoyanchev

Unwrap decorated request or response

Closes: gh-23598
上级 9db41181
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -18,6 +18,7 @@ package org.springframework.web.reactive.socket.server.upgrade;
import java.io.IOException;
import java.util.function.Supplier;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
......@@ -32,7 +33,9 @@ import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.server.reactive.AbstractServerHttpRequest;
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
......@@ -144,8 +147,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(response);
HttpServletRequest servletRequest = getNativeRequest(request);
HttpServletResponse servletResponse = getNativeResponse(response);
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
DataBufferFactory factory = response.bufferFactory();
......@@ -173,14 +176,30 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
return Mono.empty();
}
private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
Assert.isInstanceOf(AbstractServerHttpRequest.class, request);
return ((AbstractServerHttpRequest) request).getNativeRequest();
private static HttpServletRequest getNativeRequest(ServerHttpRequest request) {
if (request instanceof AbstractServerHttpRequest) {
return ((AbstractServerHttpRequest) request).getNativeRequest();
}
else if (request instanceof ServerHttpRequestDecorator) {
return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
}
else {
throw new IllegalArgumentException(
"Couldn't find HttpServletRequest in " + request.getClass().getName());
}
}
private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
Assert.isInstanceOf(AbstractServerHttpResponse.class, response);
return ((AbstractServerHttpResponse) response).getNativeResponse();
private static HttpServletResponse getNativeResponse(ServerHttpResponse response) {
if (response instanceof AbstractServerHttpResponse) {
return ((AbstractServerHttpResponse) response).getNativeResponse();
}
else if (response instanceof ServerHttpResponseDecorator) {
return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate());
}
else {
throw new IllegalArgumentException(
"Couldn't find HttpServletResponse in " + response.getClass().getName());
}
}
private void startLazily(HttpServletRequest request) {
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -24,6 +24,7 @@ import reactor.netty.http.server.HttpServerResponse;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
......@@ -72,7 +73,7 @@ public class ReactorNettyRequestUpgradeStrategy implements RequestUpgradeStrateg
@Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory) {
ServerHttpResponse response = exchange.getResponse();
HttpServerResponse reactorResponse = ((AbstractServerHttpResponse) response).getNativeResponse();
HttpServerResponse reactorResponse = getNativeResponse(response);
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
NettyDataBufferFactory bufferFactory = (NettyDataBufferFactory) response.bufferFactory();
......@@ -85,4 +86,17 @@ public class ReactorNettyRequestUpgradeStrategy implements RequestUpgradeStrateg
});
}
private static HttpServerResponse getNativeResponse(ServerHttpResponse response) {
if (response instanceof AbstractServerHttpResponse) {
return ((AbstractServerHttpResponse) response).getNativeResponse();
}
else if (response instanceof ServerHttpResponseDecorator) {
return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate());
}
else {
throw new IllegalArgumentException(
"Couldn't find native response in " + response.getClass().getName());
}
}
}
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -19,6 +19,7 @@ package org.springframework.web.reactive.socket.server.upgrade;
import java.io.IOException;
import java.util.Collections;
import java.util.function.Supplier;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
......@@ -32,7 +33,9 @@ import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.server.reactive.AbstractServerHttpRequest;
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
......@@ -130,8 +133,8 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(response);
HttpServletRequest servletRequest = getNativeRequest(request);
HttpServletResponse servletResponse = getNativeResponse(response);
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
DataBufferFactory bufferFactory = response.bufferFactory();
......@@ -155,14 +158,30 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
return Mono.empty();
}
private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
Assert.isInstanceOf(AbstractServerHttpRequest.class, request, "ServletServerHttpRequest required");
return ((AbstractServerHttpRequest) request).getNativeRequest();
private static HttpServletRequest getNativeRequest(ServerHttpRequest request) {
if (request instanceof AbstractServerHttpRequest) {
return ((AbstractServerHttpRequest) request).getNativeRequest();
}
else if (request instanceof ServerHttpRequestDecorator) {
return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
}
else {
throw new IllegalArgumentException(
"Couldn't find HttpServletRequest in " + request.getClass().getName());
}
}
private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
Assert.isInstanceOf(AbstractServerHttpResponse.class, response, "ServletServerHttpResponse required");
return ((AbstractServerHttpResponse) response).getNativeResponse();
private static HttpServletResponse getNativeResponse(ServerHttpResponse response) {
if (response instanceof AbstractServerHttpResponse) {
return ((AbstractServerHttpResponse) response).getNativeResponse();
}
else if (response instanceof ServerHttpResponseDecorator) {
return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate());
}
else {
throw new IllegalArgumentException(
"Couldn't find HttpServletResponse in " + response.getClass().getName());
}
}
private WsServerContainer getContainer(HttpServletRequest request) {
......
......@@ -33,6 +33,7 @@ import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.server.reactive.AbstractServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
......@@ -55,9 +56,7 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
@Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory) {
ServerHttpRequest request = exchange.getRequest();
Assert.isInstanceOf(AbstractServerHttpRequest.class, request);
HttpServerExchange httpExchange = ((AbstractServerHttpRequest) request).getNativeRequest();
HttpServerExchange httpExchange = getNativeRequest(exchange.getRequest());
Set<String> protocols = (subProtocol != null ? Collections.singleton(subProtocol) : Collections.emptySet());
Hybi13Handshake handshake = new Hybi13Handshake(protocols, false);
......@@ -77,6 +76,19 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
return Mono.empty();
}
private static HttpServerExchange getNativeRequest(ServerHttpRequest request) {
if (request instanceof AbstractServerHttpRequest) {
return ((AbstractServerHttpRequest) request).getNativeRequest();
}
else if (request instanceof ServerHttpRequestDecorator) {
return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
}
else {
throw new IllegalArgumentException(
"Couldn't find HttpServerExchange in " + request.getClass().getName());
}
}
private class DefaultCallback implements WebSocketConnectionCallback {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册