From 6f1246ac6e45e69fb279f34eb85e851cc8d15e30 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 15 Mar 2023 13:51:08 +0800 Subject: [PATCH] fix(trace): fix name duplication and fix error message for invalid input GitOrigin-RevId: 7fe1605c2639ba2c67488e0bffd6d0e2fab73e6a --- imperative/python/megengine/jit/tracing.py | 3 ++ imperative/python/src/tensor.cpp | 1 + .../python/test/unit/jit/test_tracing.py | 34 +++++++++++++++++++ imperative/src/impl/transformations/trace.cpp | 9 +++-- 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 1d1fc5df7..28d8eeb77 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -216,6 +216,9 @@ class trace: def _process_inputs(self, *args, **kwargs): for i, arg in enumerate(args): + assert isinstance( + arg, RawTensor + ), "Only support tensor type args when capture_as_const is enabled" name_tensor("arg_{}".format(i), arg) # TODO: mark kwargs in order diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index ff7f99174..97785fec3 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -1229,6 +1229,7 @@ void init_tensor(py::module m) { m.def("name_tensor", [](std::string name, py::object tensor) { auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw, "Arg_1 shoud be Tensor!"); auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; tw->m_tensor->reset(output); }); diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index c12a27565..dc0d3f5f3 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -748,3 +748,37 @@ def test_trace_jit_config(): for fuse_dimshuffle in [None, False, True]: for fuse_reduce in [None, False, True]: run(fuse_dimshuffle, fuse_reduce) + + +def test_trace_naming(): + @trace(symbolic=True, capture_as_const=True) + def func(x): + return F.max(x, axis=2, keepdims=False) + 1 + + inp = tensor(np.random.random((1, 3, 3, 3))) + func(inp) + file = io.BytesIO() + func.dump(file, optimize_for_inference=False) + file.seek(0) + import megengine.utils.network as network + + net = network.Network.load(file) + names = set() + for var in net.all_vars: + assert var.name not in names + names.add(var.name) + + +def test_invalid_inp_error(): + @trace(capture_as_const=True) + def func(a): + return a * 2 + + try: + func(1) + except Exception as e: + assert ( + str(e) == "Only support tensor type args when capture_as_const is enabled" + ) + else: + assert False diff --git a/imperative/src/impl/transformations/trace.cpp b/imperative/src/impl/transformations/trace.cpp index d65d358f4..d1c7a56c3 100644 --- a/imperative/src/impl/transformations/trace.cpp +++ b/imperative/src/impl/transformations/trace.cpp @@ -98,8 +98,6 @@ VarNodeArray TraceResult::dump( "do model.eval()?"); } output_nodes = OpDef::apply_on_var_node(*op, input_nodes); - name2ops[output_nodes[0]->owner_opr()->name()].push_back( - output_nodes[0]->owner_opr()); } else { // no opr, just forward VarNode mgb_assert( @@ -121,6 +119,13 @@ VarNodeArray TraceResult::dump( } } } + auto on_opr = [&name2ops](cg::OperatorNodeBase* opr) { + name2ops[opr->name()].push_back(opr); + }; + cg::DepOprIter dep_iter(on_opr); + for (auto&& [output, name] : outputs) { + dep_iter.add(nodes[output]->owner_opr()); + } for (auto&& [name, ops] : name2ops) { if (ops.size() <= 1) { continue; -- GitLab