未验证 提交 b6d1d890 编写于 作者: C chengduo 提交者: GitHub

Increase num_iteration_per_drop_scope (#19075)

* increase num_iteration_per_drop_scope
test=develop

* Fix bug of while_op
test=develop

* fix bug of whileOp
test=develop
上级 1d0f0431
...@@ -31,7 +31,7 @@ struct ExecutionStrategy { ...@@ -31,7 +31,7 @@ struct ExecutionStrategy {
// iterations the framework cleans up a local execution scope. // iterations the framework cleans up a local execution scope.
// In some models, the value of this parameter has a great // In some models, the value of this parameter has a great
// influence on the performance(about 15%) of the program. // influence on the performance(about 15%) of the program.
size_t num_iteration_per_drop_scope_{1}; size_t num_iteration_per_drop_scope_{100};
// At present, the kExperimental executor is the fastest in most models. // At present, the kExperimental executor is the fastest in most models.
ExecutorType type_{kExperimental}; ExecutorType type_{kExperimental};
// This debug option. // This debug option.
......
...@@ -62,7 +62,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -62,7 +62,7 @@ class WhileOp : public framework::OperatorBase {
auto step_scopes = auto step_scopes =
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
PADDLE_ENFORCE_EQ(step_scopes->size(), 0, "The StepScope should be empty.");
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory."); "Condition of while op must in CPU memory.");
...@@ -197,17 +197,22 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -197,17 +197,22 @@ class WhileGradOp : public framework::OperatorBase {
inside_tensor.set_lod(outside_tensor.lod()); inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor); inside_tensor.ShareDataWith(outside_tensor);
} else if (og_outside.IsType<framework::LoDTensorArray>()) { } else if (og_outside.IsType<framework::LoDTensorArray>()) {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>(); auto outside_array =
og_outside.GetMutable<framework::LoDTensorArray>();
auto &inside_array = auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>()); detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
VLOG(8) << outside_og_name << " size = " << outside_array.size(); inside_array.clear();
inside_array.resize(outside_array.size()); inside_array.resize(outside_array->size());
VLOG(8) << outside_og_name << " size = " << outside_array->size();
for (size_t j = 0; j < inside_array.size(); ++j) { for (size_t j = 0; j < inside_array.size(); ++j) {
VLOG(8) << j << " " << outside_array[j].numel(); if (!outside_array->at(j).IsInitialized()) {
if (outside_array[j].numel() != 0) { outside_array->at(j).Resize({0});
inside_array[j].set_lod(outside_array[j].lod()); }
inside_array[j].ShareDataWith(outside_array[j]); VLOG(8) << j << " " << outside_array->at(j).numel();
if (outside_array->at(j).numel() != 0) {
inside_array[j].set_lod(outside_array->at(j).lod());
inside_array[j].ShareDataWith(outside_array->at(j));
} else { } else {
PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0); PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0);
} }
...@@ -300,6 +305,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -300,6 +305,7 @@ class WhileGradOp : public framework::OperatorBase {
dev_ctx.Wait(); dev_ctx.Wait();
const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope); const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
} }
step_scopes->clear();
} }
}; };
......
...@@ -141,7 +141,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -141,7 +141,7 @@ class SumOp : public framework::OperatorWithKernel {
for (auto& x_var : x_vars) { for (auto& x_var : x_vars) {
auto& array = x_var->Get<framework::LoDTensorArray>(); auto& array = x_var->Get<framework::LoDTensorArray>();
for (auto& each : array) { for (auto& each : array) {
if (each.numel() != 0) { if (each.numel() != 0 && each.IsInitialized()) {
return framework::OpKernelType(each.type(), ctx.device_context(), return framework::OpKernelType(each.type(), ctx.device_context(),
layout, library); layout, library);
} }
......
...@@ -97,11 +97,11 @@ void LodTensorArrayCompute(const framework::ExecutionContext &context) { ...@@ -97,11 +97,11 @@ void LodTensorArrayCompute(const framework::ExecutionContext &context) {
auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>(); auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < in_array.size(); ++i) { for (size_t i = 0; i < in_array.size(); ++i) {
if (in_array[i].numel() != 0) { if (in_array[i].IsInitialized() && (in_array[i].numel() != 0)) {
if (i >= out_array.size()) { if (i >= out_array.size()) {
out_array.resize(i + 1); out_array.resize(i + 1);
} }
if (out_array[i].numel() == 0) { if (!out_array[i].IsInitialized() || (out_array[i].numel() == 0)) {
framework::TensorCopy(in_array[i], in_array[i].place(), framework::TensorCopy(in_array[i], in_array[i].place(),
context.device_context(), &out_array[i]); context.device_context(), &out_array[i]);
out_array[i].set_lod(in_array[i].lod()); out_array[i].set_lod(in_array[i].lod());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册