提交 fbc30215 编写于 作者: F fengjiayi

refine WhileGradOp code

上级 8f962f74
...@@ -219,18 +219,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -219,18 +219,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
auto *grad_block = this->grad_block_[0]; auto *grad_block = this->grad_block_[0];
auto *fwd_block = grad_block->ParentBlock(); auto *fwd_block = grad_block->ParentBlock();
// auto *parent_block = fwd_block->ParentBlock();
// Not all of IGs will be generated by inner gradient operators of while op. // Not all of IGs will be generated by inner gradient operators of while op.
// Ignore IGs that is not generated by the inside block. // Ignore IGs that is not generated by the inside block.
std::unordered_set<std::string> inner_op_outputs; std::unordered_set<std::string> inner_op_outputs;
LOG(INFO) << "FUCK1";
for (const auto *op : grad_block->AllOps()) { for (const auto *op : grad_block->AllOps()) {
for (auto &oname : op->OutputArgumentNames()) { for (auto &oname : op->OutputArgumentNames()) {
inner_op_outputs.insert(oname); inner_op_outputs.insert(oname);
} }
} }
LOG(INFO) << "FUCK2";
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
for (auto &each_ig : igs) { for (auto &each_ig : igs) {
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) { if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
...@@ -243,11 +240,13 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -243,11 +240,13 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// OG should be re-calculated by step blocks, since many outputs of while op // OG should be re-calculated by step blocks, since many outputs of while op
// do not need to calculate gradients. // do not need to calculate gradients.
std::unordered_set<std::string> block_ins; std::unordered_set<std::string> block_ins;
std::copy(Input(kX).begin(), Input(kX).end(), block_ins.reserve(Input(kX).size() + Output(kOutputs).size());
std::inserter(block_ins, block_ins.end())); for (auto &p : Input(kX)) {
std::copy(Output(kOutputs).begin(), Output(kOutputs).end(), block_ins.insert(p);
std::inserter(block_ins, block_ins.end())); }
for (auto &o : Output(kOutputs)) {
block_ins.insert(o);
}
std::unordered_set<std::string> extra_inputs; std::unordered_set<std::string> extra_inputs;
for (const auto *op : grad_block->AllOps()) { for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : op->InputArgumentNames()) { for (auto &input_name : op->InputArgumentNames()) {
...@@ -257,15 +256,6 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -257,15 +256,6 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
fwd_block->FindVar(input_name) != nullptr) { fwd_block->FindVar(input_name) != nullptr) {
continue; continue;
} }
/*
if (parent_block->FindVarRecursive(input_name) == nullptr) {
VLOG(5) << "WARNING! Variable '" << input_name
<< "' is the input of '" << op->Type()
<< "'. But can not be found in any block.";
continue;
}
*/
extra_inputs.insert(input_name); extra_inputs.insert(input_name);
} }
for (auto &output_name : op->OutputArgumentNames()) { for (auto &output_name : op->OutputArgumentNames()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册