提交 ef188371 编写于 作者: Y Yang Yu

Polish `Scope::LocalVarNames`

Cannot get var name recursive since they could be same.
上级 f8391545
...@@ -74,17 +74,9 @@ void Scope::DropKids() { ...@@ -74,17 +74,9 @@ void Scope::DropKids() {
kids_.clear(); kids_.clear();
} }
std::vector<std::string> Scope::GetAllNames(bool recursive) const { std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> known_vars(vars_.size()); std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
if (recursive) {
for (auto& kid : kids_) {
auto kid_vars = kid->GetAllNames();
for (auto& p : kid_vars) {
known_vars.emplace_back(p);
}
}
}
for (auto& p : vars_) { for (auto& p : vars_) {
known_vars.emplace_back(p.first); known_vars.emplace_back(p.first);
} }
......
...@@ -66,7 +66,7 @@ class Scope { ...@@ -66,7 +66,7 @@ class Scope {
void DropKids(); void DropKids();
// enumerate all the variables current contains. // enumerate all the variables current contains.
std::vector<std::string> GetAllNames(bool recursive = false) const; std::vector<std::string> LocalVarNames() const;
// Rename variable to a new name // Rename variable to a new name
void Rename(const std::string& origin_name, void Rename(const std::string& origin_name,
......
...@@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) { ...@@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) {
Variable* v = s.Var("a"); Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v)); EXPECT_EQ(&s, s.FindScope(v));
std::vector<std::string> ans = s.GetAllNames(); std::vector<std::string> ans = s.LocalVarNames();
std::string str; std::string str;
for (auto& var : ans) { for (auto& var : ans) {
str += var; str += var;
......
...@@ -491,7 +491,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -491,7 +491,7 @@ class RecurrentGradOp : public RecurrentBase {
std::unordered_set<std::string> LocalVarNames( std::unordered_set<std::string> LocalVarNames(
const framework::Scope &scope) const { const framework::Scope &scope) const {
return this->List2Set(scope.GetAllNames(false)); return this->List2Set(scope.LocalVarNames());
} }
static std::vector<std::string> GradVarLists( static std::vector<std::string> GradVarLists(
const std::vector<std::string> &var_names) { const std::vector<std::string> &var_names) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册