提交 2c48dc22 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix function compatible issue

GitOrigin-RevId: 31d28dd2f336a8c3940c482d0a5c576d459ca683
上级 1a987b7b
...@@ -11,6 +11,7 @@ from .serialization import ( ...@@ -11,6 +11,7 @@ from .serialization import (
register_opdef_loader, register_opdef_loader,
register_tensor_method_loader, register_tensor_method_loader,
) )
from .utils import _convert_kwargs_to_args
""" """
...@@ -224,3 +225,68 @@ def square_func_loader(expr): ...@@ -224,3 +225,68 @@ def square_func_loader(expr):
astype_expr.set_args_kwargs(oup, "float32") astype_expr.set_args_kwargs(oup, "float32")
orig_oup.expr = astype_expr orig_oup.expr = astype_expr
astype_expr.return_val = (orig_oup,) astype_expr.return_val = (orig_oup,)
@register_functional_loader(("megengine.functional.math", "topk"))
def topk_loader(expr):
if not hasattr(expr, "version"): # for mge 1.6
def origin_topk_signature(
inp, k, descending=False, kth_only=False, no_sort=False
):
pass
args, kwargs = _convert_kwargs_to_args(
origin_topk_signature, expr.args, expr.kwargs
)
expr.set_args_kwargs(*args, **kwargs)
@register_functional_loader(("megengine.functional.tensor", "arange"))
def arange_func_loader(expr):
kwargs = expr.kwargs
args = expr.args
if not hasattr(expr, "version"):
def origin_arange_signature(
start=0, stop=None, step=1, dtype="float32", device=None
):
pass
args, kwargs = _convert_kwargs_to_args(
origin_arange_signature, expr.args, expr.kwargs
)
expr.set_args_kwargs(*args, **kwargs)
if len(args) == 5:
device = args[-1]
dtype = args[-2]
args = args[: len(args) - 2]
kwargs["dtype"] = dtype
kwargs["device"] = device
expr.set_args_kwargs(*args, **kwargs)
@register_functional_loader(("megengine.functional.tensor", "full"))
def full_func_loader(expr):
kwargs = expr.kwargs
args = expr.args
if not hasattr(expr, "version"):
def orig_full_signature(shape, value, dtype=None, device=None):
pass
args, kwargs = _convert_kwargs_to_args(
orig_full_signature, expr.args, expr.kwargs
)
expr.set_args_kwargs(*args, **kwargs)
if len(args) == 4:
device = args[-1]
dtype = args[-2]
args = args[: len(args) - 2]
kwargs["dtype"] = dtype
kwargs["device"] = device
expr.set_args_kwargs(*args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册