未验证 提交 ccbd03d5 编写于 作者: S sneaxiy 提交者: GitHub

Add vectorized bfloat16 atomicAdd (#48056)

* add vectorized bfloat16 atomicAdd

* fix compile error

* fix compile error again

* fix V100 compile error

* fix V100 compile again
上级 33d81aa4
......@@ -151,47 +151,68 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
}
#endif
template <typename T, bool IsAvailable, typename NVType, typename NVVec2Type>
struct VecAtomicAddHelperBase {
static constexpr auto kIsAvailable = IsAvailable;
using NVT = NVType;
using NVVec2T = NVVec2Type;
};
template <typename T>
struct VecAtomicAddHelper : VecAtomicAddHelperBase<T, false, void, void> {};
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
template <>
struct VecAtomicAddHelper<platform::float16>
: VecAtomicAddHelperBase<platform::float16, true, __half, __half2> {};
#endif
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
struct VecAtomicAddHelper<platform::bfloat16>
: VecAtomicAddHelperBase<platform::bfloat16,
true,
__nv_bfloat16,
__nv_bfloat162> {};
#endif
// The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )"
// is good. So for fp16 type, we can use "atomicAdd(half2* )" to speed up.
template <typename T,
typename std::enable_if<
std::is_same<platform::float16, T>::value>::type * = nullptr>
typename std::enable_if<VecAtomicAddHelper<T>::kIsAvailable>::type * =
nullptr>
__device__ __forceinline__ void fastAtomicAdd(T *tensor,
size_t index,
const size_t numel,
T value) {
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
CudaAtomicAdd(reinterpret_cast<platform::float16 *>(tensor) + index,
static_cast<platform::float16>(value));
#else
// whether the address is 32-byte aligned.
__half *target_addr = reinterpret_cast<__half *>(tensor + index);
using NVT = typename VecAtomicAddHelper<T>::NVT;
using NVVec2T = typename VecAtomicAddHelper<T>::NVVec2T;
NVT *target_addr = reinterpret_cast<NVT *>(tensor + index);
bool aligned_half2 =
(reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
(reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(NVVec2T) == 0);
if (aligned_half2 && index < (numel - 1)) {
__half2 value2;
value2.x = *reinterpret_cast<__half *>(&value);
value2.y = __int2half_rz(0);
atomicAdd(reinterpret_cast<__half2 *>(target_addr), value2);
NVVec2T value2;
value2.x = *reinterpret_cast<NVT *>(&value);
value2.y = 0.0;
atomicAdd(reinterpret_cast<NVVec2T *>(target_addr), value2);
} else if (!aligned_half2 && index > 0) {
__half2 value2;
value2.x = __int2half_rz(0);
value2.y = *reinterpret_cast<__half *>(&value);
atomicAdd(reinterpret_cast<__half2 *>(target_addr - 1), value2);
NVVec2T value2;
value2.x = 0.0;
value2.y = *reinterpret_cast<NVT *>(&value);
atomicAdd(reinterpret_cast<NVVec2T *>(target_addr - 1), value2);
} else {
atomicAdd(reinterpret_cast<__half *>(tensor) + index,
*reinterpret_cast<__half *>(&value));
atomicAdd(reinterpret_cast<NVT *>(tensor) + index,
*reinterpret_cast<NVT *>(&value));
}
#endif
}
template <typename T,
typename std::enable_if<
!std::is_same<platform::float16, T>::value>::type * = nullptr>
typename std::enable_if<!VecAtomicAddHelper<T>::kIsAvailable>::type
* = nullptr>
__device__ __forceinline__ void fastAtomicAdd(T *arr,
size_t index,
const size_t numel,
......@@ -546,16 +567,16 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
}
#endif
#ifdef PADDLE_CUDA_FP16
#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
* @in: [x1, x2, x3, ...]
* @out:[y1+x1, y2+x2, y3+x3, ...]
* */
template <typename T,
typename std::enable_if<
!std::is_same<platform::float16, T>::value>::type * = nullptr>
typename std::enable_if<!VecAtomicAddHelper<T>::kIsAvailable>::type
* = nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
for (int i = tid; i < len; i += threads_per_block) {
......@@ -565,30 +586,26 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock(
// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
template <typename T,
typename std::enable_if<
std::is_same<platform::float16, T>::value>::type * = nullptr>
typename std::enable_if<VecAtomicAddHelper<T>::kIsAvailable>::type * =
nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
for (int i = tid; i < len; i += threads_per_block) {
CudaAtomicAdd(&out[i], in[i]);
}
#else
int i = 0;
int loops = len / 2 * 2;
using NVT = typename VecAtomicAddHelper<T>::NVT;
using NVVec2T = typename VecAtomicAddHelper<T>::NVVec2T;
bool aligned_half2 =
(reinterpret_cast<std::uintptr_t>(out) % sizeof(__half2) == 0);
(reinterpret_cast<std::uintptr_t>(out) % sizeof(NVT) == 0);
if (aligned_half2) {
for (i = tid * 2; i < loops; i += threads_per_block * 2) {
__half2 value2;
NVVec2T value2;
T value_1 = in[i];
T value_2 = in[i + 1];
value2.x = *reinterpret_cast<__half *>(&value_1);
value2.y = *reinterpret_cast<__half *>(&value_2);
atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2);
value2.x = *reinterpret_cast<NVT *>(&value_1);
value2.y = *reinterpret_cast<NVT *>(&value_2);
atomicAdd(reinterpret_cast<NVVec2T *>(&out[i]), value2);
}
for (; i < len; i += threads_per_block) {
fastAtomicAdd(out, i, len, in[i]);
......@@ -598,9 +615,8 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock(
fastAtomicAdd(out, i, len, in[i]);
}
}
#endif
}
#endif
#endif
} // namespace platform
} // namespace paddle
......@@ -156,47 +156,65 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
}
#endif
template <typename T, bool IsAvailable, typename NVType, typename NVVec2Type>
struct VecAtomicAddHelperBase {
static constexpr auto kIsAvailable = IsAvailable;
using NVT = NVType;
using NVVec2T = NVVec2Type;
};
template <typename T>
struct VecAtomicAddHelper : VecAtomicAddHelperBase<T, false, void, void> {};
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
template <>
struct VecAtomicAddHelper<float16>
: VecAtomicAddHelperBase<float16, true, __half, __half2> {};
#endif
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
struct VecAtomicAddHelper<bfloat16>
: VecAtomicAddHelperBase<bfloat16, true, __nv_bfloat16, __nv_bfloat162> {};
#endif
// The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )"
// is good. So for fp16 type, we can use "atomicAdd(half2* )" to speed up.
template <
typename T,
typename std::enable_if<std::is_same<float16, T>::value>::type * = nullptr>
template <typename T,
typename std::enable_if<VecAtomicAddHelper<T>::kIsAvailable>::type * =
nullptr>
__device__ __forceinline__ void fastAtomicAdd(T *tensor,
size_t index,
const size_t numel,
T value) {
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
CudaAtomicAdd(reinterpret_cast<float16 *>(tensor) + index,
static_cast<float16>(value));
#else
// whether the address is 32-byte aligned.
__half *target_addr = reinterpret_cast<__half *>(tensor + index);
using NVT = typename VecAtomicAddHelper<T>::NVT;
using NVVec2T = typename VecAtomicAddHelper<T>::NVVec2T;
NVT *target_addr = reinterpret_cast<NVT *>(tensor + index);
bool aligned_half2 =
(reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
(reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(NVVec2T) == 0);
if (aligned_half2 && index < (numel - 1)) {
__half2 value2;
value2.x = *reinterpret_cast<__half *>(&value);
value2.y = __int2half_rz(0);
atomicAdd(reinterpret_cast<__half2 *>(target_addr), value2);
NVVec2T value2;
value2.x = *reinterpret_cast<NVT *>(&value);
value2.y = 0.0;
atomicAdd(reinterpret_cast<NVVec2T *>(target_addr), value2);
} else if (!aligned_half2 && index > 0) {
__half2 value2;
value2.x = __int2half_rz(0);
value2.y = *reinterpret_cast<__half *>(&value);
atomicAdd(reinterpret_cast<__half2 *>(target_addr - 1), value2);
NVVec2T value2;
value2.x = 0.0;
value2.y = *reinterpret_cast<NVT *>(&value);
atomicAdd(reinterpret_cast<NVVec2T *>(target_addr - 1), value2);
} else {
atomicAdd(reinterpret_cast<__half *>(tensor) + index,
*reinterpret_cast<__half *>(&value));
atomicAdd(reinterpret_cast<NVT *>(tensor) + index,
*reinterpret_cast<NVT *>(&value));
}
#endif
}
template <
typename T,
typename std::enable_if<!std::is_same<float16, T>::value>::type * = nullptr>
template <typename T,
typename std::enable_if<!VecAtomicAddHelper<T>::kIsAvailable>::type
* = nullptr>
__device__ __forceinline__ void fastAtomicAdd(T *arr,
size_t index,
const size_t numel,
......@@ -551,16 +569,16 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
}
#endif
#ifdef PADDLE_CUDA_FP16
#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
* @in: [x1, x2, x3, ...]
* @out:[y1+x1, y2+x2, y3+x3, ...]
* */
template <
typename T,
typename std::enable_if<!std::is_same<float16, T>::value>::type * = nullptr>
template <typename T,
typename std::enable_if<!VecAtomicAddHelper<T>::kIsAvailable>::type
* = nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
for (int i = tid; i < len; i += threads_per_block) {
......@@ -569,31 +587,27 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock(
}
// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
template <
typename T,
typename std::enable_if<std::is_same<float16, T>::value>::type * = nullptr>
template <typename T,
typename std::enable_if<VecAtomicAddHelper<T>::kIsAvailable>::type * =
nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
for (int i = tid; i < len; i += threads_per_block) {
CudaAtomicAdd(&out[i], in[i]);
}
#else
int i = 0;
int loops = len / 2 * 2;
using NVT = typename VecAtomicAddHelper<T>::NVT;
using NVVec2T = typename VecAtomicAddHelper<T>::NVVec2T;
bool aligned_half2 =
(reinterpret_cast<std::uintptr_t>(out) % sizeof(__half2) == 0);
(reinterpret_cast<std::uintptr_t>(out) % sizeof(NVT) == 0);
if (aligned_half2) {
for (i = tid * 2; i < loops; i += threads_per_block * 2) {
__half2 value2;
NVVec2T value2;
T value_1 = in[i];
T value_2 = in[i + 1];
value2.x = *reinterpret_cast<__half *>(&value_1);
value2.y = *reinterpret_cast<__half *>(&value_2);
atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2);
value2.x = *reinterpret_cast<NVT *>(&value_1);
value2.y = *reinterpret_cast<NVT *>(&value_2);
atomicAdd(reinterpret_cast<NVVec2T *>(&out[i]), value2);
}
for (; i < len; i += threads_per_block) {
fastAtomicAdd(out, i, len, in[i]);
......@@ -603,8 +617,7 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock(
fastAtomicAdd(out, i, len, in[i]);
}
}
#endif
}
#endif
#endif
} // namespace phi
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
import paddle.nn.functional as F
from test_sparse_attention_op import get_cuda_version
class BF16EmbeddingTest(unittest.TestCase):
def setUp(self):
self.batch_size = 30
self.vocab_size = 1024
self.hidden_size = 512
self.seed = 10
def run_main(self, dtype):
ids, weight, dout = self.gen_random()
origin_dtype = weight.dtype
weight_cast = weight.astype(dtype)
out = F.embedding(ids, weight_cast)
dout = dout.astype(out.dtype)
dweight = paddle.autograd.grad(out, weight, dout)
return (
out.astype(origin_dtype).numpy(),
dweight[0].astype(origin_dtype).numpy(),
)
def gen_random(self):
np.random.seed(self.seed)
weight = np.random.random([self.vocab_size, self.hidden_size]).astype(
'float32'
)
ids = np.random.randint(
low=0, high=self.vocab_size, size=[self.batch_size]
)
dout = np.random.random([self.batch_size, self.hidden_size]).astype(
'float32'
)
weight = paddle.to_tensor(weight)
weight.stop_gradient = False
ids = paddle.to_tensor(ids)
dout = paddle.to_tensor(dout)
return ids, weight, dout
def test_main(self):
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000:
return
ret1 = self.run_main('float32')
ret2 = self.run_main('bfloat16')
self.assertEqual(len(ret1), len(ret2))
for i, (r1, r2) in enumerate(zip(ret1, ret2)):
np.testing.assert_allclose(r1, r2, atol=1e-3, rtol=1e-2)
class BF16EmbeddingTestOddHiddenSize(BF16EmbeddingTest):
def setUp(self):
self.batch_size = 30
self.vocab_size = 511
self.hidden_size = 512
self.seed = 20
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册