HandshakeWebSocketService.java 9.1 KB
Newer Older
1
/*
2
 * Copyright 2002-2018 the original author or authors.
3 4 5 6 7 8 9 10 11 12 13 14 15
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
16

17 18
package org.springframework.web.reactive.socket.server.support;

19
import java.net.InetSocketAddress;
20
import java.net.URI;
21
import java.security.Principal;
22 23
import java.util.Collections;
import java.util.List;
24 25 26
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
27 28 29 30 31

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono;

32
import org.springframework.context.Lifecycle;
R
Rossen Stoyanchev 已提交
33
import org.springframework.http.HttpHeaders;
34 35
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
36
import org.springframework.lang.Nullable;
37 38 39
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
40
import org.springframework.util.StringUtils;
41
import org.springframework.web.reactive.socket.HandshakeInfo;
42 43 44 45 46
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.MethodNotAllowedException;
import org.springframework.web.server.ServerWebExchange;
R
Rossen Stoyanchev 已提交
47
import org.springframework.web.server.ServerWebInputException;
48 49

/**
50 51 52 53
 * {@code WebSocketService} implementation that handles a WebSocket HTTP
 * handshake request by delegating to a {@link RequestUpgradeStrategy} which
 * is either auto-detected (no-arg constructor) from the classpath but can
 * also be explicitly configured.
54 55 56 57
 *
 * @author Rossen Stoyanchev
 * @since 5.0
 */
58
public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
59 60 61

	private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key";

62 63
	private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";

64 65
	private static final Mono<Map<String, Object>> EMPTY_ATTRIBUTES = Mono.just(Collections.emptyMap());

66

67
	private static final boolean tomcatPresent;
68

69
	private static final boolean jettyPresent;
70

71
	private static final boolean undertowPresent;
72

73 74 75 76 77 78 79 80 81
	private static final boolean reactorNettyPresent;

	static {
		ClassLoader classLoader = HandshakeWebSocketService.class.getClassLoader();
		tomcatPresent = ClassUtils.isPresent("org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", classLoader);
		jettyPresent = ClassUtils.isPresent("org.eclipse.jetty.websocket.server.WebSocketServerFactory", classLoader);
		undertowPresent = ClassUtils.isPresent("io.undertow.websockets.WebSocketProtocolHandshakeHandler", classLoader);
		reactorNettyPresent = ClassUtils.isPresent("reactor.netty.http.server.HttpServerResponse", classLoader);
	}
82

83 84 85 86 87 88

	protected static final Log logger = LogFactory.getLog(HandshakeWebSocketService.class);


	private final RequestUpgradeStrategy upgradeStrategy;

89 90 91
	@Nullable
	private Predicate<String> sessionAttributePredicate;

92 93
	private volatile boolean running = false;

94 95 96 97 98 99 100 101 102 103 104 105 106 107

	/**
	 * Default constructor automatic, classpath detection based discovery of the
	 * {@link RequestUpgradeStrategy} to use.
	 */
	public HandshakeWebSocketService() {
		this(initUpgradeStrategy());
	}

	/**
	 * Alternative constructor with the {@link RequestUpgradeStrategy} to use.
	 * @param upgradeStrategy the strategy to use
	 */
	public HandshakeWebSocketService(RequestUpgradeStrategy upgradeStrategy) {
108
		Assert.notNull(upgradeStrategy, "RequestUpgradeStrategy is required");
109 110 111 112 113
		this.upgradeStrategy = upgradeStrategy;
	}

	private static RequestUpgradeStrategy initUpgradeStrategy() {
		String className;
114 115 116 117 118 119 120 121 122
		if (tomcatPresent) {
			className = "TomcatRequestUpgradeStrategy";
		}
		else if (jettyPresent) {
			className = "JettyRequestUpgradeStrategy";
		}
		else if (undertowPresent) {
			className = "UndertowRequestUpgradeStrategy";
		}
123 124 125 126
		else if (reactorNettyPresent) {
			// As late as possible (Reactor Netty commonly used for WebClient)
			className = "ReactorNettyRequestUpgradeStrategy";
		}
127 128 129 130 131
		else {
			throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
		}

		try {
132
			className = "org.springframework.web.reactive.socket.server.upgrade." + className;
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
			Class<?> clazz = ClassUtils.forName(className, HandshakeWebSocketService.class.getClassLoader());
			return (RequestUpgradeStrategy) ReflectionUtils.accessibleConstructor(clazz).newInstance();
		}
		catch (Throwable ex) {
			throw new IllegalStateException(
					"Failed to instantiate RequestUpgradeStrategy: " + className, ex);
		}
	}


	/**
	 * Return the {@link RequestUpgradeStrategy} for WebSocket requests.
	 */
	public RequestUpgradeStrategy getUpgradeStrategy() {
		return this.upgradeStrategy;
	}

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
	/**
	 * Configure a predicate to use to extract
	 * {@link org.springframework.web.server.WebSession WebSession} attributes
	 * and use them to initialize the WebSocket session with.
	 * <p>By default this is not set in which case no attributes are passed.
	 * @param predicate the predicate
	 * @since 5.1
	 */
	public void setSessionAttributePredicate(@Nullable Predicate<String> predicate) {
		this.sessionAttributePredicate = predicate;
	}

	/**
	 * Return the configured predicate for initialization WebSocket session
	 * attributes from {@code WebSession} attributes.
	 * @since 5.1
	 */
	@Nullable
	public Predicate<String> getSessionAttributePredicate() {
		return this.sessionAttributePredicate;
	}

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200

	@Override
	public void start() {
		if (!isRunning()) {
			this.running = true;
			doStart();
		}
	}

	protected void doStart() {
		if (getUpgradeStrategy() instanceof Lifecycle) {
			((Lifecycle) getUpgradeStrategy()).start();
		}
	}

	@Override
	public void stop() {
		if (isRunning()) {
			this.running = false;
			doStop();
		}
	}

	protected void doStop() {
		if (getUpgradeStrategy() instanceof Lifecycle) {
			((Lifecycle) getUpgradeStrategy()).stop();
		}
	}

