From b230e14685949894a795df591a655c82d5a8c366 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 11 May 2021 19:37:25 +0800 Subject: [PATCH] fix(mge/dump): fix dump_with_testcase_mge Varnode type mismatch GitOrigin-RevId: 05618e5ac5f8e3170698c2716f6ae2f5d11a7326 --- sdk/load-and-run/dump_with_testcase_mge.py | 51 +++++++++------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 9aafe3f56..36c2bde90 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -20,7 +20,6 @@ import megengine.core.tensor.megbrain_graph as G from megengine import tensor from megengine.core._imperative_rt.core2 import apply from megengine.core.ops import builtin -from megengine.core.tensor.megbrain_graph import VarNode from megengine.utils import comp_graph_tools as cgtools logger = mge.get_logger(__name__) @@ -268,8 +267,8 @@ def make_feeds(args): def assert_equal(expect, real, **kwargs): op = builtin.AssertEqual(**kwargs) - (res,) = G.apply_normal_varnode(op, expect, real) - return G.VarNode(res) + (res,) = apply(op, expect, real) + return res verbose = not args.silent @@ -284,8 +283,8 @@ def make_feeds(args): # insert assert opr to check expect and real. outputs_new.append( assert_equal( - G.VarNode(expect_get), - G.VarNode(i), + expect_get, + i, verbose=verbose, maxerr=args.maxerr, ) @@ -297,29 +296,26 @@ def make_feeds(args): def optimize_for_inference(args, outputs): - args_map = { - "enable_io16xc32": "f16_io_f32_comp", - "enable_ioc16": "f16_io_comp", - "enable_hwcd4": "use_nhwcd4", - "enable_nchw4": "use_nchw4", - "enable_nchw88": "use_nchw88", - "enable_nchw44": "use_nchw44", - "enable_nchw44_dot": "use_nchw44_dot", - "enable_nchw32": "use_nchw32", - "enable_chwn4": "use_chwn4", - "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity", - "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z", - } + args_list = [ + "enable_io16xc32", + "enable_ioc16", + "enable_hwcd4", + "enable_nchw4", + "enable_nchw88", + "enable_nchw44", + "enable_nchw44_dot", + "enable_nchw32", + "enable_chwn4", + "enable_fuse_conv_bias_nonlinearity", + "enable_fuse_conv_bias_with_z", + ] kwargs = {} - for k, v in args_map.items(): + for k in args_list: if getattr(args, k): - assert ( - args.optimize_for_inference - ), "optimize_for_inference should be set when {} is given".format(k) - kwargs[v] = True + kwargs[k] = True if args.optimize_for_inference: - outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)] + outputs = G.optimize_for_inference(outputs, **kwargs) return outputs @@ -476,7 +472,6 @@ def main(): output_mgbvars = feeds["outputs"] output_mgbvars = optimize_for_inference(args, output_mgbvars) - output_mgbvars = [var._node for var in output_mgbvars] inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") inputs = sorted((i.name, i.dtype) for i in inputs) @@ -491,12 +486,8 @@ def main(): with open(args.output, "wb") as fout: fout.write(b"mgbtest0") fout.write(struct.pack("I", len(feeds["testcases"]))) - if isinstance(output_mgbvars, dict): - wrap_output_vars = dict([(i, VarNode(j)) for i, j in output_mgbvars]) - else: - wrap_output_vars = [VarNode(i) for i in output_mgbvars] dump_content, stat = G.dump_graph( - wrap_output_vars, + output_mgbvars, append_json=True, strip_info_file=strip_info_file, **sereg_kwargs, -- GitLab