diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index acd4afa0dfad1133708b5bb934bfbdc0acb1d2bb..9350b1bdaddb5b1315dc9040c657a9cb8c8ba59f 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -8,7 +8,8 @@ import numpy as np -from .. import tensor +from megengine.functional.tensor import zeros + from ..core.ops.builtin import BatchNorm from .expr import CallMethod, Constant from .node import TensorNode @@ -135,3 +136,29 @@ def bn_opdef_loader(expr): output = expr.outputs[-1] oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,) expr.outputs.insert(4, oup) + + +@register_functional_loader( + ("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros") +) +def tensor_gen_func_loader(expr): + if hasattr(expr, "version") and expr.version == "1.7.0": + expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2]) + if not hasattr(expr, "version"): + # compatiable for version 1.6 + shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"] + + if len(expr.args) > 1: + dtype = expr.args[1] + elif "dtype" in expr.kwargs: + dtype = expr.kwargs["dtype"] + else: + dtype = "float32" + + if len(expr.args) > 2: + device = expr.args[2] + elif "device" in expr.kwargs: + device = expr.kwargs["device"] + else: + device = None + expr.set_args_kwargs(shape, dtype=dtype, device=device)