未验证 提交 3fd3b663 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix gc bug in controlflow ops, test=develop (#19827)

上级 982e61f5
......@@ -78,10 +78,11 @@ void ExecutorPrepareContext::PrepareUnusedVars(
// If gc is enabled and block size > 1
if (prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
block_id_, ops_);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(block_id_, ops_);
prog_, block_id_, ops_);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_,
ops_);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
block_id_, ops_);
prog_, block_id_, ops_);
}
unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
}
......
......@@ -48,7 +48,7 @@ class ConditionalOpEagerDeletionPass : public Pass {
auto &ifelse_ops = ops_pair.second.first;
auto &ifelse_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
ifelse_ops, ifelse_grad_ops);
graph->OriginProgram(), ifelse_ops, ifelse_grad_ops);
}
}
};
......
......@@ -40,7 +40,8 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const {
// Prepare safe eager deletion on different devices because the garbage
// collection may be different across devices
OpAndGradOpPair &op_pair = entry.second;
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair);
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
graph->OriginProgram(), &op_pair);
}
}
......
......@@ -47,7 +47,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
auto &while_ops = ops_pair.second.first;
auto &while_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
graph->OriginProgram(), while_ops, while_grad_ops);
}
}
};
......
......@@ -29,16 +29,12 @@ static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp(
}
static void FindAllConditionalBlockAndConditionalBlockGradOp(
std::vector<OpVariant> *fwd_ops, std::vector<OpVariant> *bwd_ops) {
const framework::ProgramDesc &program, std::vector<OpVariant> *fwd_ops,
std::vector<OpVariant> *bwd_ops) {
PADDLE_ENFORCE_GE(fwd_ops->size(), bwd_ops->size());
if (fwd_ops->empty()) return;
const auto *program =
fwd_ops->front().Attr<framework::BlockDesc *>("sub_block")->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t i = 1; i < program.Size(); ++i) {
auto &block = program.Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == "conditional_block") {
......@@ -86,9 +82,10 @@ static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op,
}
static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
std::vector<OpVariant> *ifelse_ops,
const framework::ProgramDesc &program, std::vector<OpVariant> *ifelse_ops,
std::vector<OpVariant> *ifelse_grad_ops) {
FindAllConditionalBlockAndConditionalBlockGradOp(ifelse_ops, ifelse_grad_ops);
FindAllConditionalBlockAndConditionalBlockGradOp(program, ifelse_ops,
ifelse_grad_ops);
VLOG(2) << "Found conditional_block op num: " << ifelse_ops->size()
<< ", conditional_block_grad op num: " << ifelse_grad_ops->size();
......@@ -121,7 +118,7 @@ static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
}
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
int block_id,
const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns
// This is because all conditional_block_ops and conditional_block_grad_ops
......@@ -143,11 +140,12 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
}
}
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops,
&bwd_ops);
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
program, &fwd_ops, &bwd_ops);
}
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &ifelse_ops,
const std::vector<framework::OperatorBase *> &ifelse_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops;
......@@ -161,8 +159,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops,
&bwd_ops);
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
program, &fwd_ops, &bwd_ops);
}
} // namespace operators
......
......@@ -23,10 +23,11 @@ namespace paddle {
namespace operators {
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
int block_id,
const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &ifelse_ops,
const std::vector<framework::OperatorBase *> &ifelse_grad_ops);
......
......@@ -65,7 +65,8 @@ static void AddSkipVars(const OpVariant &op, const Container &skip_vars) {
// Find all ops and grad ops with given type name. The ops and grad ops
// may locate in different blocks so we should traverse all blocks in the
// program and find them out
static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op,
static void FindAllOpAndGradOp(const framework::ProgramDesc &program,
OpAndGradOpPair *op_and_grad_op,
const std::string &type_name,
const std::string &backward_type_name) {
OpVariantSet &ops = op_and_grad_op->first;
......@@ -74,14 +75,8 @@ static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op,
PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(),
"There are extra grad ops in the graph or program");
if (ops.empty()) return;
const auto *program =
ops.begin()
->Attr<framework::BlockDesc *>(RecurrentBase::kStepBlock)
->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t i = 1; i < program.Size(); ++i) {
auto &block = program.Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == type_name) {
......@@ -201,7 +196,7 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
}
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
int block_id,
const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<paddle::framework::OperatorBase>>
&all_ops) {
// If block_id is not 0, returns
......@@ -224,13 +219,13 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
op_pair.second.emplace(op.get());
}
}
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair);
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(program, &op_pair);
}
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
OpAndGradOpPair *op_pair) {
const framework::ProgramDesc &program, OpAndGradOpPair *op_pair) {
// Find all ops and grad ops at all blocks
FindAllOpAndGradOp(op_pair, "recurrent", "recurrent_grad");
FindAllOpAndGradOp(program, op_pair, "recurrent", "recurrent_grad");
OpVariantSet &recurrent_ops = op_pair->first;
OpVariantSet &recurrent_grad_ops = op_pair->second;
......
......@@ -37,14 +37,14 @@ using OpAndGradOpPair = std::pair<OpVariantSet, OpVariantSet>;
// recurrent_grad ops at block 0 and the function will find all recurrent and
// recurrent_grad ops across blocks.
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
OpAndGradOpPair *op_pair);
const framework::ProgramDesc &program, OpAndGradOpPair *op_pair);
// Set vars to skip eager deletion on input recurrent and recurrent_grad for
// preparing safe eager deletion. The input block_id must be 0 and caller can
// input all ops in the block. The function will find all recurrent and
// recurrent_grad ops across blocks.
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
int block_id,
const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<paddle::framework::OperatorBase>>
&all_ops);
......
......@@ -100,16 +100,12 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program,
std::vector<OpVariant> *while_ops,
std::vector<OpVariant> *while_grad_ops) {
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
if (while_ops->empty()) return;
const auto *program =
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t i = 1; i < program.Size(); ++i) {
auto &block = program.Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == "while") {
......@@ -125,8 +121,9 @@ static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
}
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
const framework::ProgramDesc &program, std::vector<OpVariant> *while_ops,
std::vector<OpVariant> *while_grad_ops) {
FindAllWhileAndWhileGradOp(program, while_ops, while_grad_ops);
VLOG(2) << "Found while op num: " << while_ops->size()
<< ", while grad op num: " << while_grad_ops->size();
......@@ -155,7 +152,7 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id,
const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
......@@ -176,10 +173,12 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
bwd_ops.emplace_back(op.get());
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops,
&bwd_ops);
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops;
......@@ -193,7 +192,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops,
&bwd_ops);
}
} // namespace operators
......
......@@ -32,10 +32,11 @@ static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id,
const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册