From 423ea9782ab7ae76501e6acc7ec92b2cecd8633e Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Fri, 30 Jul 2021 10:04:57 +0800 Subject: [PATCH] all reduce fusion for shardinug, test=develop (#34480) --- .../fleet/meta_optimizers/sharding/utils.py | 84 +++++++++++++++++-- .../meta_optimizers/sharding_optimizer.py | 19 +++-- .../test_fleet_sharding_meta_optimizer.py | 30 +++++++ 3 files changed, 121 insertions(+), 12 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index c10978e9d94..a0e18eb16b6 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import paddle -from paddle.fluid import core +from paddle.fluid import core, unique_name from functools import reduce from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY @@ -333,26 +333,96 @@ def insert_allreduce_ops(block, ring_id, allreduce_vars, op_role=OpRole.Backward, - use_calc_stream=False): + use_calc_stream=False, + user_defined_strategy=None): """ _add_allreduce_ops """ if len(allreduce_vars) == 0: return + if user_defined_strategy and user_defined_strategy.fuse_all_reduce_ops: + insert_fused_allreduce_ops(block, insert_idx, ring_id, allreduce_vars, + op_role, use_calc_stream, + user_defined_strategy.fuse_grad_size_in_MB) + else: + for var in allreduce_vars: + block._insert_op_without_sync( + insert_idx, + type='c_allreduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': use_calc_stream, + OP_ROLE_KEY: op_role + }) + + return + + +def insert_fused_allreduce_ops(block, + insert_idx, + ring_id, + allreduce_vars, + op_role=OpRole.Backward, + use_calc_stream=False, + fuse_grad_size_in_MB=32): + segments = [] + cur_size = 0. + last_dtype = None for var in allreduce_vars: + real_var = block.var(var) + var_size = get_var_size(real_var) + if cur_size + var_size > fuse_grad_size_in_MB \ + or len(segments) == 0 \ + or real_var.dtype != last_dtype: + segments.append([real_var]) + cur_size = var_size + last_dtype = real_var.dtype + else: + segments[-1].append(real_var) + cur_size += var_size + + fused_vars = [] + for segment in segments: + tmp_var = block.create_var( + name=unique_name.generate('FusedOutput_{}'.format(segment[0].name)), + dtype=segment[0].dtype, + persistable=False, + stop_gradient=True) + fused_vars.append(tmp_var) block._insert_op_without_sync( insert_idx, + type="coalesce_tensor", + inputs={"Input": segment}, + outputs={"Output": segment, + "FusedOutput": tmp_var}, + attrs={ + "copy_data": True, + "use_align": True, + "dtype": segment[0].dtype, + OP_ROLE_KEY: op_role + }) + + for fused_var in fused_vars: + block._insert_op_without_sync( + insert_idx + len(fused_vars), type='c_allreduce_sum', - inputs={'X': var}, - outputs={'Out': var}, + inputs={'X': fused_var}, + outputs={'Out': fused_var}, attrs={ 'ring_id': ring_id, 'use_calc_stream': use_calc_stream, OP_ROLE_KEY: op_role }) - - return + if not use_calc_stream: + block._insert_op_without_sync( + insert_idx + len(fused_vars), + type='c_sync_calc_stream', + inputs={'X': fused_var}, + outputs={'Out': fused_var}, + attrs={OP_ROLE_KEY: op_role}) def insert_reduce_ops(block, @@ -528,7 +598,7 @@ def add_sync_comm(program, sharding_ring_id): add the sync_comm op for the test prog. """ - #NOTE (liangjianzhong): only support one comm stream by now, use more than one + #NOTE (liangjianzhong): only support one comm stream by now, use more than one # comm streams will cause error. should be revise in future. assert sharding_ring_id >= 0, "sharding_ring_id should larger than zero" diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 878614ca152..8211f3ea0fb 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -322,7 +322,8 @@ class ShardingOptimizer(MetaOptimizerBase): self.dp_ring_id, accumulated_grad_names, core.op_proto_and_checker_maker.OpRole.Optimize, - use_calc_stream=True) + use_calc_stream=True, + user_defined_strategy=self.user_defined_strategy) # if not use sharding, adapt amp/clip, for remain parallelism. # cast --> amp --> clip --> opt @@ -778,8 +779,12 @@ class ShardingOptimizer(MetaOptimizerBase): shard_allredue_vars) >= 1: insert_sync_comm_ops(block, self._segments[-1]._end_idx, self.dp_ring_id, shard_allredue_vars) - insert_allreduce_ops(block, self._segments[-1]._end_idx, - self.dp_ring_id, shard_allredue_vars) + insert_allreduce_ops( + block, + self._segments[-1]._end_idx, + self.dp_ring_id, + shard_allredue_vars, + user_defined_strategy=self.user_defined_strategy) # gradient merge elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: self.create_persistable_gradients_and_insert_merge_ops( @@ -896,8 +901,12 @@ class ShardingOptimizer(MetaOptimizerBase): if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( shard_allredue_vars) >= 1: - insert_allreduce_ops(block, segment._start_idx, - self.dp_ring_id, shard_allredue_vars) + insert_allreduce_ops( + block, + segment._start_idx, + self.dp_ring_id, + shard_allredue_vars, + user_defined_strategy=self.user_defined_strategy) insert_sync_comm_ops(block, segment._start_idx, self.sharding_ring_id, allreduce_vars) # gradient merge diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index a29d752ed75..d66fb2c36b7 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -586,6 +586,36 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + def test_sharding_dp_with_allreduce_fuse(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, _ = self.net(train_prog, startup_prog) + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.sharding = True + strategy.sharding_configs = { + "sharding_segment_strategy": "segment_broadcast_MB", + "segment_broadcast_MB": 0.1, + "segment_anchors": None, + "sharding_degree": 2, + "dp_degree": 2, + "hybrid_dp": True, + "gradient_merge_acc_step": 1, + "mp_degree": 1 + } + strategy.fuse_all_reduce_ops = True + strategy.fuse_grad_size_in_MB = 2 + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + main_prog_ops = train_prog.global_block().ops + main_prog_op_types = [op.type for op in main_prog_ops] + + assert 'c_allreduce_sum' in main_prog_op_types + assert 'coalesce_tensor' in main_prog_op_types + + for op in main_prog_ops: + if op.type == 'c_allreduce_sum': + assert 'FusedOutput' in op.input_arg_names[0] + if __name__ == "__main__": unittest.main() -- GitLab