float16_test.cu 11.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#include "paddle/fluid/platform/float16.h"
K
Kexin Zhao 已提交
13

P
peizhilin 已提交
14
#define GLOG_NO_ABBREVIATED_SEVERITIES  // msvc conflict logging with windows.h
15
#include <glog/logging.h>
K
Kexin Zhao 已提交
16
#include <gtest/gtest.h>
17 18
#include <bitset>
#include <iostream>
K
Kexin Zhao 已提交
19

K
kexinzhao 已提交
20 21
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
22
#include "paddle/fluid/platform/enforce.h"
K
Kexin Zhao 已提交
23

24 25 26
#define ARITHMETIC_KERNEL(op_type, sign)                                 \
  __global__ void op_type(const half* in1, const half* in2, half* out) { \
    out[0] = in1[0] sign in2[0];                                         \
K
Kexin Zhao 已提交
27 28
  }

29 30
#define COMPOUND_KERNEL(op_type, sign) \
  __global__ void op_type(half* in1, const half* in2) { in1[0] sign in2[0]; }
K
Kexin Zhao 已提交
31

32 33 34
#define COMPARISON_KERNEL(op_type, sign)                                 \
  __global__ void op_type(const half* in1, const half* in2, bool* out) { \
    out[0] = in1[0] sign in2[0];                                         \
K
Kexin Zhao 已提交
35 36 37 38 39
  }

#define ARITHMETIC_KERNEL_LAUNCH(op_type)                     \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
40 41 42
    half *in1, *in2, *out;                                    \
    half *d_in1, *d_in2, *d_out;                              \
    int size = sizeof(half);                                  \
43 44 45 46 47 48
    cudaMalloc(reinterpret_cast<void**>(&d_in1), size);       \
    cudaMalloc(reinterpret_cast<void**>(&d_in2), size);       \
    cudaMalloc(reinterpret_cast<void**>(&d_out), size);       \
    in1 = reinterpret_cast<half*>(malloc(size));              \
    in2 = reinterpret_cast<half*>(malloc(size));              \
    out = reinterpret_cast<half*>(malloc(size));              \
49 50
    in1[0] = half(float16(v_in1));                            \
    in2[0] = half(float16(v_in2));                            \
K
Kexin Zhao 已提交
51 52 53 54
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                   \
    cudaMemcpy(out, d_out, size, cudaMemcpyDeviceToHost);     \
55
    EXPECT_EQ(static_cast<float>(float16(out[0])), v_out);    \
K
Kexin Zhao 已提交
56 57 58 59 60 61 62 63 64 65 66
    free(in1);                                                \
    free(in2);                                                \
    free(out);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
    cudaFree(d_out);                                          \
  }

#define COMPOUND_KERNEL_LAUNCH(op_type)                       \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
67 68 69
    half *in1, *in2;                                          \
    half *d_in1, *d_in2;                                      \
    int size = sizeof(half);                                  \
70 71 72 73
    cudaMalloc(reinterpret_cast<void**>(&d_in1), size);       \
    cudaMalloc(reinterpret_cast<void**>(&d_in2), size);       \
    in1 = reinterpret_cast<half*>(malloc(size));              \
    in2 = reinterpret_cast<half*>(malloc(size));              \
74 75
    in1[0] = half(float16(v_in1));                            \
    in2[0] = half(float16(v_in2));                            \
K
Kexin Zhao 已提交
76 77 78 79
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2);                          \
    cudaMemcpy(in1, d_in1, size, cudaMemcpyDeviceToHost);     \
80
    EXPECT_EQ(static_cast<float>(float16(in1[0])), v_out);    \
K
Kexin Zhao 已提交
81 82 83 84 85 86 87 88 89
    free(in1);                                                \
    free(in2);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
  }

#define COMPARISON_KERNEL_LAUNCH(op_type)                    \
  void Test##op_type(float v_in1, float v_in2, bool v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";          \
90 91
    half *in1, *in2;                                         \
    half *d_in1, *d_in2;                                     \
K
Kexin Zhao 已提交
92
    bool *out, *d_out;                                       \
93
    int size = sizeof(half);                                 \
94 95 96 97 98 99
    cudaMalloc(reinterpret_cast<void**>(&d_in1), size);      \
    cudaMalloc(reinterpret_cast<void**>(&d_in2), size);      \
    cudaMalloc(reinterpret_cast<void**>(&d_out), 1);         \
    in1 = reinterpret_cast<half*>(malloc(size));             \
    in2 = reinterpret_cast<half*>(malloc(size));             \
    out = reinterpret_cast<bool*>(malloc(1));                \
100 101
    in1[0] = half(float16(v_in1));                           \
    in2[0] = half(float16(v_in2));                           \