201 202 203 204 205
	@Override
	public boolean isRunning() {
		return this.running;
	}

206 207

	@Override
208
	public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler handler) {
209
		ServerHttpRequest request = exchange.getRequest();
R
Rossen Stoyanchev 已提交
210 211
		HttpMethod method = request.getMethod();
		HttpHeaders headers = request.getHeaders();
212

R
Rossen Stoyanchev 已提交
213
		if (HttpMethod.GET != method) {
214 215
			return Mono.error(new MethodNotAllowedException(
					request.getMethodValue(), Collections.singleton(HttpMethod.GET)));
216 217
		}

R
Rossen Stoyanchev 已提交
218
		if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
219
			return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
220 221
		}

R
Rossen Stoyanchev 已提交
222
		List<String> connectionValue = headers.getConnection();
223
		if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
224
			return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
225
		}
R
Rossen Stoyanchev 已提交
226 227

		String key = headers.getFirst(SEC_WEBSOCKET_KEY);
228
		if (key == null) {
229
			return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
R
Rossen Stoyanchev 已提交
230 231
		}

232
		String protocol = selectProtocol(headers, handler);
233 234 235 236 237

		return initAttributes(exchange).flatMap(attributes ->
				this.upgradeStrategy.upgrade(exchange, handler, protocol,
						() -> createHandshakeInfo(exchange, request, protocol, attributes))
		);
R
Rossen Stoyanchev 已提交
238 239
	}

240
	private Mono<Void> handleBadRequest(ServerWebExchange exchange, String reason) {
R
Rossen Stoyanchev 已提交
241
		if (logger.isDebugEnabled()) {
242
			logger.debug(exchange.getLogPrefix() + reason);
243
		}
R
Rossen Stoyanchev 已提交
244
		return Mono.error(new ServerWebInputException(reason));
245 246
	}

247
	@Nullable
248
	private String selectProtocol(HttpHeaders headers, WebSocketHandler handler) {
R
Rossen Stoyanchev 已提交
249
		String protocolHeader = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
250 251 252 253 254 255 256 257 258
		if (protocolHeader != null) {
			List<String> supportedProtocols = handler.getSubProtocols();
			for (String protocol : StringUtils.commaDelimitedListToStringArray(protocolHeader)) {
				if (supportedProtocols.contains(protocol)) {
					return protocol;
				}
			}
		}
		return null;
259 260
	}

261 262 263 264 265 266 267 268 269 270 271 272 273
	private Mono<Map<String, Object>> initAttributes(ServerWebExchange exchange) {
		if (this.sessionAttributePredicate == null) {
			return EMPTY_ATTRIBUTES;
		}
		return exchange.getSession().map(session ->
				session.getAttributes().entrySet().stream()
						.filter(entry -> this.sessionAttributePredicate.test(entry.getKey()))
						.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
	}

	private HandshakeInfo createHandshakeInfo(ServerWebExchange exchange, ServerHttpRequest request,
			@Nullable String protocol, Map<String, Object> attributes) {

274 275
		URI uri = request.getURI();
		HttpHeaders headers = request.getHeaders();
276
		Mono<Principal> principal = exchange.getPrincipal();
277
		String logPrefix = exchange.getLogPrefix();
278 279
		InetSocketAddress remoteAddress = request.getRemoteAddress();
		return new HandshakeInfo(uri, headers, principal, protocol, remoteAddress, attributes, logPrefix);
280 281
	}

282
}