cuda_helper_test.cu 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
16
#include <algorithm>
17
#include <iostream>
18
#ifdef _WIN32
P
peizhilin 已提交
19
#include <numeric>
20
#endif
21 22 23 24 25 26 27
#include <random>

#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"

28 29
#include "paddle/fluid/platform/cuda_helper.h"

30 31 32
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using paddle::platform::float16;

D
dzhwinter 已提交
33 34
template <typename T>
__global__ void AddKernel(const T* data_a, T* data_b, size_t num) {
35
  CUDA_KERNEL_LOOP(i, num) {
D
dzhwinter 已提交
36
    paddle::platform::CudaAtomicAdd(&data_b[i], data_a[i]);
37
  }
D
dzhwinter 已提交
38
}
39 40 41 42 43 44 45

template <typename T>
struct AddFunctor {
  T operator()(const T& a, const T& b) { return a + b; }
};

template <typename T>
D
dzhwinter 已提交
46 47 48 49
void TestCase(size_t num) {
  T *in1, *in2, *out;
  T *d_in1, *d_in2;
  size_t size = sizeof(T) * num;
50 51 52 53
#ifdef PADDLE_WITH_HIP
  hipMalloc(reinterpret_cast<void**>(&d_in1), size);
  hipMalloc(reinterpret_cast<void**>(&d_in2), size);
#else
D
dzhwinter 已提交
54 55
  cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
  cudaMalloc(reinterpret_cast<void**>(&d_in2), size);
56
#endif
D
dzhwinter 已提交
57 58 59 60 61 62 63 64
  in1 = reinterpret_cast<T*>(malloc(size));
  in2 = reinterpret_cast<T*>(malloc(size));
  out = reinterpret_cast<T*>(malloc(size));
  std::minstd_rand engine;
  std::uniform_real_distribution<double> dist(0.0, 1.0);
  for (size_t i = 0; i < num; ++i) {
    in1[i] = static_cast<T>(dist(engine));
    in2[i] = static_cast<T>(dist(engine));
65
  }
66 67 68 69 70 71 72 73 74
#ifdef PADDLE_WITH_HIP
  hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice);
  hipMemcpy(d_in2, in2, size, hipMemcpyHostToDevice);
  hipLaunchKernelGGL(HIP_KERNEL_NAME(AddKernel<T>), dim3(1),
                     dim3(PADDLE_CUDA_NUM_THREADS), 0, 0, d_in1, d_in2, num);
  hipDeviceSynchronize();
  hipMemcpy(out, d_in2, size, hipMemcpyDeviceToHost);
  hipDeviceSynchronize();
#else
D
dzhwinter 已提交
75 76 77 78 79 80
  cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);
  cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);
  AddKernel<T><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num);
  cudaDeviceSynchronize();
  cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost);
  cudaDeviceSynchronize();
81
#endif
D
dzhwinter 已提交
82 83 84 85 86 87 88 89 90
  for (size_t i = 0; i < num; ++i) {
    // NOTE(dzhwinter): the float16 add has small underflow/overflow
    // so we use EXPECT_NEAR to check the result.
    EXPECT_NEAR(static_cast<float>(out[i]),
                static_cast<float>(AddFunctor<T>()(in1[i], in2[i])), 0.001);
  }
  free(in1);
  free(in2);
  free(out);
91 92 93 94
#ifdef PADDLE_WITH_HIP
  hipFree(d_in1);
  hipFree(d_in2);
#else
D
dzhwinter 已提交
95 96
  cudaFree(d_in1);
  cudaFree(d_in2);
97
#endif
D
dzhwinter 已提交
98
}
99 100 101

// cuda primitives
TEST(CudaAtomic, Add) {
D
dzhwinter 已提交
102 103
  TestCase<float>(static_cast<size_t>(10));
  TestCase<float>(static_cast<size_t>(1024 * 1024));
104

D
dzhwinter 已提交
105 106
  TestCase<double>(static_cast<size_t>(10));
  TestCase<double>(static_cast<size_t>(1024 * 1024));
107 108 109
}

