未验证 提交 494cb36d 编写于 作者: A Aurelius84 提交者: GitHub

Modify tmp var name prefix in dygraph (#25280)

* Modify tmp var name prefix in dygraph test=develop

* refine comment test=develop
上级 82ec247a
......@@ -285,6 +285,11 @@ TEST(test_tracer, test_unique_name_generator) {
auto fc_2 = tracer.GenerateUniqueName("fc");
ASSERT_STREQ("fc_0", fc_1.c_str());
ASSERT_STREQ("fc_1", fc_2.c_str());
// use `eager_tmp` as key if not specify it.
auto tmp_var_2 = tracer.GenerateUniqueName();
ASSERT_STREQ("eager_tmp_2", tmp_var_2.c_str());
auto tmp_var_3 = tracer.GenerateUniqueName("eager_tmp");
ASSERT_STREQ("eager_tmp_3", tmp_var_3.c_str());
}
TEST(test_tracer, test_current_tracer) {
......
......@@ -76,7 +76,14 @@ class Tracer {
return program_desc_tracer_.get();
}
std::string GenerateUniqueName(std::string key = "tmp") {
// Note(Aurelius84): The `tmp` is used as prefix key while naming a temporary
// intermediate var both in imperative and static mode. But the
// `UniqueNameGenerator` in C++ and `unique_name.py` in Python doesn't share
// the same auto-increment id. It will create a variable repeatedly with same
// name like `tmp_0` in some cases when transform dygraph into static layers.
// So we modify the default prefix key into `eager_tmp` to distinguish with
// static graph.
std::string GenerateUniqueName(std::string key = "eager_tmp") {
return generator_->Generate(key);
}
......
......@@ -873,7 +873,7 @@ void BindImperative(py::module *m_ptr) {
&imperative::Tracer::GetProgramDescTracer,
py::return_value_policy::reference)
.def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
py::arg("key") = "tmp")
py::arg("key") = "eager_tmp")
.def("trace",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
......
......@@ -43,3 +43,18 @@ class TestUniqueName(unittest.TestCase):
name3 = fluid.unique_name.generate('tmp')
self.assertNotEqual(name1, name2)
self.assertEqual(name1[-2:], name3[-2:])
class TestImperativeUniqueName(unittest.TestCase):
def test_name_generator(self):
with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer()
tmp_var_0 = tracer._generate_unique_name()
self.assertEqual(tmp_var_0, "eager_tmp_0")
tmp_var_1 = tracer._generate_unique_name("eager_tmp")
self.assertEqual(tmp_var_1, "eager_tmp_1")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册