提交 d23d1352 编写于 作者: M Megvii Engine Team

fix(imperative/python): add the default warning for args descending

GitOrigin-RevId: cb5f065e6ca7e3d18f39e95966316d0a2110d499
上级 5e80f021
......@@ -22,6 +22,7 @@ from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
from ..jit import exclude_from_trace
from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default
from .debug_param import get_execution_strategy
from .elemwise import clip, minimum
from .tensor import broadcast_to, concat, expand_dims, squeeze
......@@ -684,6 +685,7 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
return tns, ind
@deprecated_kwargs_default("1.12", "descending", 3)
def topk(
inp: Tensor,
k: int,
......@@ -712,7 +714,7 @@ def topk(
import megengine.functional as F
x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
top, indices = F.topk(x, 5)
top, indices = F.topk(x, 5, descending=False)
print(top.numpy(), indices.numpy())
Outputs:
......
......@@ -7,9 +7,12 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import importlib
import warnings
from functools import wraps
from deprecated.sphinx import deprecated
warnings.filterwarnings(action="default", module="megengine")
def deprecated_func(version, origin, name, tbd):
r"""
......@@ -27,16 +30,39 @@ def deprecated_func(version, origin, name, tbd):
module = importlib.import_module(origin)
func = module.__getattribute__(name)
if should_warning:
with warnings.catch_warnings():
warnings.simplefilter(action="always")
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
name, origin, name, version
),
category=DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return wrapper
def deprecated_kwargs_default(version, kwargs_name, kwargs_pos):
r"""
Args:
version: version to deprecate this default
kwargs_name: kwargs name
kwargs_pos: kwargs position
"""
def deprecated(func):
@wraps(func)
def wrapper(*args, **kwargs):
if len(args) < kwargs_pos and kwargs_name not in kwargs:
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
name, origin, name, version
"the default behavior for {} will be changed in version {}, please use it in keyword parameter way".format(
kwargs_name, version
),
category=DeprecationWarning,
category=PendingDeprecationWarning,
stacklevel=2,
)
should_warning = False
return func(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
return wrapper
return deprecated
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册