diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 7cf8d55aeeb1d99acd2f501461f0563f87a25e78..524a112ff197499e5dcf0937c23ceafa5784ee1c 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -32,6 +32,9 @@ message ShardingConfig { optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional bool hybrid_dp = 2 [ default = false ]; optional int32 sharding_group_size = 3 [ default = 8 ]; + optional bool as_outer_parallelism = 4 [ default = false ]; + optional int32 inner_parallelism_size = 5 [ default = 8 ]; + optional bool use_pipeline = 6 [ default = false ]; } message AMPConfig { @@ -117,6 +120,8 @@ message AsyncConfig { message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } +message ModelParallelConfig { optional int32 parallelism = 1 [ default = 1 ]; } + message DistributedStrategy { // bool options optional Mode mode = 1 [ default = COLLECTIVE ]; @@ -146,6 +151,7 @@ message DistributedStrategy { optional bool fp16_allreduce = 25 [ default = false ]; optional bool sharding = 26 [ default = false ]; optional float last_comm_group_size_MB = 27 [ default = 1 ]; + optional bool model_parallel = 28 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; @@ -158,6 +164,7 @@ message DistributedStrategy { optional LambConfig lamb_configs = 109; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional ShardingConfig sharding_configs = 111; + optional ModelParallelConfig model_parallel_configs = 112; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; } diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 6634cb98d67413087f6a9acb4bac3378bf15dcab..54682e3e1cf1c2bad3a4c7904866009105e45502 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -107,7 +107,7 @@ void SectionWorker::TrainFiles() { int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kOptimize)) { VLOG(3) << "Update: running op " << op->Type(); - op->Run(*microbatch_scopes_[0], place_); + op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_); if (gc) { DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_, gc.get()); diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index c751e229cbbe2b900ead900297ff9956946b9e75..eefa206c8d87c1da925e2a39cb2002dae772543d 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -56,9 +56,10 @@ class AMPOptimizer(MetaOptimizerBase): # add is_distributed to optimize amp, overlap communication and # computation by split the check_finite_and_unscale op. is_distributed = self.role_maker._worker_num() > 1 - if self.user_defined_strategy.sharding: - # FIXME(wangxi). sharding failed when split check_finite_and_unscale - is_distributed = False + #if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel: + # # FIXME(wangxi). sharding failed when split check_finite_and_unscale + # # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior + # is_distributed = False self.wrapped_opt._set_distributed(is_distributed) def _can_apply(self): diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index da8adf47b854bf3cf74eab712088ad1d481face3..779b7534494c2ddb2c346c79717ab0ff4d8ca15f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -154,8 +154,10 @@ class PipelineOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(PipelineOptimizer, self).__init__(optimizer) self.inner_opt = optimizer - # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = [ + "RecomputeOptimizer", + "AMPOptimizer", + ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] def _set_basic_info(self, loss, role_maker, user_defined_optimizer,