K
Kexin Zhao 已提交
102 103 104 105 106 107 108 109 110 111 112 113
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);    \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);    \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                  \
    cudaMemcpy(out, d_out, 1, cudaMemcpyDeviceToHost);       \
    EXPECT_EQ(out[0], v_out);                                \
    free(in1);                                               \
    free(in2);                                               \
    free(out);                                               \
    cudaFree(d_in1);                                         \
    cudaFree(d_in2);                                         \
    cudaFree(d_out);                                         \
  }
K
Kexin Zhao 已提交
114 115

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
116
namespace paddle {
K
kexinzhao 已提交
117
namespace platform {
K
Kexin Zhao 已提交
118

119
#if CUDA_VERSION < 9000
K
Kexin Zhao 已提交
120 121 122 123
ARITHMETIC_KERNEL(Add, +)
ARITHMETIC_KERNEL(Sub, -)
ARITHMETIC_KERNEL(Mul, *)
ARITHMETIC_KERNEL(Div, /)
K
Kexin Zhao 已提交
124

K
Kexin Zhao 已提交
125 126 127 128
ARITHMETIC_KERNEL_LAUNCH(Add)
ARITHMETIC_KERNEL_LAUNCH(Sub)
ARITHMETIC_KERNEL_LAUNCH(Mul)
ARITHMETIC_KERNEL_LAUNCH(Div)
K
Kexin Zhao 已提交
129

K
Kexin Zhao 已提交
130
// Negative sign kernel
131
__global__ void Neg(half* in) { in[0] = -in[0]; }
K
Kexin Zhao 已提交
132

K
Kexin Zhao 已提交
133 134
void TestNeg(float v_in, float v_out) {
  LOG(INFO) << "Test Neg on GPU!";
135 136
  half *in, *d_in;
  int size = sizeof(half);
137 138
  cudaMalloc(reinterpret_cast<void**>(&d_in), size);
  in = reinterpret_cast<half*>(malloc(size));
139
  in[0] = half(float16(v_in));
K
Kexin Zhao 已提交
140 141 142
  cudaMemcpy(d_in, in, size, cudaMemcpyHostToDevice);
  Neg<<<1, 1>>>(d_in);
  cudaMemcpy(in, d_in, size, cudaMemcpyDeviceToHost);
143
  EXPECT_EQ(static_cast<float>(float16(in[0])), v_out);
K
Kexin Zhao 已提交
144 145 146
  free(in);
  cudaFree(d_in);
}
K
Kexin Zhao 已提交
147

K
Kexin Zhao 已提交
148 149 150 151
COMPOUND_KERNEL(AddAssign, +=)
COMPOUND_KERNEL(SubAssign, -=)
COMPOUND_KERNEL(MulAssign, *=)
COMPOUND_KERNEL(DivAssign, /=)
K
Kexin Zhao 已提交
152

K
Kexin Zhao 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
COMPOUND_KERNEL_LAUNCH(AddAssign)
COMPOUND_KERNEL_LAUNCH(SubAssign)
COMPOUND_KERNEL_LAUNCH(MulAssign)
COMPOUND_KERNEL_LAUNCH(DivAssign)

COMPARISON_KERNEL(Equal, ==)
COMPARISON_KERNEL(NotEqual, !=)
COMPARISON_KERNEL(Less, <)
COMPARISON_KERNEL(LessEqual, <=)
COMPARISON_KERNEL(Greater, >)
COMPARISON_KERNEL(GreaterEqual, >=)

COMPARISON_KERNEL_LAUNCH(Equal)
COMPARISON_KERNEL_LAUNCH(NotEqual)
COMPARISON_KERNEL_LAUNCH(Less)
COMPARISON_KERNEL_LAUNCH(LessEqual)
COMPARISON_KERNEL_LAUNCH(Greater)
COMPARISON_KERNEL_LAUNCH(GreaterEqual)

TEST(float16, arithmetic_on_gpu) {
  TestAdd(1, 2, 3);
  TestSub(2, 1, 1);
  TestMul(2, 3, 6);
  TestDiv(6, 2, 3);
  TestNeg(1, -1);
K
Kexin Zhao 已提交
178 179
}

K
Kexin Zhao 已提交
180 181 182 183 184 185
TEST(float16, compound_on_gpu) {
  TestAddAssign(1, 2, 3);
  TestSubAssign(2, 1, 1);
  TestMulAssign(2, 3, 6);
  TestDivAssign(6, 2, 3);
}
K
Kexin Zhao 已提交
186

K
Kexin Zhao 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200
TEST(float16, comparision_on_gpu) {
  TestEqual(1, 1, true);
  TestEqual(1, 2, false);
  TestNotEqual(2, 3, true);
  TestNotEqual(2, 2, false);
  TestLess(3, 4, true);
  TestLess(3, 3, false);
  TestLessEqual(3, 3, true);
  TestLessEqual(3, 2, false);
  TestGreater(4, 3, true);
  TestGreater(4, 4, false);
  TestGreaterEqual(4, 4, true);
  TestGreaterEqual(4, 5, false);
}
201
#endif  // CUDA_VERSION
K
Kexin Zhao 已提交
202 203 204 205 206 207 208 209 210 211 212 213 214

TEST(float16, conversion_on_gpu) {
  // Explicit conversion to and from cuda half
  EXPECT_EQ(float16(half(float16(1.0f))).x, 0x3c00);
  EXPECT_EQ(float16(half(float16(0.5f))).x, 0x3800);
  EXPECT_EQ(float16(half(float16(0.33333f))).x, 0x3555);
  EXPECT_EQ(float16(half(float16(0.0f))).x, 0x0000);
  EXPECT_EQ(float16(half(float16(-0.0f))).x, 0x8000);
  EXPECT_EQ(float16(half(float16(65504.0f))).x, 0x7bff);
  EXPECT_EQ(float16(half(float16(65536.0f))).x, 0x7c00);

  // Assignment operator
  float16 v_assign;
215
  v_assign = half(float16(1.0f));
K
Kexin Zhao 已提交
216 217
  EXPECT_EQ(v_assign.x, 0x3c00);
}
K
Kexin Zhao 已提交
218

K
kexinzhao 已提交
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
TEST(float16, lod_tensor_on_gpu) {
  framework::LoDTensor src_tensor;
  framework::LoDTensor gpu_tensor;
  framework::LoDTensor dst_tensor;

  float16* src_ptr = src_tensor.mutable_data<float16>(
      framework::make_ddim({2, 2}), CPUPlace());

  float16 arr[4] = {float16(1.0f), float16(0.5f), float16(0.33333f),
                    float16(0.0f)};
  memcpy(src_ptr, arr, 4 * sizeof(float16));

  // CPU LoDTensor to GPU LoDTensor
  CUDAPlace gpu_place(0);
  CUDADeviceContext gpu_ctx(gpu_place);
  framework::TensorCopy(src_tensor, gpu_place, gpu_ctx, &gpu_tensor);

  // GPU LoDTensor to CPU LoDTensor
  framework::TensorCopy(gpu_tensor, CPUPlace(), gpu_ctx, &dst_tensor);

  // Sync before comparing LoDTensors
  gpu_ctx.Wait();
  const float16* dst_ptr = dst_tensor.data<float16>();
  ASSERT_NE(src_ptr, dst_ptr);
  for (size_t i = 0; i < 4; ++i) {
    EXPECT_EQ(src_ptr[i].x, dst_ptr[i].x);
  }
}

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
template <typename T>
struct Functor {
  bool operator()(const T& val) {
    return std::type_index(typeid(T)) ==
           std::type_index(typeid(platform::float16));
  }
};

TEST(float16, typeid) {
  // the framework heavily used typeid hash
  Functor<float16> functor;
  float16 a = float16(.0f);
  Functor<int> functor2;
  int b(0);

  // compile time assert
G
GaoWei8 已提交
264 265 266 267 268 269
  PADDLE_ENFORCE_EQ(
      functor(a), true,
      platform::errors::Unavailable("The float16 support in GPU failed."));
  PADDLE_ENFORCE_EQ(
      functor2(b), false,
      platform::errors::Unavailable("The float16 support in GPU failed."));
270 271 272 273 274 275 276 277 278 279 280
}

// GPU test
TEST(float16, isinf) {
  float16 a;
  a.x = 0x7c00;
  float16 b = float16(INFINITY);
  // underflow to 0
  float16 native_a(5e-40f);
  EXPECT_EQ(std::isinf(a), true);
  EXPECT_EQ(std::isinf(b), true);
P
peizhilin 已提交
281 282 283
#ifndef _WIN32
  // overflow to inf
  float16 native_b(5e40f);
284
  EXPECT_EQ(std::isinf(native_b), true);
P
peizhilin 已提交
285
#endif
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
  EXPECT_EQ(native_a, float16(0));
}

TEST(float16, isnan) {
  float16 a;
  a.x = 0x7fff;
  float16 b = float16(NAN);
  float16 c = float16(5e40);
  // inf * +-0 will get a nan
  float16 d = c * float16(0);
  EXPECT_EQ(std::isnan(a), true);
  EXPECT_EQ(std::isnan(b), true);
  EXPECT_EQ(std::isnan(d), true);
}

TEST(float16, cast) {
  float16 a;
  a.x = 0x0070;
  auto b = a;
  {
    // change semantic, keep the same value
    float16 c = reinterpret_cast<float16&>(reinterpret_cast<unsigned&>(b));
    EXPECT_EQ(b, c);
  }

  {
    // use uint32 low 16 bit store float16
    uint32_t c = reinterpret_cast<uint32_t&>(b);
    float16 d;
    d.x = c;
    EXPECT_EQ(b, d);
  }
}

K
kexinzhao 已提交
320
}  // namespace platform
K
Kexin Zhao 已提交
321
}  // namespace paddle
322
#endif  // PADDLE_CUDA_FP16