float16_test.cu 17.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#include "paddle/fluid/platform/float16.h"
K
Kexin Zhao 已提交
13

P
peizhilin 已提交
14
#define GLOG_NO_ABBREVIATED_SEVERITIES  // msvc conflict logging with windows.h
15
#include <glog/logging.h>
K
Kexin Zhao 已提交
16
#include <gtest/gtest.h>
17 18
#include <bitset>
#include <iostream>
K
Kexin Zhao 已提交
19

K
kexinzhao 已提交
20 21
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
22
#include "paddle/fluid/platform/enforce.h"
K
Kexin Zhao 已提交
23

24
#define ARITHMETIC_KERNEL(op_type, sign)                                 \
25
  __global__ void op_type(const half *in1, const half *in2, half *out) { \
26
    out[0] = in1[0] sign in2[0];                                         \
K
Kexin Zhao 已提交
27 28
  }

29
#define COMPOUND_KERNEL(op_type, sign) \
30
  __global__ void op_type(half *in1, const half *in2) { in1[0] sign in2[0]; }
K
Kexin Zhao 已提交
31

32
#define COMPARISON_KERNEL(op_type, sign)                                 \
33
  __global__ void op_type(const half *in1, const half *in2, bool *out) { \
34
    out[0] = in1[0] sign in2[0];                                         \
K
Kexin Zhao 已提交
35 36
  }

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
#ifdef PADDLE_WITH_HIP
#define ARITHMETIC_KERNEL_LAUNCH(op_type)                                     \
  void Test##op_type(float v_in1, float v_in2, float v_out) {                 \
    LOG(INFO) << "Test " << #op_type << " on GPU!";                           \
    half *in1, *in2, *out;                                                    \
    half *d_in1, *d_in2, *d_out;                                              \
    int size = sizeof(half);                                                  \
    hipMalloc(reinterpret_cast<void **>(&d_in1), size);                       \
    hipMalloc(reinterpret_cast<void **>(&d_in2), size);                       \
    hipMalloc(reinterpret_cast<void **>(&d_out), size);                       \
    in1 = reinterpret_cast<half *>(malloc(size));                             \
    in2 = reinterpret_cast<half *>(malloc(size));                             \
    out = reinterpret_cast<half *>(malloc(size));                             \
    in1[0] = half(float16(v_in1));                                            \
    in2[0] = half(float16(v_in2));                                            \
    hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice);                       \
    hipMemcpy(d_in2, in2, size, hipMemcpyHostToDevice);                       \
    hipLaunchKernelGGL(op_type, dim3(1), dim3(1), 0, 0, d_in1, d_in2, d_out); \
    hipMemcpy(out, d_out, size, hipMemcpyDeviceToHost);                       \
    EXPECT_EQ(static_cast<float>(float16(out[0])), v_out);                    \
    free(in1);                                                                \
    free(in2);                                                                \
    free(out);                                                                \
    hipFree(d_in1);                                                           \
    hipFree(d_in2);                                                           \
    hipFree(d_out);                                                           \
  }

#define COMPOUND_KERNEL_LAUNCH(op_type)                                \
  void Test##op_type(float v_in1, float v_in2, float v_out) {          \
    LOG(INFO) << "Test " << #op_type << " on GPU!";                    \
    half *in1, *in2;                                                   \
    half *d_in1, *d_in2;                                               \
    int size = sizeof(half);                                           \
    hipMalloc(reinterpret_cast<void **>(&d_in1), size);                \
    hipMalloc(reinterpret_cast<void **>(&d_in2), size);                \
    in1 = reinterpret_cast<half *>(malloc(size));                      \
    in2 = reinterpret_cast<half *>(malloc(size));                      \
    in1[0] = half(float16(v_in1));                                     \
    in2[0] = half(float16(v_in2));                                     \
    hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice);                \
    hipMemcpy(d_in2, in2, size, hipMemcpyHostToDevice);                \
    hipLaunchKernelGGL(op_type, dim3(1), dim3(1), 0, 0, d_in1, d_in2); \
    hipMemcpy(in1, d_in1, size, hipMemcpyDeviceToHost);                \
    EXPECT_EQ(static_cast<float>(float16(in1[0])), v_out);             \
    free(in1);                                                         \
    free(in2);                                                         \
    hipFree(d_in1);                                                    \
    hipFree(d_in2);                                                    \
  }

