cuda_helper_test.cu 7.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
16
#include <algorithm>
17
#include <iostream>
18
#ifdef _WIN32
P
peizhilin 已提交
19
#include <numeric>
20
#endif
21 22 23 24 25 26 27
#include <random>

#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"

28 29
#include "paddle/fluid/platform/cuda_helper.h"

30 31 32
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using paddle::platform::float16;

D
dzhwinter 已提交
33 34
template <typename T>
__global__ void AddKernel(const T* data_a, T* data_b, size_t num) {
35
  CUDA_KERNEL_LOOP(i, num) {
D
dzhwinter 已提交
36
    paddle::platform::CudaAtomicAdd(&data_b[i], data_a[i]);
37
  }
D
dzhwinter 已提交
38
}
39 40 41 42 43 44 45

template <typename T>
struct AddFunctor {
  T operator()(const T& a, const T& b) { return a + b; }
};

template <typename T>
D
dzhwinter 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59
void TestCase(size_t num) {
  T *in1, *in2, *out;
  T *d_in1, *d_in2;
  size_t size = sizeof(T) * num;
  cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
  cudaMalloc(reinterpret_cast<void**>(&d_in2), size);
  in1 = reinterpret_cast<T*>(malloc(size));
  in2 = reinterpret_cast<T*>(malloc(size));
  out = reinterpret_cast<T*>(malloc(size));
  std::minstd_rand engine;
  std::uniform_real_distribution<double> dist(0.0, 1.0);
  for (size_t i = 0; i < num; ++i) {
    in1[i] = static_cast<T>(dist(engine));
    in2[i] = static_cast<T>(dist(engine));
60
  }
D
dzhwinter 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
  cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);
  cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);
  AddKernel<T><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num);
  cudaDeviceSynchronize();
  cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost);
  cudaDeviceSynchronize();
  for (size_t i = 0; i < num; ++i) {
    // NOTE(dzhwinter): the float16 add has small underflow/overflow
    // so we use EXPECT_NEAR to check the result.
    EXPECT_NEAR(static_cast<float>(out[i]),
                static_cast<float>(AddFunctor<T>()(in1[i], in2[i])), 0.001);
  }
  free(in1);
  free(in2);
  free(out);
  cudaFree(d_in1);
  cudaFree(d_in2);
}
79 80 81

// cuda primitives
TEST(CudaAtomic, Add) {
D
dzhwinter 已提交
82 83
  TestCase<float>(static_cast<size_t>(10));
  TestCase<float>(static_cast<size_t>(1024 * 1024));
84

D
dzhwinter 已提交
85 86
  TestCase<double>(static_cast<size_t>(10));
  TestCase<double>(static_cast<size_t>(1024 * 1024));
87 88 89
}

TEST(CudaAtomic, float16) {
D
dzhwinter 已提交
90 91 92 93 94 95 96 97 98 99
  TestCase<float16>(static_cast<size_t>(1));
  TestCase<float16>(static_cast<size_t>(2));
  TestCase<float16>(static_cast<size_t>(3));

  TestCase<float16>(static_cast<size_t>(10));
  TestCase<float16>(static_cast<size_t>(1024 * 1024));
}

