diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 5f51a273dd4d2e45b20e3714805297dad85196aa..8b62b242cf8745378eb216db10605388b294ca75 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { for (auto &o : Output(kOutputs)) { block_ins.insert(o); } - std::unordered_set extra_inputs; + std::unordered_set output_grads; for (const auto *op : grad_block->AllOps()) { for (auto &input_name : op->InputArgumentNames()) { // If the input of Op has been recorded or is generated by the forward // block, do not make it as input again. + // The input is located in I/O or other op's outputs or the variable is + // located in grad_block's parents if (block_ins.find(input_name) != block_ins.end() || - fwd_block->FindVar(input_name) != nullptr || - parent_block->FindVar(input_name) != nullptr) { + (fwd_block->FindVarRecursive(input_name) != nullptr || + parent_block->FindVarRecursive(input_name) != nullptr)) { continue; } - extra_inputs.insert(input_name); + output_grads.insert(input_name); } for (auto &output_name : op->OutputArgumentNames()) { block_ins.insert(output_name); } } - std::vector extra_inputs_list; - extra_inputs_list.resize(extra_inputs.size()); - std::copy(extra_inputs.begin(), extra_inputs.end(), - extra_inputs_list.begin()); - while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); + std::vector output_grads_list; + output_grads_list.resize(output_grads.size()); + std::copy(output_grads.begin(), output_grads.end(), + output_grads_list.begin()); + while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list); while_grad->SetAttrMap(this->Attrs()); while_grad->SetBlockAttr(kStepBlock, *grad_block); // record the original output gradient names, since the gradient name of // while operator could be renamed. - while_grad->SetAttr("original_output_grad", extra_inputs_list); + while_grad->SetAttr("original_output_grad", output_grads_list); return std::unique_ptr(while_grad); }