From 5c85f1a7d621f6b588a87abec38649b478975187 Mon Sep 17 00:00:00 2001 From: Ghost Screaming Date: Mon, 24 Oct 2022 11:30:41 +0800 Subject: [PATCH] Support BF16 training for sharding (#46846) (#47246) * Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result is wrong. * support pure bfloat16 * support bf16 linear * update PR to pass CI * tiny fix where_grad_kernel.cu * Support bfloat16 type for reducer and sharding. * Fix some bug. * Polish code. * Polise code. * Add bfloat16 datatype in fill_grad kernels. Co-authored-by: sneaxiy Co-authored-by: sneaxiy --- paddle/fluid/distributed/collective/reducer.cc | 8 ++++++++ paddle/phi/kernels/cpu/fill_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/fill_kernel.cc | 1 + paddle/phi/kernels/gpu/fill_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/fill_kernel.cu | 1 + .../sharding/group_sharded_optimizer_stage2.py | 1 + .../fleet/meta_parallel/sharding/group_sharded_stage2.py | 6 ++++++ .../fleet/meta_parallel/sharding/group_sharded_storage.py | 2 ++ .../fleet/meta_parallel/sharding/group_sharded_utils.py | 1 + .../fleet/meta_parallel/sharding/sharding_utils.py | 1 + 10 files changed, 23 insertions(+) diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 75a16bac37..0d46425b2e 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -254,6 +254,10 @@ static void ConcatTensorsWithType( ConcatTensorsForAllReduce()( context, dense_tensors_, p_dense_contents); break; + case phi::DataType::BFLOAT16: + ConcatTensorsForAllReduce()( + context, dense_tensors_, p_dense_contents); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors for " @@ -281,6 +285,10 @@ static void SplitTensorsWithType(const DeviceContext &context, SplitTensorsForAllReduce()( context, p_dense_contents, p_dense_tensors); break; + case phi::DataType::BFLOAT16: + SplitTensorsForAllReduce()( + context, p_dense_contents, p_dense_tensors); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors for " diff --git a/paddle/phi/kernels/cpu/fill_grad_kernel.cc b/paddle/phi/kernels/cpu/fill_grad_kernel.cc index ee67677376..07448c85a5 100644 --- a/paddle/phi/kernels/cpu/fill_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_grad_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/cpu/fill_kernel.cc b/paddle/phi/kernels/cpu/fill_kernel.cc index ee8dac7f67..adca39e6ab 100644 --- a/paddle/phi/kernels/cpu/fill_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_grad_kernel.cu index 32559ba95d..e18bb5c6db 100644 --- a/paddle/phi/kernels/gpu/fill_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_kernel.cu b/paddle/phi/kernels/gpu/fill_kernel.cu index 141e47b8cb..3fedb4118f 100644 --- a/paddle/phi/kernels/gpu/fill_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index de5743a226..beda2401b7 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -43,6 +43,7 @@ from .group_sharded_utils import Type, device_guard, GroupShardedClipGrad alignment = {"gpu": 256, "cpu": 4096} align = { Type.fp16.value: 2, + Type.bf16.value: 2, Type.fp32.value: 4, } diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 4caf2d6013..3f3ab817e9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -532,6 +532,12 @@ class GroupShardedStage2(nn.Layer): "====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" .format(rank_buffer_size[Type.fp16.value] / 2**19, model_size / 2**19)) + if Type.bf16.value in rank_buffer_size.keys(): + # FP16 GradStorage and model size + logger_.info( + "====== BF16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" + .format(rank_buffer_size[Type.bf16.value] / 2**19, + model_size / 2**19)) if Type.fp32.value in rank_buffer_size.keys(): # FP32 GradStorage and model size logger_.info( diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index c448724910..5b9ab7343f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -53,6 +53,8 @@ class InternalStorage: dtype=np.float16) if Type.fp16.value == dtype else np.zeros( size, dtype=np.float32) self.buffer = core.eager.Tensor(value=value, place=core.CPUPlace()) + if dtype == Type.bf16.value: + self.buffer = paddle.cast(self.buffer, dtype=paddle.bfloat16) else: self.buffer = paddle.zeros(size, dtype=dtype) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 8cff407363..7eb7b1e878 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -41,6 +41,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32 diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index d21502bcc1..42f43ce537 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -45,6 +45,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32 -- GitLab