未验证 提交 494cb36d 编写于 作者: A Aurelius84 提交者: GitHub

Modify tmp var name prefix in dygraph (#25280)

* Modify tmp var name prefix in dygraph test=develop

* refine comment test=develop
上级 82ec247a
...@@ -285,6 +285,11 @@ TEST(test_tracer, test_unique_name_generator) { ...@@ -285,6 +285,11 @@ TEST(test_tracer, test_unique_name_generator) {
auto fc_2 = tracer.GenerateUniqueName("fc"); auto fc_2 = tracer.GenerateUniqueName("fc");
ASSERT_STREQ("fc_0", fc_1.c_str()); ASSERT_STREQ("fc_0", fc_1.c_str());
ASSERT_STREQ("fc_1", fc_2.c_str()); ASSERT_STREQ("fc_1", fc_2.c_str());
// use `eager_tmp` as key if not specify it.
auto tmp_var_2 = tracer.GenerateUniqueName();
ASSERT_STREQ("eager_tmp_2", tmp_var_2.c_str());
auto tmp_var_3 = tracer.GenerateUniqueName("eager_tmp");
ASSERT_STREQ("eager_tmp_3", tmp_var_3.c_str());
} }
TEST(test_tracer, test_current_tracer) { TEST(test_tracer, test_current_tracer) {
......
...@@ -76,7 +76,14 @@ class Tracer { ...@@ -76,7 +76,14 @@ class Tracer {
return program_desc_tracer_.get(); return program_desc_tracer_.get();
} }
std::string GenerateUniqueName(std::string key = "tmp") { // Note(Aurelius84): The `tmp` is used as prefix key while naming a temporary
// intermediate var both in imperative and static mode. But the
// `UniqueNameGenerator` in C++ and `unique_name.py` in Python doesn't share
// the same auto-increment id. It will create a variable repeatedly with same
// name like `tmp_0` in some cases when transform dygraph into static layers.
// So we modify the default prefix key into `eager_tmp` to distinguish with
// static graph.
std::string GenerateUniqueName(std::string key = "eager_tmp") {
return generator_->Generate(key); return generator_->Generate(key);
} }
......
...@@ -873,7 +873,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -873,7 +873,7 @@ void BindImperative(py::module *m_ptr) {
&imperative::Tracer::GetProgramDescTracer, &imperative::Tracer::GetProgramDescTracer,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName, .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
py::arg("key") = "tmp") py::arg("key") = "eager_tmp")
.def("trace", .def("trace",
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
......
...@@ -43,3 +43,18 @@ class TestUniqueName(unittest.TestCase): ...@@ -43,3 +43,18 @@ class TestUniqueName(unittest.TestCase):
name3 = fluid.unique_name.generate('tmp') name3 = fluid.unique_name.generate('tmp')
self.assertNotEqual(name1, name2) self.assertNotEqual(name1, name2)
self.assertEqual(name1[-2:], name3[-2:]) self.assertEqual(name1[-2:], name3[-2:])
class TestImperativeUniqueName(unittest.TestCase):
def test_name_generator(self):
with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer()
tmp_var_0 = tracer._generate_unique_name()
self.assertEqual(tmp_var_0, "eager_tmp_0")
tmp_var_1 = tracer._generate_unique_name("eager_tmp")
self.assertEqual(tmp_var_1, "eager_tmp_1")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册