From d2c815294d8fb1eb385296cea118ae0cecd1889f Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Tue, 2 Mar 2021 17:45:23 +0800 Subject: [PATCH] update fb_scheduler --- paddle/fluid/framework/section_worker.cc | 135 ++++++++++++++++++++--- 1 file changed, 121 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 13736c49e1e..09d882ff39b 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -48,7 +48,18 @@ void SectionWorker::TrainFiles() { #endif } - for (int i = 0; i < num_microbatches_; ++i) { + auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1; + VLOG(3) << "startup_steps:" << startup_steps + << ", num_stages: " << num_pipeline_stages_ + << ", stage:" << pipeline_stage_; + if (startup_steps > num_microbatches_) { + startup_steps = num_microbatches_; + } + int fw_step = 0; + int bw_step = 0; + // startup phase + while (fw_step < startup_steps) { + VLOG(3) << "to run forward batch:" << fw_step; for (auto& op : ops_) { int op_role = op->Attr(std::string("op_role")); // We run op with op_role = kLRSched only for the first microbatch @@ -60,37 +71,129 @@ void SectionWorker::TrainFiles() { bool run_others = op_role == static_cast(OpRole::kForward) || op_role == (static_cast(OpRole::kForward) | static_cast(OpRole::kLoss)); - if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { + if ((fw_step == 0 && run_first_mbatch) || (fw_step != 0 && run_others)) { VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " - << i; - op->Run(*microbatch_scopes_[i], place_); + << fw_step; + op->Run(*microbatch_scopes_[fw_step], place_); if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, - gc.get()); + DeleteUnusedTensors(*microbatch_scopes_[fw_step], op.get(), + unused_vars_, gc.get()); } } } - cudaDeviceSynchronize(); + fw_step += 1; } - // backward pass - for (int i = 0; i < num_microbatches_; ++i) { + // 1f1b phase + while (fw_step < num_microbatches_) { + VLOG(3) << "to run forward batch:" << fw_step; + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + // We run op with op_role = kLRSched only for the first microbatch + // to avoid increasing the @LR_DECAY_STEP@ multiple times. + bool run_first_mbatch = op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)) || + op_role == static_cast(OpRole::kLRSched); + bool run_others = op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)); + if ((fw_step == 0 && run_first_mbatch) || (fw_step != 0 && run_others)) { + VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " + << fw_step; + op->Run(*microbatch_scopes_[fw_step], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[fw_step], op.get(), + unused_vars_, gc.get()); + } + } + } + fw_step += 1; + VLOG(3) << "to run backward batch:" << bw_step; + for (auto& op : ops_) { int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kBackward) || op_role == (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss))) { VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " - << i; - op->Run(*microbatch_scopes_[i], place_); + << bw_step; + op->Run(*microbatch_scopes_[bw_step], place_); if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, - gc.get()); + DeleteUnusedTensors(*microbatch_scopes_[bw_step], op.get(), + unused_vars_, gc.get()); } } } - cudaDeviceSynchronize(); + bw_step += 1; } + // backward phase + while (bw_step < num_microbatches_) { + VLOG(3) << "to run backward batch:" << bw_step; + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kBackward) || + op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))) { + VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " + << bw_step; + op->Run(*microbatch_scopes_[bw_step], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[bw_step], op.get(), + unused_vars_, gc.get()); + } + } + } + bw_step += 1; + } + + // for (int i = 0; i < num_microbatches_; ++i) { + // for (auto& op : ops_) { + // int op_role = op->Attr(std::string("op_role")); + // // We run op with op_role = kLRSched only for the first microbatch + // // to avoid increasing the @LR_DECAY_STEP@ multiple times. + // bool run_first_mbatch = op_role == static_cast(OpRole::kForward) + // || + // op_role == (static_cast(OpRole::kForward) + // | + // static_cast(OpRole::kLoss)) || + // op_role == static_cast(OpRole::kLRSched); + // bool run_others = op_role == static_cast(OpRole::kForward) || + // op_role == (static_cast(OpRole::kForward) | + // static_cast(OpRole::kLoss)); + // if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { + // VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch + // " + // << i; + // op->Run(*microbatch_scopes_[i], place_); + // if (gc) { + // DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, + // gc.get()); + // } + // } + // } + // cudaDeviceSynchronize(); + // } + + // // backward pass + // for (int i = 0; i < num_microbatches_; ++i) { + // for (auto& op : ops_) { + // int op_role = op->Attr(std::string("op_role")); + // if (op_role == static_cast(OpRole::kBackward) || + // op_role == (static_cast(OpRole::kBackward) | + // static_cast(OpRole::kLoss))) { + // VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch + // " + // << i; + // op->Run(*microbatch_scopes_[i], place_); + // if (gc) { + // DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, + // gc.get()); + // } + // } + // } + // cudaDeviceSynchronize(); + // } // update pass for (auto& op : ops_) { @@ -99,6 +202,10 @@ void SectionWorker::TrainFiles() { VLOG(3) << "Update: running op " << op->Type(); op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_); if (gc) { + // for (int i = 0; i < num_microbatches_; ++i) { + // DeleteUnusedTensors(*microbatch_scopes_[i], + // op.get(), unused_vars_, gc.get()); + //} DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], op.get(), unused_vars_, gc.get()); } -- GitLab