// unalignment of uint8
void TestUnalign(size_t num, const int shift_bit) {
Y
Yu Yang 已提交
100
  ASSERT_EQ(num % 2, 0);
D
dzhwinter 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
  float16 *in1, *in2, *out;
  float16 *d_in1, *d_in2;
  size_t size = sizeof(uint8_t) * (num + shift_bit);
  size_t array_size = sizeof(float16) * (num / 2);

  cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
  cudaMalloc(reinterpret_cast<void**>(&d_in2), size);
  in1 = reinterpret_cast<float16*>(malloc(size));
  in2 = reinterpret_cast<float16*>(malloc(size));
  out = reinterpret_cast<float16*>(malloc(size));

  // right shift 1, mimic the unalignment of address
  float16* r_in1 =
      reinterpret_cast<float16*>(reinterpret_cast<uint8_t*>(in1) + shift_bit);
  float16* r_in2 =
      reinterpret_cast<float16*>(reinterpret_cast<uint8_t*>(in2) + shift_bit);

  std::minstd_rand engine;
  std::uniform_real_distribution<double> dist(0.0, 1.0);
  for (size_t i = 0; i < num / 2; ++i) {
    r_in1[i] = static_cast<float16>(dist(engine));
    r_in2[i] = static_cast<float16>(dist(engine));
  }
  cudaMemcpy(d_in1, r_in1, array_size, cudaMemcpyHostToDevice);
  cudaMemcpy(d_in2, r_in2, array_size, cudaMemcpyHostToDevice);
  AddKernel<float16><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num / 2);
  cudaDeviceSynchronize();
  cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost);
  cudaDeviceSynchronize();
  for (size_t i = 0; i < num / 2; ++i) {
131
    // NOTE(dzhwinter): the float16 add has small truncate error.
D
dzhwinter 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    // so we use EXPECT_NEAR to check the result.
    EXPECT_NEAR(static_cast<float>(out[i]),
                static_cast<float>(AddFunctor<float16>()(r_in1[i], r_in2[i])),
                0.001);
  }
  free(in1);
  free(in2);
  free(out);
  cudaFree(d_in1);
  cudaFree(d_in2);
}

TEST(CudaAtomic, float16Unalign) {
  // same with float16 testcase
  TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 2);
  TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 2);
  TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 2);

  // shift the address.
  TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 1);
  TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 1);
  TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 1);
154

D
dzhwinter 已提交
155 156 157
  TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 3);
  TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 3);
  TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 3);
158
}
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194

// https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
template <typename T>
static __forceinline__ __device__ T WarpReduceSum(T val) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int offset = warpSize / 2; offset > 0; offset /= 2) {
    val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
  }
  return val;
}

template <typename T>
__forceinline__ __device__ T BlockReduce(T val) {
  static __shared__ T shared[32];  // Shared mem for 32 partial sums
  int lane = threadIdx.x % warpSize;
  int wid = threadIdx.x / warpSize;

  val = WarpReduceSum(val);  // Each warp performs partial reduction

  if (lane == 0) shared[wid] = val;  // Write reduced value to shared memory

  __syncthreads();  // Wait for all partial reductions

  // read from shared memory only if that warp existed
  val =
      (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<T>(0);

  if (wid == 0) val = WarpReduceSum(val);  // Final reduce within first warp

  return val;
}

template <typename T>
__global__ void DeviceReduceSum(T* in, T* out, size_t N) {
  T sum(0);
195
  CUDA_KERNEL_LOOP(i, N) { sum += in[i]; }
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
  sum = BlockReduce<T>(sum);
  __syncthreads();
  if (threadIdx.x == 0) out[blockIdx.x] = sum;
}

template <typename T>
void TestReduce(size_t num, float atol = 0.01) {
  T* in1;
  T *d_in1, *d_in2;
  size_t size = sizeof(T) * num;
  cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
  cudaMalloc(reinterpret_cast<void**>(&d_in2), sizeof(T));
  in1 = reinterpret_cast<T*>(malloc(size));
  std::minstd_rand engine;
  std::uniform_real_distribution<double> dist(0.0, 1.0);
  for (size_t i = 0; i < num; ++i) {
    in1[i] = static_cast<T>(dist(engine));
  }
  auto out = std::accumulate(in1, in1 + num, static_cast<T>(0));
  cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);
  cudaDeviceSynchronize();
  DeviceReduceSum<T><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num);
  cudaMemcpy(in1, d_in2, sizeof(T), cudaMemcpyDeviceToHost);
  cudaDeviceSynchronize();
  // NOTE(dzhwinter): the float16 add has small underflow/overflow
  // so we use EXPECT_NEAR to check the result.
  EXPECT_NEAR(static_cast<float>(in1[0]), static_cast<float>(out), atol);
  free(in1);
  cudaFree(d_in1);
  cudaFree(d_in2);
}

TEST(CudaShuffleSync, float16) {
  TestReduce<float>(10);
  TestReduce<float>(1000);

  // float16 will overflow or accumulate truncate errors in big size.
  TestReduce<float16>(10);
  TestReduce<float16>(100, /*atol error*/ 1.0);
}