test_float16.cu 8.6 KB
Newer Older
K
Kexin Zhao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/math/float16.h"

K
Kexin Zhao 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
#include <gtest/gtest.h>

#include "paddle/utils/Logging.h"

#define ARITHMETIC_KERNEL(op_type, sign)                      \
  __global__ void op_type(                                    \
      const float16* in1, const float16* in2, float16* out) { \
    out[0] = in1[0] sign in2[0];                              \
  }

#define COMPOUND_KERNEL(op_type, sign)                        \
  __global__ void op_type(float16* in1, const float16* in2) { \
    in1[0] sign in2[0];                                       \
  }

#define COMPARISON_KERNEL(op_type, sign)                                       \
  __global__ void op_type(const float16* in1, const float16* in2, bool* out) { \
    out[0] = in1[0] sign in2[0];                                               \
  }

#define ARITHMETIC_KERNEL_LAUNCH(op_type)                     \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
    float16 *in1, *in2, *out;                                 \
    float16 *d_in1, *d_in2, *d_out;                           \
    int size = sizeof(float16);                               \
    cudaMalloc((void**)&d_in1, size);                         \
    cudaMalloc((void**)&d_in2, size);                         \
    cudaMalloc((void**)&d_out, size);                         \
    in1 = (float16*)malloc(size);                             \
    in2 = (float16*)malloc(size);                             \
    out = (float16*)malloc(size);                             \
    in1[0] = float16(v_in1);                                  \
    in2[0] = float16(v_in2);                                  \
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                   \
    cudaMemcpy(out, d_out, size, cudaMemcpyDeviceToHost);     \
    EXPECT_EQ(float(out[0]), v_out);                          \
    free(in1);                                                \
    free(in2);                                                \
    free(out);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
    cudaFree(d_out);                                          \
  }

#define COMPOUND_KERNEL_LAUNCH(op_type)                       \
  void Test##op_type(float v_in1, float v_in2, float v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";           \
    float16 *in1, *in2;                                       \
    float16 *d_in1, *d_in2;                                   \
    int size = sizeof(float16);                               \
    cudaMalloc((void**)&d_in1, size);                         \
    cudaMalloc((void**)&d_in2, size);                         \
    in1 = (float16*)malloc(size);                             \
    in2 = (float16*)malloc(size);                             \
    in1[0] = float16(v_in1);                                  \
    in2[0] = float16(v_in2);                                  \
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);     \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);     \
    op_type<<<1, 1>>>(d_in1, d_in2);                          \
    cudaMemcpy(in1, d_in1, size, cudaMemcpyDeviceToHost);     \
    EXPECT_EQ(float(in1[0]), v_out);                          \
    free(in1);                                                \
    free(in2);                                                \
    cudaFree(d_in1);                                          \
    cudaFree(d_in2);                                          \
  }

#define COMPARISON_KERNEL_LAUNCH(op_type)                    \
  void Test##op_type(float v_in1, float v_in2, bool v_out) { \
    LOG(INFO) << "Test " << #op_type << " on GPU!";          \
    float16 *in1, *in2;                                      \
    float16 *d_in1, *d_in2;                                  \
    bool *out, *d_out;                                       \
    int size = sizeof(float16);                              \
    cudaMalloc((void**)&d_in1, size);                        \
    cudaMalloc((void**)&d_in2, size);                        \
    cudaMalloc((void**)&d_out, 1);                           \
    in1 = (float16*)malloc(size);                            \
    in2 = (float16*)malloc(size);                            \
    out = (bool*)malloc(1);                                  \
    in1[0] = float16(v_in1);                                 \
    in2[0] = float16(v_in2);                                 \
    cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);    \
    cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);    \
    op_type<<<1, 1>>>(d_in1, d_in2, d_out);                  \
    cudaMemcpy(out, d_out, 1, cudaMemcpyDeviceToHost);       \
    EXPECT_EQ(out[0], v_out);                                \
    free(in1);                                               \
    free(in2);                                               \
    free(out);                                               \
    cudaFree(d_in1);                                         \
    cudaFree(d_in2);                                         \
    cudaFree(d_out);                                         \
  }