#define COMPARISON_KERNEL_LAUNCH(op_type)                                     \
  void Test##op_type(float v_in1, float v_in2, bool v_out) {                  \
    LOG(INFO) << "Test " << #op_type << " on GPU!";                           \
    half *in1, *in2;                                                          \
    half *d_in1, *d_in2;                                                      \
    bool *out, *d_out;                                                        \
    int size = sizeof(half);                                                  \
    hipMalloc(reinterpret_cast<void **>(&d_in1), size);                       \
    hipMalloc(reinterpret_cast<void **>(&d_in2), size);                       \
    hipMalloc(reinterpret_cast<void **>(&d_out), 1);                          \
    in1 = reinterpret_cast<half *>(malloc(size));                             \
    in2 = reinterpret_cast<half *>(malloc(size));                             \
    out = reinterpret_cast<bool *>(malloc(1));                                \
    in1[0] = half(float16(v_in1));                                            \
    in2[0] = half(float16(v_in2));                                            \
    hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice);                       \
    hipMemcpy(d_in2, in2, size, hipMemcpyHostToDevice);                       \
    hipLaunchKernelGGL(op_type, dim3(1), dim3(1), 0, 0, d_in1, d_in2, d_out); \
    hipMemcpy(out, d_out, 1, hipMemcpyDeviceToHost);                          \
    EXPECT_EQ(out[0], v_out);                                                 \
    free(in1);                                                                \
    free(in2);                                                                \
    free(out);                                                                \
    hipFree(d_in1);                                                           \
    hipFree(d_in2);                                                           \
    hipFree(d_out);                                                           \
  }
#else
K
Kexin Zhao 已提交
116 117 118
#define ARITHMETIC_KERNEL_LAUNCH(op_type)                     \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
119 120 121
    half *in1, *in2, *out;                                    \
    half *d_in1, *d_in2, *d_out;                              \
    int size = sizeof(half);                                  \
122 123 124 125 126 127
    cudaMalloc(reinterpret_cast<void **>(&d_in1), size);      \
    cudaMalloc(reinterpret_cast<void **>(&d_in2), size);      \
    cudaMalloc(reinterpret_cast<void **>(&d_out), size);      \
    in1 = reinterpret_cast<half *>(malloc(size));             \
    in2 = reinterpret_cast<half *>(malloc(size));             \
    out = reinterpret_cast<half *>(malloc(size));             \
128 129
    in1[0] = half(float16(v_in1));                            \
    in2[0] = half(float16(v_in2));                            \
K
Kexin Zhao 已提交
130 131 132 133
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                   \
    cudaMemcpy(out, d_out, size, cudaMemcpyDeviceToHost);     \
134
    EXPECT_EQ(static_cast<float>(float16(out[0])), v_out);    \
K
Kexin Zhao 已提交
135 136 137 138 139 140 141 142 143 144 145
    free(in1);                                                \
    free(in2);                                                \
    free(out);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
    cudaFree(d_out);                                          \
  }

#define COMPOUND_KERNEL_LAUNCH(op_type)                       \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
146 147 148
    half *in1, *in2;                                          \
    half *d_in1, *d_in2;                                      \
    int size = sizeof(half);                                  \
149 150 151 152
    cudaMalloc(reinterpret_cast<void **>(&d_in1), size);      \
    cudaMalloc(reinterpret_cast<void **>(&d_in2), size);      \
    in1 = reinterpret_cast<half *>(malloc(size));             \
    in2 = reinterpret_cast<half *>(malloc(size));             \
153 154
    in1[0] = half(float16(v_in1));                            \
    in2[0] = half(float16(v_in2));                            \
K
Kexin Zhao 已提交
155 156 157 158
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2);                          \
    cudaMemcpy(in1, d_in1, size, cudaMemcpyDeviceToHost);     \
