diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index e0b20bbba91b9a0af34f5d8e6e8fabab3b24785d..2f8820fc1147736a18bc5b8e4ca6d63a09450f98 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -35,6 +35,7 @@ message ShardingConfig { optional bool as_outer_parallelism = 4 [ default = false ]; optional int32 parallelism = 5 [ default = 1 ]; optional bool use_pipeline = 6 [ default = false ]; + optional int32 acc_steps = 7 [ default = 1 ]; } message AMPConfig { diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 58e09203299e87e6b0ec21042b4af1fb26e38d0a..0fd6b25bcaca6323b34b4fc1c2e4a2ae68229be3 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -25,6 +25,8 @@ namespace framework { void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { const auto& section_params = trainer_desc.section_param(); + const auto num_pipeline_stages_ = section_params.num_pipeline_stages(); + const auto pipeline_stage_ = section_params.pipeline_stage(); num_microbatches_ = section_params.num_microbatches(); VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; trainer_desc_ = trainer_desc; @@ -40,6 +42,8 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, this_worker->SetPlace(place_); this_worker->Initialize(trainer_desc); this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetPipelineStageNum(num_pipeline_stages_); + this_worker->SetPipelineStage(pipeline_stage_); } void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 70481cf3727012e4cf41d235154eb277d92cc92f..1bcdc8458c7d0ece664d352201c93c95bf834111 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -93,6 +93,8 @@ message SectionWorkerParameter { optional int32 start_cpu_core_id = 4 [ default = 1 ]; repeated string param_need_sync = 5; optional int32 num_microbatches = 6; + optional int32 num_pipeline_stages = 7 [ default = 1 ]; + optional int32 pipeline_stage = 8 [ default = 1 ]; } message SectionConfig {