An error in building the OG of `while_op`
Created by: JiayiFeng
In the current code, we build the OG of while_op
like this:
std::unordered_set<std::string> block_ins;
{
for (auto &p : Input(kParameters)) {
block_ins.insert(p);
}
for (auto &o : Output(kOutputs)) {
block_ins.insert(o);
}
}
std::unordered_set<std::string> extra_inputs;
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
for (auto &input_name : grad_block_[0]->Op(i)->InputArgumentNames()) {
if (block_ins.find(input_name) != block_ins.end()) {
continue;
}
extra_inputs.insert(input_name);
}
for (auto &output_name : grad_block_[0]->Op(i)->OutputArgumentNames()) {
block_ins.insert(output_name);
}
}
We traverse all the gradient op in the gradient block, if some input variable is not in I
nor G
, nor generated by some precious op in the gradient block, it will be marked as OG
.
However, this assumption is not correct. For while_op's forward op and backward op are in two different blocks, all variables generated in forward block can meet all the above criteria while they are obviously not OG.
This will lead to a runtime error. In
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/while_op.cc#L127
we try to find the variable of OG in the higher level scope. However, if the variable is generated in while_op
's forward block, it will be in the current scope, not higher level scope.