TEST(CudaAtomic, float16) {
D
dzhwinter 已提交
110 111 112 113 114 115 116 117 118 119
  TestCase<float16>(static_cast<size_t>(1));
  TestCase<float16>(static_cast<size_t>(2));
  TestCase<float16>(static_cast<size_t>(3));

  TestCase<float16>(static_cast<size_t>(10));
  TestCase<float16>(static_cast<size_t>(1024 * 1024));
}

// unalignment of uint8
void TestUnalign(size_t num, const int shift_bit) {
Y
Yu Yang 已提交
120
  ASSERT_EQ(num % 2, 0);
D
dzhwinter 已提交
121 122 123 124 125
  float16 *in1, *in2, *out;
  float16 *d_in1, *d_in2;
  size_t size = sizeof(uint8_t) * (num + shift_bit);
  size_t array_size = sizeof(float16) * (num / 2);

126 127 128 129
#ifdef PADDLE_WITH_HIP
  hipMalloc(reinterpret_cast<void**>(&d_in1), size);
  hipMalloc(reinterpret_cast<void**>(&d_in2), size);
#else
D
dzhwinter 已提交
130 131
  cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
  cudaMalloc(reinterpret_cast<void**>(&d_in2), size);
132
#endif
D
dzhwinter 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
  in1 = reinterpret_cast<float16*>(malloc(size));
  in2 = reinterpret_cast<float16*>(malloc(size));
  out = reinterpret_cast<float16*>(malloc(size));

  // right shift 1, mimic the unalignment of address
  float16* r_in1 =
      reinterpret_cast<float16*>(reinterpret_cast<uint8_t*>(in1) + shift_bit);
  float16* r_in2 =
      reinterpret_cast<float16*>(reinterpret_cast<uint8_t*>(in2) + shift_bit);

  std::minstd_rand engine;
  std::uniform_real_distribution<double> dist(0.0, 1.0);
  for (size_t i = 0; i < num / 2; ++i) {
    r_in1[i] = static_cast<float16>(dist(engine));
    r_in2[i] = static_cast<float16>(dist(engine));
  }
149 150 151 152 153 154 155 156 157 158
#ifdef PADDLE_WITH_HIP
  hipMemcpy(d_in1, r_in1, array_size, hipMemcpyHostToDevice);
  hipMemcpy(d_in2, r_in2, array_size, hipMemcpyHostToDevice);
  hipLaunchKernelGGL(HIP_KERNEL_NAME(AddKernel<float16>), dim3(1),
                     dim3(PADDLE_CUDA_NUM_THREADS), 0, 0, d_in1, d_in2,
                     num / 2);
  hipDeviceSynchronize();
  hipMemcpy(out, d_in2, array_size, hipMemcpyDeviceToHost);
  hipDeviceSynchronize();
#else
D
dzhwinter 已提交
159 160 161 162 163 164
  cudaMemcpy(d_in1, r_in1, array_size, cudaMemcpyHostToDevice);
  cudaMemcpy(d_in2, r_in2, array_size, cudaMemcpyHostToDevice);
  AddKernel<float16><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num / 2);
  cudaDeviceSynchronize();
  cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost);
  cudaDeviceSynchronize();
165
#endif
D
dzhwinter 已提交
166
  for (size_t i = 0; i < num / 2; ++i) {
167
    // NOTE(dzhwinter): the float16 add has small truncate error.
D
dzhwinter 已提交
168 169 170 171 172 173 174 175
    // so we use EXPECT_NEAR to check the result.
    EXPECT_NEAR(static_cast<float>(out[i]),
                static_cast<float>(AddFunctor<float16>()(r_in1[i], r_in2[i])),
                0.001);
  }
  free(in1);
  free(in2);
  free(out);
176 177 178 179
#ifdef PADDLE_WITH_HIP
  hipFree(d_in1);
  hipFree(d_in2);
#else
D
dzhwinter 已提交
180 181
  cudaFree(d_in1);
  cudaFree(d_in2);
182
#endif
D
dzhwinter 已提交
183 184 185 186 187 188 189 190 191 192 193 194
}

