提交 6760cbd9 编写于 作者: S sandyhouse

update

上级 03e99e26
...@@ -35,6 +35,7 @@ message ShardingConfig { ...@@ -35,6 +35,7 @@ message ShardingConfig {
optional bool as_outer_parallelism = 4 [ default = false ]; optional bool as_outer_parallelism = 4 [ default = false ];
optional int32 parallelism = 5 [ default = 1 ]; optional int32 parallelism = 5 [ default = 1 ];
optional bool use_pipeline = 6 [ default = false ]; optional bool use_pipeline = 6 [ default = false ];
optional int32 acc_steps = 7 [ default = 1 ];
} }
message AMPConfig { message AMPConfig {
......
...@@ -25,6 +25,8 @@ namespace framework { ...@@ -25,6 +25,8 @@ namespace framework {
void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
const auto& section_params = trainer_desc.section_param(); const auto& section_params = trainer_desc.section_param();
const auto num_pipeline_stages_ = section_params.num_pipeline_stages();
const auto pipeline_stage_ = section_params.pipeline_stage();
num_microbatches_ = section_params.num_microbatches(); num_microbatches_ = section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc; trainer_desc_ = trainer_desc;
...@@ -40,6 +42,8 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -40,6 +42,8 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->SetPlace(place_); this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc); this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_); this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
} }
void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
......
...@@ -93,6 +93,8 @@ message SectionWorkerParameter { ...@@ -93,6 +93,8 @@ message SectionWorkerParameter {
optional int32 start_cpu_core_id = 4 [ default = 1 ]; optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5; repeated string param_need_sync = 5;
optional int32 num_microbatches = 6; optional int32 num_microbatches = 6;
optional int32 num_pipeline_stages = 7 [ default = 1 ];
optional int32 pipeline_stage = 8 [ default = 1 ];
} }
message SectionConfig { message SectionConfig {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册