From e474994f466fe47d44a794e59ffd9e1d912944f7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 2 Apr 2021 16:59:58 +0800 Subject: [PATCH] feat(imperative/jit): catch input tensors name when tracing GitOrigin-RevId: 9c692548663654265f9f9e2753f8637d444cb78d --- imperative/python/megengine/jit/tracing.py | 6 ++++-- .../test/unit/utils/test_dump_naming.py | 21 +++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 669a659f0..08631bc2d 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -772,7 +772,8 @@ class trace: len(self._output_bindings) ) ) - if arg_names is None: + without_arg_names = arg_names is None + if without_arg_names: arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))] if arg_names and not isinstance(arg_names, collections.abc.Sequence): arg_names = (arg_names,) @@ -802,7 +803,7 @@ class trace: dtype=info.dtype, device=dumped_device(info), shape=info.shape or (1,), - name=arg_names[i] if arg_names else None, + name=info.name if without_arg_names and info.name else arg_names[i], ) for k, h in self._kwarg_bindings.items(): info = self._tinfo[h] @@ -889,6 +890,7 @@ class trace: return h, info = self._new_handle() info.external = False + info.name = x.c_name info.device = x.device info.dtype = x.dtype info.shape = x.numpy().shape diff --git a/imperative/python/test/unit/utils/test_dump_naming.py b/imperative/python/test/unit/utils/test_dump_naming.py index 44845a0b2..a546c338a 100644 --- a/imperative/python/test/unit/utils/test_dump_naming.py +++ b/imperative/python/test/unit/utils/test_dump_naming.py @@ -203,14 +203,31 @@ def test_with_same_operators(symbolic): assert ops[-2].name == "simple.RELU[0]" -def test_not_keep_opr_name(): +@pytest.mark.parametrize("symbolic", [False, True]) +def test_not_keep_opr_name(symbolic): def f(x): return 2 * x - op = _dump_and_load(f, True, False)[-1] + op = _dump_and_load(f, symbolic, False)[-1] assert op.name == "MUL(x,const<2>[2])[4]" +@pytest.mark.parametrize("tensor_name, var_name", [("data", "data"), (None, "arg_0")]) +def test_catch_input_name(tensor_name, var_name): + def f(x): + return 2 * x + + func = trace(f, symbolic=True, capture_as_const=True) + x = Tensor(np.ones(shape=(2, 3)), name=tensor_name) + func(x).numpy() + file = io.BytesIO() + func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) + file.seek(0) + *_, outputs = G.load_graph(file) + op = cgtools.get_oprs_seq(outputs)[-1] + assert op.inputs[0].name == var_name + + @pytest.mark.parametrize("symbolic", [False, True]) def test_quantized_module_auto_naming(symbolic): class Simple(M.Module): -- GitLab