From 8f962f74338833abe8a5186270f57d1e1897bbdb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 9 Jan 2018 18:53:15 +0800 Subject: [PATCH] Update --- paddle/operators/while_op.cc | 79 +++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index 65d827e0e0..3b78dd128f 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -211,59 +211,64 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { protected: std::unique_ptr Apply() const override { - auto *grad = new framework::OpDesc(); - grad->SetType("while_grad"); - grad->SetInput(kX, Input(kX)); + auto *while_grad = new framework::OpDesc(); + while_grad->SetType("while_grad"); + while_grad->SetInput(kX, Input(kX)); + while_grad->SetInput(kOutputs, Output(kOutputs)); + while_grad->SetInput(kStepScopes, Output(kStepScopes)); + + auto *grad_block = this->grad_block_[0]; + auto *fwd_block = grad_block->ParentBlock(); + // auto *parent_block = fwd_block->ParentBlock(); // Not all of IGs will be generated by inner gradient operators of while op. // Ignore IGs that is not generated by the inside block. - auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); - std::unordered_set all_outs; - for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) { - for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) { - all_outs.insert(oname); + std::unordered_set inner_op_outputs; + LOG(INFO) << "FUCK1"; + for (const auto *op : grad_block->AllOps()) { + for (auto &oname : op->OutputArgumentNames()) { + inner_op_outputs.insert(oname); } } + LOG(INFO) << "FUCK2"; + auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); for (auto &each_ig : igs) { - if (all_outs.find(each_ig) == all_outs.end()) { + if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) { VLOG(10) << "Ignore " << each_ig; each_ig = framework::kEmptyVarName; } } - - grad->SetOutput(framework::GradVarName(kX), igs); - - grad->SetInput(kOutputs, Output(kOutputs)); + while_grad->SetOutput(framework::GradVarName(kX), igs); // OG should be re-calculated by step blocks, since many outputs of while op // do not need to calculate gradients. std::unordered_set block_ins; - auto *fwd_block = this->grad_block_[0]->ParentBlock(); - { - for (auto &p : Input(kX)) { - block_ins.insert(p); - } - for (auto &o : Output(kOutputs)) { - block_ins.insert(o); - } - } + std::copy(Input(kX).begin(), Input(kX).end(), + std::inserter(block_ins, block_ins.end())); + std::copy(Output(kOutputs).begin(), Output(kOutputs).end(), + std::inserter(block_ins, block_ins.end())); + std::unordered_set extra_inputs; - for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) { - for (auto &input_name : grad_block_[0]->Op(i)->InputArgumentNames()) { - if (block_ins.find(input_name) != block_ins.end()) { + for (const auto *op : grad_block->AllOps()) { + for (auto &input_name : op->InputArgumentNames()) { + // If the input of Op has been recorded or is generated by the forward + // block, do not make it as input again. + if (block_ins.find(input_name) != block_ins.end() || + fwd_block->FindVar(input_name) != nullptr) { continue; } - // If the input of Op is generated by the forward block, do not make it - // as input again. - if (fwd_block->FindVar(input_name) != nullptr) { + /* + if (parent_block->FindVarRecursive(input_name) == nullptr) { + VLOG(5) << "WARNING! Variable '" << input_name + << "' is the input of '" << op->Type() + << "'. But can not be found in any block."; continue; } - + */ extra_inputs.insert(input_name); } - - for (auto &output_name : grad_block_[0]->Op(i)->OutputArgumentNames()) { + for (auto &output_name : op->OutputArgumentNames()) { block_ins.insert(output_name); } } @@ -272,15 +277,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { extra_inputs_list.resize(extra_inputs.size()); std::copy(extra_inputs.begin(), extra_inputs.end(), extra_inputs_list.begin()); - grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); - grad->SetInput(kStepScopes, Output(kStepScopes)); - grad->SetAttrMap(this->Attrs()); - grad->SetBlockAttr(kStepBlock, *grad_block_[0]); + while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); + + while_grad->SetAttrMap(this->Attrs()); + while_grad->SetBlockAttr(kStepBlock, *grad_block); // record the original output gradient names, since the gradient name of // while operator could be renamed. - grad->SetAttr("original_output_grad", extra_inputs_list); + while_grad->SetAttr("original_output_grad", extra_inputs_list); - return std::unique_ptr(grad); + return std::unique_ptr(while_grad); } }; -- GitLab