提交 aabe1715 编写于 作者: Q qiaolongfei

merge CreateVar and GetOrCreateVar

上级 1678ad7b
......@@ -19,7 +19,6 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/variable.h"
#include "paddle/platform/assert.h"
namespace paddle {
namespace framework {
......@@ -44,9 +43,13 @@ class Scope {
/// Create Variable in this Scope. Failed if Variable already been
/// created.
Variable* CreateVariable(const std::string& name) {
PADDLE_ASSERT(!HasVariable(name));
vars_[name] = std::unique_ptr<Variable>(new Variable());
return GetVariable(name);
auto var = GetVariable(name);
if (var) {
return var;
} else {
vars_[name] = std::unique_ptr<Variable>(new Variable());
return GetVariable(name);
}
}
/// Get Variable from this Scope, this function will recursive find Variable
......@@ -62,16 +65,6 @@ class Scope {
}
}
/// Get Variable from scope, if Variable is not exist, creat one and return.
Variable* GetOrCreateVariable(const std::string& name) {
auto var = GetVariable(name);
if (var) {
return var;
} else {
return CreateVariable(name);
}
}
/// Find if there is a Variable in this scope and it's parent scope
bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() ||
......
......@@ -24,18 +24,22 @@ TEST(Scope, Create) {
Variable* var0 = scope->CreateVariable("");
EXPECT_NE(var0, nullptr);
/// GetVariable will return nullptr if not exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var1, nullptr);
/// CreateVariable will return one.
Variable* var2 = scope->CreateVariable("a");
EXPECT_NE(var2, nullptr);
ASSERT_DEATH({ scope->CreateVariable("a"); }, "");
/// Get the created variable.
Variable* var3 = scope->GetVariable("a");
EXPECT_EQ(var2, var3);
Variable* var4 = scope->GetOrCreateVariable("a");
EXPECT_EQ(var2, var4);
/// CreateVariable will just return the variable if it's
/// already exist.
Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2);
}
TEST(Scope, Parent) {
......@@ -48,6 +52,7 @@ TEST(Scope, Parent) {
Variable* var0 = parent_scope->CreateVariable("a");
EXPECT_NE(var0, nullptr);
/// GetVariable will get Variable from parent scope if exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var0, var1);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册