From 2c48dc22ecbe8954854ad5a3c103b7df02e049e3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 13 Feb 2023 13:30:00 +0800 Subject: [PATCH] fix(mge/traced_module): fix function compatible issue GitOrigin-RevId: 31d28dd2f336a8c3940c482d0a5c576d459ca683 --- .../python/megengine/traced_module/compat.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index 757b0d4ee..7e0e955ce 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -11,6 +11,7 @@ from .serialization import ( register_opdef_loader, register_tensor_method_loader, ) +from .utils import _convert_kwargs_to_args """ @@ -224,3 +225,68 @@ def square_func_loader(expr): astype_expr.set_args_kwargs(oup, "float32") orig_oup.expr = astype_expr astype_expr.return_val = (orig_oup,) + + +@register_functional_loader(("megengine.functional.math", "topk")) +def topk_loader(expr): + if not hasattr(expr, "version"): # for mge 1.6 + + def origin_topk_signature( + inp, k, descending=False, kth_only=False, no_sort=False + ): + pass + + args, kwargs = _convert_kwargs_to_args( + origin_topk_signature, expr.args, expr.kwargs + ) + expr.set_args_kwargs(*args, **kwargs) + + +@register_functional_loader(("megengine.functional.tensor", "arange")) +def arange_func_loader(expr): + kwargs = expr.kwargs + args = expr.args + if not hasattr(expr, "version"): + + def origin_arange_signature( + start=0, stop=None, step=1, dtype="float32", device=None + ): + pass + + args, kwargs = _convert_kwargs_to_args( + origin_arange_signature, expr.args, expr.kwargs + ) + expr.set_args_kwargs(*args, **kwargs) + if len(args) == 5: + device = args[-1] + dtype = args[-2] + args = args[: len(args) - 2] + + kwargs["dtype"] = dtype + kwargs["device"] = device + + expr.set_args_kwargs(*args, **kwargs) + + +@register_functional_loader(("megengine.functional.tensor", "full")) +def full_func_loader(expr): + kwargs = expr.kwargs + args = expr.args + if not hasattr(expr, "version"): + + def orig_full_signature(shape, value, dtype=None, device=None): + pass + + args, kwargs = _convert_kwargs_to_args( + orig_full_signature, expr.args, expr.kwargs + ) + expr.set_args_kwargs(*args, **kwargs) + if len(args) == 4: + device = args[-1] + dtype = args[-2] + args = args[: len(args) - 2] + + kwargs["dtype"] = dtype + kwargs["device"] = device + + expr.set_args_kwargs(*args, **kwargs) -- GitLab