diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 9eca2e667a8fd8c81aa3a4b1083ada9204cbecb6..61aa3d894f05e60160b45dcd5c81552ea2263f08 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -50,8 +50,10 @@ class HybridParallelClipGrad: @no_grad() def _dygraph_clip(self, params_grads): sum_square_dist_fp16 = [] + sum_square_dist_bf16 = [] sum_square_dist_fp32 = [] sum_square_not_dist_fp16 = [] + sum_square_not_dist_bf16 = [] sum_square_not_dist_fp32 = [] for p, g in params_grads: @@ -73,14 +75,18 @@ class HybridParallelClipGrad: if not_shared_enable: if p.is_distributed: - if p.dtype == paddle.float16: + if g.dtype == paddle.float16: sum_square_dist_fp16.append(sum_square) - elif p.dtype == paddle.float32: + elif g.dtype == paddle.bfloat16: + sum_square_dist_bf16.append(sum_square) + elif g.dtype == paddle.float32: sum_square_dist_fp32.append(sum_square) else: - if p.dtype == paddle.float16: + if g.dtype == paddle.float16: sum_square_not_dist_fp16.append(sum_square) - elif p.dtype == paddle.float32: + if g.dtype == paddle.bfloat16: + sum_square_not_dist_bf16.append(sum_square) + elif g.dtype == paddle.float32: sum_square_not_dist_fp32.append(sum_square) # global norm of distributed FP16 params_and_grads @@ -107,6 +113,30 @@ class HybridParallelClipGrad: global_norm_not_dist_fp16, dtype=paddle.float32 ) + # global norm of distributed BF16 params_and_grads + if len(sum_square_dist_bf16) == 0: + global_norm_dist_bf16 = paddle.to_tensor( + [0.0], dtype=paddle.float32 + ) + else: + global_norm_dist_bf16 = paddle.concat(sum_square_dist_bf16) + global_norm_dist_bf16 = paddle.sum(global_norm_dist_bf16) + global_norm_dist_bf16 = paddle.cast( + global_norm_dist_bf16, dtype=paddle.float32 + ) + + # global norm of non-distributed FP16 params_and_grads + if len(sum_square_not_dist_bf16) == 0: + global_norm_not_dist_bf16 = paddle.to_tensor( + [0.0], dtype=paddle.float32 + ) + else: + global_norm_not_dist_bf16 = paddle.concat(sum_square_not_dist_bf16) + global_norm_not_dist_bf16 = paddle.sum(global_norm_not_dist_bf16) + global_norm_not_dist_bf16 = paddle.cast( + global_norm_not_dist_bf16, dtype=paddle.float32 + ) + # global norm of distributed FP32 params_and_grads global_norm_dist_fp32 = ( paddle.concat(sum_square_dist_fp32) @@ -123,9 +153,15 @@ class HybridParallelClipGrad: ) global_norm_not_dist_fp32 = paddle.sum(global_norm_not_dist_fp32) - global_norm_var_dist = global_norm_dist_fp16 + global_norm_dist_fp32 + global_norm_var_dist = ( + global_norm_dist_fp16 + + global_norm_dist_bf16 + + global_norm_dist_fp32 + ) global_norm_var_not_dist = ( - global_norm_not_dist_fp16 + global_norm_not_dist_fp32 + global_norm_not_dist_fp16 + + global_norm_not_dist_bf16 + + global_norm_not_dist_fp32 ) # add all reduce to get global norm of distributed params_and_grads @@ -160,16 +196,20 @@ class HybridParallelClipGrad: ) clip_var = paddle.divide( x=max_global_norm, - y=paddle.maximum(x=global_norm_var_fp32, y=max_global_norm), + y=paddle.maximum(x=global_norm_var_fp32, y=max_global_norm) + + paddle.to_tensor([1.0e-6], dtype=paddle.float32), ) clip_var_fp16 = paddle.cast(clip_var, paddle.float16) + clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16) for p, g in params_grads: if g is None: continue if getattr(p, 'need_clip', True) is False: continue - if p.dtype == paddle.float16: + if g.dtype == paddle.float16: g.scale_(clip_var_fp16) + elif g.dtype == paddle.bfloat16: + g.scale_(clip_var_bf16) else: g.scale_(clip_var) p._reset_grad_inplace_version(True) @@ -216,6 +256,22 @@ class HybridParallelOptimizer: self._inner_opt._inner_optimizer._grad_clip = ( HybridParallelClipGrad(self._inner_opt._grad_clip, hcg) ) + elif ( + self._inner_opt._parameter_list + and not isinstance(self._inner_opt._parameter_list[0], dict) + and len( + [ + p + for p in self._inner_opt._parameter_list + if hasattr(p, "main_grad") + ] + ) + > 0 + ): + + self._inner_opt._inner_opt._grad_clip = HybridParallelClipGrad( + self._inner_opt._inner_opt._grad_clip, hcg + ) else: self._inner_opt._grad_clip = HybridParallelClipGrad( self._inner_opt._grad_clip, hcg diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index b32b71e7180cfb11aeb516c7c5e3676a8a627ee1..86b0ea9eb5b0c9c72e2ab72c3dde56fd59720829 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -502,7 +502,10 @@ class PipelineLayer(nn.Layer): if framework.in_dygraph_mode(): with paddle.framework.no_grad(): paddle.distributed.all_reduce( - param.grad, group=comm['group'] + param.grad + if not hasattr(param, "main_grad") + else param.main_grad, + group=comm['group'], ) else: with paddle.framework.no_grad(): diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 8df7e69ea0b6fc98998d38d20c1f188850a4700e..67084ad68b8a41ca674c6b602049a9a5064c9baa 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -77,8 +77,13 @@ def _apply_collective_grads_eager( grad_vars = [] for param in parameters: + g_var = None if param.trainable and (param._grad_ivar() is not None): g_var = param._grad_ivar() + if param.trainable and hasattr(param, "main_grad"): + assert param._grad_ivar() is None, "param.grad is not None" + g_var = param.main_grad + if g_var is not None: assert ( not g_var.is_sparse() ), "Now, it doesn't support sparse parameters" diff --git a/python/paddle/distributed/utils/nccl_utils.py b/python/paddle/distributed/utils/nccl_utils.py index 00f49bc1e99e0df840718eec7ad7bd2d379ac5f7..5aafb6ff5a4bed25e22a844c93d0790001d2e589 100644 --- a/python/paddle/distributed/utils/nccl_utils.py +++ b/python/paddle/distributed/utils/nccl_utils.py @@ -49,3 +49,14 @@ def check_nccl_version_for_p2p(): ) else: logging.warning("No version for NCCL library found!") + + +def check_nccl_version_for_bf16(): + nccl_version_str = get_nccl_version_str() + if nccl_version_str: + nccl_version_str = nccl_version_str.replace("\n", "") + nccl_version_int = [int(s) for s in nccl_version_str.split(".")] + nccl_version_baseline = [2, 10, 0] + return nccl_version_int >= nccl_version_baseline + + return False diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt index 1e46921bdaa28a295aaad82b063f6c0607eba42f..bdb789c813be55714564f76e0b3ab458d4742d2c 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt @@ -2,8 +2,8 @@ # Please don't modify this file manually. # If you need to change unittests in this file, please modify testslist.csv in the current directory # and then run the command `python3 ${PADDLE_ROOT}/tools/gen_ut_cmakelists.py -f ${CURRENT_DIRECTORY}/testslist.csv` -set(LOCAL_ALL_ARCH ON) set(LOCAL_ALL_PLAT ON) +set(LOCAL_ALL_ARCH ON) if((WITH_GPU OR WITH_XPU OR WITH_ASCEND diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_bf16.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_bf16.py new file mode 100644 index 0000000000000000000000000000000000000000..7a06a0326f049628bc207bf417f78747f1b809a4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_bf16.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from hybrid_parallel_mp_model import TestDistMPTraning + +import paddle +import paddle.distributed.fleet as fleet +from paddle.distributed.utils.nccl_utils import check_nccl_version_for_bf16 + + +class TestMPFP16(TestDistMPTraning): + def build_optimizer(self, model): + grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0) + scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=0.001, gamma=0.999, verbose=True + ) + optimizer = paddle.optimizer.SGD( + scheduler, grad_clip=grad_clip, parameters=model.parameters() + ) + + model, optimizer = paddle.amp.decorate( + models=model, + optimizers=optimizer, + dtype='bfloat16', + level='O2', + save_dtype='float32', + ) + + return optimizer + + def train_batch(self, batch, model, optimizer, is_mp): + scaler = paddle.amp.GradScaler( + init_loss_scaling=1, use_dynamic_loss_scaling=False + ) + if is_mp: + scaler = fleet.distributed_scaler(scaler) + with paddle.amp.auto_cast(enable=True, dtype='bfloat16', level="O2"): + output = model(batch) + loss = output.mean() + + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + return scaled + + +if __name__ == "__main__": + if check_nccl_version_for_bf16(): + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_bf16.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_bf16.py new file mode 100644 index 0000000000000000000000000000000000000000..a996baac9dc401a39523131d78563f732b0166d6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_bf16.py @@ -0,0 +1,169 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +from hybrid_parallel_pp_layer import AlexNet, AlexNetPipeDesc + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.distributed.utils.nccl_utils import check_nccl_version_for_bf16 + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 4 +micro_batch_size = 2 + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size, + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0) + + # construct model a + model_a = AlexNet(10) + scheduler_a = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True + ) + optimizer_a = paddle.optimizer.SGD( + learning_rate=scheduler_a, + grad_clip=grad_clip, + parameters=model_a.parameters(), + ) + + scaler_a = paddle.amp.GradScaler( + init_loss_scaling=1, use_dynamic_loss_scaling=False + ) + + # construct model b + model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size) + scheduler_b = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True + ) + optimizer_b = paddle.optimizer.SGD( + learning_rate=scheduler_b, + grad_clip=grad_clip, + parameters=model_b.parameters(), + ) + + param_len = len(model_a.parameters()) + parameters = [] + for param in model_a.parameters(): + parameters.append(param.numpy()) + + for idx, param in enumerate(model_b.parameters()): + param.set_value(parameters[idx + pp_id * (param_len // 2)]) + + model_a, optimizer_a = paddle.amp.decorate( + models=model_a, + optimizers=optimizer_a, + level='O2', + dtype='bfloat16', + save_dtype='float32', + ) + model_b, optimizer_b = paddle.amp.decorate( + models=model_b, + optimizers=optimizer_b, + level='O2', + dtype='bfloat16', + save_dtype='float32', + ) + + model_b = fleet.distributed_model(model_b) + optimizer_b = fleet.distributed_optimizer(optimizer_b) + scaler_b = paddle.amp.GradScaler( + init_loss_scaling=1, use_dynamic_loss_scaling=False + ) + scaler_b = fleet.distributed_scaler(scaler_b) + + # construct reader + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True + ) + + for step_id, data in enumerate(train_reader()): + x_data = ( + np.array([x[0] for x in data]) + .astype('float32') + .reshape(batch_size, 1, 28, 28) + ) + y_data = ( + np.array([x[1] for x in data]) + .astype('int64') + .reshape(batch_size, 1) + ) + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + img.stop_gradient = True + label.stop_gradient = True + + if step_id >= 5: + return True + + with paddle.amp.auto_cast( + enable=True, dtype='bfloat16', level='O2' + ): + loss_a = model_a(img, label) + scaler_a.scale(loss_a).backward() + scaler_a.minimize(optimizer_a, loss_a) + optimizer_a.clear_grad() + scheduler_a.step() + + with paddle.amp.auto_cast( + enable=True, dtype='bfloat16', level='O2' + ): + loss_b = model_b.train_batch( + [img, label], optimizer_b, scheduler_b, scaler=scaler_b + ) + + print("loss: ", loss_a.numpy(), loss_b.numpy()) + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=5e-3 + ) + + +if __name__ == "__main__": + if check_nccl_version_for_bf16(): + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py index 4bf669dbe57744c8c51160f47a9d33905db0dd9e..d1a0f5a6d38292ca681253f528740fefa6c3680d 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py @@ -36,6 +36,9 @@ class TestHybridPipeParallel(TestMultipleGpus): def test_pipeline_parallel_fp16(self): self.run_mnist_2gpu('hybrid_parallel_pp_fp16.py') + def test_pipeline_parallel_bf16(self): + self.run_mnist_2gpu('hybrid_parallel_pp_bf16.py') + def test_hybrid_parallel_transformer(self): self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py') diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_tensor_parallel.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_tensor_parallel.py index 45235965c2386589a550ada6d0b472c8c29f645b..d6112a0671506926e182a34c1c08f7b6fc71f23c 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_tensor_parallel.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_tensor_parallel.py @@ -30,6 +30,9 @@ class TestHybridParallel(TestMultipleGpus): def test_hybrid_parallel_mp_fp16(self): self.run_mnist_2gpu('hybrid_parallel_mp_fp16.py') + def test_hybrid_parallel_mp_bf16(self): + self.run_mnist_2gpu('hybrid_parallel_mp_bf16.py') + def test_hybrid_parallel_mp_clip_grad(self): self.run_mnist_2gpu('hybrid_parallel_mp_clip_grad.py')