未验证 提交 2673798d 编写于 作者: D dzhwinter 提交者: GitHub

"fix float16 ShuffleDownSync Bug" (#12756)

* "fix bug"

* "add test case"
上级 6fe5547d
...@@ -36,7 +36,7 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, ...@@ -36,7 +36,7 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
#if CUDA_VERSION < 9000 #if CUDA_VERSION < 9000
return __shfl_down(val, delta, width); return __shfl_down(val, delta, width);
#else #else
return __shfl_down_sync(mask, val, delta, width); return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
#endif #endif
} }
...@@ -46,9 +46,16 @@ template <> ...@@ -46,9 +46,16 @@ template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
float16 val, int delta, float16 val, int delta,
int width) { int width) {
half tmp = static_cast<half>(val); return float16(
__shfl_down(tmp, static_cast<unsigned>(delta), width); __shfl_down(static_cast<half>(val), static_cast<unsigned>(delta), width));
return float16(tmp); }
#else
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
float16 val, int delta,
int width) {
return float16(__shfl_down_sync(mask, static_cast<half>(val),
static_cast<unsigned>(delta), width));
} }
#endif #endif
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm>
#include <iostream> #include <iostream>
#include <random> #include <random>
...@@ -123,7 +124,7 @@ void TestUnalign(size_t num, const int shift_bit) { ...@@ -123,7 +124,7 @@ void TestUnalign(size_t num, const int shift_bit) {
cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost); cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
for (size_t i = 0; i < num / 2; ++i) { for (size_t i = 0; i < num / 2; ++i) {
// NOTE(dzhwinter): the float16 add has small underflow/overflow // NOTE(dzhwinter): the float16 add has small truncate error.
// so we use EXPECT_NEAR to check the result. // so we use EXPECT_NEAR to check the result.
EXPECT_NEAR(static_cast<float>(out[i]), EXPECT_NEAR(static_cast<float>(out[i]),
static_cast<float>(AddFunctor<float16>()(r_in1[i], r_in2[i])), static_cast<float>(AddFunctor<float16>()(r_in1[i], r_in2[i])),
...@@ -151,3 +152,83 @@ TEST(CudaAtomic, float16Unalign) { ...@@ -151,3 +152,83 @@ TEST(CudaAtomic, float16Unalign) {
TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 3); TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 3);
TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 3); TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 3);
} }
// https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
template <typename T>
static __forceinline__ __device__ T WarpReduceSum(T val) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
}
return val;
}
template <typename T>
__forceinline__ __device__ T BlockReduce(T val) {
static __shared__ T shared[32]; // Shared mem for 32 partial sums
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;
val = WarpReduceSum(val); // Each warp performs partial reduction
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions
// read from shared memory only if that warp existed
val =
(threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<T>(0);
if (wid == 0) val = WarpReduceSum(val); // Final reduce within first warp
return val;
}
template <typename T>
__global__ void DeviceReduceSum(T* in, T* out, size_t N) {
T sum(0);
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
sum += in[i];
}
sum = BlockReduce<T>(sum);
__syncthreads();
if (threadIdx.x == 0) out[blockIdx.x] = sum;
}
template <typename T>
void TestReduce(size_t num, float atol = 0.01) {
T* in1;
T *d_in1, *d_in2;
size_t size = sizeof(T) * num;
cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
cudaMalloc(reinterpret_cast<void**>(&d_in2), sizeof(T));
in1 = reinterpret_cast<T*>(malloc(size));
std::minstd_rand engine;
std::uniform_real_distribution<double> dist(0.0, 1.0);
for (size_t i = 0; i < num; ++i) {
in1[i] = static_cast<T>(dist(engine));
}
auto out = std::accumulate(in1, in1 + num, static_cast<T>(0));
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);
cudaDeviceSynchronize();
DeviceReduceSum<T><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num);
cudaMemcpy(in1, d_in2, sizeof(T), cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();
// NOTE(dzhwinter): the float16 add has small underflow/overflow
// so we use EXPECT_NEAR to check the result.
EXPECT_NEAR(static_cast<float>(in1[0]), static_cast<float>(out), atol);
free(in1);
cudaFree(d_in1);
cudaFree(d_in2);
}
TEST(CudaShuffleSync, float16) {
TestReduce<float>(10);
TestReduce<float>(1000);
// float16 will overflow or accumulate truncate errors in big size.
TestReduce<float16>(10);
TestReduce<float16>(100, /*atol error*/ 1.0);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册