diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 75a16bac37130721a8e882b1585304e4e2d185a7..0d46425b2e83274fb5eb62306e11a0b0a6d7221a 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -254,6 +254,10 @@ static void ConcatTensorsWithType( ConcatTensorsForAllReduce()( context, dense_tensors_, p_dense_contents); break; + case phi::DataType::BFLOAT16: + ConcatTensorsForAllReduce()( + context, dense_tensors_, p_dense_contents); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors for " @@ -281,6 +285,10 @@ static void SplitTensorsWithType(const DeviceContext &context, SplitTensorsForAllReduce()( context, p_dense_contents, p_dense_tensors); break; + case phi::DataType::BFLOAT16: + SplitTensorsForAllReduce()( + context, p_dense_contents, p_dense_tensors); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors for " diff --git a/paddle/phi/kernels/cpu/fill_grad_kernel.cc b/paddle/phi/kernels/cpu/fill_grad_kernel.cc index ee676773762ca5987aceb1b007ab2196de792d59..07448c85a57d6054cb82e15d3a64cac2505c0234 100644 --- a/paddle/phi/kernels/cpu/fill_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_grad_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/cpu/fill_kernel.cc b/paddle/phi/kernels/cpu/fill_kernel.cc index ee8dac7f6770c40b2192d18016b8bc2582dd6d33..adca39e6ab95d36b1e5c8a90ac4abd2fc4512a26 100644 --- a/paddle/phi/kernels/cpu/fill_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_grad_kernel.cu index 32559ba95dfbca8abdaf0539182879e3953effca..e18bb5c6dbb2446eae98969075a6da8ecd6fd4bd 100644 --- a/paddle/phi/kernels/gpu/fill_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_kernel.cu b/paddle/phi/kernels/gpu/fill_kernel.cu index 141e47b8cb109bd3be55311c997f88ec117a6e3c..3fedb4118ff9e111828ac832bd25b04101060be1 100644 --- a/paddle/phi/kernels/gpu/fill_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index de5743a22668384d4420573a198f330acc1a88e7..beda2401b7573e00659e4d01105e3c0e40fed82a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -43,6 +43,7 @@ from .group_sharded_utils import Type, device_guard, GroupShardedClipGrad alignment = {"gpu": 256, "cpu": 4096} align = { Type.fp16.value: 2, + Type.bf16.value: 2, Type.fp32.value: 4, } diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 4caf2d6013a4f41ac8ef40254fc2620fe9e7c4ba..3f3ab817e91461d26fad680767d62352e4b10c6c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -532,6 +532,12 @@ class GroupShardedStage2(nn.Layer): "====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" .format(rank_buffer_size[Type.fp16.value] / 2**19, model_size / 2**19)) + if Type.bf16.value in rank_buffer_size.keys(): + # FP16 GradStorage and model size + logger_.info( + "====== BF16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" + .format(rank_buffer_size[Type.bf16.value] / 2**19, + model_size / 2**19)) if Type.fp32.value in rank_buffer_size.keys(): # FP32 GradStorage and model size logger_.info( diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index c44872491093ec09c579f3f56ce1177e7f42236f..5b9ab7343f08ca244217160c2f15299d8a72511c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -53,6 +53,8 @@ class InternalStorage: dtype=np.float16) if Type.fp16.value == dtype else np.zeros( size, dtype=np.float32) self.buffer = core.eager.Tensor(value=value, place=core.CPUPlace()) + if dtype == Type.bf16.value: + self.buffer = paddle.cast(self.buffer, dtype=paddle.bfloat16) else: self.buffer = paddle.zeros(size, dtype=dtype) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 8cff407363a3b7e730817ee8e093013e6db2c5a7..7eb7b1e8784aa9e912a6963abf29fa287486b9f0 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -41,6 +41,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32 diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index d21502bcc16b8853c29699484465ef23bfe9ff2c..42f43ce5377484412bbc69b50d4848fd5f6a9f58 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -45,6 +45,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32