From 503f422eaa7a7c4aaeb87a202e3877168c77a1b9 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Tue, 25 Apr 2023 16:22:20 +0800 Subject: [PATCH] add mp_sync config. (#53254) --- .../framework/distributed_strategy.proto | 1 + .../fleet/base/distributed_strategy.py | 6 ++ .../hybrid_parallel_optimizer.py | 95 +++++++++++++------ .../fleet/hybrid_parallel_mp_model.py | 52 +++++++++- .../fleet/test_fleet_distributed_strategy.py | 30 ++++++ 5 files changed, 150 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index d0e494ad494..6b093e9ee03 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -55,6 +55,7 @@ message MpConfig { optional bool sync_param= 1 [ default = false ]; optional bool sync_grad= 2 [ default = false ]; optional bool sync_moment= 3 [ default = false ]; + optional string sync_mode= 4 [ default = 'broadcast' ]; } message PpConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 14e2fc09d33..b7519891024 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -146,6 +146,8 @@ class DistributedStrategy: self.strategy.sync_nccl_allreduce = bool(_global_flags()[key]) self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp'] + self.sync_param_name = ["embedding", "layer_norm", ".b_"] + self.__lock_attr = True logger.info("distributed strategy initialized") @@ -1698,6 +1700,10 @@ class DistributedStrategy: ) if "mp_configs" in configs: + if "sync_param_name" in configs["mp_configs"]: + self.sync_param_name = configs["mp_configs"]["sync_param_name"] + configs["mp_configs"].pop("sync_param_name") + assign_configs_value( self.strategy.hybrid_configs.mp_configs, configs["mp_configs"] ) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index ab1b270e2fd..3254bffb254 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -303,9 +303,20 @@ class HybridParallelOptimizer: inner_opt._grad_clip, hcg ) - def _filter_fn(self, param): + def _insert_sync(self, sync_var, src, mp_group, sync_mode): + if sync_mode == "broadcast": + paddle.distributed.broadcast( + sync_var, src=src, group=mp_group, sync_op=True + ) + else: + paddle.distributed.all_reduce( + sync_var, group=mp_group, sync_op=True + ) + sync_var.scale_(1.0 / mp_group.nranks) + + def _filter_fn(self, param, strategy): p_name = param.name - tar_param = ["embedding", "layer_norm", ".b_"] + tar_param = strategy.sync_param_name if param.is_distributed is False: for tar in tar_param: if tar in p_name: @@ -329,26 +340,48 @@ class HybridParallelOptimizer: or mp_configs.sync_moment ): params = sorted( - [p for p in parameters_list if self._filter_fn(p)], + [ + p + for p in parameters_list + if self._filter_fn(p, fleet.fleet._user_defined_strategy) + ], key=lambda p: p.name, ) + # Grad sync before opt if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad: for p in params: - if p.grad is None: - continue - paddle.distributed.broadcast( - p.grad, src=src_rank, group=mp_group, sync_op=True - ) + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p.grad is None + self._insert_sync( + p.main_grad, src_rank, mp_group, mp_configs.sync_mode + ) + elif p.grad is not None: + self._insert_sync( + p.grad, src_rank, mp_group, mp_configs.sync_mode + ) self._inner_opt.step() if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param: for p in params: - paddle.distributed.broadcast( - p, src=src_rank, group=mp_group, sync_op=True - ) + # Param sync after opt + self._insert_sync(p, src_rank, mp_group, mp_configs.sync_mode) + + # Master param sync after opt + if ( + hasattr(self._inner_opt, "_multi_precision") + and self._inner_opt._multi_precision + and p.name in self._inner_opt._master_weights + ): + self._insert_sync( + self._inner_opt._master_weights[p.name], + src_rank, + mp_group, + mp_configs.sync_mode, + ) + # Moment sync after opt if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment: for p in params: # support opt state of adam and adamw to broadcast now. @@ -357,28 +390,30 @@ class HybridParallelOptimizer: (paddle.optimizer.Adam, paddle.optimizer.AdamW), ): if ( - self._inner_opt._multi_precision - and p.name in self._master_weights + p.name + in self._inner_opt._accumulators[ + self._inner_opt._moment1_acc_str + ] ): - paddle.distributed.broadcast( - self._inner_opt._master_weights[p.name], - src=src_rank, - group=mp_group, - sync_op=True, + moment1 = self._inner_opt._get_accumulator( + self._inner_opt._moment1_acc_str, p + ) + self._insert_sync( + moment1, src_rank, mp_group, mp_configs.sync_mode ) - moment1 = self._inner_opt._get_accumulator( - self._inner_opt._moment1_acc_str, p - ) - moment2 = self._inner_opt._get_accumulator( - self._inner_opt._moment2_acc_str, p - ) - paddle.distributed.broadcast( - moment1, src=src_rank, group=mp_group, sync_op=True - ) - paddle.distributed.broadcast( - moment2, src=src_rank, group=mp_group, sync_op=True - ) + if ( + p.name + in self._inner_opt._accumulators[ + self._inner_opt._moment2_acc_str + ] + ): + moment2 = self._inner_opt._get_accumulator( + self._inner_opt._moment2_acc_str, p + ) + self._insert_sync( + moment2, src_rank, mp_group, mp_configs.sync_mode + ) @no_grad() @framework.dygraph_only diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py index 26e740bfa6b..82efb6fa466 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py @@ -202,6 +202,7 @@ class TestDistMPSyncTraning(unittest.TestCase): self, batchs, fp16=False, + amp_level="O1", mp_sync_param=False, mp_sync_grad=False, mp_sync_moment=False, @@ -232,6 +233,11 @@ class TestDistMPSyncTraning(unittest.TestCase): learning_rate=0.1, parameters=model.parameters() ) + if fp16 and amp_level == "O2": + model, optimizer = paddle.amp.decorate( + models=model, optimizers=optimizer, level='O2' + ) + strategy = fleet.fleet._user_defined_strategy strategy.hybrid_configs = { "dp_degree": self.data_parallel_size, @@ -246,15 +252,15 @@ class TestDistMPSyncTraning(unittest.TestCase): model = fleet.distributed_model(model) optimizer = fleet.distributed_optimizer(optimizer) - return self.train_batch(batchs, model, optimizer, fp16) + return self.train_batch(batchs, model, optimizer, fp16, amp_level) - def train_batch(self, batchs, model, optimizer, fp16=False): + def train_batch(self, batchs, model, optimizer, fp16=False, amp_level="O1"): losses = [] if fp16: scaler = paddle.amp.GradScaler(init_loss_scaling=1024) scaler = fleet.distributed_scaler(scaler) for batch in batchs: - with paddle.amp.auto_cast(enable=fp16, level='O1'): + with paddle.amp.auto_cast(enable=fp16, level=amp_level): output = model(batch) loss = output.mean() losses.append(loss.numpy()) @@ -295,7 +301,7 @@ class TestDistMPSyncTraning(unittest.TestCase): for i in range(len(losses)): np.testing.assert_allclose(losses[i], losses_sync[i], rtol=1e-6) - # test fp16 + # test fp16 O1 losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True) losses_sync_fp16 = self.build_model_optimizer_train( batchs, @@ -310,6 +316,24 @@ class TestDistMPSyncTraning(unittest.TestCase): losses_fp16[i], losses_sync_fp16[i], rtol=1e-6 ) + # test fp16 O2 + losses_fp16_O2 = self.build_model_optimizer_train( + batchs, fp16=True, amp_level="O2" + ) + losses_sync_fp16_O2 = self.build_model_optimizer_train( + batchs, + fp16=True, + amp_level="O2", + mp_sync_param=mp_sync_param, + mp_sync_grad=mp_sync_grad, + mp_sync_moment=mp_sync_moment, + ) + + for i in range(len(losses_fp16_O2)): + np.testing.assert_allclose( + losses_fp16_O2[i], losses_sync_fp16_O2[i], rtol=1e-6 + ) + def test_mp_sync_param(self): self.mp_sync_base(mp_sync_param=True) @@ -325,6 +349,26 @@ class TestDistMPSyncTraning(unittest.TestCase): ) +class TestDistMPSyncModelTraning(TestDistMPSyncTraning): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "sync_param": False, + "sync_grad": False, + "sync_moment": False, + "sync_mode": "average", + "sync_param_name": ["embedding", "layer_norm", ".b_"], + }, + } + fleet.init(is_collective=True, strategy=strategy) + + class TestDistMPTraning(unittest.TestCase): def setUp(self): strategy = fleet.DistributedStrategy() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py index 99f235b5887..ba49cbf125a 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py @@ -84,6 +84,36 @@ class TestStrategyConfig(unittest.TestCase): self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) + def test_hybrid_parallel_mp_configs(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 2, + "pp_degree": 4, + "mp_configs": { + "sync_param": True, + "sync_grad": False, + "sync_moment": False, + "sync_mode": "broadcast", + "sync_param_name": ["embedding", "layer_norm", ".w", ".b_"], + }, + } + self.assertEqual(strategy.hybrid_configs["dp_degree"], 1) + self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) + self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) + self.assertEqual(strategy.hybrid_configs["mp_configs"].sync_param, True) + self.assertEqual(strategy.hybrid_configs["mp_configs"].sync_grad, False) + self.assertEqual( + strategy.hybrid_configs["mp_configs"].sync_moment, False + ) + self.assertEqual( + strategy.hybrid_configs["mp_configs"].sync_mode, "broadcast" + ) + + self.assertEqual( + strategy.sync_param_name, ["embedding", "layer_norm", ".w", ".b_"] + ) + def test_hybrid_parallel_configs_order(self): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.hybrid_configs = { -- GitLab