float16_test.cu 9.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#include "paddle/fluid/platform/float16.h"
K
Kexin Zhao 已提交
13

K
Kexin Zhao 已提交
14 15
#include <gtest/gtest.h>

K
kexinzhao 已提交
16 17
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
X
Xin Pan 已提交
18
#include "paddle/legacy/utils/Logging.h"
K
Kexin Zhao 已提交
19

20 21 22
#define ARITHMETIC_KERNEL(op_type, sign)                                 \
  __global__ void op_type(const half* in1, const half* in2, half* out) { \
    out[0] = in1[0] sign in2[0];                                         \
K
Kexin Zhao 已提交
23 24
  }

25 26
#define COMPOUND_KERNEL(op_type, sign) \
  __global__ void op_type(half* in1, const half* in2) { in1[0] sign in2[0]; }
K
Kexin Zhao 已提交
27

28 29 30
#define COMPARISON_KERNEL(op_type, sign)                                 \
  __global__ void op_type(const half* in1, const half* in2, bool* out) { \
    out[0] = in1[0] sign in2[0];                                         \
K
Kexin Zhao 已提交
31 32 33 34 35
  }

#define ARITHMETIC_KERNEL_LAUNCH(op_type)                     \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
36 37 38
    half *in1, *in2, *out;                                    \
    half *d_in1, *d_in2, *d_out;                              \
    int size = sizeof(half);                                  \
39 40 41 42 43 44
    cudaMalloc(reinterpret_cast<void**>(&d_in1), size);       \
    cudaMalloc(reinterpret_cast<void**>(&d_in2), size);       \
    cudaMalloc(reinterpret_cast<void**>(&d_out), size);       \
    in1 = reinterpret_cast<half*>(malloc(size));              \
    in2 = reinterpret_cast<half*>(malloc(size));              \
    out = reinterpret_cast<half*>(malloc(size));              \
45 46
    in1[0] = half(float16(v_in1));                            \
    in2[0] = half(float16(v_in2));                            \
K
Kexin Zhao 已提交
47 48 49 50
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                   \
    cudaMemcpy(out, d_out, size, cudaMemcpyDeviceToHost);     \
51
    EXPECT_EQ(static_cast<float>(float16(out[0])), v_out);    \
K
Kexin Zhao 已提交
52 53 54 55 56 57 58 59 60 61 62
    free(in1);                                                \
    free(in2);                                                \
    free(out);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
    cudaFree(d_out);                                          \
  }

#define COMPOUND_KERNEL_LAUNCH(op_type)                       \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
63 64 65
    half *in1, *in2;                                          \
    half *d_in1, *d_in2;                                      \
    int size = sizeof(half);                                  \
66 67 68 69
    cudaMalloc(reinterpret_cast<void**>(&d_in1), size);       \
    cudaMalloc(reinterpret_cast<void**>(&d_in2), size);       \
    in1 = reinterpret_cast<half*>(malloc(size));              \
    in2 = reinterpret_cast<half*>(malloc(size));              \
70 71
    in1[0] = half(float16(v_in1));                            \
    in2[0] = half(float16(v_in2));                            \
K
Kexin Zhao 已提交
72 73 74 75
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2);                          \
    cudaMemcpy(in1, d_in1, size, cudaMemcpyDeviceToHost);     \
76
    EXPECT_EQ(static_cast<float>(float16(in1[0])), v_out);    \
K
Kexin Zhao 已提交
77 78 79 80 81 82 83 84 85
    free(in1);                                                \
    free(in2);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
  }

#define COMPARISON_KERNEL_LAUNCH(op_type)                    \
  void Test##op_type(float v_in1, float v_in2, bool v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";          \
86 87
    half *in1, *in2;                                         \
    half *d_in1, *d_in2;                                     \
K
Kexin Zhao 已提交
88
    bool *out, *d_out;                                       \
89
    int size = sizeof(half);                                 \
