From 6760cbd991abbc559aa536a13d4f881786430bf3 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Tue, 2 Mar 2021 17:37:57 +0800 Subject: [PATCH] update --- paddle/fluid/framework/distributed_strategy.proto | 1 + paddle/fluid/framework/pipeline_trainer.cc | 4 ++++ paddle/fluid/framework/trainer_desc.proto | 2 ++ 3 files changed, 7 insertions(+) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index e0b20bbba91..2f8820fc114 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -35,6 +35,7 @@ message ShardingConfig { optional bool as_outer_parallelism = 4 [ default = false ]; optional int32 parallelism = 5 [ default = 1 ]; optional bool use_pipeline = 6 [ default = false ]; + optional int32 acc_steps = 7 [ default = 1 ]; } message AMPConfig { diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 58e09203299..0fd6b25bcac 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -25,6 +25,8 @@ namespace framework { void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { const auto& section_params = trainer_desc.section_param(); + const auto num_pipeline_stages_ = section_params.num_pipeline_stages(); + const auto pipeline_stage_ = section_params.pipeline_stage(); num_microbatches_ = section_params.num_microbatches(); VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; trainer_desc_ = trainer_desc; @@ -40,6 +42,8 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, this_worker->SetPlace(place_); this_worker->Initialize(trainer_desc); this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetPipelineStageNum(num_pipeline_stages_); + this_worker->SetPipelineStage(pipeline_stage_); } void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 70481cf3727..1bcdc8458c7 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -93,6 +93,8 @@ message SectionWorkerParameter { optional int32 start_cpu_core_id = 4 [ default = 1 ]; repeated string param_need_sync = 5; optional int32 num_microbatches = 6; + optional int32 num_pipeline_stages = 7 [ default = 1 ]; + optional int32 pipeline_stage = 8 [ default = 1 ]; } message SectionConfig { -- GitLab