diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 75a16bac37130721a8e882b1585304e4e2d185a7..0d46425b2e83274fb5eb62306e11a0b0a6d7221a 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -254,6 +254,10 @@ static void ConcatTensorsWithType( ConcatTensorsForAllReduce()( context, dense_tensors_, p_dense_contents); break; + case phi::DataType::BFLOAT16: + ConcatTensorsForAllReduce()( + context, dense_tensors_, p_dense_contents); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors for " @@ -281,6 +285,10 @@ static void SplitTensorsWithType(const DeviceContext &context, SplitTensorsForAllReduce()( context, p_dense_contents, p_dense_tensors); break; + case phi::DataType::BFLOAT16: + SplitTensorsForAllReduce()( + context, p_dense_contents, p_dense_tensors); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors for " diff --git a/paddle/phi/kernels/cpu/fill_grad_kernel.cc b/paddle/phi/kernels/cpu/fill_grad_kernel.cc index ee676773762ca5987aceb1b007ab2196de792d59..07448c85a57d6054cb82e15d3a64cac2505c0234 100644 --- a/paddle/phi/kernels/cpu/fill_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_grad_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/cpu/fill_kernel.cc b/paddle/phi/kernels/cpu/fill_kernel.cc index ee8dac7f6770c40b2192d18016b8bc2582dd6d33..adca39e6ab95d36b1e5c8a90ac4abd2fc4512a26 100644 --- a/paddle/phi/kernels/cpu/fill_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_grad_kernel.cu index 32559ba95dfbca8abdaf0539182879e3953effca..e18bb5c6dbb2446eae98969075a6da8ecd6fd4bd 100644 --- a/paddle/phi/kernels/gpu/fill_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_kernel.cu b/paddle/phi/kernels/gpu/fill_kernel.cu index 141e47b8cb109bd3be55311c997f88ec117a6e3c..3fedb4118ff9e111828ac832bd25b04101060be1 100644 --- a/paddle/phi/kernels/gpu/fill_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 4a5d6c85f855b3e6b5ec97addcd873f6c9e68d96..3f479d073c97eab95b9ade2082a4dbe7b70ae881 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -40,6 +40,7 @@ from .group_sharded_utils import Type, device_guard, GroupShardedClipGrad alignment = {"gpu": 256, "cpu": 4096} align = { Type.fp16.value: 2, + Type.bf16.value: 2, Type.fp32.value: 4, } diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 573e0b597c8fb1ad527862986be10e87a4f74733..5876df9d3ff74e417dcd2e97ddae4c17a2f9edae 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -555,6 +555,12 @@ class GroupShardedStage2(nn.Layer): "====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" .format(rank_buffer_size[Type.fp16.value] / 2**19, model_size / 2**19)) + if Type.bf16.value in rank_buffer_size.keys(): + # FP16 GradStorage and model size + logger_.info( + "====== BF16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" + .format(rank_buffer_size[Type.bf16.value] / 2**19, + model_size / 2**19)) if Type.fp32.value in rank_buffer_size.keys(): # FP32 GradStorage and model size logger_.info( diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 219090d94672b4f2a455c122593a7a81b7fa7c56..66e9617f00efd22140a1f306cabe406996969418 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -51,6 +51,8 @@ class InternalStorage: dtype=np.float16) if Type.fp16.value == dtype else np.zeros( size, dtype=np.float32) self.buffer = core.eager.Tensor(value=value, place=core.CPUPlace()) + if dtype == Type.bf16.value: + self.buffer = paddle.cast(self.buffer, dtype=paddle.bfloat16) else: self.buffer = paddle.zeros(size, dtype=dtype) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 86ed36799cb8dbd269d524e86a060fa1cfae2f38..0981b2cab76fd43a969a901de302a4d6dc68ffd9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -40,6 +40,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32 diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index 2303a61cdb3986c07051e15064d944a1469ba84a..0e7725e3e21f875ea1defa657e93198cd72e324b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -41,6 +41,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32