提交 aabe1715 编写于 作者: Q qiaolongfei

merge CreateVar and GetOrCreateVar

上级 1678ad7b
...@@ -19,7 +19,6 @@ limitations under the License. */ ...@@ -19,7 +19,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/framework/variable.h" #include "paddle/framework/variable.h"
#include "paddle/platform/assert.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -44,10 +43,14 @@ class Scope { ...@@ -44,10 +43,14 @@ class Scope {
/// Create Variable in this Scope. Failed if Variable already been /// Create Variable in this Scope. Failed if Variable already been
/// created. /// created.
Variable* CreateVariable(const std::string& name) { Variable* CreateVariable(const std::string& name) {
PADDLE_ASSERT(!HasVariable(name)); auto var = GetVariable(name);
if (var) {
return var;
} else {
vars_[name] = std::unique_ptr<Variable>(new Variable()); vars_[name] = std::unique_ptr<Variable>(new Variable());
return GetVariable(name); return GetVariable(name);
} }
}
/// Get Variable from this Scope, this function will recursive find Variable /// Get Variable from this Scope, this function will recursive find Variable
/// from it's parent scope. Return nullptr if not found. /// from it's parent scope. Return nullptr if not found.
...@@ -62,16 +65,6 @@ class Scope { ...@@ -62,16 +65,6 @@ class Scope {
} }
} }
/// Get Variable from scope, if Variable is not exist, creat one and return.
Variable* GetOrCreateVariable(const std::string& name) {
auto var = GetVariable(name);
if (var) {
return var;
} else {
return CreateVariable(name);
}
}
/// Find if there is a Variable in this scope and it's parent scope /// Find if there is a Variable in this scope and it's parent scope
bool HasVariable(const std::string& name) const { bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() || return (vars_.find(name) != vars_.end() ||
......
...@@ -24,18 +24,22 @@ TEST(Scope, Create) { ...@@ -24,18 +24,22 @@ TEST(Scope, Create) {
Variable* var0 = scope->CreateVariable(""); Variable* var0 = scope->CreateVariable("");
EXPECT_NE(var0, nullptr); EXPECT_NE(var0, nullptr);
/// GetVariable will return nullptr if not exist.
Variable* var1 = scope->GetVariable("a"); Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var1, nullptr); EXPECT_EQ(var1, nullptr);
/// CreateVariable will return one.
Variable* var2 = scope->CreateVariable("a"); Variable* var2 = scope->CreateVariable("a");
EXPECT_NE(var2, nullptr);
ASSERT_DEATH({ scope->CreateVariable("a"); }, ""); /// Get the created variable.
Variable* var3 = scope->GetVariable("a"); Variable* var3 = scope->GetVariable("a");
EXPECT_EQ(var2, var3); EXPECT_EQ(var2, var3);
Variable* var4 = scope->GetOrCreateVariable("a"); /// CreateVariable will just return the variable if it's
EXPECT_EQ(var2, var4); /// already exist.
Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2);
} }
TEST(Scope, Parent) { TEST(Scope, Parent) {
...@@ -48,6 +52,7 @@ TEST(Scope, Parent) { ...@@ -48,6 +52,7 @@ TEST(Scope, Parent) {
Variable* var0 = parent_scope->CreateVariable("a"); Variable* var0 = parent_scope->CreateVariable("a");
EXPECT_NE(var0, nullptr); EXPECT_NE(var0, nullptr);
/// GetVariable will get Variable from parent scope if exist.
Variable* var1 = scope->GetVariable("a"); Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var0, var1); EXPECT_EQ(var0, var1);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册