diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 7ea17f7114c6558f5a786d02eeff1ba097a870ee..6caeb1be3a8f1d15df87f4b4f7a99bad352ec5b5 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -6,6 +6,4 @@ nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_test(variable_test SRCS variable_test.cc) -# scope lib -cc_library(scope SRCS scope.cc) -cc_test(scope_test SRCS scope_test.cc DEPS scope) +cc_test(scope_test SRCS scope_test.cc) diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc deleted file mode 100644 index 72cb744707dd6f8ccf8fa02a818af4927d662f24..0000000000000000000000000000000000000000 --- a/paddle/framework/scope.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/framework/scope.h" - -namespace paddle { -namespace framework { - -Variable* Scope::CreateVariable(const std::string& name) { - if (!HasVariable(name)) { - vars_[name] = std::unique_ptr(new Variable()); - } - return GetVariable(name); -} - -Variable* Scope::GetVarLocally(const std::string& name) const { - if (vars_.count(name)) { - return vars_.at(name).get(); - } - return nullptr; -} - -Variable* Scope::GetVariable(const std::string& name) const { - Variable* var = GetVarLocally(name); - if (var != nullptr) { - return var; - } else if (parent_ != nullptr) { - return parent_->GetVariable(name); - } else { - return nullptr; - } -} - -bool Scope::HasVariable(const std::string& name) { - return (vars_.count(name) > 0 || (parent_ && parent_->HasVariable(name))); -} - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index a624fe3bbef3ae3c7671e6f549c2cb7c63def15f..2f8d6dbd9763220903d52d6f26aa769f08ad86bb 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -19,37 +19,58 @@ limitations under the License. */ #include #include "paddle/framework/variable.h" +#include "paddle/platform/assert.h" namespace paddle { namespace framework { /** * Scope is an association of a name to Variable. All variables belong to - * `Scope`. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. + * Scope. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. * One net can run in different scopes and update different variable in the * scope. */ class Scope { public: - Scope() {} - - explicit Scope(const std::shared_ptr& scope) : parent_(scope) {} - - ~Scope() {} - - // Create Variable in this Scope. Return error if Variable already been - // created. - Variable* CreateVariable(const std::string& name); - - // Get Variable from this Scope, this function will recursive find Variable - // from it's parent scope. Return nullptr if not found. - Variable* GetVariable(const std::string& name) const; - - // Find and return Variables in the scope it self. - Variable* GetVarLocally(const std::string& name) const; - - // Find if there is a Variable in this scope and it's parent scope - bool HasVariable(const std::string& name); + explicit Scope(const std::shared_ptr& parent = nullptr) + : parent_(parent) {} + + /// Create Variable in this Scope. Failed if Variable already been + /// created. + Variable* CreateVariable(const std::string& name) { + PADDLE_ASSERT(!HasVariable(name)); + vars_[name] = std::unique_ptr(new Variable()); + return GetVariable(name); + } + + /// Get Variable from this Scope, this function will recursive find Variable + /// from it's parent scope. Return nullptr if not found. + Variable* GetVariable(const std::string& name) const { + auto it = vars_.find(name); + if (it != vars_.end()) { + return it->second.get(); + } else if (parent_ != nullptr) { + return parent_->GetVariable(name); + } else { + return nullptr; + } + } + + /// Get Variable from scope, if Variable is not exist, creat one and return. + Variable* GetOrCreateVariable(const std::string& name) { + auto var = GetVariable(name); + if (var) { + return var; + } else { + return CreateVariable(name); + } + } + + /// Find if there is a Variable in this scope and it's parent scope + bool HasVariable(const std::string& name) const { + return (vars_.find(name) != vars_.end() || + (parent_ && parent_->HasVariable(name))); + } private: std::unordered_map> vars_; diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index 25c144868b45da47f897caa7e0522590b80b6019..34ee21e1aaa11202f837a2a0cac239da2c0b2e66 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -28,12 +28,13 @@ TEST(Scope, Create) { EXPECT_EQ(var1, nullptr); Variable* var2 = scope->CreateVariable("a"); - EXPECT_NE(var2, nullptr); - Variable* var3 = scope->CreateVariable("a"); + ASSERT_DEATH({ scope->CreateVariable("a"); }, ""); + + Variable* var3 = scope->GetVariable("a"); EXPECT_EQ(var2, var3); - Variable* var4 = scope->GetVariable("a"); + Variable* var4 = scope->GetOrCreateVariable("a"); EXPECT_EQ(var2, var4); } @@ -47,9 +48,6 @@ TEST(Scope, Parent) { Variable* var0 = parent_scope_ptr->CreateVariable("a"); EXPECT_NE(var0, nullptr); - Variable* var1 = scope->GetVarLocally("a"); - EXPECT_EQ(var1, nullptr); - - Variable* var2 = scope->GetVariable("a"); - EXPECT_EQ(var2, var0); -} \ No newline at end of file + Variable* var1 = scope->GetVariable("a"); + EXPECT_EQ(var0, var1); +}