提交 c0b9e071 编写于 作者: M Megvii Engine Team

fix(traced_module): fix compatible test and fix functional compatiblity

GitOrigin-RevId: 5824d232b36199d1b65d779fd442bb039c4ede6a
上级 2c48dc22
......@@ -229,7 +229,11 @@ def square_func_loader(expr):
@register_functional_loader(("megengine.functional.math", "topk"))
def topk_loader(expr):
if not hasattr(expr, "version"): # for mge 1.6
import pkg_resources as pkg
if not hasattr(expr, "version") or pkg.parse_version(
expr.version
) <= pkg.parse_version("1.12.0"):
def origin_topk_signature(
inp, k, descending=False, kth_only=False, no_sort=False
......@@ -260,7 +264,7 @@ def arange_func_loader(expr):
if len(args) == 5:
device = args[-1]
dtype = args[-2]
args = args[: len(args) - 2]
args = args[:-2]
kwargs["dtype"] = dtype
kwargs["device"] = device
......@@ -268,6 +272,24 @@ def arange_func_loader(expr):
expr.set_args_kwargs(*args, **kwargs)
@register_functional_loader(("megengine.functional.tensor", "linspace"))
def linespace_loader(expr):
args, kwargs = expr.args, expr.kwargs
if not hasattr(expr, "version"):
def orig_linspace_signature(start, stop, num, dtype="float32", device=None):
pass
args, kwargs = _convert_kwargs_to_args(
orig_linspace_signature, expr.args, expr.kwargs
)
expr.set_args_kwargs(*args, **kwargs)
if len(args) == 5:
new_args = args[0:-2]
new_kwargs = {"dtype": args[-2], "device": args[-1]}
expr.set_args_kwargs(*new_args, **new_kwargs)
@register_functional_loader(("megengine.functional.tensor", "full"))
def full_func_loader(expr):
kwargs = expr.kwargs
......@@ -290,3 +312,80 @@ def full_func_loader(expr):
kwargs["device"] = device
expr.set_args_kwargs(*args, **kwargs)
@register_functional_loader(("megengine.functional.nn", "conv_transpose2d"))
def deconv_loader(expr):
args, kwargs = expr.args, expr.kwargs
if not hasattr(expr, "version"):
def orig_conv_transpose2d_signature(
inp,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
conv_mode="cross_correlation",
compute_mode="default",
):
pass
args, kwargs = _convert_kwargs_to_args(
orig_conv_transpose2d_signature, expr.args, expr.kwargs
)
expr.set_args_kwargs(*args, **kwargs)
if len(args) == 9:
args = list(args)
args.insert(4, 0) # output padding = 0
expr.set_args_kwargs(*args, **kwargs)
@register_functional_loader(("megengine.functional.quantized", "conv_transpose2d"))
def deconv_loader(expr):
args, kwargs = expr.args, expr.kwargs
if not hasattr(expr, "version"):
def orig_conv_transpose2d_signature(
inp,
weight,
bias=None,
dtype=None,
stride=1,
padding=0,
dilation=1,
groups=1,
conv_mode="cross_correlation",
compute_mode="default",
):
pass
args, kwargs = _convert_kwargs_to_args(
orig_conv_transpose2d_signature, expr.args, expr.kwargs
)
expr.set_args_kwargs(*args, **kwargs)
if len(args) == 10:
args = list(args)
args.insert(5, 0) # output padding = 0
expr.set_args_kwargs(*args, **kwargs)
@register_functional_loader(("megengine.functional.nn", "conv_transpose3d"))
def deconv3d_loader(expr):
args, kwargs = expr.args, expr.kwargs
if not hasattr(expr, "version"):
def origin_conv_transpose3d_signature(
inp, weight, bias=None, stride=1, padding=0, dilation=1, groups=1,
):
pass
args, kwargs = _convert_kwargs_to_args(
origin_conv_transpose3d_signature, expr.args, expr.kwargs
)
expr.set_args_kwargs(*args, **kwargs)
if len(args) == 7:
args = list(args)
args.insert(4, 0)
expr.set_args_kwargs(*args, **kwargs)
......@@ -328,6 +328,12 @@ class LeafDef(TreeDef):
assert isinstance(leaves[0], self.type), self.type
return leaves[0]
def __setstate__(self, state):
for k, v in state.items():
setattr(self, k, v)
if hasattr(self, "const_val") and isinstance(self.const_val, np.dtype):
self.type = _leaf_type(self.const_val)
def __ne__(self, other) -> bool:
return not self.__eq__(other)
......
......@@ -56,6 +56,9 @@ def _convert_kwargs_to_args(
argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False
):
# is_bounded = True when func is a method and provided args don't include 'self'
if isinstance(argspecs, Callable) and hasattr(argspecs, "__wrapped__"):
argspecs = inspect.unwrap(argspecs)
arg_specs = (
inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册