提交 574bcdab 编写于 作者: Y Yu Yang

Add comments

上级 7ffd50b9
...@@ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
for (auto &o : Output(kOutputs)) { for (auto &o : Output(kOutputs)) {
block_ins.insert(o); block_ins.insert(o);
} }
std::unordered_set<std::string> extra_inputs; std::unordered_set<std::string> output_grads;
for (const auto *op : grad_block->AllOps()) { for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : op->InputArgumentNames()) { for (auto &input_name : op->InputArgumentNames()) {
// If the input of Op has been recorded or is generated by the forward // If the input of Op has been recorded or is generated by the forward
// block, do not make it as input again. // block, do not make it as input again.
// The input is located in I/O or other op's outputs or the variable is
// located in grad_block's parents
if (block_ins.find(input_name) != block_ins.end() || if (block_ins.find(input_name) != block_ins.end() ||
fwd_block->FindVar(input_name) != nullptr || (fwd_block->FindVarRecursive(input_name) != nullptr ||
parent_block->FindVar(input_name) != nullptr) { parent_block->FindVarRecursive(input_name) != nullptr)) {
continue; continue;
} }
extra_inputs.insert(input_name); output_grads.insert(input_name);
} }
for (auto &output_name : op->OutputArgumentNames()) { for (auto &output_name : op->OutputArgumentNames()) {
block_ins.insert(output_name); block_ins.insert(output_name);
} }
} }
std::vector<std::string> extra_inputs_list; std::vector<std::string> output_grads_list;
extra_inputs_list.resize(extra_inputs.size()); output_grads_list.resize(output_grads.size());
std::copy(extra_inputs.begin(), extra_inputs.end(), std::copy(output_grads.begin(), output_grads.end(),
extra_inputs_list.begin()); output_grads_list.begin());
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
while_grad->SetAttrMap(this->Attrs()); while_grad->SetAttrMap(this->Attrs());
while_grad->SetBlockAttr(kStepBlock, *grad_block); while_grad->SetBlockAttr(kStepBlock, *grad_block);
// record the original output gradient names, since the gradient name of // record the original output gradient names, since the gradient name of
// while operator could be renamed. // while operator could be renamed.
while_grad->SetAttr("original_output_grad", extra_inputs_list); while_grad->SetAttr("original_output_grad", output_grads_list);
return std::unique_ptr<framework::OpDesc>(while_grad); return std::unique_ptr<framework::OpDesc>(while_grad);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册