K
Kexin Zhao 已提交
111 112

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
113
namespace paddle {
K
Kexin Zhao 已提交
114

K
Kexin Zhao 已提交
115 116 117 118
ARITHMETIC_KERNEL(Add, +)
ARITHMETIC_KERNEL(Sub, -)
ARITHMETIC_KERNEL(Mul, *)
ARITHMETIC_KERNEL(Div, /)
K
Kexin Zhao 已提交
119

K
Kexin Zhao 已提交
120 121 122 123
ARITHMETIC_KERNEL_LAUNCH(Add)
ARITHMETIC_KERNEL_LAUNCH(Sub)
ARITHMETIC_KERNEL_LAUNCH(Mul)
ARITHMETIC_KERNEL_LAUNCH(Div)
K
Kexin Zhao 已提交
124

K
Kexin Zhao 已提交
125 126
// Negative sign kernel
__global__ void Neg(float16* in) { in[0] = -in[0]; }
K
Kexin Zhao 已提交
127

K
Kexin Zhao 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141
void TestNeg(float v_in, float v_out) {
  LOG(INFO) << "Test Neg on GPU!";
  float16 *in, *d_in;
  int size = sizeof(float16);
  cudaMalloc((void**)&d_in, size);
  in = (float16*)malloc(size);
  in[0] = float16(v_in);
  cudaMemcpy(d_in, in, size, cudaMemcpyHostToDevice);
  Neg<<<1, 1>>>(d_in);
  cudaMemcpy(in, d_in, size, cudaMemcpyDeviceToHost);
  EXPECT_EQ(float(in[0]), v_out);
  free(in);
  cudaFree(d_in);
}
K
Kexin Zhao 已提交
142

K
Kexin Zhao 已提交
143 144 145 146
COMPOUND_KERNEL(AddAssign, +=)
COMPOUND_KERNEL(SubAssign, -=)
COMPOUND_KERNEL(MulAssign, *=)
COMPOUND_KERNEL(DivAssign, /=)
K
Kexin Zhao 已提交
147

K
Kexin Zhao 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
COMPOUND_KERNEL_LAUNCH(AddAssign)
COMPOUND_KERNEL_LAUNCH(SubAssign)
COMPOUND_KERNEL_LAUNCH(MulAssign)
COMPOUND_KERNEL_LAUNCH(DivAssign)

COMPARISON_KERNEL(Equal, ==)
COMPARISON_KERNEL(NotEqual, !=)
COMPARISON_KERNEL(Less, <)
COMPARISON_KERNEL(LessEqual, <=)
COMPARISON_KERNEL(Greater, >)
COMPARISON_KERNEL(GreaterEqual, >=)

COMPARISON_KERNEL_LAUNCH(Equal)
COMPARISON_KERNEL_LAUNCH(NotEqual)
COMPARISON_KERNEL_LAUNCH(Less)
COMPARISON_KERNEL_LAUNCH(LessEqual)
COMPARISON_KERNEL_LAUNCH(Greater)
COMPARISON_KERNEL_LAUNCH(GreaterEqual)

TEST(float16, arithmetic_on_gpu) {
  TestAdd(1, 2, 3);
  TestSub(2, 1, 1);
  TestMul(2, 3, 6);
  TestDiv(6, 2, 3);
  TestNeg(1, -1);
K
Kexin Zhao 已提交
173 174
}

K
Kexin Zhao 已提交
175 176 177 178 179 180
TEST(float16, compound_on_gpu) {
  TestAddAssign(1, 2, 3);
  TestSubAssign(2, 1, 1);
  TestMulAssign(2, 3, 6);
  TestDivAssign(6, 2, 3);
}
K
Kexin Zhao 已提交
181

K
Kexin Zhao 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
TEST(float16, comparision_on_gpu) {
  TestEqual(1, 1, true);
  TestEqual(1, 2, false);
  TestNotEqual(2, 3, true);
  TestNotEqual(2, 2, false);
  TestLess(3, 4, true);
  TestLess(3, 3, false);
  TestLessEqual(3, 3, true);
  TestLessEqual(3, 2, false);
  TestGreater(4, 3, true);
  TestGreater(4, 4, false);
  TestGreaterEqual(4, 4, true);
  TestGreaterEqual(4, 5, false);
}

TEST(float16, conversion_on_gpu) {
  // Explicit conversion to and from cuda half
  EXPECT_EQ(float16(half(float16(1.0f))).x, 0x3c00);
  EXPECT_EQ(float16(half(float16(0.5f))).x, 0x3800);
  EXPECT_EQ(float16(half(float16(0.33333f))).x, 0x3555);
  EXPECT_EQ(float16(half(float16(0.0f))).x, 0x0000);
  EXPECT_EQ(float16(half(float16(-0.0f))).x, 0x8000);
  EXPECT_EQ(float16(half(float16(65504.0f))).x, 0x7bff);
  EXPECT_EQ(float16(half(float16(65536.0f))).x, 0x7c00);

  // Implicit conversion to and from cuda half
  half tmp = float16(1.0f);
  float16 val = tmp;
  EXPECT_EQ(val.x, 0x3c00);

  // Assignment operator
  float16 v_assign;
  v_assign = tmp;
  EXPECT_EQ(v_assign.x, 0x3c00);
}
K
Kexin Zhao 已提交
217 218

}  // namespace paddle
K
Kexin Zhao 已提交
219
#endif