提交 e6dcfbe8 编写于 作者: M Megvii Engine Team

fix(traced_module): fix traced module compatible issues

GitOrigin-RevId: 67e68ef5eae78d93a167d8d32ac78837932f3b45
上级 18f83a25
...@@ -99,11 +99,10 @@ def add_loader(expr): ...@@ -99,11 +99,10 @@ def add_loader(expr):
("megengine.module.batchnorm", "SyncBatchNorm"), ("megengine.module.batchnorm", "SyncBatchNorm"),
) )
def bn2d_module_loader(expr): def bn2d_module_loader(expr):
# mge 1.6 module = expr.inputs[0].owner
if not hasattr(expr, "version"): if hasattr(module, "param_dim"):
module = expr.inputs[0].owner assert module.param_dim == "dim_1c11"
if not hasattr(module, "param_dim"): delattr(module, "param_dim")
module.param_dim = "dim_1c11"
@register_module_loader( @register_module_loader(
...@@ -113,12 +112,10 @@ def bn2d_module_loader(expr): ...@@ -113,12 +112,10 @@ def bn2d_module_loader(expr):
("megengine.module.qat.conv_bn", "ConvBnRelu2d"), ("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
) )
def convbn2d_module_loader(expr): def convbn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module.bn, "param_dim"):
module.bn.param_dim = "dim_1c11"
module = expr.inputs[0].owner module = expr.inputs[0].owner
if hasattr(module.bn, "param_dim"):
assert module.bn.param_dim == "dim_1c11"
delattr(module.bn, "param_dim")
if not hasattr(module.conv, "padding_mode"): if not hasattr(module.conv, "padding_mode"):
module.conv.padding_mode = "zeros" module.conv.padding_mode = "zeros"
...@@ -167,6 +164,26 @@ def pad_func_loader(expr): ...@@ -167,6 +164,26 @@ def pad_func_loader(expr):
expr.set_args_kwargs(*expr.args, **kwargs) expr.set_args_kwargs(*expr.args, **kwargs)
@register_functional_loader(("megengine.functional.nn", "batch_norm"))
def bn_func_loader(expr):
kwargs = expr.kwargs
if "compute_mode" in kwargs:
assert kwargs["compute_mode"] == "default"
kwargs.pop("compute_mode")
if "param_dim" in kwargs:
assert kwargs["param_dim"] == "dim_1c11"
kwargs.pop("param_dim")
expr.set_args_kwargs(*expr.args, **kwargs)
@register_functional_loader(("megengine.functional.math", "matmul"))
def matmul_func_loader(expr):
args = expr.args
if len(args) == 6:
assert args[5] == "default"
expr.set_args_kwargs(*args[0:5])
@register_module_loader( @register_module_loader(
("megengine.module.conv", "Conv1d"), ("megengine.module.conv", "Conv1d"),
("megengine.module.conv", "Conv2d"), ("megengine.module.conv", "Conv2d"),
......
...@@ -17,11 +17,12 @@ def deprecated_func(version, origin, name, tbd): ...@@ -17,11 +17,12 @@ def deprecated_func(version, origin, name, tbd):
tbd: to be discussed, if true, ignore warnings tbd: to be discussed, if true, ignore warnings
""" """
should_warning = not tbd should_warning = not tbd
module = importlib.import_module(origin)
func = module.__getattribute__(name)
@wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
nonlocal should_warning nonlocal should_warning
module = importlib.import_module(origin)
func = module.__getattribute__(name)
if should_warning: if should_warning:
warnings.warn( warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format( "Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册