conv_compute_test.cc 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/cuda/conv_compute.h"
16

17
#include <gtest/gtest.h>
18

19
#include <memory>
20
#include <random>
21 22 23
#include <utility>
#include <vector>

24 25 26
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"

27 28 29 30 31
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

32
static float random_num(float low, float high) {
33 34 35 36 37
  static std::mt19937 mt(100);
  std::uniform_real_distribution<double> dist(low, high);
  return dist(mt);
}

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
class Conv2dTest : public ::testing::Test {
 protected:
  Conv2dTest()
      : batch(16),
        in_channels(32),
        out_channels(128),
        height(64),
        width(64),
        kernel_h(5),
        kernel_w(5),
        stride_h(2),
        stride_w(2),
        pad_h(1),
        pad_w(1),
        dilation_h(2),
        dilation_w(2),
        groups(1),
        x_shape({batch, in_channels, height, width}),
        w_shape({out_channels, in_channels, kernel_h, kernel_w}),
        b_shape({out_channels}) {
    calc_output_shape();
59

60 61
    X_gpu.Resize(lite::DDim(x_shape));
    X_ref.Resize(lite::DDim(x_shape));
62

63 64
    W_gpu.Resize(lite::DDim(w_shape));
    W_ref.Resize(lite::DDim(w_shape));
65

66 67
    b_gpu.Resize(lite::DDim(b_shape));
    b_ref.Resize(lite::DDim(b_shape));
68

69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    auto x_ref_data = X_ref.mutable_data<float>();
    auto w_ref_data = W_ref.mutable_data<float>();
    auto b_ref_data = b_ref.mutable_data<float>();

    // prepare input
    for (int64_t i = 0; i < X_ref.numel(); i++) {
      x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
    }
    for (int64_t i = 0; i < W_ref.numel(); i++) {
      w_ref_data[i] = static_cast<float>(i % 10 * 0.2);
    }
    for (int64_t i = 0; i < b_ref.numel(); i++) {
      b_ref_data[i] = static_cast<float>(i % 10 * 0.2);
    }

    Out_ref.Resize(lite::DDim(out_shape));
    Out_gpu.Resize(lite::DDim(out_shape));
    Out_cpu.Resize(lite::DDim(out_shape));

    device_init();
89
  }
90 91 92 93 94 95

  int ConvOutputSize(
      int input_size, int filter_size, int dilation, int pad, int stride) {
    const int dkernel = dilation * (filter_size - 1) + 1;
    int output_size = (input_size + pad * 2 - dkernel) / stride + 1;
    return output_size;
96 97
  }

98 99 100 101 102 103 104 105 106
  void calc_output_shape() {
    out_shape.clear();
    out_shape.push_back(batch);
    out_shape.push_back(out_channels);
    out_shape.push_back(
        ConvOutputSize(height, kernel_h, dilation_h, pad_h, stride_h));
    out_shape.push_back(
        ConvOutputSize(width, kernel_w, dilation_w, pad_w, stride_w));
  }
107

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
  void device_init() {
    ctx.reset(new KernelContext);
    cudaStreamCreate(&stream);
    param.x = &X_gpu;
    param.filter = &W_gpu;
    param.output = &Out_gpu;
    param.bias = &b_gpu;
    param.paddings.reset(new std::vector<int>);
    param.paddings->push_back(pad_h);
    param.paddings->push_back(pad_h);
    param.paddings->push_back(pad_w);
    param.paddings->push_back(pad_w);
    param.dilations.reset(new std::vector<int>);
    param.dilations->push_back(dilation_h);
    param.dilations->push_back(dilation_w);
    param.strides[0] = stride_h;
    param.strides[1] = stride_w;
  }

  void float_data_init() {
    X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(),
                                                   X_gpu.dims());
    X_gpu.set_lod(X_ref.lod());
    W_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(W_ref.data<float>(),
                                                   W_gpu.dims());
    b_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(b_ref.data<float>(),
                                                   b_gpu.dims());
  }
