fused_dropout_act_bias_test.cu 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <time.h>

#include <random>
#include <vector>

#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
21
#include "paddle/phi/common/amp_type_traits.h"
22
#include "paddle/phi/core/kernel_registry.h"
23
#include "paddle/phi/kernels/funcs/functors.h"
24
#include "test/cpp/fluid/fused/fused_dropout_test.h"
25

26 27 28 29 30
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(dropout, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT);
#endif

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
namespace framework = paddle::framework;
namespace platform = paddle::platform;

/**
 * @brief the unittest of fused_dropout_act_bias
 * 1. random input data
 * 2. add bias, call activation, call paddle dropout, and get the base result
 * 3. call FusedDropoutActBias function get fused result
 * 4. compare ther base result and fused result
 */

template <typename T, typename Functor, typename GradFunctor>
struct TestFusedDropoutActBias {
  uint32_t rows;
  uint32_t cols;
  uint64_t seed;
  float dropout_prob;
  bool is_upscale_in_train;
  bool is_test;  // default false,  Set to true for inference only
  bool has_bias = true;
51 52
  phi::DenseTensor src, bias, out, mask;
  phi::DenseTensor dsrc, dbias;
53 54 55 56 57 58

  std::vector<T> src_vec, bias_vec, out_vec, mask_vec;
  std::vector<T> correct_out, correct_dsrc, correct_dbias;
  std::vector<uint8_t> correct_mask;

  platform::CUDAPlace place;
L
Leo Chen 已提交
59
  phi::GPUContext *ctx;
60 61 62 63 64 65 66 67 68 69 70

  TestFusedDropoutActBias() {
    rows = 32;
    cols = 32;
    seed = 0;
    dropout_prob = 0.0;
    is_upscale_in_train = false;
    is_test = false;
    has_bias = true;
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto devicectx = pool.Get(place);
L
Leo Chen 已提交
71
    ctx = reinterpret_cast<phi::GPUContext *>(devicectx);
72 73
  }

74 75 76
  TestFusedDropoutActBias(int rows_,
                          int cols_,
                          uint64_t seed_ = 0,
77 78 79 80 81 82 83 84 85 86 87 88
                          float dropout_prob_ = 0.0,
                          bool is_upscale_in_train_ = false,
                          bool is_test_ = false) {
    rows = rows_;
    cols = cols_;
    seed = seed_;
    dropout_prob = dropout_prob_;
    is_upscale_in_train = is_upscale_in_train_;
    is_test = is_test_;
    has_bias = true;
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto devicectx = pool.Get(place);
L
Leo Chen 已提交
89
    ctx = reinterpret_cast<phi::GPUContext *>(devicectx);
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  }

  ~TestFusedDropoutActBias() {}

  void SetUp() {
    const int n = rows * cols;
    correct_out.resize(n);
    correct_mask.resize(n);
    correct_dsrc.resize(n);
    correct_dbias.resize(cols);

    src_vec.resize(n);
    bias_vec.resize(cols);
    std::default_random_engine random(time(NULL));
    std::uniform_real_distribution<float> dis(0.0, 1.0);

    for (int i = 0; i < rows; i++) {
      for (int j = 0; j < cols; j++) {
        src_vec[i * cols + j] = static_cast<T>(dis(random));
        if (i == 0) bias_vec[j] = dis(random);
      }
    }

    framework::TensorFromVector<T>(src_vec, *ctx, &src);
    src.Resize({rows, cols});
    if (has_bias) {
      framework::TensorFromVector<T>(bias_vec, *ctx, &bias);
      bias.Resize({cols});
    }

    {
      out.mutable_data<T>({rows, cols}, place);
      mask.mutable_data<uint8_t>({rows, cols}, place);
      dsrc.mutable_data<T>({rows, cols}, place);

      if (has_bias) {
        dbias.mutable_data<T>({cols}, place);
      }
    }
  }

  void BaseForward() {
    std::vector<T> out1(rows * cols);
    Functor act;
    if (has_bias) {
      // add bias and call activation
      for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
          const T tmp = src_vec[i * cols + j] + bias_vec[j];
          out1[i * cols + j] = act(tmp);
        }
      }
      // call dropout
143 144 145 146 147 148 149 150 151
      Dropout<T>(out1,
                 src.dims(),
                 &correct_out,
                 &correct_mask,
                 *ctx,
                 seed,
                 dropout_prob,
                 is_upscale_in_train,
                 is_test);
152 153 154 155 156 157 158 159
    } else {
      for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
          const T tmp = src_vec[i * cols + j];
          out1[i * cols + j] = act(tmp);
        }
      }

