diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index e98559884906cf5907e618b48d9f63faa953e2c3..5c197cec2a0054722e1c84f8a6920cb5d4f32725 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -1,18 +1,27 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include "paddle/framework/scope.h" namespace paddle { namespace framework { -Error Scope::CreateVariable(const std::string& name) { - if (name == "") { - return Error("Variable name should not be empty"); +Variable* Scope::CreateVariable(const std::string& name) { + if (!HasVariable(name)) { + vars_[name] = std::unique_ptr(new Variable()); } - - if (HaveVariable(name)) { - return AlreadyCreated; - } - vars_[name] = std::unique_ptr(new Variable()); - return Error(); + return GetVariable(name); } Variable* Scope::GetVarLocally(const std::string& name) const { @@ -33,22 +42,8 @@ Variable* Scope::GetVariable(const std::string& name) const { } } -Variable* Scope::GetOrCreateVariable(const std::string& name) { - Variable* var = GetVariable(name); - if (var != nullptr) { - return var; - } - - Error err = CreateVariable(name); - if (!err.isOK()) { - return nullptr; - } else { - return GetVariable(name); - } -} - -bool Scope::HaveVariable(const std::string& name) { - return vars_.count(name) != 0; +bool Scope::HasVariable(const std::string &name) { + return (vars_.count(name) > 0 || (parent_ && parent_->HasVariable(name))); } } // namespace framework diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 90c8141e4f41a1c1cb73439012c38a60f24f0774..81491f34d8cb22b408c572fc0dba60ad90c008d3 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -1,15 +1,28 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include #include +#include + #include "paddle/framework/variable.h" -#include "paddle/utils/Error.h" namespace paddle { namespace framework { -const static Error AlreadyCreated("Variable has already been created"); - /** * Scope is an association of a name to Variable. All variables belong to * `Scope`. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. @@ -26,20 +39,17 @@ class Scope { // Create Variable in this Scope. Return error if Variable already been // created. - Error __must_check CreateVariable(const std::string& name); + Variable* CreateVariable(const std::string& name); // Get Variable from this Scope, this function will recursive find Variable // from it's parent scope. Return nullptr if not found. Variable* GetVariable(const std::string& name) const; - // find and return Variables in the scope it self. + // Find and return Variables in the scope it self. Variable* GetVarLocally(const std::string& name) const; - // Get a Variable from Scope, if the Variable is not exist then create it. - // User should call this function most of time. - Variable* GetOrCreateVariable(const std::string& name); - - bool HaveVariable(const std::string& name); + // Find if there is a Variable in this scope and it's parent scope + bool HasVariable(const std::string &name); private: std::unordered_map> vars_; diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index 09fbb78d69f2be44982f1d15cb3ec13610434f63..25c144868b45da47f897caa7e0522590b80b6019 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -1,47 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include "paddle/framework/scope.h" #include "gtest/gtest.h" TEST(Scope, Create) { using paddle::framework::Scope; - using paddle::Error; using paddle::framework::Variable; - using paddle::framework::AlreadyCreated; Scope* scope = new Scope(); - Error err = scope->CreateVariable(""); - EXPECT_FALSE(err.isOK()); + Variable* var0 = scope->CreateVariable(""); + EXPECT_NE(var0, nullptr); Variable* var1 = scope->GetVariable("a"); EXPECT_EQ(var1, nullptr); - Error err1 = scope->CreateVariable("a"); - EXPECT_TRUE(err1.isOK()); - - Error err2 = scope->CreateVariable("a"); - EXPECT_EQ(err2, AlreadyCreated); - - Variable* var2 = scope->GetVariable("a"); + Variable* var2 = scope->CreateVariable("a"); EXPECT_NE(var2, nullptr); - Variable* var3 = scope->GetOrCreateVariable("b"); - EXPECT_NE(var3, nullptr); + Variable* var3 = scope->CreateVariable("a"); + EXPECT_EQ(var2, var3); + + Variable* var4 = scope->GetVariable("a"); + EXPECT_EQ(var2, var4); } TEST(Scope, Parent) { using paddle::framework::Scope; using paddle::framework::Variable; - using paddle::Error; const auto parent_scope_ptr = std::shared_ptr(new Scope()); Scope* scope = new Scope(parent_scope_ptr); - Error err = parent_scope_ptr->CreateVariable("a"); - EXPECT_TRUE(err.isOK()); + Variable* var0 = parent_scope_ptr->CreateVariable("a"); + EXPECT_NE(var0, nullptr); Variable* var1 = scope->GetVarLocally("a"); EXPECT_EQ(var1, nullptr); Variable* var2 = scope->GetVariable("a"); - EXPECT_NE(var2, nullptr); + EXPECT_EQ(var2, var0); } \ No newline at end of file