From 62753c4d30bd568343b552a9871c6b032d102512 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 3 Jun 2020 16:55:34 +0800 Subject: [PATCH] fix(mge/sdk): fix comp_node bug in dump_with_testcast_mge GitOrigin-RevId: 26a8dc50b8fb63e4cbec45785cb49aa74290f8d3 --- python_module/megengine/_internal/__init__.py | 1 + sdk/load-and-run/dump_with_testcase_mge.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index e95067473..add99da07 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -172,6 +172,7 @@ def make_arg( infer would be deferred to first graph execution :param enable_static_infer: whether to enable static inference for this var """ + comp_node = _detail.as_comp_node(comp_node) host_val = mgb._HostSharedND(comp_node, dtype) if value is not None: diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 5af630db6..cec42f85e 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -232,9 +232,8 @@ def make_feeds(args): outputs_new = [] for i in outputs: get = mgb.make_arg( - mge.core.graph.get_default_device(), + i.comp_node, cg, - shape=expect_shp(i), dtype=i.dtype, name=expect_name(i) ) @@ -463,7 +462,7 @@ def main(): for testcase in feeds['testcases']: assert isinstance(testcase, dict) cg = mgb.comp_graph() - cn = mgb.comp_node('cpux') + cn = mgb.comp_node('xpux') output_mgbvars = [] for name, dtype in inputs: output_mgbvars.append( -- GitLab