未验证 提交 cedd9805 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #7361 from JiayiFeng/refine_and_enhence_WhileGradOp

Refine while grad op
...@@ -211,59 +211,54 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -211,59 +211,54 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad = new framework::OpDesc(); auto *while_grad = new framework::OpDesc();
grad->SetType("while_grad"); while_grad->SetType("while_grad");
grad->SetInput(kX, Input(kX)); while_grad->SetInput(kX, Input(kX));
while_grad->SetInput(kOutputs, Output(kOutputs));
while_grad->SetInput(kStepScopes, Output(kStepScopes));
auto *grad_block = this->grad_block_[0];
auto *fwd_block = grad_block->ParentBlock();
// Not all of IGs will be generated by inner gradient operators of while op. // Not all of IGs will be generated by inner gradient operators of while op.
// Ignore IGs that is not generated by the inside block. // Ignore IGs that is not generated by the inside block.
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); std::unordered_set<std::string> inner_op_outputs;
std::unordered_set<std::string> all_outs; for (const auto *op : grad_block->AllOps()) {
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) { for (auto &oname : op->OutputArgumentNames()) {
for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) { inner_op_outputs.insert(oname);
all_outs.insert(oname);
} }
} }
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
for (auto &each_ig : igs) { for (auto &each_ig : igs) {
if (all_outs.find(each_ig) == all_outs.end()) { if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
VLOG(10) << "Ignore " << each_ig; VLOG(10) << "Ignore " << each_ig;
each_ig = framework::kEmptyVarName; each_ig = framework::kEmptyVarName;
} }
} }
while_grad->SetOutput(framework::GradVarName(kX), igs);
grad->SetOutput(framework::GradVarName(kX), igs);
grad->SetInput(kOutputs, Output(kOutputs));
// OG should be re-calculated by step blocks, since many outputs of while op // OG should be re-calculated by step blocks, since many outputs of while op
// do not need to calculate gradients. // do not need to calculate gradients.
std::unordered_set<std::string> block_ins; std::unordered_set<std::string> block_ins;
auto *fwd_block = this->grad_block_[0]->ParentBlock(); block_ins.reserve(Input(kX).size() + Output(kOutputs).size());
{ for (auto &p : Input(kX)) {
for (auto &p : Input(kX)) { block_ins.insert(p);
block_ins.insert(p); }
} for (auto &o : Output(kOutputs)) {
for (auto &o : Output(kOutputs)) { block_ins.insert(o);
block_ins.insert(o);
}
} }
std::unordered_set<std::string> extra_inputs; std::unordered_set<std::string> extra_inputs;
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) { for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : grad_block_[0]->Op(i)->InputArgumentNames()) { for (auto &input_name : op->InputArgumentNames()) {
if (block_ins.find(input_name) != block_ins.end()) { // If the input of Op has been recorded or is generated by the forward
continue; // block, do not make it as input again.
} if (block_ins.find(input_name) != block_ins.end() ||
fwd_block->FindVar(input_name) != nullptr) {
// If the input of Op is generated by the forward block, do not make it
// as input again.
if (fwd_block->FindVar(input_name) != nullptr) {
continue; continue;
} }
extra_inputs.insert(input_name); extra_inputs.insert(input_name);
} }
for (auto &output_name : op->OutputArgumentNames()) {
for (auto &output_name : grad_block_[0]->Op(i)->OutputArgumentNames()) {
block_ins.insert(output_name); block_ins.insert(output_name);
} }
} }
...@@ -272,15 +267,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -272,15 +267,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
extra_inputs_list.resize(extra_inputs.size()); extra_inputs_list.resize(extra_inputs.size());
std::copy(extra_inputs.begin(), extra_inputs.end(), std::copy(extra_inputs.begin(), extra_inputs.end(),
extra_inputs_list.begin()); extra_inputs_list.begin());
grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
grad->SetInput(kStepScopes, Output(kStepScopes));
grad->SetAttrMap(this->Attrs()); while_grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kStepBlock, *grad_block_[0]); while_grad->SetBlockAttr(kStepBlock, *grad_block);
// record the original output gradient names, since the gradient name of // record the original output gradient names, since the gradient name of
// while operator could be renamed. // while operator could be renamed.
grad->SetAttr("original_output_grad", extra_inputs_list); while_grad->SetAttr("original_output_grad", extra_inputs_list);
return std::unique_ptr<framework::OpDesc>(grad); return std::unique_ptr<framework::OpDesc>(while_grad);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册