159
    EXPECT_EQ(static_cast<float>(float16(in1[0])), v_out);    \
K
Kexin Zhao 已提交
160 161 162 163 164 165 166 167 168
    free(in1);                                                \
    free(in2);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
  }

#define COMPARISON_KERNEL_LAUNCH(op_type)                    \
  void Test##op_type(float v_in1, float v_in2, bool v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";          \
169 170
    half *in1, *in2;                                         \
    half *d_in1, *d_in2;                                     \
K
Kexin Zhao 已提交
171
    bool *out, *d_out;                                       \
172
    int size = sizeof(half);                                 \
173 174 175 176 177 178
    cudaMalloc(reinterpret_cast<void **>(&d_in1), size);     \
    cudaMalloc(reinterpret_cast<void **>(&d_in2), size);     \
    cudaMalloc(reinterpret_cast<void **>(&d_out), 1);        \
    in1 = reinterpret_cast<half *>(malloc(size));            \
    in2 = reinterpret_cast<half *>(malloc(size));            \
    out = reinterpret_cast<bool *>(malloc(1));               \
179 180
    in1[0] = half(float16(v_in1));                           \
    in2[0] = half(float16(v_in2));                           \
K
Kexin Zhao 已提交
181 182 183 184 185 186 187 188 189 190 191 192
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);    \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);    \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                  \
    cudaMemcpy(out, d_out, 1, cudaMemcpyDeviceToHost);       \
    EXPECT_EQ(out[0], v_out);                                \
    free(in1);                                               \
    free(in2);                                               \
    free(out);                                               \
    cudaFree(d_in1);                                         \
    cudaFree(d_in2);                                         \
    cudaFree(d_out);                                         \
  }
