提交 2f47562d 编写于 作者: Q qiaolongfei

scope-impl

上级 dcece75b
...@@ -41,7 +41,7 @@ class Scope { ...@@ -41,7 +41,7 @@ class Scope {
const Variable* GetVariable(const std::string& name) const; const Variable* GetVariable(const std::string& name) const;
private: private:
std::unordered_map<std::string, std::unique_ptr<Vairable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
}; };
``` ```
......
# ddim lib
cc_library(ddim SRCS ddim.cc) cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
# scope lib
cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
#include "paddle/framework/scope.h"
namespace paddle {
namespace framework {
Error Scope::CreateVariable(const std::string &name) {
if (name == "") {
return Error("Variable name should not be empty");
}
if (HaveVariable(name)) {
return AlreadyCreated;
}
vars_[name] = std::unique_ptr<Variable>(new Variable());
return Error();
}
Variable* Scope::GetVarLocally(const std::string& name) const {
if (vars_.count(name)) {
return vars_.at(name).get();
}
return nullptr;
}
Variable* Scope::GetVariable(const std::string &name) const {
Variable* var = GetVarLocally(name);
if (var != nullptr) {
return var;
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
return nullptr;
}
}
Variable* Scope::GetOrCreateVariable(const std::string &name) {
Variable* var;
var = GetVariable(name);
if (var == nullptr) {
auto err = CreateVariable(name);
if (!err.isOK()) {
return nullptr;
}
}
return GetVariable(name);
}
bool Scope::HaveVariable(const std::string &name) {
return vars_.count(name) != 0;
}
} // namespace framework
} // namespace paddle
#pragma once
#include <vector>
#include <unordered_map>
#include "paddle/framework/variable.h"
#include "paddle/utils/Error.h"
namespace paddle {
namespace framework {
const static Error AlreadyCreated("Variable has already been created");
/**
* Scope is an association of a name to Variable. All variables belong to `Scope`.
* You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. One net can
* run in different scopes and update different variable in the scope.
*/
class Scope {
public:
Scope() {}
explicit Scope(const std::shared_ptr<Scope> &scope):
parent_(scope) {}
~Scope() {}
// Create Variable in this Scope. Return error if Variable already been
// created.
Error __must_check CreateVariable(const std::string& name);
// Get Variable from this Scope, this function will recursive find Variable
// from it's parent scope.
// Return nullptr if not found.
Variable* GetVariable(const std::string& name) const;
// find and return Variables in the scope it self.
Variable* GetVarLocally(const std::string& name) const;
// Get a Variable from Scope, if the Variable is not exist then create it.
// User should call this function most of time.
Variable* GetOrCreateVariable(const std::string& name);
bool HaveVariable(const std::string& name);
private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::shared_ptr<Scope> parent_ {nullptr};
};
} // namespace framework
} // namespace paddle
#include "paddle/framework/scope.h"
#include "gtest/gtest.h"
TEST(Scope, Create) {
using paddle::framework::Scope;
using paddle::Error;
using paddle::framework::Variable;
using paddle::framework::AlreadyCreated;
Scope* scope = new Scope();
Error err = scope->CreateVariable("");
EXPECT_FALSE(err.isOK());
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var1, nullptr);
Error err1 = scope->CreateVariable("a");
EXPECT_TRUE(err1.isOK());
Error err2 = scope->CreateVariable("a");
EXPECT_EQ(err2, AlreadyCreated);
Variable* var2 = scope->GetVariable("a");
EXPECT_NE(var2, nullptr);
Variable* var3 = scope->GetOrCreateVariable("b");
EXPECT_NE(var3, nullptr);
}
TEST(Scope, Parent) {
using paddle::framework::Scope;
using paddle::framework::Variable;
using paddle::Error;
const auto parent_scope_ptr = std::shared_ptr<Scope>(new Scope());
Scope* scope = new Scope(parent_scope_ptr);
Error err = parent_scope_ptr->CreateVariable("a");
EXPECT_TRUE(err.isOK());
Variable* var1 = scope->GetVarLocally("a");
EXPECT_EQ(var1, nullptr);
Variable* var2 = scope->GetVariable("a");
EXPECT_NE(var2, nullptr);
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册