diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index 3ad6e320316c61ed1b74829b6074685874eb61fc..bb6af1b3195f705bfa3813f3c0eb98a72f939212 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -25,8 +25,9 @@ class OffloadHelper(object): cuda_place_type = 1 cuda_pinned_place_type = 2 - def __init__(self, ring_id=None): - self.ring_id = ring_id + def __init__(self, mp_ring_id=None, dp_ring_id=None): + self.mp_ring_id = mp_ring_id + self.dp_ring_id = dp_ring_id def _insert_cast_op(self, block, idx, src_name, dst_name): src_var = block.var(src_name) @@ -49,20 +50,31 @@ class OffloadHelper(object): OP_ROLE_KEY: OpRole.Optimize }) - def _insert_broadcast_op(self, block, idx, param): - if self.ring_id is None: - return - block._insert_op_without_sync( - idx, - type="c_broadcast", - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': self.ring_id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward, - }) + def _insert_broadcast_op(self, block, idx, param_name): + rings = [] + + if self.dp_ring_id is not None: + rings.append(self.dp_ring_id) + + # need sync non distributed param in mp group + if self.mp_ring_id is not None: + param = block.var(param_name) + if not hasattr(param, 'is_distributed') or not param.is_distributed: + rings.append(self.mp_ring_id) + + # the insert op order is: mp, dp + for ring in rings: + block._insert_op_without_sync( + idx, + type="c_broadcast", + inputs={'X': param_name}, + outputs={'Out': param_name}, + attrs={ + 'ring_id': ring, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward, + }) def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type): src_var = block.var(src_name) @@ -236,7 +248,7 @@ class OffloadHelper(object): self._insert_cast_op(startup_block, insert_idx, var_name, param_to_fp16[var_name]) # NOTE(wangxi): cast and offload should insert after broadcast param. - # the insert op order is: broadcast, cast, offload + # the insert op order is: {mp, dp}broadcast, cast, offload self._insert_broadcast_op(startup_block, insert_idx, var_name) @@ -489,6 +501,8 @@ class OffloadHelper(object): self._insert_cast_op(startup_block, insert_idx, var_name, param_to_fp16[var_name]) + # NOTE(wangxi): cast and offload should insert after broadcast param. + # the insert op order is: {mp, dp}broadcast, cast, offload self._insert_broadcast_op(startup_block, insert_idx, var_name) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 75a69e5527bc18e71eb9286ce1bda60c0aeaaf1d..18211459a4e0833ddafcbc56e1dd50e7543cbce3 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -467,14 +467,16 @@ class ShardingOptimizer(MetaOptimizerBase): main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() + mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None + offload_helper = OffloadHelper( + mp_ring_id=mp_ring_id, dp_ring_id=dp_ring_id) # optimize offload should be enable while gradient merge is enable and # acc_step is quite large (e.g. >> 100). Since its memcpy could not be # overlap with calc, otherwise it will slower down training severely. if sharding_configs["optimize_offload"]: logger.info("Sharding with optimize offload !") - offload_helper = OffloadHelper(ring_id=dp_ring_id) offload_helper.offload(main_block, startup_block) # The optimize_cast is already included in offload_fp32param offload_helper.offload_fp32param(main_block, startup_block) @@ -482,7 +484,6 @@ class ShardingOptimizer(MetaOptimizerBase): logger.info("Sharding with optimize cast !") # NOTE(wangxi): optimize_cast will persist fp16 param, it # will take more memory, but will be faster. Trade space for time. - offload_helper = OffloadHelper(ring_id=dp_ring_id) if self._optimizer_sharding: offload_helper.opt_sharding_cast_fp32param( main_block, startup_block, @@ -554,6 +555,10 @@ class ShardingOptimizer(MetaOptimizerBase): # init param broadcast should be called after startup pruning self._initialization_broadcast() + # NOTE(wangxi): if param is not persistable, program.clone will + # failed, so we remove no persistable param, recreate param as a var + self._recreate_not_persist_param_as_var() + self._dump_program_for_debug() # GPU need to wait server ready, GPU and NPU is Layered connection @@ -1385,23 +1390,14 @@ class ShardingOptimizer(MetaOptimizerBase): return - def _initialization_broadcast(self): - """ - this funtion is to ensure the initialization between dp group to be - identical when hybrid-dp is used. - """ - if not self.hybrid_dp: - return - - startup_block = self._startup_program.global_block() - params = startup_block.all_parameters() - params_name = [] + def _recreate_not_persist_param_as_var(self): + def recreate_not_persist_param_as_var(program): + block = program.global_block() + params = block.all_parameters() + for param in params: + if param.persistable: + continue - # NOTE(wangxi): if param is not persistable, program.clone will - # failed, so we remove no persistable param, re add param as a var - for param in params: - params_name.append(param.name) - if not param.persistable: name = param.name shape = param.shape dtype = param.dtype @@ -1411,15 +1407,14 @@ class ShardingOptimizer(MetaOptimizerBase): trainable = param.trainable optimize_attr = param.optimize_attr regularizer = param.regularizer - have_dist_attr = False is_distributed = False if hasattr(param, 'is_distributed'): have_dist_attr = True is_distributed = param.is_distributed - startup_block._remove_var(name, sync=False) - var = startup_block.create_var( + block._remove_var(name, sync=False) + var = block.create_var( name=name, shape=shape, dtype=dtype, @@ -1431,6 +1426,31 @@ class ShardingOptimizer(MetaOptimizerBase): if have_dist_attr: var.is_distributed = is_distributed + block._sync_with_cpp() + + recreate_not_persist_param_as_var(self._startup_program) + recreate_not_persist_param_as_var(self._main_program) + + def _initialization_broadcast(self): + """ + this funtion is to ensure the initialization between dp group to be + identical when hybrid-dp is used, and the initialization of + not distributed param between mp group to be identical. + """ + if self.dp_degree <= 1 and self.mp_degree <= 1: + return + + startup_block = self._startup_program.global_block() + + params = startup_block.all_parameters() + params_name = [] + not_dist_param_name = set() + + for param in params: + params_name.append(param.name) + if not hasattr(param, 'is_distributed') or not param.is_distributed: + not_dist_param_name.add(param.name) + # offload and optimize_cast will insert broadcast op broadcast_params = set() for op in startup_block.ops: @@ -1439,23 +1459,25 @@ class ShardingOptimizer(MetaOptimizerBase): for param in params_name: if param in broadcast_params: continue - startup_block.append_op( - type='c_broadcast', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': self.dp_ring_id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward - }) - startup_block.append_op( - type='c_sync_comm_stream', - inputs={'X': params_name}, - outputs={'Out': params_name}, - attrs={'ring_id': self.dp_ring_id, - OP_ROLE_KEY: OpRole.Forward}) + rings = [] + # need sync not distributed param in mp group + if self.mp_degree > 1 and param in not_dist_param_name: + rings.append(self.mp_ring_id) + if self.dp_degree > 1: + rings.append(self.dp_ring_id) + + for ring in rings: + startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': ring, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) startup_block._sync_with_cpp() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py index 6eb566935d9d52ced1444bad96dd16df94832fc0..35b74eac4b0750845b26884490eb811c1bfb6860 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py @@ -72,8 +72,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_sync_comm_stream' + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -155,8 +154,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_sync_comm_stream' + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -218,7 +216,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream' + 'c_broadcast', 'c_broadcast', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -292,7 +290,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream' + 'c_broadcast', 'c_broadcast', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -371,7 +369,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', - 'cast', 'c_broadcast', 'c_sync_comm_stream' + 'cast', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -460,7 +458,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', - 'c_comm_init', 'c_broadcast', 'c_sync_comm_stream' + 'c_comm_init', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -511,7 +509,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', - 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_sync_comm_stream' + 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 73eacd118ecad506aa993e65d76f70f3177b3d26..7cb033b748874c08b7ae9defcfcf81ae4659a4e7 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -655,7 +655,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', - 'c_gen_nccl_id', 'c_comm_init' + 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -764,7 +766,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream' + 'c_broadcast', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -932,7 +934,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', - 'c_broadcast', 'c_sync_comm_stream' + 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -1029,7 +1031,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy', - 'c_broadcast', 'c_sync_comm_stream' + 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -1129,7 +1131,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', - 'c_broadcast', 'c_sync_comm_stream' + 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [ @@ -1221,7 +1223,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream' + 'c_broadcast', 'c_broadcast' ]) self.assertEqual(main_prog_op_types, [