193
#endif
K
Kexin Zhao 已提交
194 195

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
196
namespace paddle {
K
kexinzhao 已提交
197
namespace platform {
K
Kexin Zhao 已提交
198

199 200
#if defined(PADDLE_WITH_HIP) || \
    (defined(PADDLE_WITH_CUDA) && CUDA_VERSION < 9000)
K
Kexin Zhao 已提交
201 202 203 204
ARITHMETIC_KERNEL(Add, +)
ARITHMETIC_KERNEL(Sub, -)
ARITHMETIC_KERNEL(Mul, *)
ARITHMETIC_KERNEL(Div, /)
K
Kexin Zhao 已提交
205

K
Kexin Zhao 已提交
206 207 208 209
ARITHMETIC_KERNEL_LAUNCH(Add)
ARITHMETIC_KERNEL_LAUNCH(Sub)
ARITHMETIC_KERNEL_LAUNCH(Mul)
ARITHMETIC_KERNEL_LAUNCH(Div)
K
Kexin Zhao 已提交
210

K
Kexin Zhao 已提交
211
// Negative sign kernel
212
__global__ void Neg(half *in) { in[0] = -in[0]; }
K
Kexin Zhao 已提交
213

K
Kexin Zhao 已提交
214 215
void TestNeg(float v_in, float v_out) {
  LOG(INFO) << "Test Neg on GPU!";
216 217
  half *in, *d_in;
  int size = sizeof(half);
218 219 220 221 222 223
#ifdef PADDLE_WITH_HIP
  hipMalloc(reinterpret_cast<void **>(&d_in), size);
#else
  cudaMalloc(reinterpret_cast<void **>(&d_in), size);
#endif
  in = reinterpret_cast<half *>(malloc(size));
224
  in[0] = half(float16(v_in));
225 226 227
#ifdef PADDLE_WITH_HIP
  hipMemcpy(d_in, in, size, hipMemcpyHostToDevice);
#else
K
Kexin Zhao 已提交
228
  cudaMemcpy(d_in, in, size, cudaMemcpyHostToDevice);
229
#endif
K
Kexin Zhao 已提交
230
  Neg<<<1, 1>>>(d_in);
231 232 233
#ifdef PADDLE_WITH_HIP
  hipMemcpy(in, d_in, size, hipMemcpyDeviceToHost);
#else
K
Kexin Zhao 已提交
234
  cudaMemcpy(in, d_in, size, cudaMemcpyDeviceToHost);
235
#endif
236
  EXPECT_EQ(static_cast<float>(float16(in[0])), v_out);
K
Kexin Zhao 已提交
237
  free(in);
238 239 240
#ifdef PADDLE_WITH_HIP
  hipFree(d_in);
#else
K
Kexin Zhao 已提交
241
  cudaFree(d_in);
242
#endif
K
Kexin Zhao 已提交
243
}
K
Kexin Zhao 已提交
244

K
Kexin Zhao 已提交
245 246 247 248
COMPOUND_KERNEL(AddAssign, +=)
COMPOUND_KERNEL(SubAssign, -=)
COMPOUND_KERNEL(MulAssign, *=)
COMPOUND_KERNEL(DivAssign, /=)
K
Kexin Zhao 已提交
249

K
Kexin Zhao 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
COMPOUND_KERNEL_LAUNCH(AddAssign)
COMPOUND_KERNEL_LAUNCH(SubAssign)
COMPOUND_KERNEL_LAUNCH(MulAssign)
COMPOUND_KERNEL_LAUNCH(DivAssign)

COMPARISON_KERNEL(Equal, ==)
COMPARISON_KERNEL(NotEqual, !=)
COMPARISON_KERNEL(Less, <)
COMPARISON_KERNEL(LessEqual, <=)
COMPARISON_KERNEL(Greater, >)
COMPARISON_KERNEL(GreaterEqual, >=)

COMPARISON_KERNEL_LAUNCH(Equal)
COMPARISON_KERNEL_LAUNCH(NotEqual)
COMPARISON_KERNEL_LAUNCH(Less)
COMPARISON_KERNEL_LAUNCH(LessEqual)
COMPARISON_KERNEL_LAUNCH(Greater)
COMPARISON_KERNEL_LAUNCH(GreaterEqual)

TEST(float16, arithmetic_on_gpu) {
  TestAdd(1, 2, 3);
  TestSub(2, 1, 1);
  TestMul(2, 3, 6);
  TestDiv(6, 2, 3);
  TestNeg(1, -1);
K
Kexin Zhao 已提交
275 276
}

K
Kexin Zhao 已提交
277 278 279 280 281 282
TEST(float16, compound_on_gpu) {
  TestAddAssign(1, 2, 3);
  TestSubAssign(2, 1, 1);
  TestMulAssign(2, 3, 6);
  TestDivAssign(6, 2, 3);
}
K
Kexin Zhao 已提交
283

K
Kexin Zhao 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297
TEST(float16, comparision_on_gpu) {
  TestEqual(1, 1, true);
  TestEqual(1, 2, false);
  TestNotEqual(2, 3, true);
  TestNotEqual(2, 2, false);
  TestLess(3, 4, true);
  TestLess(3, 3, false);
  TestLessEqual(3, 3, true);
  TestLessEqual(3, 2, false);
  TestGreater(4, 3, true);
  TestGreater(4, 4, false);
  TestGreaterEqual(4, 4, true);
  TestGreaterEqual(4, 5, false);
}
298
#endif  // CUDA_VERSION
K
Kexin Zhao 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311

TEST(float16, conversion_on_gpu) {
  // Explicit conversion to and from cuda half
  EXPECT_EQ(float16(half(float16(1.0f))).x, 0x3c00);
  EXPECT_EQ(float16(half(float16(0.5f))).x, 0x3800);
  EXPECT_EQ(float16(half(float16(0.33333f))).x, 0x3555);
  EXPECT_EQ(float16(half(float16(0.0f))).x, 0x0000);
  EXPECT_EQ(float16(half(float16(-0.0f))).x, 0x8000);
  EXPECT_EQ(float16(half(float16(65504.0f))).x, 0x7bff);
  EXPECT_EQ(float16(half(float16(65536.0f))).x, 0x7c00);

  // Assignment operator
  float16 v_assign;
312
  v_assign = half(float16(1.0f));
K
Kexin Zhao 已提交
313 314
  EXPECT_EQ(v_assign.x, 0x3c00);
}
K
Kexin Zhao 已提交
315

K
kexinzhao 已提交
316 317 318 319 320
TEST(float16, lod_tensor_on_gpu) {
  framework::LoDTensor src_tensor;
  framework::LoDTensor gpu_tensor;
  framework::LoDTensor dst_tensor;

321
  float16 *src_ptr = src_tensor.mutable_data<float16>(
K
kexinzhao 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
      framework::make_ddim({2, 2}), CPUPlace());

  float16 arr[4] = {float16(1.0f), float16(0.5f), float16(0.33333f),
                    float16(0.0f)};
  memcpy(src_ptr, arr, 4 * sizeof(float16));

  // CPU LoDTensor to GPU LoDTensor
  CUDAPlace gpu_place(0);
  CUDADeviceContext gpu_ctx(gpu_place);
  framework::TensorCopy(src_tensor, gpu_place, gpu_ctx, &gpu_tensor);

  // GPU LoDTensor to CPU LoDTensor
  framework::TensorCopy(gpu_tensor, CPUPlace(), gpu_ctx, &dst_tensor);

  // Sync before comparing LoDTensors
  gpu_ctx.Wait();
338
  const float16 *dst_ptr = dst_tensor.data<float16>();
K
kexinzhao 已提交
339 340 341 342 343 344
  ASSERT_NE(src_ptr, dst_ptr);
  for (size_t i = 0; i < 4; ++i) {
    EXPECT_EQ(src_ptr[i].x, dst_ptr[i].x);
  }
}

345 346
template <typename T>
struct Functor {
347
  bool operator()(const T &val) {
348 349 350 351 352 353 354 355 356 357 358 359 360
    return std::type_index(typeid(T)) ==
           std::type_index(typeid(platform::float16));
  }
};

TEST(float16, typeid) {
  // the framework heavily used typeid hash
  Functor<float16> functor;
  float16 a = float16(.0f);
  Functor<int> functor2;
  int b(0);

  // compile time assert
G
GaoWei8 已提交
361 362 363 364 365 366
  PADDLE_ENFORCE_EQ(
      functor(a), true,
      platform::errors::Unavailable("The float16 support in GPU failed."));
  PADDLE_ENFORCE_EQ(
      functor2(b), false,
      platform::errors::Unavailable("The float16 support in GPU failed."));
367 368 369 370 371 372 373 374 375 376 377
}

// GPU test
TEST(float16, isinf) {
  float16 a;
  a.x = 0x7c00;
  float16 b = float16(INFINITY);
  // underflow to 0
  float16 native_a(5e-40f);
  EXPECT_EQ(std::isinf(a), true);
  EXPECT_EQ(std::isinf(b), true);
P
peizhilin 已提交
378 379 380
#ifndef _WIN32
  // overflow to inf
  float16 native_b(5e40f);
381
  EXPECT_EQ(std::isinf(native_b), true);
P
peizhilin 已提交
382
#endif
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
  EXPECT_EQ(native_a, float16(0));
}

TEST(float16, isnan) {
  float16 a;
  a.x = 0x7fff;
  float16 b = float16(NAN);
  float16 c = float16(5e40);
  // inf * +-0 will get a nan
  float16 d = c * float16(0);
  EXPECT_EQ(std::isnan(a), true);
  EXPECT_EQ(std::isnan(b), true);
  EXPECT_EQ(std::isnan(d), true);
}

TEST(float16, cast) {
  float16 a;
  a.x = 0x0070;
  auto b = a;
  {
    // change semantic, keep the same value
404
    float16 c = reinterpret_cast<float16 &>(reinterpret_cast<unsigned &>(b));
405 406 407 408 409
    EXPECT_EQ(b, c);
  }

  {
    // use uint32 low 16 bit store float16
410
    uint32_t c = reinterpret_cast<uint32_t &>(b);
411 412 413 414 415 416
    float16 d;
    d.x = c;
    EXPECT_EQ(b, d);
  }
}

K
kexinzhao 已提交
417
}  // namespace platform
K
Kexin Zhao 已提交
418
}  // namespace paddle
419
#endif  // PADDLE_CUDA_FP16