test_float16.cu 8.4 KB
Newer Older
K
Kexin Zhao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/math/float16.h"

K
Kexin Zhao 已提交
14 15 16 17
#include <gtest/gtest.h>

#include "paddle/utils/Logging.h"

18 19 20
#define ARITHMETIC_KERNEL(op_type, sign)                                 \
  __global__ void op_type(const half* in1, const half* in2, half* out) { \
    out[0] = in1[0] sign in2[0];                                         \
K
Kexin Zhao 已提交
21 22
  }

23 24
#define COMPOUND_KERNEL(op_type, sign) \
  __global__ void op_type(half* in1, const half* in2) { in1[0] sign in2[0]; }
K
Kexin Zhao 已提交
25

26 27 28
#define COMPARISON_KERNEL(op_type, sign)                                 \
  __global__ void op_type(const half* in1, const half* in2, bool* out) { \
    out[0] = in1[0] sign in2[0];                                         \
K
Kexin Zhao 已提交
29 30 31 32 33
  }

#define ARITHMETIC_KERNEL_LAUNCH(op_type)                     \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
34 35 36
    half *in1, *in2, *out;                                    \
    half *d_in1, *d_in2, *d_out;                              \
    int size = sizeof(half);                                  \
K
Kexin Zhao 已提交
37 38 39
    cudaMalloc((void**)&d_in1, size);                         \
    cudaMalloc((void**)&d_in2, size);                         \
    cudaMalloc((void**)&d_out, size);                         \
40 41 42 43 44
    in1 = (half*)malloc(size);                                \
    in2 = (half*)malloc(size);                                \
    out = (half*)malloc(size);                                \
    in1[0] = half(float16(v_in1));                            \
    in2[0] = half(float16(v_in2));                            \
K
Kexin Zhao 已提交
45 46 47 48
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                   \
    cudaMemcpy(out, d_out, size, cudaMemcpyDeviceToHost);     \
49
    EXPECT_EQ(float(float16(out[0])), v_out);                 \
K
Kexin Zhao 已提交
50 51 52 53 54 55 56 57 58 59 60
    free(in1);                                                \
    free(in2);                                                \
    free(out);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
    cudaFree(d_out);                                          \
  }

#define COMPOUND_KERNEL_LAUNCH(op_type)                       \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
61 62 63
    half *in1, *in2;                                          \
    half *d_in1, *d_in2;                                      \
    int size = sizeof(half);                                  \
K
Kexin Zhao 已提交
64 65
    cudaMalloc((void**)&d_in1, size);                         \
    cudaMalloc((void**)&d_in2, size);                         \
66 67 68 69
    in1 = (half*)malloc(size);                                \
    in2 = (half*)malloc(size);                                \
    in1[0] = half(float16(v_in1));                            \
    in2[0] = half(float16(v_in2));                            \
K
Kexin Zhao 已提交
70 71 72 73
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2);                          \
    cudaMemcpy(in1, d_in1, size, cudaMemcpyDeviceToHost);     \
74
    EXPECT_EQ(float(float16(in1[0])), v_out);                 \
K
Kexin Zhao 已提交
75 76 77 78 79 80 81 82 83
    free(in1);                                                \
    free(in2);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
  }

#define COMPARISON_KERNEL_LAUNCH(op_type)                    \
  void Test##op_type(float v_in1, float v_in2, bool v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";          \
84 85
    half *in1, *in2;                                         \
    half *d_in1, *d_in2;                                     \
K
Kexin Zhao 已提交
86
    bool *out, *d_out;                                       \
87
    int size = sizeof(half);                                 \
K
Kexin Zhao 已提交
88 89 90
    cudaMalloc((void**)&d_in1, size);                        \
    cudaMalloc((void**)&d_in2, size);                        \
    cudaMalloc((void**)&d_out, 1);                           \
91 92
    in1 = (half*)malloc(size);                               \
    in2 = (half*)malloc(size);                               \
K
Kexin Zhao 已提交
93
    out = (bool*)malloc(1);                                  \
94 95
    in1[0] = half(float16(v_in1));                           \
    in2[0] = half(float16(v_in2));                           \
K
Kexin Zhao 已提交
96 97 98 99 100 101 102 103 104 105 106 107
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);    \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);    \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                  \
    cudaMemcpy(out, d_out, 1, cudaMemcpyDeviceToHost);       \
    EXPECT_EQ(out[0], v_out);                                \
    free(in1);                                               \
    free(in2);                                               \
    free(out);                                               \
    cudaFree(d_in1);                                         \
    cudaFree(d_in2);                                         \
    cudaFree(d_out);                                         \
  }
