diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index c44bda490bb6f05ae77001de4748bb2b73a88df8..45efa43ccb74bc4dd48c0f78e43d479eb1c7d341 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -548,6 +548,7 @@ class SectionWorker : public DeviceWorker { ~SectionWorker() override {} void Initialize(const TrainerDesc& desc) override; + void PrepareUnusedVar(); void BindingDataFeedMemory() override {} void CreateDeviceResource(const ProgramDesc& main_prog) override{}; @@ -581,7 +582,8 @@ class SectionWorker : public DeviceWorker { void RunUpdate( std::unique_ptr&, std::unordered_map>&); - void PrepareUnusedVar(); + void RunFThenB(std::unique_ptr&); + void Run1F1B(std::unique_ptr&); protected: int section_id_; @@ -591,9 +593,12 @@ class SectionWorker : public DeviceWorker { int pipeline_stage_; int schedule_mode_; // 0 for F-then-B and 1 for 1F1B std::vector microbatch_scopes_; - std::vector skip_vars_; const Scope* minibatch_scope_; + // skip&backward vars are only used in 1F1B + std::vector skip_vars_; + std::vector backward_send_vars_; + std::vector> ops_; std::shared_ptr program_; std::unordered_map> diff --git a/paddle/fluid/framework/executor_gc_helper.cc b/paddle/fluid/framework/executor_gc_helper.cc index 4b7c8c6e3f49bca036a0bf1f367071b273381f01..1fe6b70d26ec5b5611b977e477ec53ecb7a563e5 100644 --- a/paddle/fluid/framework/executor_gc_helper.cc +++ b/paddle/fluid/framework/executor_gc_helper.cc @@ -146,18 +146,9 @@ GetUnusedVars(const BlockDesc &block, return result; } -void DeleteUnusedTensors( - const Scope &scope, const OperatorBase *op, - const std::unordered_map> - &delete_vars_map, - GarbageCollector *gc) { - auto iter = delete_vars_map.find(op); - if (iter == delete_vars_map.end()) { - return; - } - - auto &delete_vars = iter->second; - +void DeleteUnusedTensors(const Scope &scope, + const std::vector &delete_vars, + GarbageCollector *gc) { std::deque> garbages; for (auto &var_name : delete_vars) { @@ -189,6 +180,20 @@ void DeleteUnusedTensors( } } +void DeleteUnusedTensors( + const Scope &scope, const OperatorBase *op, + const std::unordered_map> + &delete_vars_map, + GarbageCollector *gc) { + auto iter = delete_vars_map.find(op); + if (iter == delete_vars_map.end()) { + return; + } + + auto &delete_vars = iter->second; + DeleteUnusedTensors(scope, delete_vars, gc); +} + static std::vector> CreateOpsFromBlock( const BlockDesc &block) { std::vector> ops; diff --git a/paddle/fluid/framework/executor_gc_helper.h b/paddle/fluid/framework/executor_gc_helper.h index 886341791bade8697773bac69722f6827d5e33d8..184516d4d6160b5ca74f2c9ff7ffa491d2603220 100644 --- a/paddle/fluid/framework/executor_gc_helper.h +++ b/paddle/fluid/framework/executor_gc_helper.h @@ -36,6 +36,11 @@ GetUnusedVars(const BlockDesc &block, const std::vector> &ops, const std::vector &skip_vars); +// Collect unused tensors +void DeleteUnusedTensors(const Scope &scope, + const std::vector &delete_vars, + GarbageCollector *gc); + // Collect unused tensors after op runs void DeleteUnusedTensors( const Scope &scope, const OperatorBase *op, diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 42577972e9b79d2dcfdf692afdec19b3ab576c90..695525c876a3dbe956bf2f67c9ec9deb08ae1383 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -45,11 +45,11 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, auto this_worker = std::dynamic_pointer_cast(worker_); this_worker->SetPlace(place_); - this_worker->Initialize(trainer_desc); this_worker->SetMicrobatchNum(num_microbatches_); this_worker->SetPipelineStageNum(num_pipeline_stages_); this_worker->SetPipelineStage(pipeline_stage_); this_worker->SetScheduleMode(schedule_mode_); + this_worker->Initialize(trainer_desc); } void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index a7e84b34b2436bf60d1af19f4f128597250d5033..f68ee153e0025a2fed1fe3055ebb1f1acd4d3935 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -30,6 +30,36 @@ void SectionWorker::Initialize(const TrainerDesc &desc) { for (auto &op_desc : program_->Block(0).AllOps()) { ops_.push_back(OpRegistry::CreateOp(*op_desc)); } + + // if not 1F1B scheduler + if (schedule_mode_ != 1) return; + + bool is_first_stage = (pipeline_stage_ == 0); + int BACKWARD = static_cast(OpRole::kBackward); + for (auto &op : ops_) { + int op_role = op->Attr("op_role"); + auto op_type = op->Type(); + + // pipeline backward send op + if (op_role != BACKWARD) continue; + if (op_type != "send_v2" && op_type != "partial_send") continue; + + auto var_name = op->InputVars()[0]; + VLOG(3) << "Pipeline backward send var " << var_name; + PADDLE_ENFORCE_NE(is_first_stage, true, + platform::errors::PreconditionNotMet( + "The first pipeline stage must do not have a " + "backward send var, please check var %s", + var_name)); + + backward_send_vars_.push_back(var_name); + skip_vars_.push_back(var_name); + } +} + +void SectionWorker::PrepareUnusedVar() { + VLOG(5) << "begin prepare the unsed vars"; + unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_); } void SectionWorker::RunForward( @@ -96,9 +126,79 @@ void SectionWorker::RunUpdate( } } -void SectionWorker::PrepareUnusedVar() { - VLOG(5) << "begin prepare the unsed vars"; - unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_); +void SectionWorker::RunFThenB(std::unique_ptr &gc) { + // F-then-B scheduler which runs Forward phase for all microbatches, + // then runs Backward phase for all microbatches. + // step1: run forward + for (int i = 0; i < num_microbatches_; ++i) { + RunForward(i, gc, unused_vars_); + } + // step2: run backward + for (int i = 0; i < num_microbatches_; ++i) { + RunBackward(i, gc, unused_vars_); + } + // step3: run update + RunUpdate(gc, unused_vars_); +} + +void SectionWorker::Run1F1B(std::unique_ptr &gc) { + // 1F1B scheduler, which runs forward phase and backward phase altertively + // after startup phase. For a stage, the number of microbatches for + // startup is num_pipeline_stages_ - pipeline_stage_ - 1, where + // num_pipeline_stages_ is the total number of pipeline stages and + // pipeline_stage_ is the pipeline stage of the current device. + auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1; + VLOG(3) << "startup_steps:" << startup_steps + << ", num_stages: " << num_pipeline_stages_ + << ", stage:" << pipeline_stage_; + PADDLE_ENFORCE_GT( + num_microbatches_, startup_steps, + platform::errors::InvalidArgument( + "To use pipeline with 1F1B scheduler, please make sure number of " + "microbatches (%d) is than startup steps (%d).", + num_microbatches_, startup_steps)); + int fw_step = 0; + int bw_step = 0; + + // startup phase + while (fw_step < startup_steps) { + RunForward(fw_step, gc, unused_vars_); + fw_step += 1; + } + + // 1f1b phase + while (fw_step < num_microbatches_) { + RunForward(fw_step, gc, unused_vars_); + + // delete backward send var at step=(bw_step - 2) + if (gc && bw_step >= 2) { + DeleteUnusedTensors(*microbatch_scopes_[bw_step - 2], backward_send_vars_, + gc.get()); + } + + RunBackward(bw_step, gc, unused_vars_); + + fw_step += 1; + bw_step += 1; + } + + int reserve_bw_send_step = bw_step - 2; + // backward phase + while (bw_step < num_microbatches_) { + RunBackward(bw_step, gc, unused_vars_); + bw_step += 1; + } + + RunUpdate(gc, unused_vars_); + + if (gc) { + // NOTE(wangxi): program must add sync backward send comm at update + // delete backward send var + for (int i = reserve_bw_send_step; i < num_microbatches_; ++i) { + DeleteUnusedTensors(*microbatch_scopes_[i], backward_send_vars_, + gc.get()); + } + } } void SectionWorker::TrainFiles() { @@ -132,56 +232,11 @@ void SectionWorker::TrainFiles() { } // max_memory_size >= 0 if (schedule_mode_ == 0) { - // F-then-B scheduler which runs Forward phase for all microbatches, - // then runs Backward phase for all microbatches. - // step1: run forward - for (int i = 0; i < num_microbatches_; ++i) { - RunForward(i, gc, unused_vars_); - } - // step2: run backward - for (int i = 0; i < num_microbatches_; ++i) { - RunBackward(i, gc, unused_vars_); - } - // step3: run update - RunUpdate(gc, unused_vars_); + RunFThenB(gc); } else { - // 1F1B scheduler, which runs forward phase and backward phase altertively - // after startup phase. For a stage, the number of microbatches for - // startup is num_pipeline_stages_ - pipeline_stage_ - 1, where - // num_pipeline_stages_ is the total number of pipeline stages and - // pipeline_stage_ is the pipeline stage of the current device. - auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1; - VLOG(3) << "startup_steps:" << startup_steps - << ", num_stages: " << num_pipeline_stages_ - << ", stage:" << pipeline_stage_; - PADDLE_ENFORCE_GT( - num_microbatches_, startup_steps, - platform::errors::InvalidArgument( - "To use pipeline with 1F1B scheduler, please make sure number of " - "microbatches (%d) is than startup steps (%d).", - num_microbatches_, startup_steps)); - int fw_step = 0; - int bw_step = 0; - // startup phase - while (fw_step < startup_steps) { - RunForward(fw_step, gc, unused_vars_); - fw_step += 1; - } - - // 1f1b phase - while (fw_step < num_microbatches_) { - RunForward(fw_step, gc, unused_vars_); - fw_step += 1; - RunBackward(bw_step, gc, unused_vars_); - bw_step += 1; - } - // backward phase - while (bw_step < num_microbatches_) { - RunBackward(bw_step, gc, unused_vars_); - bw_step += 1; - } - RunUpdate(gc, unused_vars_); + Run1F1B(gc); } + dev_ctx_->Wait(); ++batch_id_; } diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ddd9ef2327c605c167973aa98647a860a71b4ed3..486792093a35cf9c3078b6e3201fedad86a9662d 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -5277,6 +5277,7 @@ class PipelineOptimizer(object): backward_recv_index = index break + # last pipeline stage if backward_recv_index is None: return offset = 0 diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index a4f9436f262bf36d8d2b28f1c8e751701756314f..446b9a1e697e9053ae0fb84e7f6254030252b902 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -1258,7 +1258,7 @@ class TestDistBase(unittest.TestCase): "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \ "alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \ "sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,gen_nccl_id_op_help=10,nccl_helper=10,grpc_client=10," \ - "grpc_server=10,request_handler_impl=10" + "grpc_server=10,request_handler_impl=10,section_worker=10" required_envs["GLOG_logtostderr"] = "1" required_envs.update(need_envs)