From c5afddc681d2c4ffbbb747aad8f273d85994e7d8 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Sun, 30 Jul 2017 14:06:03 -0700 Subject: [PATCH] Rewrite Scope --- paddle/framework/scope.cc | 58 +++++++++++++++++++++ paddle/framework/scope.h | 93 ++++++++++------------------------ paddle/framework/scope_test.cc | 63 ++++++++++------------- paddle/framework/variable.h | 12 +++++ 4 files changed, 126 insertions(+), 100 deletions(-) create mode 100644 paddle/framework/scope.cc diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc new file mode 100644 index 00000000000..ad5360d98fb --- /dev/null +++ b/paddle/framework/scope.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/scope.h" + +namespace paddle { +namespace framework { + +Scope::~Scope() { + for (Variable* v : vars_) delete v; + + for (Scope* s : kids_) delete s; +} + +Scope& Scope::NewScope() { + kids_.push_back(new Scope(this)); + return *kids_.back(); +} + +Variable* Scope::NewVar(const std::string& name) { + atuo iter = vars_.find(name); + if (iter != vars_.end()) { + return iter.second->get(); + } + Variable* v = new Variable(); + v->name_ = name; + var_[name] = v; + return v; +} + +Variable* Scope::NewVar() { + return NewVar(string.Sprintf("%p.%d", this, vars_.size())); +} + +Variable* Scope::FindVar(const std::string& name) const { + auto it = vars_.find(name); + if (it != vars_.end()) return it->second.get(); + return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); +} + +Scope* Scope::FindScope(const Variable* var) const { + if (FindVar(var->name_) != nullptr) return this; + return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 4faaf841440..b145ae3a4d4 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once +#include +#include #include -#include -#include #include "paddle/framework/variable.h" @@ -35,73 +35,36 @@ class Scope; */ class Scope { public: - /** - * @brief Initialize s Scope without parent. - */ Scope() {} + ~Scope(); - /** - * @brief Initialize a Scope with parent. - */ - explicit Scope(const std::shared_ptr& parent) : parent_(parent) {} - - /** - * @brief Create Variable - * - * Create Variable in this Scope. Return the exist one if Variable already - * been created. - */ - Variable* CreateVariable(const std::string& name) { - auto var = GetVariable(name); - if (var) { - return var; - } else { - auto ptr = new Variable(); - name_to_var_[name] = std::unique_ptr(ptr); - var_to_name_[ptr] = name; - return GetVariable(name); - } - } - - /** - * @brief Get Variable. - * - * Get Variable from this Scope, this function will recursive find Variable - * from it's parent scope. Return nullptr if not found. - */ - Variable* GetVariable(const std::string& name) const { - auto it = name_to_var_.find(name); - if (it != name_to_var_.end()) { - return it->second.get(); - } else if (parent_ != nullptr) { - return parent_->GetVariable(name); - } else { - return nullptr; - } - } - - /** - * @brief If this scope has a Var named name. - * - * Find if there is a Variable in this scope and it's parent scope - */ - bool HasVariable(const std::string& name) const { - return (name_to_var_.find(name) != name_to_var_.end() || - (parent_ && parent_->HasVariable(name))); - } - - std::string GetVariableName(Variable* const var) const { - try { - return var_to_name_.at(var); - } catch (...) { - return ""; - } - } + // Create a sub-scope. Returns a reference other than a pointer so + // to prevent from manual deletion. + Scope& NewScope(); + + // Create a variable with given name if it doesn't exist. + Variable* NewVar(const std::string& name); + + // Create a variable with a scope-unique name. + Variable* NewVar(); + + // Find a variable in the scope or any of its ancestors. Returns + // nullptr if cannot find. + Variable* FindVar(const std::string& name) const; + + // Find the scope or an ancestor scope that contains the given variable. + Scope* FindScope(const Variable* var) const; + + // Returns the name of a variable in this scope. + std::string VarName(const Variable* var) const { return var->name_; } private: - std::unordered_map var_to_name_; - std::unordered_map> name_to_var_; - std::shared_ptr parent_{nullptr}; + // Call Scope::NewScope for a sub-scope. + explicit Scope(Scope* parent) : parent_(parent) {} + + std::map vars_; + std::list kids_; + Scope* parent_{nullptr}; }; } // namespace framework diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index ff069c7be00..6f5e735d824 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -15,49 +15,42 @@ limitations under the License. */ #include "paddle/framework/scope.h" #include "gtest/gtest.h" -TEST(Scope, Create) { - using paddle::framework::Scope; - using paddle::framework::Variable; +using paddle::framework::Scope; +using paddle::framework::Variable; - auto scope = std::make_shared(); +TEST(Scope, VarsShadowing) { + Scope s; + Scope& ss1 = s.NewScope(); + Scope& ss2 = s.NewScope(); - Variable* var0 = scope->CreateVariable(""); - EXPECT_NE(var0, nullptr); + Variable* v0 = s.NewVar("a"); + Variable* v1 = ss1.NewVar("a"); - /// GetVariable will return nullptr if not exist. - Variable* var1 = scope->GetVariable("a"); - EXPECT_EQ(var1, nullptr); + EXPECT_NE(v0, v1); - /// CreateVariable will return one. - Variable* var2 = scope->CreateVariable("a"); - EXPECT_NE(var2, nullptr); - - /// Get the created variable. - Variable* var3 = scope->GetVariable("a"); - EXPECT_EQ(var2, var3); + EXPECT_EQ(v0, s.FindVar("a")); + EXPECT_EQ(v1, ss1.FindVar("a")); + EXPECT_EQ(v0, ss2.FindVar("a")); +} - /// CreateVariable will just return the variable if it's - /// already exist. - Variable* var4 = scope->CreateVariable("a"); - EXPECT_EQ(var4, var2); +TEST(Scope, FindVar) { + Scope s; + Scope& ss = s.NewScope(); - EXPECT_EQ("a", scope->GetVariableName(var4)); - Scope scope2; - auto var = scope2.CreateVariable("tmp"); - EXPECT_EQ("", scope->GetVariableName(var)); -} + EXPECT_EQ(nullptr, s.FindVar("a")); + EXPECT_EQ(nullptr, ss.FindVar("a")); -TEST(Scope, Parent) { - using paddle::framework::Scope; - using paddle::framework::Variable; + ss.NewVar("a"); - auto parent_scope = std::make_shared(); - auto scope = std::make_shared(parent_scope); + EXPECT_EQ(nullptr, s.FindVar("a")); + EXPECT_NE(nullptr, ss.FindVar("a")); +} - Variable* var0 = parent_scope->CreateVariable("a"); - EXPECT_NE(var0, nullptr); +TEST(Scope, FindScope) { + Scope s; + Scope& ss = s.NewScope(); + s.NewVar("a"); - /// GetVariable will get Variable from parent scope if exist. - Variable* var1 = scope->GetVariable("a"); - EXPECT_EQ(var0, var1); + EXPECT_EQ(&s, s.FindVar("a")); + EXPECT_EQ(&s, ss.FindVar("a")); } diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index 72c4a7a2a1d..68a443a06ea 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -17,6 +17,7 @@ #include #include "paddle/platform/assert.h" +#include "paddle/string/piece.h" namespace paddle { namespace framework { @@ -65,6 +66,17 @@ class Variable { std::unique_ptr holder_; // pointers to a PlaceholderImpl object indeed. + + // name_ is only meaningful with a Scope and accessible by it. + // + // NOTE: Please don't expose name_ by adding methods like + // Variable::Name or Scope::VarName! A variable could have a human + // readable name or an auto-generated scope-unique name. In the + // former case, the caller knows the name and doesn't need to access + // the name; in the latter case, the variable should be identified + // by its address but not the unreadable name. + friend class Scope; + string::Piece name_; }; } // namespace framework -- GitLab