提交 574bcdab 编写于 作者: Y Yu Yang

Add comments

上级 7ffd50b9
......@@ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
for (auto &o : Output(kOutputs)) {
block_ins.insert(o);
}
std::unordered_set<std::string> extra_inputs;
std::unordered_set<std::string> output_grads;
for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : op->InputArgumentNames()) {
// If the input of Op has been recorded or is generated by the forward
// block, do not make it as input again.
// The input is located in I/O or other op's outputs or the variable is
// located in grad_block's parents
if (block_ins.find(input_name) != block_ins.end() ||
fwd_block->FindVar(input_name) != nullptr ||
parent_block->FindVar(input_name) != nullptr) {
(fwd_block->FindVarRecursive(input_name) != nullptr ||
parent_block->FindVarRecursive(input_name) != nullptr)) {
continue;
}
extra_inputs.insert(input_name);
output_grads.insert(input_name);
}
for (auto &output_name : op->OutputArgumentNames()) {
block_ins.insert(output_name);
}
}
std::vector<std::string> extra_inputs_list;
extra_inputs_list.resize(extra_inputs.size());
std::copy(extra_inputs.begin(), extra_inputs.end(),
extra_inputs_list.begin());
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
std::vector<std::string> output_grads_list;
output_grads_list.resize(output_grads.size());
std::copy(output_grads.begin(), output_grads.end(),
output_grads_list.begin());
while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
while_grad->SetAttrMap(this->Attrs());
while_grad->SetBlockAttr(kStepBlock, *grad_block);
// record the original output gradient names, since the gradient name of
// while operator could be renamed.
while_grad->SetAttr("original_output_grad", extra_inputs_list);
while_grad->SetAttr("original_output_grad", output_grads_list);
return std::unique_ptr<framework::OpDesc>(while_grad);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册