提交 e6dcfbe8 编写于 作者: M Megvii Engine Team

fix(traced_module): fix traced module compatible issues

GitOrigin-RevId: 67e68ef5eae78d93a167d8d32ac78837932f3b45
上级 18f83a25
......@@ -99,11 +99,10 @@ def add_loader(expr):
("megengine.module.batchnorm", "SyncBatchNorm"),
)
def bn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module, "param_dim"):
module.param_dim = "dim_1c11"
module = expr.inputs[0].owner
if hasattr(module, "param_dim"):
assert module.param_dim == "dim_1c11"
delattr(module, "param_dim")
@register_module_loader(
......@@ -113,12 +112,10 @@ def bn2d_module_loader(expr):
("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
)
def convbn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module.bn, "param_dim"):
module.bn.param_dim = "dim_1c11"
module = expr.inputs[0].owner
if hasattr(module.bn, "param_dim"):
assert module.bn.param_dim == "dim_1c11"
delattr(module.bn, "param_dim")
if not hasattr(module.conv, "padding_mode"):
module.conv.padding_mode = "zeros"
......@@ -167,6 +164,26 @@ def pad_func_loader(expr):
expr.set_args_kwargs(*expr.args, **kwargs)
@register_functional_loader(("megengine.functional.nn", "batch_norm"))
def bn_func_loader(expr):
kwargs = expr.kwargs
if "compute_mode" in kwargs:
assert kwargs["compute_mode"] == "default"
kwargs.pop("compute_mode")
if "param_dim" in kwargs:
assert kwargs["param_dim"] == "dim_1c11"
kwargs.pop("param_dim")
expr.set_args_kwargs(*expr.args, **kwargs)
@register_functional_loader(("megengine.functional.math", "matmul"))
def matmul_func_loader(expr):
args = expr.args
if len(args) == 6:
assert args[5] == "default"
expr.set_args_kwargs(*args[0:5])
@register_module_loader(
("megengine.module.conv", "Conv1d"),
("megengine.module.conv", "Conv2d"),
......
......@@ -17,11 +17,12 @@ def deprecated_func(version, origin, name, tbd):
tbd: to be discussed, if true, ignore warnings
"""
should_warning = not tbd
module = importlib.import_module(origin)
func = module.__getattribute__(name)
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal should_warning
module = importlib.import_module(origin)
func = module.__getattribute__(name)
if should_warning:
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册