diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index 1cbf87d842f5a0464d4cd3d1a851ea2b7e6a0e77..3941b470ce59353a1ff49baa05614d72ae8bc24c 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -99,11 +99,10 @@ def add_loader(expr): ("megengine.module.batchnorm", "SyncBatchNorm"), ) def bn2d_module_loader(expr): - # mge 1.6 - if not hasattr(expr, "version"): - module = expr.inputs[0].owner - if not hasattr(module, "param_dim"): - module.param_dim = "dim_1c11" + module = expr.inputs[0].owner + if hasattr(module, "param_dim"): + assert module.param_dim == "dim_1c11" + delattr(module, "param_dim") @register_module_loader( @@ -113,12 +112,10 @@ def bn2d_module_loader(expr): ("megengine.module.qat.conv_bn", "ConvBnRelu2d"), ) def convbn2d_module_loader(expr): - # mge 1.6 - if not hasattr(expr, "version"): - module = expr.inputs[0].owner - if not hasattr(module.bn, "param_dim"): - module.bn.param_dim = "dim_1c11" module = expr.inputs[0].owner + if hasattr(module.bn, "param_dim"): + assert module.bn.param_dim == "dim_1c11" + delattr(module.bn, "param_dim") if not hasattr(module.conv, "padding_mode"): module.conv.padding_mode = "zeros" @@ -167,6 +164,26 @@ def pad_func_loader(expr): expr.set_args_kwargs(*expr.args, **kwargs) +@register_functional_loader(("megengine.functional.nn", "batch_norm")) +def bn_func_loader(expr): + kwargs = expr.kwargs + if "compute_mode" in kwargs: + assert kwargs["compute_mode"] == "default" + kwargs.pop("compute_mode") + if "param_dim" in kwargs: + assert kwargs["param_dim"] == "dim_1c11" + kwargs.pop("param_dim") + expr.set_args_kwargs(*expr.args, **kwargs) + + +@register_functional_loader(("megengine.functional.math", "matmul")) +def matmul_func_loader(expr): + args = expr.args + if len(args) == 6: + assert args[5] == "default" + expr.set_args_kwargs(*args[0:5]) + + @register_module_loader( ("megengine.module.conv", "Conv1d"), ("megengine.module.conv", "Conv2d"), diff --git a/imperative/python/megengine/utils/deprecation.py b/imperative/python/megengine/utils/deprecation.py index ea58d71aaf2f016c82f3f1f604a764ada405c0e8..f9ed8bdf5660969acecf5285c7cf50949f4dd736 100644 --- a/imperative/python/megengine/utils/deprecation.py +++ b/imperative/python/megengine/utils/deprecation.py @@ -17,11 +17,12 @@ def deprecated_func(version, origin, name, tbd): tbd: to be discussed, if true, ignore warnings """ should_warning = not tbd + module = importlib.import_module(origin) + func = module.__getattribute__(name) + @wraps(func) def wrapper(*args, **kwargs): nonlocal should_warning - module = importlib.import_module(origin) - func = module.__getattribute__(name) if should_warning: warnings.warn( "Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(