TEST(CudaAtomic, float16Unalign) {
  // same with float16 testcase
  TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 2);
  TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 2);
  TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 2);

  // shift the address.
  TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 1);
  TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 1);
  TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 1);
195

D
dzhwinter 已提交
196 197 198
  TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 3);
  TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 3);
  TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 3);
199
}
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235

// https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
template <typename T>
static __forceinline__ __device__ T WarpReduceSum(T val) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int offset = warpSize / 2; offset > 0; offset /= 2) {
    val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
  }
  return val;
}

template <typename T>
__forceinline__ __device__ T BlockReduce(T val) {
  static __shared__ T shared[32];  // Shared mem for 32 partial sums
  int lane = threadIdx.x % warpSize;
  int wid = threadIdx.x / warpSize;

  val = WarpReduceSum(val);  // Each warp performs partial reduction

  if (lane == 0) shared[wid] = val;  // Write reduced value to shared memory

  __syncthreads();  // Wait for all partial reductions

  // read from shared memory only if that warp existed
  val =
      (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<T>(0);

  if (wid == 0) val = WarpReduceSum(val);  // Final reduce within first warp

  return val;
}

template <typename T>
__global__ void DeviceReduceSum(T* in, T* out, size_t N) {
  T sum(0);
236
  CUDA_KERNEL_LOOP(i, N) { sum += in[i]; }
237 238 239 240 241 242 243 244 245 246
  sum = BlockReduce<T>(sum);
  __syncthreads();
  if (threadIdx.x == 0) out[blockIdx.x] = sum;
}

template <typename T>
void TestReduce(size_t num, float atol = 0.01) {
  T* in1;
  T *d_in1, *d_in2;
  size_t size = sizeof(T) * num;
247 248 249 250
#ifdef PADDLE_WITH_HIP
  hipMalloc(reinterpret_cast<void**>(&d_in1), size);
  hipMalloc(reinterpret_cast<void**>(&d_in2), sizeof(T));
#else
251 252
  cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
  cudaMalloc(reinterpret_cast<void**>(&d_in2), sizeof(T));
253
#endif
254 255 256 257 258 259 260
  in1 = reinterpret_cast<T*>(malloc(size));
  std::minstd_rand engine;
  std::uniform_real_distribution<double> dist(0.0, 1.0);
  for (size_t i = 0; i < num; ++i) {
    in1[i] = static_cast<T>(dist(engine));
  }
  auto out = std::accumulate(in1, in1 + num, static_cast<T>(0));
261 262 263 264 265 266 267 268
#ifdef PADDLE_WITH_HIP
  hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice);
  hipDeviceSynchronize();
  hipLaunchKernelGGL(HIP_KERNEL_NAME(DeviceReduceSum<T>), dim3(1),
                     dim3(PADDLE_CUDA_NUM_THREADS), 0, 0, d_in1, d_in2, num);
  hipMemcpy(in1, d_in2, sizeof(T), hipMemcpyDeviceToHost);
  hipDeviceSynchronize();
#else
269 270 271 272 273
  cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);
  cudaDeviceSynchronize();
  DeviceReduceSum<T><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num);
  cudaMemcpy(in1, d_in2, sizeof(T), cudaMemcpyDeviceToHost);
  cudaDeviceSynchronize();
274
#endif
275 276 277 278
  // NOTE(dzhwinter): the float16 add has small underflow/overflow
  // so we use EXPECT_NEAR to check the result.
  EXPECT_NEAR(static_cast<float>(in1[0]), static_cast<float>(out), atol);
  free(in1);
279 280 281 282
#ifdef PADDLE_WITH_HIP
  hipFree(d_in1);
  hipFree(d_in2);
#else
283 284
  cudaFree(d_in1);
  cudaFree(d_in2);
285
#endif
286 287 288 289 290 291 292 293 294 295
}

TEST(CudaShuffleSync, float16) {
  TestReduce<float>(10);
  TestReduce<float>(1000);

  // float16 will overflow or accumulate truncate errors in big size.
  TestReduce<float16>(10);
  TestReduce<float16>(100, /*atol error*/ 1.0);
}