diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index 757b0d4ee5ee2bed8e3a2efccde525910bb76385..7e0e955ce599d8ad74468967836abf2d7f1c0c35 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -11,6 +11,7 @@ from .serialization import ( register_opdef_loader, register_tensor_method_loader, ) +from .utils import _convert_kwargs_to_args """ @@ -224,3 +225,68 @@ def square_func_loader(expr): astype_expr.set_args_kwargs(oup, "float32") orig_oup.expr = astype_expr astype_expr.return_val = (orig_oup,) + + +@register_functional_loader(("megengine.functional.math", "topk")) +def topk_loader(expr): + if not hasattr(expr, "version"): # for mge 1.6 + + def origin_topk_signature( + inp, k, descending=False, kth_only=False, no_sort=False + ): + pass + + args, kwargs = _convert_kwargs_to_args( + origin_topk_signature, expr.args, expr.kwargs + ) + expr.set_args_kwargs(*args, **kwargs) + + +@register_functional_loader(("megengine.functional.tensor", "arange")) +def arange_func_loader(expr): + kwargs = expr.kwargs + args = expr.args + if not hasattr(expr, "version"): + + def origin_arange_signature( + start=0, stop=None, step=1, dtype="float32", device=None + ): + pass + + args, kwargs = _convert_kwargs_to_args( + origin_arange_signature, expr.args, expr.kwargs + ) + expr.set_args_kwargs(*args, **kwargs) + if len(args) == 5: + device = args[-1] + dtype = args[-2] + args = args[: len(args) - 2] + + kwargs["dtype"] = dtype + kwargs["device"] = device + + expr.set_args_kwargs(*args, **kwargs) + + +@register_functional_loader(("megengine.functional.tensor", "full")) +def full_func_loader(expr): + kwargs = expr.kwargs + args = expr.args + if not hasattr(expr, "version"): + + def orig_full_signature(shape, value, dtype=None, device=None): + pass + + args, kwargs = _convert_kwargs_to_args( + orig_full_signature, expr.args, expr.kwargs + ) + expr.set_args_kwargs(*args, **kwargs) + if len(args) == 4: + device = args[-1] + dtype = args[-2] + args = args[: len(args) - 2] + + kwargs["dtype"] = dtype + kwargs["device"] = device + + expr.set_args_kwargs(*args, **kwargs)