diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index bb22c4b834f08e585045a97e9756cf359a1d89c2..88a13145ca9ce622894b36fdf9638817f523dfb8 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "paddle/framework/variable.h" -#include "paddle/platform/assert.h" namespace paddle { namespace framework { @@ -44,9 +43,13 @@ class Scope { /// Create Variable in this Scope. Failed if Variable already been /// created. Variable* CreateVariable(const std::string& name) { - PADDLE_ASSERT(!HasVariable(name)); - vars_[name] = std::unique_ptr(new Variable()); - return GetVariable(name); + auto var = GetVariable(name); + if (var) { + return var; + } else { + vars_[name] = std::unique_ptr(new Variable()); + return GetVariable(name); + } } /// Get Variable from this Scope, this function will recursive find Variable @@ -62,16 +65,6 @@ class Scope { } } - /// Get Variable from scope, if Variable is not exist, creat one and return. - Variable* GetOrCreateVariable(const std::string& name) { - auto var = GetVariable(name); - if (var) { - return var; - } else { - return CreateVariable(name); - } - } - /// Find if there is a Variable in this scope and it's parent scope bool HasVariable(const std::string& name) const { return (vars_.find(name) != vars_.end() || diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index d73391d9770e84b5f3c7a15bd240dedfd1d8a559..ec6236ec62197f20b77cbdcfaad6be35ef42835b 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -24,18 +24,22 @@ TEST(Scope, Create) { Variable* var0 = scope->CreateVariable(""); EXPECT_NE(var0, nullptr); + /// GetVariable will return nullptr if not exist. Variable* var1 = scope->GetVariable("a"); EXPECT_EQ(var1, nullptr); + /// CreateVariable will return one. Variable* var2 = scope->CreateVariable("a"); + EXPECT_NE(var2, nullptr); - ASSERT_DEATH({ scope->CreateVariable("a"); }, ""); - + /// Get the created variable. Variable* var3 = scope->GetVariable("a"); EXPECT_EQ(var2, var3); - Variable* var4 = scope->GetOrCreateVariable("a"); - EXPECT_EQ(var2, var4); + /// CreateVariable will just return the variable if it's + /// already exist. + Variable* var4 = scope->CreateVariable("a"); + EXPECT_EQ(var4, var2); } TEST(Scope, Parent) { @@ -48,6 +52,7 @@ TEST(Scope, Parent) { Variable* var0 = parent_scope->CreateVariable("a"); EXPECT_NE(var0, nullptr); + /// GetVariable will get Variable from parent scope if exist. Variable* var1 = scope->GetVariable("a"); EXPECT_EQ(var0, var1); }