提交 6760cbd9 编写于 作者: S sandyhouse

update

上级 03e99e26
......@@ -35,6 +35,7 @@ message ShardingConfig {
optional bool as_outer_parallelism = 4 [ default = false ];
optional int32 parallelism = 5 [ default = 1 ];
optional bool use_pipeline = 6 [ default = false ];
optional int32 acc_steps = 7 [ default = 1 ];
}
message AMPConfig {
......
......@@ -25,6 +25,8 @@ namespace framework {
void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
const auto& section_params = trainer_desc.section_param();
const auto num_pipeline_stages_ = section_params.num_pipeline_stages();
const auto pipeline_stage_ = section_params.pipeline_stage();
num_microbatches_ = section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc;
......@@ -40,6 +42,8 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
}
void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
......
......@@ -93,6 +93,8 @@ message SectionWorkerParameter {
optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5;
optional int32 num_microbatches = 6;
optional int32 num_pipeline_stages = 7 [ default = 1 ];
optional int32 pipeline_stage = 8 [ default = 1 ];
}
message SectionConfig {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册