提交 300b048a 编写于 作者: M Megvii Engine Team

fix(mge/dump): fix dump_with_testcase_mge Varnode type mismatch

GitOrigin-RevId: 05618e5ac5f8e3170698c2716f6ae2f5d11a7326
上级 f9ed8d71
......@@ -20,7 +20,6 @@ import megengine.core.tensor.megbrain_graph as G
from megengine import tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.core.tensor.megbrain_graph import VarNode
from megengine.utils import comp_graph_tools as cgtools
logger = mge.get_logger(__name__)
......@@ -268,8 +267,8 @@ def make_feeds(args):
def assert_equal(expect, real, **kwargs):
op = builtin.AssertEqual(**kwargs)
(res,) = G.apply_normal_varnode(op, expect, real)
return G.VarNode(res)
(res,) = apply(op, expect, real)
return res
verbose = not args.silent
......@@ -284,8 +283,8 @@ def make_feeds(args):
# insert assert opr to check expect and real.
outputs_new.append(
assert_equal(
G.VarNode(expect_get),
G.VarNode(i),
expect_get,
i,
verbose=verbose,
maxerr=args.maxerr,
)
......@@ -297,29 +296,26 @@ def make_feeds(args):
def optimize_for_inference(args, outputs):
args_map = {
"enable_io16xc32": "f16_io_f32_comp",
"enable_ioc16": "f16_io_comp",
"enable_hwcd4": "use_nhwcd4",
"enable_nchw4": "use_nchw4",
"enable_nchw88": "use_nchw88",
"enable_nchw44": "use_nchw44",
"enable_nchw44_dot": "use_nchw44_dot",
"enable_nchw32": "use_nchw32",
"enable_chwn4": "use_chwn4",
"enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity",
"enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z",
}
args_list = [
"enable_io16xc32",
"enable_ioc16",
"enable_hwcd4",
"enable_nchw4",
"enable_nchw88",
"enable_nchw44",
"enable_nchw44_dot",
"enable_nchw32",
"enable_chwn4",
"enable_fuse_conv_bias_nonlinearity",
"enable_fuse_conv_bias_with_z",
]
kwargs = {}
for k, v in args_map.items():
for k in args_list:
if getattr(args, k):
assert (
args.optimize_for_inference
), "optimize_for_inference should be set when {} is given".format(k)
kwargs[v] = True
kwargs[k] = True
if args.optimize_for_inference:
outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)]
outputs = G.optimize_for_inference(outputs, **kwargs)
return outputs
......@@ -476,7 +472,6 @@ def main():
output_mgbvars = feeds["outputs"]
output_mgbvars = optimize_for_inference(args, output_mgbvars)
output_mgbvars = [var._node for var in output_mgbvars]
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy")
inputs = sorted((i.name, i.dtype) for i in inputs)
......@@ -491,12 +486,8 @@ def main():
with open(args.output, "wb") as fout:
fout.write(b"mgbtest0")
fout.write(struct.pack("I", len(feeds["testcases"])))
if isinstance(output_mgbvars, dict):
wrap_output_vars = dict([(i, VarNode(j)) for i, j in output_mgbvars])
else:
wrap_output_vars = [VarNode(i) for i in output_mgbvars]
dump_content, stat = G.dump_graph(
wrap_output_vars,
output_mgbvars,
append_json=True,
strip_info_file=strip_info_file,
**sereg_kwargs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册