From 494cb36d09f54db7e4051527d24caf4b7dc73d8f Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 1 Jul 2020 19:23:24 +0800 Subject: [PATCH] Modify tmp var name prefix in dygraph (#25280) * Modify tmp var name prefix in dygraph test=develop * refine comment test=develop --- paddle/fluid/imperative/tests/test_tracer.cc | 5 +++++ paddle/fluid/imperative/tracer.h | 9 ++++++++- paddle/fluid/pybind/imperative.cc | 2 +- .../fluid/tests/unittests/test_unique_name.py | 15 +++++++++++++++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index 5852e60a48..3c3ec2e626 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -285,6 +285,11 @@ TEST(test_tracer, test_unique_name_generator) { auto fc_2 = tracer.GenerateUniqueName("fc"); ASSERT_STREQ("fc_0", fc_1.c_str()); ASSERT_STREQ("fc_1", fc_2.c_str()); + // use `eager_tmp` as key if not specify it. + auto tmp_var_2 = tracer.GenerateUniqueName(); + ASSERT_STREQ("eager_tmp_2", tmp_var_2.c_str()); + auto tmp_var_3 = tracer.GenerateUniqueName("eager_tmp"); + ASSERT_STREQ("eager_tmp_3", tmp_var_3.c_str()); } TEST(test_tracer, test_current_tracer) { diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 49aa39d2b0..1bcd97a929 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -76,7 +76,14 @@ class Tracer { return program_desc_tracer_.get(); } - std::string GenerateUniqueName(std::string key = "tmp") { + // Note(Aurelius84): The `tmp` is used as prefix key while naming a temporary + // intermediate var both in imperative and static mode. But the + // `UniqueNameGenerator` in C++ and `unique_name.py` in Python doesn't share + // the same auto-increment id. It will create a variable repeatedly with same + // name like `tmp_0` in some cases when transform dygraph into static layers. + // So we modify the default prefix key into `eager_tmp` to distinguish with + // static graph. + std::string GenerateUniqueName(std::string key = "eager_tmp") { return generator_->Generate(key); } diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index e2ff4161db..626f6b1ecc 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -873,7 +873,7 @@ void BindImperative(py::module *m_ptr) { &imperative::Tracer::GetProgramDescTracer, py::return_value_policy::reference) .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName, - py::arg("key") = "tmp") + py::arg("key") = "eager_tmp") .def("trace", [](imperative::Tracer &self, const std::string &type, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, diff --git a/python/paddle/fluid/tests/unittests/test_unique_name.py b/python/paddle/fluid/tests/unittests/test_unique_name.py index b8c751b2e9..8f116db855 100644 --- a/python/paddle/fluid/tests/unittests/test_unique_name.py +++ b/python/paddle/fluid/tests/unittests/test_unique_name.py @@ -43,3 +43,18 @@ class TestUniqueName(unittest.TestCase): name3 = fluid.unique_name.generate('tmp') self.assertNotEqual(name1, name2) self.assertEqual(name1[-2:], name3[-2:]) + + +class TestImperativeUniqueName(unittest.TestCase): + def test_name_generator(self): + with fluid.dygraph.guard(): + tracer = fluid.framework._dygraph_tracer() + tmp_var_0 = tracer._generate_unique_name() + self.assertEqual(tmp_var_0, "eager_tmp_0") + + tmp_var_1 = tracer._generate_unique_name("eager_tmp") + self.assertEqual(tmp_var_1, "eager_tmp_1") + + +if __name__ == '__main__': + unittest.main() -- GitLab