136

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
  void half_data_init() {
    X_half.Resize(lite::DDim(x_shape));
    auto x_half_data = X_half.mutable_data<half>();
    for (int64_t i = 0; i < X_half.numel(); i++) {
      x_half_data[i] = half(lite::float16(X_ref.data<float>()[i]));
    }
    X_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, X_gpu.dims());
    X_gpu.set_lod(X_ref.lod());

    W_half.Resize(W_ref.dims());
    auto w_half_data = W_half.mutable_data<half>();
    for (int64_t i = 0; i < W_half.numel(); i++) {
      w_half_data[i] = half(lite::float16(W_ref.data<float>()[i]));
    }
    W_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(w_half_data, W_gpu.dims());

    b_half.Resize(b_ref.dims());
    auto b_half_data = b_half.mutable_data<half>();
    for (int64_t i = 0; i < b_half.numel(); i++) {
      b_half_data[i] = half(lite::float16(b_ref.data<float>()[i]));
    }
    b_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(b_half_data, b_gpu.dims());
  }

  void conv_cpu_base(const lite::Tensor* X,
                     const lite::Tensor* W,
                     lite::Tensor* Out,
                     lite::Tensor* Col) {}

  int batch, in_channels, out_channels, height, width;
  int kernel_h, kernel_w;
  int stride_h, stride_w;
  int pad_h, pad_w;
  int dilation_h, dilation_w, groups;
  std::vector<int64_t> x_shape, w_shape, b_shape, out_shape;
  lite::Tensor X_ref, W_ref, b_ref, Out_ref;
  lite::Tensor X_gpu, W_gpu, b_gpu;
  lite::Tensor X_half, W_half, b_half;
  lite::Tensor Out_cpu, Out_gpu;

  operators::ConvParam param;
  std::unique_ptr<KernelContext> ctx;
179
  cudaStream_t stream;
180 181 182 183 184
};

TEST_F(Conv2dTest, fp32) {
  float_data_init();
  auto& context = ctx->As<CUDAContext>();
185
  context.SetExecStream(stream);
186 187 188 189 190 191 192 193
  ConvCompute<float, PRECISION(kFloat)> conv_2d_kernel;
  conv_2d_kernel.SetParam(param);
  conv_2d_kernel.SetContext(std::move(ctx));

  for (int i = 0; i < FLAGS_warmup; ++i) {
    conv_2d_kernel.Launch();
    cudaDeviceSynchronize();
  }
194

195 196 197 198 199
  auto start = GetCurrentUS();
  conv_2d_kernel.PrepareForRun();
  for (int i = 0; i < FLAGS_repeats; ++i) {
    conv_2d_kernel.Run();
  }
200
  cudaDeviceSynchronize();
201 202 203 204 205
  auto duration = (GetCurrentUS() - start) / 1000.0;
  LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
            << ", repeats: " << FLAGS_repeats << ", spend "
            << duration / FLAGS_repeats << " ms in average.";
}
206

207 208 209 210 211 212 213 214 215 216 217 218
TEST_F(Conv2dTest, fp16) {
  half_data_init();
  auto& context = ctx->As<CUDAContext>();
  context.SetExecStream(stream);
  ConvCompute<half, PRECISION(kFP16)> conv_2d_kernel;
  conv_2d_kernel.SetParam(param);
  conv_2d_kernel.SetContext(std::move(ctx));

  for (int i = 0; i < FLAGS_warmup; ++i) {
    conv_2d_kernel.Launch();
    cudaDeviceSynchronize();
  }
219

220 221 222 223
  auto start = GetCurrentUS();
  conv_2d_kernel.PrepareForRun();
  for (int i = 0; i < FLAGS_repeats; ++i) {
    conv_2d_kernel.Run();
224
  }
225 226 227 228 229
  cudaDeviceSynchronize();
  auto duration = (GetCurrentUS() - start) / 1000.0;
  LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
            << ", repeats: " << FLAGS_repeats << ", spend "
            << duration / FLAGS_repeats << " ms in average.";
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
}

TEST(conv_compute, int8) {
  ConvComputeInt8<PRECISION(kFloat)> int8_conv_fp32out;
  std::unique_ptr<KernelContext> ctx(new KernelContext);
  auto& context = ctx->As<CUDAContext>();

  operators::ActivationParam act_param;
  act_param.has_active = true;
  act_param.active_type = lite_api::ActivationType::kRelu;
  operators::ConvParam param;
  // param.activation_param = act_param;
  param.groups = 1;

  Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
  int n = 1, c = 4, h = 3, w = 3;
  y.Resize({1, 1, 1, c});
  x_cpu.Resize({n, h, w, c});
  filter_cpu.Resize({c, 3, 3, c / param.groups});
  y_cpu.Resize({1, 1, 1, c});
  bias_cpu.Resize({c});

  auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
  auto* x_cpu_data = x_cpu.mutable_data<int8_t>();
  auto* filter_cpu_data = filter_cpu.mutable_data<int8_t>();
  auto* y_cpu_data = x_cpu.mutable_data<float>();
  auto* bias_cpu_data = bias_cpu.mutable_data<float>();

  for (int i = 0; i < x_cpu.numel(); i++) {
    x_cpu_data[i] = static_cast<int8_t>(1);
  }
  for (int i = 0; i < filter_cpu.numel(); i++) {
    filter_cpu_data[i] = static_cast<int8_t>(1);
  }
  for (int i = 0; i < bias_cpu.numel(); i++) {
    bias_cpu_data[i] = i + 1.0;
  }

  x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
  filter.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(filter_cpu_data,
                                                   filter_cpu.dims());
  bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
                                                filter_cpu.dims());

