提交 7ecdbf25 编写于 作者: M Megvii Engine Team

fix(traced_module): fix ones/zeros functional compatiable

GitOrigin-RevId: 7ec2c4d3f5e0233659c63568c9bd40adf4958002
上级 d24f198c
......@@ -8,7 +8,8 @@
import numpy as np
from .. import tensor
from megengine.functional.tensor import zeros
from ..core.ops.builtin import BatchNorm
from .expr import CallMethod, Constant
from .node import TensorNode
......@@ -135,3 +136,29 @@ def bn_opdef_loader(expr):
output = expr.outputs[-1]
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
expr.outputs.insert(4, oup)
@register_functional_loader(
("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros")
)
def tensor_gen_func_loader(expr):
if hasattr(expr, "version") and expr.version == "1.7.0":
expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2])
if not hasattr(expr, "version"):
# compatiable for version 1.6
shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"]
if len(expr.args) > 1:
dtype = expr.args[1]
elif "dtype" in expr.kwargs:
dtype = expr.kwargs["dtype"]
else:
dtype = "float32"
if len(expr.args) > 2:
device = expr.args[2]
elif "device" in expr.kwargs:
device = expr.kwargs["device"]
else:
device = None
expr.set_args_kwargs(shape, dtype=dtype, device=device)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册