From ccbd03d5fb0473f4a35955ebf5ea4d8656551e12 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 17 Nov 2022 11:06:26 +0800 Subject: [PATCH] Add vectorized bfloat16 atomicAdd (#48056) * add vectorized bfloat16 atomicAdd * fix compile error * fix compile error again * fix V100 compile error * fix V100 compile again --- .../platform/device/gpu/gpu_primitives.h | 96 ++++++++++------- paddle/phi/backends/gpu/gpu_primitives.h | 101 ++++++++++-------- .../unittests/test_bfloat16_embedding.py | 79 ++++++++++++++ 3 files changed, 192 insertions(+), 84 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_bfloat16_embedding.py diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index 96eddf09237..4df203b48bb 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -151,47 +151,68 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { } #endif +template +struct VecAtomicAddHelperBase { + static constexpr auto kIsAvailable = IsAvailable; + using NVT = NVType; + using NVVec2T = NVVec2Type; +}; + +template +struct VecAtomicAddHelper : VecAtomicAddHelperBase {}; + +#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 +template <> +struct VecAtomicAddHelper + : VecAtomicAddHelperBase {}; +#endif + +#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +struct VecAtomicAddHelper + : VecAtomicAddHelperBase {}; +#endif + // The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )" // is good. So for fp16 type, we can use "atomicAdd(half2* )" to speed up. template ::value>::type * = nullptr> + typename std::enable_if::kIsAvailable>::type * = + nullptr> __device__ __forceinline__ void fastAtomicAdd(T *tensor, size_t index, const size_t numel, T value) { -#if ((CUDA_VERSION < 10000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) - CudaAtomicAdd(reinterpret_cast(tensor) + index, - static_cast(value)); -#else // whether the address is 32-byte aligned. - __half *target_addr = reinterpret_cast<__half *>(tensor + index); + using NVT = typename VecAtomicAddHelper::NVT; + using NVVec2T = typename VecAtomicAddHelper::NVVec2T; + NVT *target_addr = reinterpret_cast(tensor + index); bool aligned_half2 = - (reinterpret_cast(target_addr) % sizeof(__half2) == 0); + (reinterpret_cast(target_addr) % sizeof(NVVec2T) == 0); if (aligned_half2 && index < (numel - 1)) { - __half2 value2; - value2.x = *reinterpret_cast<__half *>(&value); - value2.y = __int2half_rz(0); - atomicAdd(reinterpret_cast<__half2 *>(target_addr), value2); + NVVec2T value2; + value2.x = *reinterpret_cast(&value); + value2.y = 0.0; + atomicAdd(reinterpret_cast(target_addr), value2); } else if (!aligned_half2 && index > 0) { - __half2 value2; - value2.x = __int2half_rz(0); - value2.y = *reinterpret_cast<__half *>(&value); - atomicAdd(reinterpret_cast<__half2 *>(target_addr - 1), value2); + NVVec2T value2; + value2.x = 0.0; + value2.y = *reinterpret_cast(&value); + atomicAdd(reinterpret_cast(target_addr - 1), value2); } else { - atomicAdd(reinterpret_cast<__half *>(tensor) + index, - *reinterpret_cast<__half *>(&value)); + atomicAdd(reinterpret_cast(tensor) + index, + *reinterpret_cast(&value)); } -#endif } template ::value>::type * = nullptr> + typename std::enable_if::kIsAvailable>::type + * = nullptr> __device__ __forceinline__ void fastAtomicAdd(T *arr, size_t index, const size_t numel, @@ -546,16 +567,16 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { } #endif -#ifdef PADDLE_CUDA_FP16 #ifdef PADDLE_WITH_CUDA /* * One thead block deals with elementwise atomicAdd for vector of len. * @in: [x1, x2, x3, ...] * @out:[y1+x1, y2+x2, y3+x3, ...] * */ + template ::value>::type * = nullptr> + typename std::enable_if::kIsAvailable>::type + * = nullptr> __device__ __forceinline__ void VectorizedAtomicAddPerBlock( const int64_t len, int tid, int threads_per_block, const T *in, T *out) { for (int i = tid; i < len; i += threads_per_block) { @@ -565,30 +586,26 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock( // Note: assume that len is even. If len is odd, call fastAtomicAdd directly. template ::value>::type * = nullptr> + typename std::enable_if::kIsAvailable>::type * = + nullptr> __device__ __forceinline__ void VectorizedAtomicAddPerBlock( const int64_t len, int tid, int threads_per_block, const T *in, T *out) { -#if ((CUDA_VERSION < 10000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) - for (int i = tid; i < len; i += threads_per_block) { - CudaAtomicAdd(&out[i], in[i]); - } -#else int i = 0; int loops = len / 2 * 2; + using NVT = typename VecAtomicAddHelper::NVT; + using NVVec2T = typename VecAtomicAddHelper::NVVec2T; bool aligned_half2 = - (reinterpret_cast(out) % sizeof(__half2) == 0); + (reinterpret_cast(out) % sizeof(NVT) == 0); if (aligned_half2) { for (i = tid * 2; i < loops; i += threads_per_block * 2) { - __half2 value2; + NVVec2T value2; T value_1 = in[i]; T value_2 = in[i + 1]; - value2.x = *reinterpret_cast<__half *>(&value_1); - value2.y = *reinterpret_cast<__half *>(&value_2); - atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2); + value2.x = *reinterpret_cast(&value_1); + value2.y = *reinterpret_cast(&value_2); + atomicAdd(reinterpret_cast(&out[i]), value2); } for (; i < len; i += threads_per_block) { fastAtomicAdd(out, i, len, in[i]); @@ -598,9 +615,8 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock( fastAtomicAdd(out, i, len, in[i]); } } -#endif } -#endif + #endif } // namespace platform } // namespace paddle diff --git a/paddle/phi/backends/gpu/gpu_primitives.h b/paddle/phi/backends/gpu/gpu_primitives.h index be08f29aa81..12f58257cf0 100644 --- a/paddle/phi/backends/gpu/gpu_primitives.h +++ b/paddle/phi/backends/gpu/gpu_primitives.h @@ -156,47 +156,65 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { } #endif +template +struct VecAtomicAddHelperBase { + static constexpr auto kIsAvailable = IsAvailable; + using NVT = NVType; + using NVVec2T = NVVec2Type; +}; + +template +struct VecAtomicAddHelper : VecAtomicAddHelperBase {}; + +#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 +template <> +struct VecAtomicAddHelper + : VecAtomicAddHelperBase {}; +#endif + +#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +struct VecAtomicAddHelper + : VecAtomicAddHelperBase {}; +#endif + // The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )" // is good. So for fp16 type, we can use "atomicAdd(half2* )" to speed up. -template < - typename T, - typename std::enable_if::value>::type * = nullptr> +template ::kIsAvailable>::type * = + nullptr> __device__ __forceinline__ void fastAtomicAdd(T *tensor, size_t index, const size_t numel, T value) { -#if ((CUDA_VERSION < 10000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) - CudaAtomicAdd(reinterpret_cast(tensor) + index, - static_cast(value)); -#else // whether the address is 32-byte aligned. - __half *target_addr = reinterpret_cast<__half *>(tensor + index); + using NVT = typename VecAtomicAddHelper::NVT; + using NVVec2T = typename VecAtomicAddHelper::NVVec2T; + NVT *target_addr = reinterpret_cast(tensor + index); bool aligned_half2 = - (reinterpret_cast(target_addr) % sizeof(__half2) == 0); + (reinterpret_cast(target_addr) % sizeof(NVVec2T) == 0); if (aligned_half2 && index < (numel - 1)) { - __half2 value2; - value2.x = *reinterpret_cast<__half *>(&value); - value2.y = __int2half_rz(0); - atomicAdd(reinterpret_cast<__half2 *>(target_addr), value2); + NVVec2T value2; + value2.x = *reinterpret_cast(&value); + value2.y = 0.0; + atomicAdd(reinterpret_cast(target_addr), value2); } else if (!aligned_half2 && index > 0) { - __half2 value2; - value2.x = __int2half_rz(0); - value2.y = *reinterpret_cast<__half *>(&value); - atomicAdd(reinterpret_cast<__half2 *>(target_addr - 1), value2); + NVVec2T value2; + value2.x = 0.0; + value2.y = *reinterpret_cast(&value); + atomicAdd(reinterpret_cast(target_addr - 1), value2); } else { - atomicAdd(reinterpret_cast<__half *>(tensor) + index, - *reinterpret_cast<__half *>(&value)); + atomicAdd(reinterpret_cast(tensor) + index, + *reinterpret_cast(&value)); } -#endif } -template < - typename T, - typename std::enable_if::value>::type * = nullptr> +template ::kIsAvailable>::type + * = nullptr> __device__ __forceinline__ void fastAtomicAdd(T *arr, size_t index, const size_t numel, @@ -551,16 +569,16 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { } #endif -#ifdef PADDLE_CUDA_FP16 #ifdef PADDLE_WITH_CUDA /* * One thead block deals with elementwise atomicAdd for vector of len. * @in: [x1, x2, x3, ...] * @out:[y1+x1, y2+x2, y3+x3, ...] * */ -template < - typename T, - typename std::enable_if::value>::type * = nullptr> + +template ::kIsAvailable>::type + * = nullptr> __device__ __forceinline__ void VectorizedAtomicAddPerBlock( const int64_t len, int tid, int threads_per_block, const T *in, T *out) { for (int i = tid; i < len; i += threads_per_block) { @@ -569,31 +587,27 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock( } // Note: assume that len is even. If len is odd, call fastAtomicAdd directly. -template < - typename T, - typename std::enable_if::value>::type * = nullptr> +template ::kIsAvailable>::type * = + nullptr> __device__ __forceinline__ void VectorizedAtomicAddPerBlock( const int64_t len, int tid, int threads_per_block, const T *in, T *out) { -#if ((CUDA_VERSION < 10000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) - for (int i = tid; i < len; i += threads_per_block) { - CudaAtomicAdd(&out[i], in[i]); - } -#else int i = 0; int loops = len / 2 * 2; + using NVT = typename VecAtomicAddHelper::NVT; + using NVVec2T = typename VecAtomicAddHelper::NVVec2T; bool aligned_half2 = - (reinterpret_cast(out) % sizeof(__half2) == 0); + (reinterpret_cast(out) % sizeof(NVT) == 0); if (aligned_half2) { for (i = tid * 2; i < loops; i += threads_per_block * 2) { - __half2 value2; + NVVec2T value2; T value_1 = in[i]; T value_2 = in[i + 1]; - value2.x = *reinterpret_cast<__half *>(&value_1); - value2.y = *reinterpret_cast<__half *>(&value_2); - atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2); + value2.x = *reinterpret_cast(&value_1); + value2.y = *reinterpret_cast(&value_2); + atomicAdd(reinterpret_cast(&out[i]), value2); } for (; i < len; i += threads_per_block) { fastAtomicAdd(out, i, len, in[i]); @@ -603,8 +617,7 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock( fastAtomicAdd(out, i, len, in[i]); } } -#endif } -#endif + #endif } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_bfloat16_embedding.py b/python/paddle/fluid/tests/unittests/test_bfloat16_embedding.py new file mode 100644 index 00000000000..e86c45cf541 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bfloat16_embedding.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest +import paddle.nn.functional as F +from test_sparse_attention_op import get_cuda_version + + +class BF16EmbeddingTest(unittest.TestCase): + def setUp(self): + self.batch_size = 30 + self.vocab_size = 1024 + self.hidden_size = 512 + self.seed = 10 + + def run_main(self, dtype): + ids, weight, dout = self.gen_random() + origin_dtype = weight.dtype + weight_cast = weight.astype(dtype) + out = F.embedding(ids, weight_cast) + dout = dout.astype(out.dtype) + dweight = paddle.autograd.grad(out, weight, dout) + return ( + out.astype(origin_dtype).numpy(), + dweight[0].astype(origin_dtype).numpy(), + ) + + def gen_random(self): + np.random.seed(self.seed) + weight = np.random.random([self.vocab_size, self.hidden_size]).astype( + 'float32' + ) + ids = np.random.randint( + low=0, high=self.vocab_size, size=[self.batch_size] + ) + dout = np.random.random([self.batch_size, self.hidden_size]).astype( + 'float32' + ) + + weight = paddle.to_tensor(weight) + weight.stop_gradient = False + ids = paddle.to_tensor(ids) + dout = paddle.to_tensor(dout) + return ids, weight, dout + + def test_main(self): + if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000: + return + + ret1 = self.run_main('float32') + ret2 = self.run_main('bfloat16') + self.assertEqual(len(ret1), len(ret2)) + for i, (r1, r2) in enumerate(zip(ret1, ret2)): + np.testing.assert_allclose(r1, r2, atol=1e-3, rtol=1e-2) + + +class BF16EmbeddingTestOddHiddenSize(BF16EmbeddingTest): + def setUp(self): + self.batch_size = 30 + self.vocab_size = 511 + self.hidden_size = 512 + self.seed = 20 + + +if __name__ == "__main__": + unittest.main() -- GitLab