274 275 276 277
  std::vector<int> pads = {0, 0, 0, 0};
  std::vector<int> dilations = {1, 1, 1, 1};
  param.paddings = std::make_shared<std::vector<int>>(pads);
  param.dilations = std::make_shared<std::vector<int>>(dilations);
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
  param.x = &x;
  param.filter = &filter;
  param.output = &y;
  param.weight_scale = {1, 2, 3, 4};

  int8_conv_fp32out.SetParam(param);
  cudaStream_t stream;
  cudaStreamCreate(&stream);
  context.SetExecStream(stream);

  int8_conv_fp32out.SetContext(std::move(ctx));
  int8_conv_fp32out.Launch();
  cudaDeviceSynchronize();

  CopySync<TARGET(kCUDA)>(
      y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
  std::vector<float> real_results = {36, 72, 108, 144};
295 296 297
  // for (int i = 0; i < y.numel(); i++) {
  //   EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5);
  // }
298 299 300 301 302 303 304 305 306
}

TEST(conv_compute, int8_int8_out) {
  ConvComputeInt8<PRECISION(kInt8)> int8_conv_fp32out;
  std::unique_ptr<KernelContext> ctx(new KernelContext);
  auto& context = ctx->As<CUDAContext>();

  operators::ActivationParam act_param;
  act_param.has_active = true;
307 308
  act_param.active_type = lite_api::ActivationType::kRelu;
  // act_param.active_type = lite_api::ActivationType::kLeakyRelu;
309 310 311 312 313 314
  act_param.Leaky_relu_alpha = 0.1;
  operators::ConvParam param;
  param.activation_param = act_param;
  param.groups = 1;

  Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
315 316
  int c_i = 3, h_i = 3, w_i = 3;
  int n = 1, c = 4;
317
  y.Resize({1, 1, 1, c});
318 319
  x_cpu.Resize({n, h_i, w_i, c_i});
  filter_cpu.Resize({c, 3, 3, c_i / param.groups});
320 321 322 323 324 325 326 327 328
  y_cpu.Resize({1, 1, 1, c});
  bias_cpu.Resize({c});

  auto* y_data = y.mutable_data<int8_t>(TARGET(kCUDA));
  auto* x_cpu_data = x_cpu.mutable_data<int8_t>();
  auto* filter_cpu_data = filter_cpu.mutable_data<int8_t>();
  auto* y_cpu_data = x_cpu.mutable_data<int8_t>();
  auto* bias_cpu_data = bias_cpu.mutable_data<float>();

329
  std::cout << "input" << std::endl;
330
  for (int i = 0; i < x_cpu.numel(); i++) {
331
    x_cpu_data[i] = static_cast<int8_t>(random_num(-36, 36));
332
  }
333
  std::cout << "filter" << std::endl;
334
  for (int i = 0; i < filter_cpu.numel(); i++) {
335
    filter_cpu_data[i] = static_cast<int8_t>(random_num(-10, 10));
336 337 338
  }
  for (int i = 0; i < bias_cpu.numel(); i++) {
    bias_cpu_data[i] = i + 1.0;
339
    //  bias_cpu_data[i] = 0;
340 341 342 343 344 345 346 347
  }

  x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
  filter.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(filter_cpu_data,
                                                   filter_cpu.dims());
  bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
                                                filter_cpu.dims());

348 349 350 351
  std::vector<int> pads = {0, 0, 0, 0};
  std::vector<int> dilations = {1, 1, 1, 1};
  param.paddings = std::make_shared<std::vector<int>>(pads);
  param.dilations = std::make_shared<std::vector<int>>(dilations);
352 353 354 355
  param.x = &x;
  param.filter = &filter;
  param.output = &y;
  param.weight_scale = {0.01, 0.02, 0.03, 0.04};
356
  param.output_scale = 2;
357 358 359 360 361 362 363 364 365 366 367 368 369 370
  param.bias = &bias;

  int8_conv_fp32out.SetParam(param);
  cudaStream_t stream;
  cudaStreamCreate(&stream);
  context.SetExecStream(stream);

  int8_conv_fp32out.SetContext(std::move(ctx));
  int8_conv_fp32out.Launch();
  cudaDeviceSynchronize();

  CopySync<TARGET(kCUDA)>(
      y_cpu_data, y_data, sizeof(int8_t) * y.numel(), IoDirection::DtoH);

371
  std::vector<float> real_results = {0, 7, 8, 1};
372 373 374 375 376 377 378 379 380 381
  for (int i = 0; i < y.numel(); i++) {
    // EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5);
    LOG(INFO) << float(y_cpu_data[i]);
  }
}

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle