From 0b39b244f1567f5fb8dc89e888ded57f5daf792c Mon Sep 17 00:00:00 2001 From: Ghost Screaming Date: Mon, 17 Oct 2022 22:17:36 +0800 Subject: [PATCH] Support BF16 training for sharding (#46846) * Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result is wrong. * support pure bfloat16 * support bf16 linear * update PR to pass CI * tiny fix where_grad_kernel.cu * Support bfloat16 type for reducer and sharding. * Fix some bug. * Polish code. * Polise code. * Add bfloat16 datatype in fill_grad kernels. Co-authored-by: sneaxiy --- paddle/fluid/distributed/collective/reducer.cc | 8 ++++++++ paddle/phi/kernels/cpu/fill_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/fill_kernel.cc | 1 + paddle/phi/kernels/gpu/fill_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/fill_kernel.cu | 1 + .../sharding/group_sharded_optimizer_stage2.py | 1 + .../fleet/meta_parallel/sharding/group_sharded_stage2.py | 6 ++++++ .../fleet/meta_parallel/sharding/group_sharded_storage.py | 2 ++ .../fleet/meta_parallel/sharding/group_sharded_utils.py | 1 + .../fleet/meta_parallel/sharding/sharding_utils.py | 1 + 10 files changed, 23 insertions(+) diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 75a16bac371..0d46425b2e8 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -254,6 +254,10 @@ static void ConcatTensorsWithType( ConcatTensorsForAllReduce()( context, dense_tensors_, p_dense_contents); break; + case phi::DataType::BFLOAT16: + ConcatTensorsForAllReduce()( + context, dense_tensors_, p_dense_contents); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors for " @@ -281,6 +285,10 @@ static void SplitTensorsWithType(const DeviceContext &context, SplitTensorsForAllReduce()( context, p_dense_contents, p_dense_tensors); break; + case phi::DataType::BFLOAT16: + SplitTensorsForAllReduce()( + context, p_dense_contents, p_dense_tensors); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors for " diff --git a/paddle/phi/kernels/cpu/fill_grad_kernel.cc b/paddle/phi/kernels/cpu/fill_grad_kernel.cc index ee676773762..07448c85a57 100644 --- a/paddle/phi/kernels/cpu/fill_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_grad_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/cpu/fill_kernel.cc b/paddle/phi/kernels/cpu/fill_kernel.cc index ee8dac7f677..adca39e6ab9 100644 --- a/paddle/phi/kernels/cpu/fill_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_grad_kernel.cu index 32559ba95df..e18bb5c6dbb 100644 --- a/paddle/phi/kernels/gpu/fill_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_kernel.cu b/paddle/phi/kernels/gpu/fill_kernel.cu index 141e47b8cb1..3fedb4118ff 100644 --- a/paddle/phi/kernels/gpu/fill_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 4a5d6c85f85..3f479d073c9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -40,6 +40,7 @@ from .group_sharded_utils import Type, device_guard, GroupShardedClipGrad alignment = {"gpu": 256, "cpu": 4096} align = { Type.fp16.value: 2, + Type.bf16.value: 2, Type.fp32.value: 4, } diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 573e0b597c8..5876df9d3ff 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -555,6 +555,12 @@ class GroupShardedStage2(nn.Layer): "====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" .format(rank_buffer_size[Type.fp16.value] / 2**19, model_size / 2**19)) + if Type.bf16.value in rank_buffer_size.keys(): + # FP16 GradStorage and model size + logger_.info( + "====== BF16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" + .format(rank_buffer_size[Type.bf16.value] / 2**19, + model_size / 2**19)) if Type.fp32.value in rank_buffer_size.keys(): # FP32 GradStorage and model size logger_.info( diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 219090d9467..66e9617f00e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -51,6 +51,8 @@ class InternalStorage: dtype=np.float16) if Type.fp16.value == dtype else np.zeros( size, dtype=np.float32) self.buffer = core.eager.Tensor(value=value, place=core.CPUPlace()) + if dtype == Type.bf16.value: + self.buffer = paddle.cast(self.buffer, dtype=paddle.bfloat16) else: self.buffer = paddle.zeros(size, dtype=dtype) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 86ed36799cb..0981b2cab76 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -40,6 +40,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32 diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index 2303a61cdb3..0e7725e3e21 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -41,6 +41,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32 -- GitLab