提交 0af45b5f 编写于 作者: D dongzhihong

NewVar to GetOrCreateVar

上级 b504a234
...@@ -243,7 +243,7 @@ class SymbolTable { ...@@ -243,7 +243,7 @@ class SymbolTable {
// TODO determine whether name is generated by python or C++. // TODO determine whether name is generated by python or C++.
// Currently assume that a unique name will be generated by C++ if the // Currently assume that a unique name will be generated by C++ if the
// argument name is left default. // argument name is left default.
VarDesc* NewVar(const string& name=""); VarDesc* GetOrCreateVar(const string& name="");
// find a VarDesc by name, if recursive is true, find parent's SymbolTable // find a VarDesc by name, if recursive is true, find parent's SymbolTable
// recursively. // recursively.
......
...@@ -37,7 +37,7 @@ Scope is an association of a name to variable. All variables belong to `Scope`. ...@@ -37,7 +37,7 @@ Scope is an association of a name to variable. All variables belong to `Scope`.
```cpp ```cpp
class Scope { class Scope {
public: public:
Variable* NewVar(const std::string& name); Variable* GetOrCreateVar(const std::string& name);
const Variable* FindVar(const std::string& name) const; const Variable* FindVar(const std::string& name) const;
private: private:
...@@ -98,7 +98,7 @@ class Scope { ...@@ -98,7 +98,7 @@ class Scope {
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
// return if already contains same name variable. // return if already contains same name variable.
Variable* NewVar(const std::string& name); Variable* GetOrCreateVar(const std::string& name);
private: private:
std::shared_ptr<Scope> parent_; std::shared_ptr<Scope> parent_;
...@@ -107,7 +107,7 @@ class Scope { ...@@ -107,7 +107,7 @@ class Scope {
``` ```
## Only scope can create a variable ## Only scope can create a variable
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `NewVar` can construct `Variable`. To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `GetOrCreateVar` can construct `Variable`.
## When scope destroyed, all variables inside this scope should be destroyed together ## When scope destroyed, all variables inside this scope should be destroyed together
...@@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar ...@@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar
## Orthogonal interface ## Orthogonal interface
`FindVar` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `NewVar` will return an `Error` when there is a name conflict locally. Combine `FindVar` and `NewVar`, we can implement `NewVar` easily. `FindVar` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `GetOrCreateVar` will return an `Error` when there is a name conflict locally. Combine `FindVar` and `GetOrCreateVar`, we can implement `GetOrCreateVar` easily.
...@@ -161,7 +161,7 @@ class TensorArray: ...@@ -161,7 +161,7 @@ class TensorArray:
@name: str @name: str
the name of the variable to output. the name of the variable to output.
''' '''
tensor = NewVar(name) tensor = GetOrCreateVar(name)
tensor_array_stack(self.name, tensor) tensor_array_stack(self.name, tensor)
return tensor return tensor
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
VarDescBind *BlockDescBind::NewVar(const std::string &name) { VarDescBind *BlockDescBind::GetOrCreateVar(const std::string &name) {
need_update_ = true; need_update_ = true;
auto it = vars_.find(name); auto it = vars_.find(name);
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name); PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
......
...@@ -66,7 +66,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { ...@@ -66,7 +66,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
// Instantiate all the vars in the global scope // Instantiate all the vars in the global scope
for (auto& var : block.vars()) { for (auto& var : block.vars()) {
scope->NewVar(var.name()); scope->GetOrCreateVar(var.name());
} }
Scope& local_scope = scope->NewScope(); Scope& local_scope = scope->NewScope();
...@@ -78,7 +78,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { ...@@ -78,7 +78,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
for (auto& var : block.ops(i).outputs()) { for (auto& var : block.ops(i).outputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
if (local_scope.FindVar(argu) == nullptr) { if (local_scope.FindVar(argu) == nullptr) {
local_scope.NewVar(argu); local_scope.GetOrCreateVar(argu);
} }
} }
} }
......
...@@ -34,7 +34,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, ...@@ -34,7 +34,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
// insert output // insert output
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
auto var = block->NewVar(v); auto var = block->GetOrCreateVar(v);
var->SetDataType(paddle::framework::DataType::FP32); var->SetDataType(paddle::framework::DataType::FP32);
} }
} }
......
...@@ -31,7 +31,7 @@ Scope& Scope::NewScope() const { ...@@ -31,7 +31,7 @@ Scope& Scope::NewScope() const {
return *kids_.back(); return *kids_.back();
} }
Variable* Scope::NewVar(const std::string& name) { Variable* Scope::GetOrCreateVar(const std::string& name) {
auto iter = vars_.find(name); auto iter = vars_.find(name);
if (iter != vars_.end()) { if (iter != vars_.end()) {
return iter->second; return iter->second;
...@@ -42,8 +42,8 @@ Variable* Scope::NewVar(const std::string& name) { ...@@ -42,8 +42,8 @@ Variable* Scope::NewVar(const std::string& name) {
return v; return v;
} }
Variable* Scope::NewVar() { Variable* Scope::GetOrCreateVar() {
return NewVar(string::Sprintf("%p.%d", this, vars_.size())); return GetOrCreateVar(string::Sprintf("%p.%d", this, vars_.size()));
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
...@@ -71,8 +71,8 @@ framework::Scope& GetGlobalScope() { ...@@ -71,8 +71,8 @@ framework::Scope& GetGlobalScope() {
static std::unique_ptr<framework::Scope> g_scope{nullptr}; static std::unique_ptr<framework::Scope> g_scope{nullptr};
std::call_once(feed_variable_flag, [&]() { std::call_once(feed_variable_flag, [&]() {
g_scope.reset(new framework::Scope()); g_scope.reset(new framework::Scope());
g_scope->NewVar("feed_value"); g_scope->GetOrCreateVar("feed_value");
g_scope->NewVar("fetch_value"); g_scope->GetOrCreateVar("fetch_value");
}); });
return *(g_scope.get()); return *(g_scope.get());
} }
......
...@@ -45,10 +45,10 @@ class Scope { ...@@ -45,10 +45,10 @@ class Scope {
Scope& NewScope() const; Scope& NewScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
Variable* NewVar(const std::string& name); Variable* GetOrCreateVar(const std::string& name);
/// Create a variable with a scope-unique name. /// Create a variable with a scope-unique name.
Variable* NewVar(); Variable* GetOrCreateVar();
/// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find. /// nullptr if cannot find.
......
...@@ -134,7 +134,7 @@ void CondOp::PrepareDataForSubnet( ...@@ -134,7 +134,7 @@ void CondOp::PrepareDataForSubnet(
for (int i = 0; i < BRANCH_NUM; ++i) { for (int i = 0; i < BRANCH_NUM; ++i) {
for (auto& output : (*sub_net_op_[i]).Outputs()) { for (auto& output : (*sub_net_op_[i]).Outputs()) {
for (auto& var_name : output.second) { for (auto& var_name : output.second) {
sub_scopes[i]->NewVar(var_name); sub_scopes[i]->GetOrCreateVar(var_name);
} }
} }
} }
......
...@@ -29,7 +29,7 @@ namespace detail { ...@@ -29,7 +29,7 @@ namespace detail {
inline void CreateVariables(Scope& scope, inline void CreateVariables(Scope& scope,
const std::vector<std::string>& var_names) { const std::vector<std::string>& var_names) {
for (const auto& name : var_names) { for (const auto& name : var_names) {
scope.NewVar(name); scope.GetOrCreateVar(name);
} }
} }
...@@ -112,7 +112,7 @@ void DynamicRecurrentOp::WriteStepInputs() const { ...@@ -112,7 +112,7 @@ void DynamicRecurrentOp::WriteStepInputs() const {
auto& step_scope = cache_.GetScope(step); auto& step_scope = cache_.GetScope(step);
Variable* var = step_scope.FindVar(item.first); Variable* var = step_scope.FindVar(item.first);
if (var == nullptr) { if (var == nullptr) {
var = step_scope.NewVar(item.first); var = step_scope.GetOrCreateVar(item.first);
} }
var->GetMutable<LoDTensor>()->ShareDataWith<value_type>(tensor); var->GetMutable<LoDTensor>()->ShareDataWith<value_type>(tensor);
} }
...@@ -125,7 +125,7 @@ void DynamicRecurrentOp::WriteStepOutputs() const { ...@@ -125,7 +125,7 @@ void DynamicRecurrentOp::WriteStepOutputs() const {
for (auto& item : step_outputs_) { for (auto& item : step_outputs_) {
auto* var = scope.FindVar(item.first); auto* var = scope.FindVar(item.first);
if (var == nullptr) { if (var == nullptr) {
var = scope.NewVar(item.first); var = scope.GetOrCreateVar(item.first);
} }
auto* tensor = var->GetMutable<LoDTensor>(); auto* tensor = var->GetMutable<LoDTensor>();
item.second.WriteShared(step, *tensor); item.second.WriteShared(step, *tensor);
......
...@@ -36,7 +36,7 @@ void OpDescNewVar(const std::string& param_name, ...@@ -36,7 +36,7 @@ void OpDescNewVar(const std::string& param_name,
// create a LoD tensor in scope with specific dims // create a LoD tensor in scope with specific dims
LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims, LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims,
const platform::Place& place) { const platform::Place& place) {
auto* var = scope.NewVar(name); auto* var = scope.GetOrCreateVar(name);
auto* tensor = var->GetMutable<LoDTensor>(); auto* tensor = var->GetMutable<LoDTensor>();
tensor->Resize(dims); tensor->Resize(dims);
tensor->mutable_data<float>(place); tensor->mutable_data<float>(place);
...@@ -85,7 +85,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test { ...@@ -85,7 +85,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
void CreateGlobalVariables() { void CreateGlobalVariables() {
platform::CPUPlace place; platform::CPUPlace place;
scope.NewVar("step_scopes"); scope.GetOrCreateVar("step_scopes");
CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place); CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place);
// auto* out0 = // auto* out0 =
CreateVar(scope, "out0", framework::make_ddim({10, 20}), place); CreateVar(scope, "out0", framework::make_ddim({10, 20}), place);
......
...@@ -70,14 +70,14 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope, ...@@ -70,14 +70,14 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,
// the weight are located in parent scope // the weight are located in parent scope
for (auto& var_name : input.second) { for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) { if (!step_scope.FindVar(var_name)) {
step_scope.NewVar(var_name)->GetMutable<LoDTensor>(); step_scope.GetOrCreateVar(var_name)->GetMutable<LoDTensor>();
} }
} }
} }
// create stepnet's outputs // create stepnet's outputs
for (const auto& output : (*stepnet_)->Outputs()) { for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) { for (auto& var_name : output.second) {
step_scope.NewVar(var_name); step_scope.GetOrCreateVar(var_name);
} }
} }
step_scopes->emplace_back(&step_scope); step_scopes->emplace_back(&step_scope);
...@@ -87,7 +87,8 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope, ...@@ -87,7 +87,8 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<LoDTensor>(); auto* pre_mem =
step_scope->GetOrCreateVar(attr.pre_var)->GetMutable<LoDTensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", attr.var, "memory [%s]'s boot variable [%s] not exists", attr.var,
attr.boot_var); attr.boot_var);
...@@ -167,9 +168,10 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( ...@@ -167,9 +168,10 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
"memory variable [%s] does not exists", attr.var); "memory variable [%s] does not exists", attr.var);
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"boot variable [%s] does not exists", attr.boot_var); "boot variable [%s] does not exists", attr.boot_var);
auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable<LoDTensor>(); auto* mem_grad =
step_scope->GetOrCreateVar(attr.var)->GetMutable<LoDTensor>();
auto* boot_mem_grad = auto* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<LoDTensor>(); step_scope->GetOrCreateVar(attr.boot_var)->GetMutable<LoDTensor>();
boot_mem_grad->Resize(mem_grad->dims()); boot_mem_grad->Resize(mem_grad->dims());
boot_mem_grad->ShareDataWith<float>(*mem_grad); boot_mem_grad->ShareDataWith<float>(*mem_grad);
} }
......
...@@ -40,7 +40,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -40,7 +40,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
f::DDim step_dims = slice_ddim(dims, 1, dims.size()); f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input = Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>(); step_scopes[j]->GetOrCreateVar(inlinks[i])->GetMutable<Tensor>();
// The input of operators of each step is Tensor here. // The input of operators of each step is Tensor here.
// Maybe need to modify Slice function. // Maybe need to modify Slice function.
*step_input = input->Slice<float>(j, j + 1); *step_input = input->Slice<float>(j, j + 1);
......
...@@ -137,7 +137,7 @@ void BindBlockDesc(py::module &m) { ...@@ -137,7 +137,7 @@ void BindBlockDesc(py::module &m) {
.def("new_var", .def("new_var",
[](BlockDescBind &self, py::bytes byte_name) { [](BlockDescBind &self, py::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
return self.NewVar(name); return self.GetOrCreateVar(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("var", .def("var",
......
...@@ -165,7 +165,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -165,7 +165,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<Scope>(m, "Scope", "") py::class_<Scope>(m, "Scope", "")
.def("new_var", .def("new_var",
[](Scope &self, const std::string &name) -> Variable * { [](Scope &self, const std::string &name) -> Variable * {
return self.NewVar(name); return self.GetOrCreateVar(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("find_var", &Scope::FindVar, py::return_value_policy::reference) .def("find_var", &Scope::FindVar, py::return_value_policy::reference)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册