diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index e95067473529bb8d05cd78c756fbf5db02ca1be1..add99da07959efd4585f0d26301112547515389e 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -172,6 +172,7 @@ def make_arg( infer would be deferred to first graph execution :param enable_static_infer: whether to enable static inference for this var """ + comp_node = _detail.as_comp_node(comp_node) host_val = mgb._HostSharedND(comp_node, dtype) if value is not None: diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 5af630db6eee238914956141bce27c02713a90a5..cec42f85e0c4412acaab68708d4a7c63c816c67c 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -232,9 +232,8 @@ def make_feeds(args): outputs_new = [] for i in outputs: get = mgb.make_arg( - mge.core.graph.get_default_device(), + i.comp_node, cg, - shape=expect_shp(i), dtype=i.dtype, name=expect_name(i) ) @@ -463,7 +462,7 @@ def main(): for testcase in feeds['testcases']: assert isinstance(testcase, dict) cg = mgb.comp_graph() - cn = mgb.comp_node('cpux') + cn = mgb.comp_node('xpux') output_mgbvars = [] for name, dtype in inputs: output_mgbvars.append(