提交 d47c7f95 编写于 作者: S Sébastien Deleuze

Propagate CoroutineContext in coRouter filters

Closes gh-26977
上级 bcf11e89
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.springframework.web.reactive.function.server package org.springframework.web.reactive.function.server
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingle import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.mono import kotlinx.coroutines.reactor.mono
import org.springframework.core.io.Resource import org.springframework.core.io.Resource
...@@ -24,7 +26,9 @@ import org.springframework.http.HttpMethod ...@@ -24,7 +26,9 @@ import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatusCode import org.springframework.http.HttpStatusCode
import org.springframework.http.MediaType import org.springframework.http.MediaType
import org.springframework.web.reactive.function.server.RouterFunctions.nest import org.springframework.web.reactive.function.server.RouterFunctions.nest
import reactor.core.publisher.Mono
import java.net.URI import java.net.URI
import kotlin.coroutines.CoroutineContext
/** /**
* Allow to create easily a WebFlux.fn [RouterFunction] with a [Coroutines router Kotlin DSL][CoRouterFunctionDsl]. * Allow to create easily a WebFlux.fn [RouterFunction] with a [Coroutines router Kotlin DSL][CoRouterFunctionDsl].
...@@ -532,7 +536,12 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -532,7 +536,12 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
builder.filter { serverRequest, handlerFunction -> builder.filter { serverRequest, handlerFunction ->
mono(Dispatchers.Unconfined) { mono(Dispatchers.Unconfined) {
filterFunction(serverRequest) { handlerRequest -> filterFunction(serverRequest) { handlerRequest ->
handlerFunction.handle(handlerRequest).awaitSingle() if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
}
else {
handlerFunction.handle(handlerRequest).awaitSingle()
}
} }
} }
} }
...@@ -618,11 +627,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -618,11 +627,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
return builder.build() return builder.build()
} }
private fun asHandlerFunction(init: suspend (ServerRequest) -> ServerResponse) = HandlerFunction { private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) =
mono(Dispatchers.Unconfined) { CoroutineContextAwareHandlerFunction(handler)
init(it)
}
}
/** /**
* @see ServerResponse.from * @see ServerResponse.from
...@@ -691,6 +697,21 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -691,6 +697,21 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/ */
fun status(status: Int) = ServerResponse.status(status) fun status(status: Int) = ServerResponse.status(status)
private class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
private val handler: suspend (ServerRequest) -> T
) : HandlerFunction<T> {
override fun handle(request: ServerRequest): Mono<T> {
return handle(Dispatchers.Unconfined, request)
}
fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) {
handler(request)
}
}
} }
/** /**
......
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -16,17 +16,20 @@ ...@@ -16,17 +16,20 @@
package org.springframework.web.reactive.function.server package org.springframework.web.reactive.function.server
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.* import org.springframework.http.HttpHeaders.ACCEPT
import org.springframework.http.HttpMethod.* import org.springframework.http.HttpHeaders.CONTENT_TYPE
import org.springframework.http.HttpMethod.PATCH
import org.springframework.http.HttpStatus import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.* import org.springframework.http.MediaType.*
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.* import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
import org.springframework.web.testfixture.server.MockServerWebExchange import org.springframework.web.testfixture.server.MockServerWebExchange
import org.springframework.web.reactive.function.server.AttributesTestVisitor
import reactor.test.StepVerifier import reactor.test.StepVerifier
/** /**
...@@ -165,6 +168,17 @@ class CoRouterFunctionDslTests { ...@@ -165,6 +168,17 @@ class CoRouterFunctionDslTests {
.verifyComplete() .verifyComplete()
} }
@Test
fun filteringWithContext() {
val mockRequest = get("https://example.com/").build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(filterRouterWithContext.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("Filter context")
}
.verifyComplete()
}
@Test @Test
fun attributes() { fun attributes() {
val visitor = AttributesTestVisitor() val visitor = AttributesTestVisitor()
...@@ -226,6 +240,17 @@ class CoRouterFunctionDslTests { ...@@ -226,6 +240,17 @@ class CoRouterFunctionDslTests {
} }
} }
private val filterRouterWithContext = coRouter {
filter { request, next ->
withContext(CoroutineName("Filter context")) {
next(request)
}
}
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}
private val otherRouter = router { private val otherRouter = router {
"/other" { "/other" {
ok().build() ok().build()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册