From ef188371a1b6106437e4c580f76ca7bbba1babc3 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Tue, 26 Dec 2017 14:49:30 +0800 Subject: [PATCH] Polish `Scope::LocalVarNames` Cannot get var name recursive since they could be same. --- paddle/framework/scope.cc | 14 +++----------- paddle/framework/scope.h | 2 +- paddle/framework/scope_test.cc | 2 +- paddle/operators/recurrent_op.cc | 2 +- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 656736e238..0c01d605bc 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -74,17 +74,9 @@ void Scope::DropKids() { kids_.clear(); } -std::vector Scope::GetAllNames(bool recursive) const { - std::vector known_vars(vars_.size()); - - if (recursive) { - for (auto& kid : kids_) { - auto kid_vars = kid->GetAllNames(); - for (auto& p : kid_vars) { - known_vars.emplace_back(p); - } - } - } +std::vector Scope::LocalVarNames() const { + std::vector known_vars; + known_vars.reserve(this->vars_.size()); for (auto& p : vars_) { known_vars.emplace_back(p.first); } diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 56e815db54..10143326df 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -66,7 +66,7 @@ class Scope { void DropKids(); // enumerate all the variables current contains. - std::vector GetAllNames(bool recursive = false) const; + std::vector LocalVarNames() const; // Rename variable to a new name void Rename(const std::string& origin_name, diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index f738d5ba9e..0f5b86061d 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) { Variable* v = s.Var("a"); EXPECT_EQ(&s, s.FindScope(v)); - std::vector ans = s.GetAllNames(); + std::vector ans = s.LocalVarNames(); std::string str; for (auto& var : ans) { str += var; diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 77f3a40b76..c4740e0ce1 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -491,7 +491,7 @@ class RecurrentGradOp : public RecurrentBase { std::unordered_set LocalVarNames( const framework::Scope &scope) const { - return this->List2Set(scope.GetAllNames(false)); + return this->List2Set(scope.LocalVarNames()); } static std::vector GradVarLists( const std::vector &var_names) { -- GitLab