K
Kexin Zhao 已提交
108 109

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
110
namespace paddle {
K
Kexin Zhao 已提交
111

112
#if CUDA_VERSION < 9000
K
Kexin Zhao 已提交
113 114 115 116
ARITHMETIC_KERNEL(Add, +)
ARITHMETIC_KERNEL(Sub, -)
ARITHMETIC_KERNEL(Mul, *)
ARITHMETIC_KERNEL(Div, /)
K
Kexin Zhao 已提交
117

K
Kexin Zhao 已提交
118 119 120 121
ARITHMETIC_KERNEL_LAUNCH(Add)
ARITHMETIC_KERNEL_LAUNCH(Sub)
ARITHMETIC_KERNEL_LAUNCH(Mul)
ARITHMETIC_KERNEL_LAUNCH(Div)
K
Kexin Zhao 已提交
122

K
Kexin Zhao 已提交
123
// Negative sign kernel
124
__global__ void Neg(half* in) { in[0] = -in[0]; }
K
Kexin Zhao 已提交
125

K
Kexin Zhao 已提交
126 127
void TestNeg(float v_in, float v_out) {
  LOG(INFO) << "Test Neg on GPU!";
128 129
  half *in, *d_in;
  int size = sizeof(half);
K
Kexin Zhao 已提交
130
  cudaMalloc((void**)&d_in, size);
131 132
  in = (half*)malloc(size);
  in[0] = half(float16(v_in));
K
Kexin Zhao 已提交
133 134 135
  cudaMemcpy(d_in, in, size, cudaMemcpyHostToDevice);
  Neg<<<1, 1>>>(d_in);
  cudaMemcpy(in, d_in, size, cudaMemcpyDeviceToHost);
136
  EXPECT_EQ(float(float16(in[0])), v_out);
K
Kexin Zhao 已提交
137 138 139
  free(in);
  cudaFree(d_in);
}
K
Kexin Zhao 已提交
140

K
Kexin Zhao 已提交
141 142 143 144
COMPOUND_KERNEL(AddAssign, +=)
COMPOUND_KERNEL(SubAssign, -=)
COMPOUND_KERNEL(MulAssign, *=)
COMPOUND_KERNEL(DivAssign, /=)
K
Kexin Zhao 已提交
145

K
Kexin Zhao 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
COMPOUND_KERNEL_LAUNCH(AddAssign)
COMPOUND_KERNEL_LAUNCH(SubAssign)
COMPOUND_KERNEL_LAUNCH(MulAssign)
COMPOUND_KERNEL_LAUNCH(DivAssign)

COMPARISON_KERNEL(Equal, ==)
COMPARISON_KERNEL(NotEqual, !=)
COMPARISON_KERNEL(Less, <)
COMPARISON_KERNEL(LessEqual, <=)
COMPARISON_KERNEL(Greater, >)
COMPARISON_KERNEL(GreaterEqual, >=)

COMPARISON_KERNEL_LAUNCH(Equal)
COMPARISON_KERNEL_LAUNCH(NotEqual)
COMPARISON_KERNEL_LAUNCH(Less)
COMPARISON_KERNEL_LAUNCH(LessEqual)
COMPARISON_KERNEL_LAUNCH(Greater)
COMPARISON_KERNEL_LAUNCH(GreaterEqual)

TEST(float16, arithmetic_on_gpu) {
  TestAdd(1, 2, 3);
  TestSub(2, 1, 1);
  TestMul(2, 3, 6);
  TestDiv(6, 2, 3);
  TestNeg(1, -1);
K
Kexin Zhao 已提交
171 172
}

K
Kexin Zhao 已提交
173 174 175 176 177 178
TEST(float16, compound_on_gpu) {
  TestAddAssign(1, 2, 3);
  TestSubAssign(2, 1, 1);
  TestMulAssign(2, 3, 6);
  TestDivAssign(6, 2, 3);
}
K
Kexin Zhao 已提交
179

K
Kexin Zhao 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193
TEST(float16, comparision_on_gpu) {
  TestEqual(1, 1, true);
  TestEqual(1, 2, false);
  TestNotEqual(2, 3, true);
  TestNotEqual(2, 2, false);
  TestLess(3, 4, true);
  TestLess(3, 3, false);
  TestLessEqual(3, 3, true);
  TestLessEqual(3, 2, false);
  TestGreater(4, 3, true);
  TestGreater(4, 4, false);
  TestGreaterEqual(4, 4, true);
  TestGreaterEqual(4, 5, false);
}
194
#endif  // CUDA_VERSION
K
Kexin Zhao 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207

TEST(float16, conversion_on_gpu) {
  // Explicit conversion to and from cuda half
  EXPECT_EQ(float16(half(float16(1.0f))).x, 0x3c00);
  EXPECT_EQ(float16(half(float16(0.5f))).x, 0x3800);
  EXPECT_EQ(float16(half(float16(0.33333f))).x, 0x3555);
  EXPECT_EQ(float16(half(float16(0.0f))).x, 0x0000);
  EXPECT_EQ(float16(half(float16(-0.0f))).x, 0x8000);
  EXPECT_EQ(float16(half(float16(65504.0f))).x, 0x7bff);
  EXPECT_EQ(float16(half(float16(65536.0f))).x, 0x7c00);

  // Assignment operator
  float16 v_assign;
208
  v_assign = half(float16(1.0f));
K
Kexin Zhao 已提交
209 210
  EXPECT_EQ(v_assign.x, 0x3c00);
}
K
Kexin Zhao 已提交
211 212

}  // namespace paddle
213
#endif  // PADDLE_CUDA_FP16