160 161 162 163 164 165 166 167 168
      Dropout<T>(out1,
                 src.dims(),
                 &correct_out,
                 &correct_mask,
                 *ctx,
                 seed,
                 dropout_prob,
                 is_upscale_in_train,
                 is_test);
169 170 171 172 173 174 175
    }
    ctx->Wait();
  }

  void BaseBackward() {
    std::vector<T> _out(rows * cols);
    // call dropout_grad
176 177 178 179 180 181 182
    DropoutGrad<T>(&_out,
                   src.dims(),
                   correct_out,
                   correct_mask,
                   *ctx,
                   dropout_prob,
                   is_upscale_in_train);
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208

    // calculate dbias
    memset(&correct_dbias[0], 0, cols * sizeof(T));
    GradFunctor act_grad;
    for (int i = 0; i < rows; i++) {
      for (int j = 0; j < cols; j++) {
        T args[2];
        args[0] = _out[i * cols + j];
        if (has_bias) {
          args[1] = src_vec[i * cols + j] + bias_vec[j];
        } else {
          args[1] = src_vec[i * cols + j];
        }
        T val = args[0] * act_grad.UseOut(args[1]);
        correct_dsrc[i * cols + j] = val;
      }
    }

    if (has_bias) {
      // reduce_sum: keep the same calculate order as the GPU
      ReduceSum<T>(correct_dsrc, &correct_dbias, rows, cols);
    }
  }

  void FusedForward() {
    const int VecSize = MAX_CACHE_BYTES / sizeof(T);
209 210 211 212 213
    auto config =
        paddle::operators::Get1DBlocksAnd2DGrids(*ctx,
                                                 static_cast<uint64_t>(rows),
                                                 static_cast<uint64_t>(cols),
                                                 VecSize);
214 215 216 217 218 219 220 221 222 223 224
    const int increment = ((cols - 1) / (config.thread_per_block.x *
                                         config.block_per_grid.x * VecSize) +
                           1) *
                          VecSize;

    T *bias_ptr = nullptr;
    if (has_bias) {
      bias_ptr = bias.data<T>();
    }
    Functor act;
    paddle::operators::LaunchDropoutActBias<T, uint8_t, Functor>(
225 226 227 228 229 230 231 232 233 234 235 236
        act,
        seed,
        rows,
        cols,
        increment,
        dropout_prob,
        is_upscale_in_train,
        is_test,
        src.data<T>(),
        bias_ptr,
        out.data<T>(),
        mask.data<uint8_t>(),
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
        *ctx);
    ctx->Wait();
  }

  void FusedBackward() {
    if (is_test) return;

    T *bias_ptr = nullptr;
    T *dbias_ptr = nullptr;
    if (has_bias) {
      dbias_ptr = dbias.data<T>();
      bias_ptr = bias.data<T>();
    }
    GradFunctor act_grad;
    paddle::operators::LaunchDropoutActBiasGrad<T, uint8_t, GradFunctor>(
252 253 254 255 256 257 258 259 260 261 262 263
        act_grad,
        out.data<T>(),
        mask.data<uint8_t>(),
        src.data<T>(),
        bias_ptr,
        dropout_prob,
        is_upscale_in_train,
        rows,
        cols,
        dsrc.data<T>(),
        dbias_ptr,
        *ctx);
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
  }

  void Run() {
    SetUp();
    BaseForward();
    FusedForward();
    BaseBackward();
    FusedBackward();
  }

  void CheckOut(const T diff) {
    const int n = rows * cols;
    std::vector<T> _out(n);
    std::vector<uint8_t> _mask(n);
    framework::TensorToVector(out, *ctx, &_out);
    if (!is_test) {
      framework::TensorToVector<uint8_t>(mask, *ctx, &_mask);
    }
    ctx->Wait();

    for (int i = 0; i < n; i++) {
      EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff);
      if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]);
    }
  }

  void CheckGrad(const T diff) {
    if (is_test) return;

    const int n = rows * cols;

    std::vector<T> _dsrc(n);
    framework::TensorToVector(dsrc, *ctx, &_dsrc);

    for (int i = 0; i < n; i++) {
      EXPECT_LT(std::abs(_dsrc[i] - correct_dsrc[i]), diff);
    }

    if (has_bias) {
      std::vector<T> _dbias(cols);
      framework::TensorToVector(dbias, *ctx, &_dbias);
      ctx->Wait();
      for (int i = 0; i < cols; i++) {
        EXPECT_LT(std::abs(_dbias[i] - correct_dbias[i]), diff);
      }
    }
  }
};

