提交 38392233 编写于 作者: S Sébastien Deleuze

Add context function to CoRouterFunctionDsl

This new function allows to customize the CoroutineContext
potentially dynamically based on the incoming
ServerRequest.

Closes gh-27010
上级 64ff37f4
...@@ -21,6 +21,7 @@ import kotlinx.coroutines.Job ...@@ -21,6 +21,7 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingle import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.mono import kotlinx.coroutines.reactor.mono
import kotlinx.coroutines.withContext
import org.springframework.core.io.Resource import org.springframework.core.io.Resource
import org.springframework.http.HttpMethod import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatusCode import org.springframework.http.HttpStatusCode
...@@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
@PublishedApi @PublishedApi
internal val builder = RouterFunctions.route() internal val builder = RouterFunctions.route()
private var contextProvider: (suspend (ServerRequest) -> CoroutineContext)? = null
/** /**
* Return a composed request predicate that tests against both this predicate AND * Return a composed request predicate that tests against both this predicate AND
* the [other] predicate (String processed as a path predicate). When evaluating the * the [other] predicate (String processed as a path predicate). When evaluating the
...@@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/ */
fun resources(lookupFunction: suspend (ServerRequest) -> Resource?) { fun resources(lookupFunction: suspend (ServerRequest) -> Resource?) {
builder.resources { builder.resources {
mono(Dispatchers.Unconfined) { asMono(it, handler = lookupFunction)
lookupFunction.invoke(it)
}
} }
} }
...@@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/ */
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) { fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
builder.filter { serverRequest, handlerFunction -> builder.filter { serverRequest, handlerFunction ->
mono(Dispatchers.Unconfined) { asMono(serverRequest) {
filterFunction(serverRequest) { handlerRequest -> filterFunction(serverRequest) { handlerRequest ->
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) { if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle() handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
...@@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/ */
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) { fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
builder.onError(predicate) { throwable, request -> builder.onError(predicate) { throwable, request ->
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) } asMono(request) { responseProvider.invoke(throwable, request) }
} }
} }
...@@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/ */
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) { inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
builder.onError({it is E}) { throwable, request -> builder.onError({it is E}) { throwable, request ->
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) } asMono(request) { responseProvider.invoke(throwable, request) }
} }
} }
...@@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
builder.withAttributes(attributesConsumer) builder.withAttributes(attributesConsumer)
} }
/**
* Allow to provide the default [CoroutineContext], potentially dynamically based on
* the incoming [ServerRequest].
* @param provider the [CoroutineContext] provider
* @since 6.1.0
*/
fun context(provider: suspend (ServerRequest) -> CoroutineContext) {
if (this.contextProvider != null) {
throw IllegalStateException("The Coroutine context provider should be defined not more than once")
}
this.contextProvider = provider
}
/** /**
* Return a composed routing function created from all the registered routes. * Return a composed routing function created from all the registered routes.
*/ */
...@@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
return builder.build() return builder.build()
} }
private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) = @PublishedApi
CoroutineContextAwareHandlerFunction(handler) internal fun <T> asMono(request: ServerRequest, context: CoroutineContext = Dispatchers.Unconfined, handler: suspend (ServerRequest) -> T): Mono<T> {
return mono(context) {
contextProvider?.let {
withContext(it.invoke(request)) {
handler.invoke(request)
}
} ?: run {
handler.invoke(request)
}
}
}
private fun asHandlerFunction(handler: suspend (ServerRequest) -> ServerResponse) = CoroutineContextAwareHandlerFunction { request ->
handler.invoke(request)
}
/** /**
* @see ServerResponse.from * @see ServerResponse.from
...@@ -698,7 +726,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -698,7 +726,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
fun status(status: Int) = ServerResponse.status(status) fun status(status: Int) = ServerResponse.status(status)
private class CoroutineContextAwareHandlerFunction<T : ServerResponse>( private inner class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
private val handler: suspend (ServerRequest) -> T private val handler: suspend (ServerRequest) -> T
) : HandlerFunction<T> { ) : HandlerFunction<T> {
...@@ -706,7 +734,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ...@@ -706,7 +734,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
return handle(Dispatchers.Unconfined, request) return handle(Dispatchers.Unconfined, request)
} }
fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) { fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) {
handler(request) handler(request)
} }
......
...@@ -16,11 +16,8 @@ ...@@ -16,11 +16,8 @@
package org.springframework.web.reactive.function.server package org.springframework.web.reactive.function.server
import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.*
import kotlinx.coroutines.currentCoroutineContext import org.assertj.core.api.Assertions.*
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.ACCEPT import org.springframework.http.HttpHeaders.ACCEPT
...@@ -179,6 +176,48 @@ class CoRouterFunctionDslTests { ...@@ -179,6 +176,48 @@ class CoRouterFunctionDslTests {
.verifyComplete() .verifyComplete()
} }
@Test
fun contextProvider() {
val mockRequest = get("https://example.com/")
.header("Custom-Header", "foo")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("foo")
}
.verifyComplete()
}
@Test
fun contextProviderAndFilter() {
val mockRequest = get("https://example.com/")
.header("Custom-Header", "bar")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.let {
it.contains("bar") && it.contains("Dispatchers.Default")
}
}
.verifyComplete()
}
@Test
fun multipleContextProviders() {
assertThatIllegalStateException().isThrownBy {
coRouter {
context {
CoroutineName("foo")
}
context {
Dispatchers.Default
}
}
}
}
@Test @Test
fun attributes() { fun attributes() {
val visitor = AttributesTestVisitor() val visitor = AttributesTestVisitor()
...@@ -251,6 +290,25 @@ class CoRouterFunctionDslTests { ...@@ -251,6 +290,25 @@ class CoRouterFunctionDslTests {
} }
} }
private val routerWithContextProvider = coRouter {
context {
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
}
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
filter { request, next ->
if (request.headers().firstHeader("Custom-Header") == "bar") {
withContext(currentCoroutineContext() + Dispatchers.Default) {
next.invoke(request)
}
}
else {
next.invoke(request)
}
}
}
private val otherRouter = router { private val otherRouter = router {
"/other" { "/other" {
ok().build() ok().build()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册