diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index c965363d701c5245a825ac1c54326bee0a714745..aae9515a565e2ffed2965628c9c19cd8e6f6e6c0 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -38,7 +38,8 @@ message ShardingConfig { optional int32 acc_steps = 7 [ default = 1 ]; optional int32 schedule_mode = 8 [ default = 0 ]; optional int32 pp_bz = 9 [ default = 1 ]; - optional bool pp_allreduce_in_optimize = 10 [ default = true ]; + optional bool pp_allreduce_in_optimize = 10 [ default = false ]; + optional bool optimize_offload = 11 [ default = false ]; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 86399de9b756219136bde7dc6faa63a51c46ca62..6c00aa9fd45a5622dc61d18f214d85dc94c98de5 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -103,6 +103,8 @@ class ShardingOptimizer(MetaOptimizerBase): self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"] self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ "pp_allreduce_in_optimize"] + self.optimize_offload = self.user_defined_strategy.sharding_configs[ + "optimize_offload"] if self.inner_opt is None: raise ValueError( @@ -359,8 +361,10 @@ class ShardingOptimizer(MetaOptimizerBase): main_block._sync_with_cpp() # TODO(wangxi): add optimize offload - offload_helper = OffloadHelper() - offload_helper.offload(main_block, startup_block) + if self.optimize_offload: + logging.info("Sharding with optimize offload !") + offload_helper = OffloadHelper() + offload_helper.offload(main_block, startup_block) with open("start_sharding_%d" % self.role_maker._worker_index(), 'w') as f: @@ -943,9 +947,8 @@ class ShardingOptimizer(MetaOptimizerBase): ] self.pp_group_size = self.pipeline_nodes self.pp_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) - if (idx % self.sharding_group_size - ) == self.sharding_rank + ep for idx, ep in enumerate(self.endpoints) if + (idx % self.sharding_group_size) == self.sharding_rank ] else: self.mp_group_id = 0 @@ -969,11 +972,12 @@ class ShardingOptimizer(MetaOptimizerBase): self._inner_parallelism_size * self.sharding_group_size) self.megatron_rank = self.global_rank % self._inner_parallelism_size self.sharding_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) - if (idx // (self._inner_parallelism_size * - self.sharding_group_size) - ) == self.sharding_group_id and idx % - self._inner_parallelism_size == self.megatron_rank + ep for idx, ep in enumerate(self.endpoints) if + (idx // + (self._inner_parallelism_size * + self.sharding_group_size)) == self.sharding_group_id + and + idx % self._inner_parallelism_size == self.megatron_rank ] print("sharding_endpoint:", self.sharding_group_endpoints) print("sharding_rank:", self.sharding_rank)