未验证 提交 3fd3b663 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix gc bug in controlflow ops, test=develop (#19827)

上级 982e61f5
...@@ -78,10 +78,11 @@ void ExecutorPrepareContext::PrepareUnusedVars( ...@@ -78,10 +78,11 @@ void ExecutorPrepareContext::PrepareUnusedVars(
// If gc is enabled and block size > 1 // If gc is enabled and block size > 1
if (prog_.Size() > 1) { if (prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
block_id_, ops_); prog_, block_id_, ops_);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(block_id_, ops_); operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_,
ops_);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
block_id_, ops_); prog_, block_id_, ops_);
} }
unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars); unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
} }
......
...@@ -48,7 +48,7 @@ class ConditionalOpEagerDeletionPass : public Pass { ...@@ -48,7 +48,7 @@ class ConditionalOpEagerDeletionPass : public Pass {
auto &ifelse_ops = ops_pair.second.first; auto &ifelse_ops = ops_pair.second.first;
auto &ifelse_grad_ops = ops_pair.second.second; auto &ifelse_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
ifelse_ops, ifelse_grad_ops); graph->OriginProgram(), ifelse_ops, ifelse_grad_ops);
} }
} }
}; };
......
...@@ -40,7 +40,8 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const { ...@@ -40,7 +40,8 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const {
// Prepare safe eager deletion on different devices because the garbage // Prepare safe eager deletion on different devices because the garbage
// collection may be different across devices // collection may be different across devices
OpAndGradOpPair &op_pair = entry.second; OpAndGradOpPair &op_pair = entry.second;
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
graph->OriginProgram(), &op_pair);
} }
} }
......
...@@ -47,7 +47,7 @@ class WhileOpEagerDeletionPass : public ir::Pass { ...@@ -47,7 +47,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
auto &while_ops = ops_pair.second.first; auto &while_ops = ops_pair.second.first;
auto &while_grad_ops = ops_pair.second.second; auto &while_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops); graph->OriginProgram(), while_ops, while_grad_ops);
} }
} }
}; };
......
...@@ -29,16 +29,12 @@ static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp( ...@@ -29,16 +29,12 @@ static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp(
} }
static void FindAllConditionalBlockAndConditionalBlockGradOp( static void FindAllConditionalBlockAndConditionalBlockGradOp(
std::vector<OpVariant> *fwd_ops, std::vector<OpVariant> *bwd_ops) { const framework::ProgramDesc &program, std::vector<OpVariant> *fwd_ops,
std::vector<OpVariant> *bwd_ops) {
PADDLE_ENFORCE_GE(fwd_ops->size(), bwd_ops->size()); PADDLE_ENFORCE_GE(fwd_ops->size(), bwd_ops->size());
if (fwd_ops->empty()) return; for (size_t i = 1; i < program.Size(); ++i) {
auto &block = program.Block(i);
const auto *program =
fwd_ops->front().Attr<framework::BlockDesc *>("sub_block")->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) { for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j); auto *op = block.Op(j);
if (op->Type() == "conditional_block") { if (op->Type() == "conditional_block") {
...@@ -86,9 +82,10 @@ static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op, ...@@ -86,9 +82,10 @@ static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op,
} }
static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
std::vector<OpVariant> *ifelse_ops, const framework::ProgramDesc &program, std::vector<OpVariant> *ifelse_ops,
std::vector<OpVariant> *ifelse_grad_ops) { std::vector<OpVariant> *ifelse_grad_ops) {
FindAllConditionalBlockAndConditionalBlockGradOp(ifelse_ops, ifelse_grad_ops); FindAllConditionalBlockAndConditionalBlockGradOp(program, ifelse_ops,
ifelse_grad_ops);
VLOG(2) << "Found conditional_block op num: " << ifelse_ops->size() VLOG(2) << "Found conditional_block op num: " << ifelse_ops->size()
<< ", conditional_block_grad op num: " << ifelse_grad_ops->size(); << ", conditional_block_grad op num: " << ifelse_grad_ops->size();
...@@ -121,7 +118,7 @@ static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( ...@@ -121,7 +118,7 @@ static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
} }
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
int block_id, const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) { const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns // If block_id is not 0, returns
// This is because all conditional_block_ops and conditional_block_grad_ops // This is because all conditional_block_ops and conditional_block_grad_ops
...@@ -143,11 +140,12 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( ...@@ -143,11 +140,12 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
} }
} }
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops, PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
&bwd_ops); program, &fwd_ops, &bwd_ops);
} }
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &ifelse_ops, const std::vector<framework::OperatorBase *> &ifelse_ops,
const std::vector<framework::OperatorBase *> &ifelse_grad_ops) { const std::vector<framework::OperatorBase *> &ifelse_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops; std::vector<OpVariant> fwd_ops, bwd_ops;
...@@ -161,8 +159,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( ...@@ -161,8 +159,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
bwd_ops.emplace_back(op); bwd_ops.emplace_back(op);
} }
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops, PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
&bwd_ops); program, &fwd_ops, &bwd_ops);
} }
} // namespace operators } // namespace operators
......
...@@ -23,10 +23,11 @@ namespace paddle { ...@@ -23,10 +23,11 @@ namespace paddle {
namespace operators { namespace operators {
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
int block_id, const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops); const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &ifelse_ops, const std::vector<framework::OperatorBase *> &ifelse_ops,
const std::vector<framework::OperatorBase *> &ifelse_grad_ops); const std::vector<framework::OperatorBase *> &ifelse_grad_ops);
......
...@@ -65,7 +65,8 @@ static void AddSkipVars(const OpVariant &op, const Container &skip_vars) { ...@@ -65,7 +65,8 @@ static void AddSkipVars(const OpVariant &op, const Container &skip_vars) {
// Find all ops and grad ops with given type name. The ops and grad ops // Find all ops and grad ops with given type name. The ops and grad ops
// may locate in different blocks so we should traverse all blocks in the // may locate in different blocks so we should traverse all blocks in the
// program and find them out // program and find them out
static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op, static void FindAllOpAndGradOp(const framework::ProgramDesc &program,
OpAndGradOpPair *op_and_grad_op,
const std::string &type_name, const std::string &type_name,
const std::string &backward_type_name) { const std::string &backward_type_name) {
OpVariantSet &ops = op_and_grad_op->first; OpVariantSet &ops = op_and_grad_op->first;
...@@ -74,14 +75,8 @@ static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op, ...@@ -74,14 +75,8 @@ static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op,
PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(),
"There are extra grad ops in the graph or program"); "There are extra grad ops in the graph or program");
if (ops.empty()) return; for (size_t i = 1; i < program.Size(); ++i) {
auto &block = program.Block(i);
const auto *program =
ops.begin()
->Attr<framework::BlockDesc *>(RecurrentBase::kStepBlock)
->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) { for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j); auto *op = block.Op(j);
if (op->Type() == type_name) { if (op->Type() == type_name) {
...@@ -201,7 +196,7 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr( ...@@ -201,7 +196,7 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
} }
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
int block_id, const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<paddle::framework::OperatorBase>> const std::vector<std::unique_ptr<paddle::framework::OperatorBase>>
&all_ops) { &all_ops) {
// If block_id is not 0, returns // If block_id is not 0, returns
...@@ -224,13 +219,13 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( ...@@ -224,13 +219,13 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
op_pair.second.emplace(op.get()); op_pair.second.emplace(op.get());
} }
} }
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(program, &op_pair);
} }
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
OpAndGradOpPair *op_pair) { const framework::ProgramDesc &program, OpAndGradOpPair *op_pair) {
// Find all ops and grad ops at all blocks // Find all ops and grad ops at all blocks
FindAllOpAndGradOp(op_pair, "recurrent", "recurrent_grad"); FindAllOpAndGradOp(program, op_pair, "recurrent", "recurrent_grad");
OpVariantSet &recurrent_ops = op_pair->first; OpVariantSet &recurrent_ops = op_pair->first;
OpVariantSet &recurrent_grad_ops = op_pair->second; OpVariantSet &recurrent_grad_ops = op_pair->second;
......
...@@ -37,14 +37,14 @@ using OpAndGradOpPair = std::pair<OpVariantSet, OpVariantSet>; ...@@ -37,14 +37,14 @@ using OpAndGradOpPair = std::pair<OpVariantSet, OpVariantSet>;
// recurrent_grad ops at block 0 and the function will find all recurrent and // recurrent_grad ops at block 0 and the function will find all recurrent and
// recurrent_grad ops across blocks. // recurrent_grad ops across blocks.
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
OpAndGradOpPair *op_pair); const framework::ProgramDesc &program, OpAndGradOpPair *op_pair);
// Set vars to skip eager deletion on input recurrent and recurrent_grad for // Set vars to skip eager deletion on input recurrent and recurrent_grad for
// preparing safe eager deletion. The input block_id must be 0 and caller can // preparing safe eager deletion. The input block_id must be 0 and caller can
// input all ops in the block. The function will find all recurrent and // input all ops in the block. The function will find all recurrent and
// recurrent_grad ops across blocks. // recurrent_grad ops across blocks.
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
int block_id, const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<paddle::framework::OperatorBase>> const std::vector<std::unique_ptr<paddle::framework::OperatorBase>>
&all_ops); &all_ops);
......
...@@ -100,16 +100,12 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op, ...@@ -100,16 +100,12 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
// Find all while_ops and while_grad_ops in the graph or program // Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks // The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out. // So we should traverse all blocks in the program and find them out.
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops, static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program,
std::vector<OpVariant> *while_ops,
std::vector<OpVariant> *while_grad_ops) { std::vector<OpVariant> *while_grad_ops) {
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size()); PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
for (size_t i = 1; i < program.Size(); ++i) {
if (while_ops->empty()) return; auto &block = program.Block(i);
const auto *program =
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) { for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j); auto *op = block.Op(j);
if (op->Type() == "while") { if (op->Type() == "while") {
...@@ -125,8 +121,9 @@ static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops, ...@@ -125,8 +121,9 @@ static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
} }
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) { const framework::ProgramDesc &program, std::vector<OpVariant> *while_ops,
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops); std::vector<OpVariant> *while_grad_ops) {
FindAllWhileAndWhileGradOp(program, while_ops, while_grad_ops);
VLOG(2) << "Found while op num: " << while_ops->size() VLOG(2) << "Found while op num: " << while_ops->size()
<< ", while grad op num: " << while_grad_ops->size(); << ", while grad op num: " << while_grad_ops->size();
...@@ -155,7 +152,7 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( ...@@ -155,7 +152,7 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
} }
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id, const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) { const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns // If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program // This is because all while_ops and while_grad_ops in the whole program
...@@ -176,10 +173,12 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( ...@@ -176,10 +173,12 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
bwd_ops.emplace_back(op.get()); bwd_ops.emplace_back(op.get());
} }
} }
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops); PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops,
&bwd_ops);
} }
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &while_ops, const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops) { const std::vector<framework::OperatorBase *> &while_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops; std::vector<OpVariant> fwd_ops, bwd_ops;
...@@ -193,7 +192,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( ...@@ -193,7 +192,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
bwd_ops.emplace_back(op); bwd_ops.emplace_back(op);
} }
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops); PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops,
&bwd_ops);
} }
} // namespace operators } // namespace operators
......
...@@ -32,10 +32,11 @@ static constexpr char kOutputs[] = "Out"; ...@@ -32,10 +32,11 @@ static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id, const framework::ProgramDesc &program, int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops); const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &while_ops, const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops); const std::vector<framework::OperatorBase *> &while_grad_ops);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册