// test the shape , bias, activation
template <typename T, typename Functor, typename GradFunctor>
static void BaseTest(const bool is_fp16 = false) {
  const int rows = 16;
  std::vector<int> cols_list = {16, 17};
  bool has_bias[2] = {true, false};
  T default_diff = !is_fp16 ? static_cast<T>(1e-5) : static_cast<T>(1e-1);
  for (auto cols : {16, 17}) {
    for (auto has_bias : {true, false}) {
      TestFusedDropoutActBias<T, Functor, GradFunctor> test(rows, cols);
      test.has_bias = has_bias;
      test.Run();
      test.CheckOut(default_diff);
      test.CheckGrad(default_diff);
    }
  }
}

TEST(FusedDropout, GPUFusedDorpoutActBias) {
332 333
  BaseTest<float,
           phi::funcs::ReluFunctor<float>,
334
           phi::funcs::ReluGradFunctor<float>>();
335 336
  BaseTest<float,
           paddle::operators::GeluFunctor<float>,
337 338 339
           paddle::operators::GeluGradFunctor<float>>();
}
TEST(FusedDropout, GPUFusedDropoutActBiasDouble) {
340 341
  BaseTest<double,
           phi::funcs::ReluFunctor<double>,
342
           phi::funcs::ReluGradFunctor<double>>();
343 344
  BaseTest<double,
           paddle::operators::GeluFunctor<double>,
345 346 347 348 349 350
           paddle::operators::GeluGradFunctor<double>>();
}

// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py
TEST(FusedDropout, GPUFusedDropoutActBiasFp16) {
  using fp16 = platform::float16;
351 352
  BaseTest<fp16,
           phi::funcs::ReluFunctor<fp16>,
353
           phi::funcs::ReluGradFunctor<fp16>>(true);
354 355 356 357 358 359
}

TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) {
  const int rows = 16;
  const int cols = 16;
  for (auto is_upscale_in_train : {true, false}) {
360 361
    TestFusedDropoutActBias<float,
                            phi::funcs::ReluFunctor<float>,
362
                            phi::funcs::ReluGradFunctor<float>>
363 364 365 366 367 368 369 370 371 372
        test(rows, cols, 0, 1.0, is_upscale_in_train, false);
    test.Run();
    test.CheckOut(static_cast<float>(1e-5));
    test.CheckGrad(static_cast<float>(1e-3));
  }
}

TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) {
  const int rows = 16;
  const int cols = 16;
373 374
  TestFusedDropoutActBias<float,
                          phi::funcs::ReluFunctor<float>,
375
                          phi::funcs::ReluGradFunctor<float>>
376 377 378 379 380 381 382 383 384
      test(rows, cols, 0, 0.35, true, true);
  test.Run();
  test.CheckOut(static_cast<float>(1e-5));
  test.CheckGrad(static_cast<float>(1e-3));
}

TEST(FusedDropout, GPUFusedDropoutActBiasSeed) {
  const int rows = 16;
  const int cols = 16;
385 386
  TestFusedDropoutActBias<float,
                          phi::funcs::ReluFunctor<float>,
387
                          phi::funcs::ReluGradFunctor<float>>
388 389 390 391 392 393 394 395 396
      test(rows, cols, 125, 0.0, false, false);
  test.Run();
  test.CheckOut(static_cast<float>(1e-5));
  test.CheckGrad(static_cast<float>(1e-3));
}

TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) {
  const int rows = 256;
  const int cols = 4096;
397 398
  TestFusedDropoutActBias<float,
                          phi::funcs::ReluFunctor<float>,
399
                          phi::funcs::ReluGradFunctor<float>>
400 401 402 403 404
      test(rows, cols);
  test.Run();
  test.CheckOut(static_cast<float>(1e-5));
  test.CheckGrad(static_cast<float>(1e-3));
}