90 91 92 93 94 95
    cudaMalloc(reinterpret_cast<void**>(&d_in1), size);      \
    cudaMalloc(reinterpret_cast<void**>(&d_in2), size);      \
    cudaMalloc(reinterpret_cast<void**>(&d_out), 1);         \
    in1 = reinterpret_cast<half*>(malloc(size));             \
    in2 = reinterpret_cast<half*>(malloc(size));             \
    out = reinterpret_cast<bool*>(malloc(1));                \
96 97
    in1[0] = half(float16(v_in1));                           \
    in2[0] = half(float16(v_in2));                           \
K
Kexin Zhao 已提交
98 99 100 101 102 103 104 105 106 107 108 109
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);    \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);    \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                  \
    cudaMemcpy(out, d_out, 1, cudaMemcpyDeviceToHost);       \
    EXPECT_EQ(out[0], v_out);                                \
    free(in1);                                               \
    free(in2);                                               \
    free(out);                                               \
    cudaFree(d_in1);                                         \
    cudaFree(d_in2);                                         \
    cudaFree(d_out);                                         \
  }
K
Kexin Zhao 已提交
110 111

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
112
namespace paddle {
K
kexinzhao 已提交
113
namespace platform {
K
Kexin Zhao 已提交
114

115
#if CUDA_VERSION < 9000
K
Kexin Zhao 已提交
116 117 118 119
ARITHMETIC_KERNEL(Add, +)
ARITHMETIC_KERNEL(Sub, -)
ARITHMETIC_KERNEL(Mul, *)
ARITHMETIC_KERNEL(Div, /)
K
Kexin Zhao 已提交
120

K
Kexin Zhao 已提交
121 122 123 124
ARITHMETIC_KERNEL_LAUNCH(Add)
ARITHMETIC_KERNEL_LAUNCH(Sub)
ARITHMETIC_KERNEL_LAUNCH(Mul)
ARITHMETIC_KERNEL_LAUNCH(Div)
K
Kexin Zhao 已提交
125

K
Kexin Zhao 已提交
126
// Negative sign kernel
127
__global__ void Neg(half* in) { in[0] = -in[0]; }
K
Kexin Zhao 已提交
128

K
Kexin Zhao 已提交
129 130
void TestNeg(float v_in, float v_out) {
  LOG(INFO) << "Test Neg on GPU!";
131 132
  half *in, *d_in;
  int size = sizeof(half);
133 134
  cudaMalloc(reinterpret_cast<void**>(&d_in), size);
  in = reinterpret_cast<half*>(malloc(size));
135
  in[0] = half(float16(v_in));
K
Kexin Zhao 已提交
136 137 138
  cudaMemcpy(d_in, in, size, cudaMemcpyHostToDevice);
  Neg<<<1, 1>>>(d_in);
  cudaMemcpy(in, d_in, size, cudaMemcpyDeviceToHost);
139
  EXPECT_EQ(static_cast<float>(float16(in[0])), v_out);
K
Kexin Zhao 已提交
140 141 142
  free(in);
  cudaFree(d_in);
}
K
Kexin Zhao 已提交
143

K
Kexin Zhao 已提交
144 145 146 147
COMPOUND_KERNEL(AddAssign, +=)
COMPOUND_KERNEL(SubAssign, -=)
COMPOUND_KERNEL(MulAssign, *=)
COMPOUND_KERNEL(DivAssign, /=)
K
Kexin Zhao 已提交
148

K
Kexin Zhao 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
COMPOUND_KERNEL_LAUNCH(AddAssign)
COMPOUND_KERNEL_LAUNCH(SubAssign)
COMPOUND_KERNEL_LAUNCH(MulAssign)
COMPOUND_KERNEL_LAUNCH(DivAssign)

COMPARISON_KERNEL(Equal, ==)
COMPARISON_KERNEL(NotEqual, !=)
COMPARISON_KERNEL(Less, <)
COMPARISON_KERNEL(LessEqual, <=)
COMPARISON_KERNEL(Greater, >)
COMPARISON_KERNEL(GreaterEqual, >=)

COMPARISON_KERNEL_LAUNCH(Equal)
COMPARISON_KERNEL_LAUNCH(NotEqual)
COMPARISON_KERNEL_LAUNCH(Less)
COMPARISON_KERNEL_LAUNCH(LessEqual)
COMPARISON_KERNEL_LAUNCH(Greater)
COMPARISON_KERNEL_LAUNCH(GreaterEqual)

TEST(float16, arithmetic_on_gpu) {
  TestAdd(1, 2, 3);
  TestSub(2, 1, 1);
  TestMul(2, 3, 6);
  TestDiv(6, 2, 3);
  TestNeg(1, -1);
K
Kexin Zhao 已提交
174 175
}

K
Kexin Zhao 已提交
176 177 178 179 180 181
TEST(float16, compound_on_gpu) {
  TestAddAssign(1, 2, 3);
  TestSubAssign(2, 1, 1);
  TestMulAssign(2, 3, 6);
  TestDivAssign(6, 2, 3);
}
K
Kexin Zhao 已提交
182

K
Kexin Zhao 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196
TEST(float16, comparision_on_gpu) {
  TestEqual(1, 1, true);
  TestEqual(1, 2, false);
  TestNotEqual(2, 3, true);
  TestNotEqual(2, 2, false);
  TestLess(3, 4, true);
  TestLess(3, 3, false);
  TestLessEqual(3, 3, true);
  TestLessEqual(3, 2, false);
  TestGreater(4, 3, true);
  TestGreater(4, 4, false);
  TestGreaterEqual(4, 4, true);
  TestGreaterEqual(4, 5, false);
}
197
#endif  // CUDA_VERSION
K
Kexin Zhao 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210

TEST(float16, conversion_on_gpu) {
  // Explicit conversion to and from cuda half
  EXPECT_EQ(float16(half(float16(1.0f))).x, 0x3c00);
  EXPECT_EQ(float16(half(float16(0.5f))).x, 0x3800);
  EXPECT_EQ(float16(half(float16(0.33333f))).x, 0x3555);
  EXPECT_EQ(float16(half(float16(0.0f))).x, 0x0000);
  EXPECT_EQ(float16(half(float16(-0.0f))).x, 0x8000);
  EXPECT_EQ(float16(half(float16(65504.0f))).x, 0x7bff);
  EXPECT_EQ(float16(half(float16(65536.0f))).x, 0x7c00);

  // Assignment operator
  float16 v_assign;
211
  v_assign = half(float16(1.0f));
K
Kexin Zhao 已提交
212 213
  EXPECT_EQ(v_assign.x, 0x3c00);
}
K
Kexin Zhao 已提交
214

K
kexinzhao 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
TEST(float16, lod_tensor_on_gpu) {
  framework::LoDTensor src_tensor;
  framework::LoDTensor gpu_tensor;
  framework::LoDTensor dst_tensor;

  float16* src_ptr = src_tensor.mutable_data<float16>(
      framework::make_ddim({2, 2}), CPUPlace());

  float16 arr[4] = {float16(1.0f), float16(0.5f), float16(0.33333f),
                    float16(0.0f)};
  memcpy(src_ptr, arr, 4 * sizeof(float16));

  // CPU LoDTensor to GPU LoDTensor
  CUDAPlace gpu_place(0);
  CUDADeviceContext gpu_ctx(gpu_place);
  framework::TensorCopy(src_tensor, gpu_place, gpu_ctx, &gpu_tensor);

  // GPU LoDTensor to CPU LoDTensor
  framework::TensorCopy(gpu_tensor, CPUPlace(), gpu_ctx, &dst_tensor);

  // Sync before comparing LoDTensors
  gpu_ctx.Wait();
  const float16* dst_ptr = dst_tensor.data<float16>();
  ASSERT_NE(src_ptr, dst_ptr);
  for (size_t i = 0; i < 4; ++i) {
    EXPECT_EQ(src_ptr[i].x, dst_ptr[i].x);
  }
}

}  // namespace platform
K
Kexin Zhao 已提交
245
}  // namespace paddle
246
#endif  // PADDLE_CUDA_FP16