diff --git a/doc/design/scope.md b/doc/design/scope.md index 2ff416f06e8ada48b1d4922f8869a106f35799e2..afe6bc028cafc5ee24b0041905857af58d3f5790 100644 --- a/doc/design/scope.md +++ b/doc/design/scope.md @@ -41,7 +41,7 @@ class Scope { const Variable* GetVariable(const std::string& name) const; private: - std::unordered_map> vars_; + std::unordered_map> vars_; }; ``` @@ -59,9 +59,9 @@ class Scope { Scope(const std::shared_ptr& scope): parent_(scope) {} Variable* GetVariable(const std::string& name) const { - Variable* var = GetVarLocally(name); - if (var != nullptr) { - return var; + auto it = vars_.find(name); + if (it != vars_.end()) { + return it->second.get(); } else if (parent_ != nullptr) { return parent_->GetVariable(name); } else { @@ -97,8 +97,8 @@ class Scope { // return nullptr if not found. Variable* GetVariable(const std::string& name) const; - // return Error if already contains same name variable. - Error CreateVariable(const std::string& name); + // return if already contains same name variable. + Variable* CreateVariable(const std::string& name); private: std::shared_ptr parent_; diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index b06ecc26286de1385f6ea4eabc01396c07d7aa52..6aa6b9bc2db6a223dd8562b76ba9d777206bfd40 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,5 +1,7 @@ +# ddim lib cc_library(ddim SRCS ddim.cc) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_test(variable_test SRCS variable_test.cc) +cc_test(scope_test SRCS scope_test.cc) cc_test(enforce_test SRCS enforce_test.cc) diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h new file mode 100644 index 0000000000000000000000000000000000000000..a4470f726fb0d59a82db29b3239c111ea1569c55 --- /dev/null +++ b/paddle/framework/scope.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/framework/variable.h" + +namespace paddle { +namespace framework { + +/** + * @brief Scope that manage all variables. + * + * Scope is an association of a name to Variable. All variables belong to + * Scope. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. + * One net can run in different scopes and update different variable in the + * scope. + */ +class Scope { + public: + /** + * @brief Initialize s Scope without parent. + */ + Scope() {} + + /** + * @brief Initialize a Scope with parent. + */ + explicit Scope(const std::shared_ptr& parent) : parent_(parent) {} + + /** + * @brief Create Variable + * + * Create Variable in this Scope. Return the exist one if Variable already + * been created. + */ + Variable* CreateVariable(const std::string& name) { + auto var = GetVariable(name); + if (var) { + return var; + } else { + vars_[name] = std::unique_ptr(new Variable()); + return GetVariable(name); + } + } + + /** + * @brief Get Variable. + * + * Get Variable from this Scope, this function will recursive find Variable + * from it's parent scope. Return nullptr if not found. + */ + Variable* GetVariable(const std::string& name) const { + auto it = vars_.find(name); + if (it != vars_.end()) { + return it->second.get(); + } else if (parent_ != nullptr) { + return parent_->GetVariable(name); + } else { + return nullptr; + } + } + + /** + * @brief If this scope has a Var named name. + * + * Find if there is a Variable in this scope and it's parent scope + */ + bool HasVariable(const std::string& name) const { + return (vars_.find(name) != vars_.end() || + (parent_ && parent_->HasVariable(name))); + } + + private: + std::unordered_map> vars_; + std::shared_ptr parent_{nullptr}; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..df1afb200ce9e75c5b1e40f2da56667487ae3576 --- /dev/null +++ b/paddle/framework/scope_test.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/scope.h" +#include "gtest/gtest.h" + +TEST(Scope, Create) { + using paddle::framework::Scope; + using paddle::framework::Variable; + + auto scope = std::make_shared(); + + Variable* var0 = scope->CreateVariable(""); + EXPECT_NE(var0, nullptr); + + /// GetVariable will return nullptr if not exist. + Variable* var1 = scope->GetVariable("a"); + EXPECT_EQ(var1, nullptr); + + /// CreateVariable will return one. + Variable* var2 = scope->CreateVariable("a"); + EXPECT_NE(var2, nullptr); + + /// Get the created variable. + Variable* var3 = scope->GetVariable("a"); + EXPECT_EQ(var2, var3); + + /// CreateVariable will just return the variable if it's + /// already exist. + Variable* var4 = scope->CreateVariable("a"); + EXPECT_EQ(var4, var2); +} + +TEST(Scope, Parent) { + using paddle::framework::Scope; + using paddle::framework::Variable; + + auto parent_scope = std::make_shared(); + auto scope = std::make_shared(parent_scope); + + Variable* var0 = parent_scope->CreateVariable("a"); + EXPECT_NE(var0, nullptr); + + /// GetVariable will get Variable from parent scope if exist. + Variable* var1 = scope->GetVariable("a"); + EXPECT_EQ(var0, var1); +}