提交 06f8aa5b 编写于 作者: S sneaxiy

remove while_op support temporarily

test=develop
上级 79230423
...@@ -419,7 +419,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -419,7 +419,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc; std::unique_ptr<GarbageCollector> gc;
if (max_memory_size >= 0) { // skip while_op and while_grad_op temporarily
if (max_memory_size >= 0 && !keep_kids) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
......
...@@ -365,51 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -365,51 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed. // while operator could be renamed.
while_grad->SetAttr("original_output_grad", output_grads_list); while_grad->SetAttr("original_output_grad", output_grads_list);
/* The following codes are used in eager deletion mode */ while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>());
std::unordered_set<std::string> bwd_skip_vars;
if (framework::GetEagerDeletionThreshold() >= 0) {
std::unordered_set<std::string> fwd_skip_vars;
for (auto *op_desc : grad_block->AllOps()) {
auto skippable = [&](const std::string &name) {
return !grad_block->HasVar(name) &&
(fwd_block->HasVarRecursive(name) ||
parent_block->HasVarRecursive(name));
};
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
if (skippable(in_arg_name)) {
fwd_skip_vars.insert(in_arg_name);
}
}
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
if (skippable(out_arg_name)) {
fwd_skip_vars.insert(out_arg_name);
}
}
}
if (!fwd_skip_vars.empty()) {
// FIXME(zjl): ugly const_cast here, maybe we should find a better way
// to modify forward while_op
auto &fwd_while_op = const_cast<framework::OpDesc &>(ForwardOp());
fwd_while_op.SetAttr(kSkipEagerDeletionVars,
std::vector<std::string>(fwd_skip_vars.begin(),
fwd_skip_vars.end()));
}
// Find backward skip vars
auto fwd_input = Input(kX);
for (size_t i = 0; i < igs.size(); ++i) {
if (igs[i] == framework::kEmptyVarName) {
continue;
}
bwd_skip_vars.insert(igs[i]);
bwd_skip_vars.insert(framework::GradVarName(fwd_input[i]));
}
}
while_grad->SetAttr(
kSkipEagerDeletionVars,
std::vector<std::string>(bwd_skip_vars.begin(), bwd_skip_vars.end()));
return std::unique_ptr<framework::OpDesc>(while_grad); return std::unique_ptr<framework::OpDesc>(while_grad);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册