From 04fdb10a95e8808c04a66b4b6cce093a4b432ab5 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 14 Sep 2021 13:59:35 +0800 Subject: [PATCH] [hybrid performance] Optimize Pipeline Scheduler (#35680) --- paddle/fluid/framework/device_worker.h | 4 ++ paddle/fluid/framework/section_worker.cc | 89 +++++++++++++----------- 2 files changed, 53 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 6dd6fed0151..810e9a087d1 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -601,6 +601,10 @@ class SectionWorker : public DeviceWorker { std::vector backward_send_vars_; std::vector> ops_; + std::vector forward_and_lr_ops_; + std::vector forward_ops_; + std::vector backward_ops_; + std::vector optimizer_ops_; std::shared_ptr program_; std::unordered_map> unused_vars_; diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 5df01e151f8..64d8332e223 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -31,6 +31,33 @@ void SectionWorker::Initialize(const TrainerDesc &desc) { ops_.push_back(OpRegistry::CreateOp(*op_desc)); } + for (auto &op : ops_) { + // cache the op type during the init part + // reduce unnecessary op visit during running + int op_role = op->Attr("op_role"); + if ((op_role == static_cast(OpRole::kForward)) || + (op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss))) || + (op_role == static_cast(OpRole::kLRSched))) { + // forward ops and lr schedule ops, used for first micro step + forward_and_lr_ops_.push_back(op.get()); + if ((op_role != static_cast(OpRole::kLRSched))) { + // only forward ops, used for second and later micro steps + forward_ops_.push_back(op.get()); + } + } else if ((op_role == static_cast(OpRole::kBackward)) || + (op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss)))) { + backward_ops_.push_back(op.get()); + } else if (op_role == static_cast(OpRole::kOptimize)) { + optimizer_ops_.push_back(op.get()); + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "The op %s is None of LRSched, Forward, Backward or Optimize.", + op->Type())); + } + } + // if not 1F1B scheduler if (schedule_mode_ != 1) return; @@ -66,25 +93,15 @@ void SectionWorker::RunForward( int micro_id, std::unique_ptr &gc, std::unordered_map> &unused_vars_) { - for (auto &op : ops_) { - int op_role = op->Attr(std::string("op_role")); - // We run op with op_role = kLRSched only for the first microbatch - // to avoid increasing the @LR_DECAY_STEP@ multiple times. - bool run_first_mbatch = (op_role == static_cast(OpRole::kForward)) || - (op_role == (static_cast(OpRole::kForward) | - static_cast(OpRole::kLoss))) || - (op_role == static_cast(OpRole::kLRSched)); - bool run_others = (op_role == static_cast(OpRole::kForward)) || - (op_role == (static_cast(OpRole::kForward) | - static_cast(OpRole::kLoss))); - if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) { - VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " - << micro_id; - op->Run(*microbatch_scopes_[micro_id], place_); - if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(), - unused_vars_, gc.get()); - } + std::vector &forward_tmp = + micro_id == 0 ? forward_and_lr_ops_ : forward_ops_; + for (auto &op : forward_tmp) { + VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " + << micro_id; + op->Run(*microbatch_scopes_[micro_id], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[micro_id], op, unused_vars_, + gc.get()); } } } @@ -93,18 +110,13 @@ void SectionWorker::RunBackward( int micro_id, std::unique_ptr &gc, std::unordered_map> &unused_vars_) { - for (auto &op : ops_) { - int op_role = op->Attr(std::string("op_role")); - if ((op_role == static_cast(OpRole::kBackward)) || - (op_role == (static_cast(OpRole::kBackward) | - static_cast(OpRole::kLoss)))) { - VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " - << micro_id; - op->Run(*microbatch_scopes_[micro_id], place_); - if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(), - unused_vars_, gc.get()); - } + for (auto &op : backward_ops_) { + VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " + << micro_id; + op->Run(*microbatch_scopes_[micro_id], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[micro_id], op, unused_vars_, + gc.get()); } } } @@ -113,15 +125,12 @@ void SectionWorker::RunUpdate( std::unique_ptr &gc, std::unordered_map> &unused_vars_) { - for (auto &op : ops_) { - int op_role = op->Attr(std::string("op_role")); - if (op_role == static_cast(OpRole::kOptimize)) { - VLOG(3) << "Update: running op " << op->Type(); - op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_); - if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], - op.get(), unused_vars_, gc.get()); - } + for (auto &op : optimizer_ops_) { + VLOG(3) << "Update: running op " << op->Type(); + op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], op, + unused_vars_, gc.get()); } } } -- GitLab