From fbc30215d4f92a593288aea7f0deb2f54dcff786 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 9 Jan 2018 19:54:34 +0800 Subject: [PATCH] refine WhileGradOp code --- paddle/operators/while_op.cc | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index 3b78dd128f..7a3400919e 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -219,18 +219,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { auto *grad_block = this->grad_block_[0]; auto *fwd_block = grad_block->ParentBlock(); - // auto *parent_block = fwd_block->ParentBlock(); // Not all of IGs will be generated by inner gradient operators of while op. // Ignore IGs that is not generated by the inside block. std::unordered_set inner_op_outputs; - LOG(INFO) << "FUCK1"; for (const auto *op : grad_block->AllOps()) { for (auto &oname : op->OutputArgumentNames()) { inner_op_outputs.insert(oname); } } - LOG(INFO) << "FUCK2"; auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); for (auto &each_ig : igs) { if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) { @@ -243,11 +240,13 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { // OG should be re-calculated by step blocks, since many outputs of while op // do not need to calculate gradients. std::unordered_set block_ins; - std::copy(Input(kX).begin(), Input(kX).end(), - std::inserter(block_ins, block_ins.end())); - std::copy(Output(kOutputs).begin(), Output(kOutputs).end(), - std::inserter(block_ins, block_ins.end())); - + block_ins.reserve(Input(kX).size() + Output(kOutputs).size()); + for (auto &p : Input(kX)) { + block_ins.insert(p); + } + for (auto &o : Output(kOutputs)) { + block_ins.insert(o); + } std::unordered_set extra_inputs; for (const auto *op : grad_block->AllOps()) { for (auto &input_name : op->InputArgumentNames()) { @@ -257,15 +256,6 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { fwd_block->FindVar(input_name) != nullptr) { continue; } - - /* - if (parent_block->FindVarRecursive(input_name) == nullptr) { - VLOG(5) << "WARNING! Variable '" << input_name - << "' is the input of '" << op->Type() - << "'. But can not be found in any block."; - continue; - } - */ extra_inputs.insert(input_name); } for (auto &output_name : op->OutputArgumentNames()) { -- GitLab