diff --git a/doc/design/scope.md b/doc/design/scope.md index 2ff416f06e8ada48b1d4922f8869a106f35799e2..4d14a64977cf583a07975898da5961aa63f70faf 100644 --- a/doc/design/scope.md +++ b/doc/design/scope.md @@ -41,7 +41,7 @@ class Scope { const Variable* GetVariable(const std::string& name) const; private: - std::unordered_map> vars_; + std::unordered_map> vars_; }; ``` diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index e3c3155aa902c941058ea1b15488360df6c06175..7ea17f7114c6558f5a786d02eeff1ba097a870ee 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,6 +1,11 @@ +# ddim lib cc_library(ddim SRCS ddim.cc) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_test(variable_test SRCS variable_test.cc) + +# scope lib +cc_library(scope SRCS scope.cc) +cc_test(scope_test SRCS scope_test.cc DEPS scope) diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed75aece01579f37f5237397767b6e4ee3b9c9d6 --- /dev/null +++ b/paddle/framework/scope.cc @@ -0,0 +1,54 @@ +#include "paddle/framework/scope.h" + +namespace paddle { +namespace framework { + +Error Scope::CreateVariable(const std::string &name) { + if (name == "") { + return Error("Variable name should not be empty"); + } + + if (HaveVariable(name)) { + return AlreadyCreated; + } + vars_[name] = std::unique_ptr(new Variable()); + return Error(); +} + +Variable* Scope::GetVarLocally(const std::string& name) const { + if (vars_.count(name)) { + return vars_.at(name).get(); + } + return nullptr; +} + +Variable* Scope::GetVariable(const std::string &name) const { + Variable* var = GetVarLocally(name); + if (var != nullptr) { + return var; + } else if (parent_ != nullptr) { + return parent_->GetVariable(name); + } else { + return nullptr; + } +} + +Variable* Scope::GetOrCreateVariable(const std::string &name) { + Variable* var; + var = GetVariable(name); + if (var == nullptr) { + auto err = CreateVariable(name); + if (!err.isOK()) { + return nullptr; + } + } + return GetVariable(name); +} + +bool Scope::HaveVariable(const std::string &name) { + return vars_.count(name) != 0; +} + +} // namespace framework +} // namespace paddle + diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h new file mode 100644 index 0000000000000000000000000000000000000000..ad1ed2ddab901379555642cfa97165ab7581a9d9 --- /dev/null +++ b/paddle/framework/scope.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include "paddle/framework/variable.h" +#include "paddle/utils/Error.h" + +namespace paddle { +namespace framework { + +const static Error AlreadyCreated("Variable has already been created"); + +/** + * Scope is an association of a name to Variable. All variables belong to `Scope`. + * You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. One net can + * run in different scopes and update different variable in the scope. + */ +class Scope { + public: + Scope() {} + + explicit Scope(const std::shared_ptr &scope): + parent_(scope) {} + + ~Scope() {} + + // Create Variable in this Scope. Return error if Variable already been + // created. + Error __must_check CreateVariable(const std::string& name); + + // Get Variable from this Scope, this function will recursive find Variable + // from it's parent scope. + // Return nullptr if not found. + Variable* GetVariable(const std::string& name) const; + + // find and return Variables in the scope it self. + Variable* GetVarLocally(const std::string& name) const; + + // Get a Variable from Scope, if the Variable is not exist then create it. + // User should call this function most of time. + Variable* GetOrCreateVariable(const std::string& name); + + bool HaveVariable(const std::string& name); + + private: + std::unordered_map> vars_; + std::shared_ptr parent_ {nullptr}; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..09fbb78d69f2be44982f1d15cb3ec13610434f63 --- /dev/null +++ b/paddle/framework/scope_test.cc @@ -0,0 +1,47 @@ +#include "paddle/framework/scope.h" +#include "gtest/gtest.h" + +TEST(Scope, Create) { + using paddle::framework::Scope; + using paddle::Error; + using paddle::framework::Variable; + using paddle::framework::AlreadyCreated; + + Scope* scope = new Scope(); + + Error err = scope->CreateVariable(""); + EXPECT_FALSE(err.isOK()); + + Variable* var1 = scope->GetVariable("a"); + EXPECT_EQ(var1, nullptr); + + Error err1 = scope->CreateVariable("a"); + EXPECT_TRUE(err1.isOK()); + + Error err2 = scope->CreateVariable("a"); + EXPECT_EQ(err2, AlreadyCreated); + + Variable* var2 = scope->GetVariable("a"); + EXPECT_NE(var2, nullptr); + + Variable* var3 = scope->GetOrCreateVariable("b"); + EXPECT_NE(var3, nullptr); +} + +TEST(Scope, Parent) { + using paddle::framework::Scope; + using paddle::framework::Variable; + using paddle::Error; + + const auto parent_scope_ptr = std::shared_ptr(new Scope()); + Scope* scope = new Scope(parent_scope_ptr); + + Error err = parent_scope_ptr->CreateVariable("a"); + EXPECT_TRUE(err.isOK()); + + Variable* var1 = scope->GetVarLocally("a"); + EXPECT_EQ(var1, nullptr); + + Variable* var2 = scope->GetVariable("a"); + EXPECT_NE(var2, nullptr); +} \ No newline at end of file