From d23d1352e77342618da7d2fb4cc0c4903caf6c6a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 21 Dec 2021 13:57:31 +0800 Subject: [PATCH] fix(imperative/python): add the default warning for args descending GitOrigin-RevId: cb5f065e6ca7e3d18f39e95966316d0a2110d499 --- .../python/megengine/functional/math.py | 4 +- .../python/megengine/utils/deprecation.py | 42 +++++++++++++++---- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index facd8206e..9ddf15cd0 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -22,6 +22,7 @@ from ..core.tensor import amp from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph from ..jit import exclude_from_trace from ..tensor import Tensor +from ..utils.deprecation import deprecated_kwargs_default from .debug_param import get_execution_strategy from .elemwise import clip, minimum from .tensor import broadcast_to, concat, expand_dims, squeeze @@ -684,6 +685,7 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: return tns, ind +@deprecated_kwargs_default("1.12", "descending", 3) def topk( inp: Tensor, k: int, @@ -712,7 +714,7 @@ def topk( import megengine.functional as F x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) - top, indices = F.topk(x, 5) + top, indices = F.topk(x, 5, descending=False) print(top.numpy(), indices.numpy()) Outputs: diff --git a/imperative/python/megengine/utils/deprecation.py b/imperative/python/megengine/utils/deprecation.py index 42b510c2e..5e2da0bf4 100644 --- a/imperative/python/megengine/utils/deprecation.py +++ b/imperative/python/megengine/utils/deprecation.py @@ -7,9 +7,12 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import importlib import warnings +from functools import wraps from deprecated.sphinx import deprecated +warnings.filterwarnings(action="default", module="megengine") + def deprecated_func(version, origin, name, tbd): r""" @@ -27,16 +30,39 @@ def deprecated_func(version, origin, name, tbd): module = importlib.import_module(origin) func = module.__getattribute__(name) if should_warning: - with warnings.catch_warnings(): - warnings.simplefilter(action="always") + warnings.warn( + "Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format( + name, origin, name, version + ), + category=DeprecationWarning, + stacklevel=2, + ) + return func(*args, **kwargs) + + return wrapper + + +def deprecated_kwargs_default(version, kwargs_name, kwargs_pos): + r""" + Args: + version: version to deprecate this default + kwargs_name: kwargs name + kwargs_pos: kwargs position + """ + + def deprecated(func): + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) < kwargs_pos and kwargs_name not in kwargs: warnings.warn( - "Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format( - name, origin, name, version + "the default behavior for {} will be changed in version {}, please use it in keyword parameter way".format( + kwargs_name, version ), - category=DeprecationWarning, + category=PendingDeprecationWarning, stacklevel=2, ) - should_warning = False - return func(*args, **kwargs) + return func(*args, **kwargs) - return wrapper + return wrapper + + return deprecated -- GitLab