未验证 提交 7f87f3dd 编写于 作者: P pgjones 提交者: David Lord

Simplify the async handling code

Firstly `run_sync` was a misleading name as it didn't run anything,
instead I think `async_to_sync` is much clearer as it converts a
coroutine function to a function. (Name stolen from asgiref).

Secondly trying to run the ensure_sync during registration made the
code more complex and brittle, e.g. the _flask_async_wrapper
usage. This was done to pay any setup costs during registration rather
than runtime, however this only saved a iscoroutne check. It allows
the weirdness of the Blueprint and Scaffold ensure_sync methods to be
removed.

Switching to runtime ensure_sync usage provides a method for
extensions to also support async, as now documented.
上级 cb13128c
......@@ -92,6 +92,21 @@ not work with async views because they will not await the function or be
awaitable. Other functions they provide will not be awaitable either and
will probably be blocking if called within an async view.
Extension authors can support async functions by utilising the
:meth:`flask.Flask.ensure_sync` method. For example, if the extension
provides a view function decorator add ``ensure_sync`` before calling
the decorated function,
.. code-block:: python
def extension(func):
@wraps(func)
def wrapper(*args, **kwargs):
... # Extension logic
return current_app.ensure_sync(func)(*args, **kwargs)
return wrapper
Check the changelog of the extension you want to use to see if they've
implemented async support, or make a feature request or PR to them.
......
......@@ -35,12 +35,12 @@ from .globals import _request_ctx_stack
from .globals import g
from .globals import request
from .globals import session
from .helpers import async_to_sync
from .helpers import get_debug_flag
from .helpers import get_env
from .helpers import get_flashed_messages
from .helpers import get_load_dotenv
from .helpers import locked_cached_property
from .helpers import run_async
from .helpers import url_for
from .json import jsonify
from .logging import create_logger
......@@ -1080,14 +1080,12 @@ class Flask(Scaffold):
self.url_map.add(rule)
if view_func is not None:
old_func = self.view_functions.get(endpoint)
if getattr(old_func, "_flask_sync_wrapper", False):
old_func = old_func.__wrapped__ # type: ignore
if old_func is not None and old_func != view_func:
raise AssertionError(
"View function mapping is overwriting an existing"
f" endpoint function: {endpoint}"
)
self.view_functions[endpoint] = self.ensure_sync(view_func)
self.view_functions[endpoint] = view_func
@setupmethod
def template_filter(self, name: t.Optional[str] = None) -> t.Callable:
......@@ -1208,7 +1206,7 @@ class Flask(Scaffold):
.. versionadded:: 0.8
"""
self.before_first_request_funcs.append(self.ensure_sync(f))
self.before_first_request_funcs.append(f)
return f
@setupmethod
......@@ -1241,7 +1239,7 @@ class Flask(Scaffold):
.. versionadded:: 0.9
"""
self.teardown_appcontext_funcs.append(self.ensure_sync(f))
self.teardown_appcontext_funcs.append(f)
return f
@setupmethod
......@@ -1308,7 +1306,7 @@ class Flask(Scaffold):
handler = self._find_error_handler(e)
if handler is None:
return e
return handler(e)
return self.ensure_sync(handler)(e)
def trap_http_exception(self, e: Exception) -> bool:
"""Checks if an HTTP exception should be trapped or not. By default
......@@ -1375,7 +1373,7 @@ class Flask(Scaffold):
if handler is None:
raise
return handler(e)
return self.ensure_sync(handler)(e)
def handle_exception(self, e: Exception) -> Response:
"""Handle an exception that did not have an error handler
......@@ -1422,7 +1420,7 @@ class Flask(Scaffold):
handler = self._find_error_handler(server_error)
if handler is not None:
server_error = handler(server_error)
server_error = self.ensure_sync(handler)(server_error)
return self.finalize_request(server_error, from_error_handler=True)
......@@ -1484,7 +1482,7 @@ class Flask(Scaffold):
):
return self.make_default_options_response()
# otherwise dispatch to the handler for that endpoint
return self.view_functions[rule.endpoint](**req.view_args)
return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
def full_dispatch_request(self) -> Response:
"""Dispatches the request and on top of that performs request
......@@ -1545,7 +1543,7 @@ class Flask(Scaffold):
if self._got_first_request:
return
for func in self.before_first_request_funcs:
func()
self.ensure_sync(func)()
self._got_first_request = True
def make_default_options_response(self) -> Response:
......@@ -1581,7 +1579,7 @@ class Flask(Scaffold):
.. versionadded:: 2.0
"""
if iscoroutinefunction(func):
return run_async(func)
return async_to_sync(func)
return func
......@@ -1807,7 +1805,7 @@ class Flask(Scaffold):
if bp in self.before_request_funcs:
funcs = chain(funcs, self.before_request_funcs[bp])
for func in funcs:
rv = func()
rv = self.ensure_sync(func)()
if rv is not None:
return rv
......@@ -1834,7 +1832,7 @@ class Flask(Scaffold):
if None in self.after_request_funcs:
funcs = chain(funcs, reversed(self.after_request_funcs[None]))
for handler in funcs:
response = handler(response)
response = self.ensure_sync(handler)(response)
if not self.session_interface.is_null_session(ctx.session):
self.session_interface.save_session(self, ctx.session, response)
return response
......@@ -1871,7 +1869,7 @@ class Flask(Scaffold):
if bp in self.teardown_request_funcs:
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
for func in funcs:
func(exc)
self.ensure_sync(func)(exc)
request_tearing_down.send(self, exc=exc)
def do_teardown_appcontext(
......@@ -1894,7 +1892,7 @@ class Flask(Scaffold):
if exc is _sentinel:
exc = sys.exc_info()[1]
for func in reversed(self.teardown_appcontext_funcs):
func(exc)
self.ensure_sync(func)(exc)
appcontext_tearing_down.send(self, exc=exc)
def app_context(self) -> AppContext:
......
......@@ -292,13 +292,10 @@ class Blueprint(Scaffold):
# Merge blueprint data into parent.
if first_registration:
def extend(bp_dict, parent_dict, ensure_sync=False):
def extend(bp_dict, parent_dict):
for key, values in bp_dict.items():
key = self.name if key is None else f"{self.name}.{key}"
if ensure_sync:
values = [app.ensure_sync(func) for func in values]
parent_dict[key].extend(values)
for key, value in self.error_handler_spec.items():
......@@ -307,8 +304,7 @@ class Blueprint(Scaffold):
dict,
{
code: {
exc_class: app.ensure_sync(func)
for exc_class, func in code_values.items()
exc_class: func for exc_class, func in code_values.items()
}
for code, code_values in value.items()
},
......@@ -316,16 +312,13 @@ class Blueprint(Scaffold):
app.error_handler_spec[key] = value
for endpoint, func in self.view_functions.items():
app.view_functions[endpoint] = app.ensure_sync(func)
app.view_functions[endpoint] = func
extend(
self.before_request_funcs, app.before_request_funcs, ensure_sync=True
)
extend(self.after_request_funcs, app.after_request_funcs, ensure_sync=True)
extend(self.before_request_funcs, app.before_request_funcs)
extend(self.after_request_funcs, app.after_request_funcs)
extend(
self.teardown_request_funcs,
app.teardown_request_funcs,
ensure_sync=True,
)
extend(self.url_default_functions, app.url_default_functions)
extend(self.url_value_preprocessors, app.url_value_preprocessors)
......@@ -478,9 +471,7 @@ class Blueprint(Scaffold):
before each request, even if outside of a blueprint.
"""
self.record_once(
lambda s: s.app.before_request_funcs.setdefault(None, []).append(
s.app.ensure_sync(f)
)
lambda s: s.app.before_request_funcs.setdefault(None, []).append(f)
)
return f
......@@ -490,9 +481,7 @@ class Blueprint(Scaffold):
"""Like :meth:`Flask.before_first_request`. Such a function is
executed before the first request to the application.
"""
self.record_once(
lambda s: s.app.before_first_request_funcs.append(s.app.ensure_sync(f))
)
self.record_once(lambda s: s.app.before_first_request_funcs.append(f))
return f
def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable:
......@@ -500,9 +489,7 @@ class Blueprint(Scaffold):
is executed after each request, even if outside of the blueprint.
"""
self.record_once(
lambda s: s.app.after_request_funcs.setdefault(None, []).append(
s.app.ensure_sync(f)
)
lambda s: s.app.after_request_funcs.setdefault(None, []).append(f)
)
return f
......@@ -553,14 +540,3 @@ class Blueprint(Scaffold):
lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
)
return f
def ensure_sync(self, f: t.Callable) -> t.Callable:
"""Ensure the function is synchronous.
Override if you would like custom async to sync behaviour in
this blueprint. Otherwise the app's
:meth:`~flask.Flask.ensure_sync` is used.
.. versionadded:: 2.0
"""
return f
......@@ -6,7 +6,6 @@ import typing as t
import warnings
from datetime import timedelta
from functools import update_wrapper
from functools import wraps
from threading import RLock
import werkzeug.utils
......@@ -803,10 +802,15 @@ def is_ip(value: str) -> bool:
return False
def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function *func*."""
def async_to_sync(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function *func*.
This can be used as so
result = async_to_async(func)(*args, **kwargs)
"""
try:
from asgiref.sync import async_to_sync
from asgiref.sync import async_to_sync as asgiref_async_to_sync
except ImportError:
raise RuntimeError(
"Install Flask with the 'async' extra in order to use async views."
......@@ -818,9 +822,4 @@ def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"Async cannot be used with this combination of Python & Greenlet versions."
)
@wraps(func)
def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any:
return async_to_sync(func)(*args, **kwargs)
wrapper._flask_sync_wrapper = True # type: ignore
return wrapper
return asgiref_async_to_sync(func)
......@@ -521,7 +521,7 @@ class Scaffold:
"""
def decorator(f):
self.view_functions[endpoint] = self.ensure_sync(f)
self.view_functions[endpoint] = f
return f
return decorator
......@@ -545,7 +545,7 @@ class Scaffold:
return value from the view, and further request handling is
stopped.
"""
self.before_request_funcs.setdefault(None, []).append(self.ensure_sync(f))
self.before_request_funcs.setdefault(None, []).append(f)
return f
@setupmethod
......@@ -561,7 +561,7 @@ class Scaffold:
should not be used for actions that must execute, such as to
close resources. Use :meth:`teardown_request` for that.
"""
self.after_request_funcs.setdefault(None, []).append(self.ensure_sync(f))
self.after_request_funcs.setdefault(None, []).append(f)
return f
@setupmethod
......@@ -600,7 +600,7 @@ class Scaffold:
debugger can still access it. This behavior can be controlled
by the ``PRESERVE_CONTEXT_ON_EXCEPTION`` configuration variable.
"""
self.teardown_request_funcs.setdefault(None, []).append(self.ensure_sync(f))
self.teardown_request_funcs.setdefault(None, []).append(f)
return f
@setupmethod
......@@ -706,7 +706,7 @@ class Scaffold:
" instead."
)
self.error_handler_spec[None][code][exc_class] = self.ensure_sync(f)
self.error_handler_spec[None][code][exc_class] = f
@staticmethod
def _get_exc_class_and_code(
......@@ -734,9 +734,6 @@ class Scaffold:
else:
return exc_class, None
def ensure_sync(self, func: t.Callable) -> t.Callable:
raise NotImplementedError()
def _endpoint_from_view_func(view_func: t.Callable) -> str:
"""Internal helper that returns the default endpoint for a given
......
......@@ -6,7 +6,7 @@ import pytest
from flask import Blueprint
from flask import Flask
from flask import request
from flask.helpers import run_async
from flask.helpers import async_to_sync
pytest.importorskip("asgiref")
......@@ -137,4 +137,4 @@ def test_async_before_after_request():
@pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7")
def test_async_runtime_error():
with pytest.raises(RuntimeError):
run_async(None)
async_to_sync(None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册