fused_residual_dropout_bias_test.cu 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <time.h>

#include <random>
#include <vector>

#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
21
#include "paddle/phi/core/kernel_registry.h"
22
#include "test/cpp/fluid/fused/fused_dropout_test.h"
23 24 25 26 27

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(dropout, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT);
#endif
28 29 30 31

namespace framework = paddle::framework;
namespace platform = paddle::platform;

32 33
bool CheckEqual(float value, float ref) { return std::abs(value - ref) < 1e-5; }

34
/**
35
 * @brief the unittest of FusedResidualDropoutBias
36 37 38 39 40 41 42
 * 1. random input data
 * 2. add bias, call paddle dropout op, add residual, and get the base result
 * 3. call FusedResidualDropoutBias function get fused result
 * 4. compare ther base result and fused result
 */

template <typename T>
43
struct FusedResidualDropoutBiasTester {
44 45 46 47 48 49 50
  uint32_t rows;
  uint32_t cols;
  uint64_t seed;
  float dropout_prob;
  bool is_upscale_in_train;
  bool is_test;  // default false,  Set to true for inference only
  bool has_bias = true;
51 52
  bool add_residual = true;

53 54
  phi::DenseTensor src, residual, bias, out, mask;
  phi::DenseTensor dsrc, dbias;
55 56 57 58 59 60

  std::vector<T> src_vec, residual_vec, bias_vec;
  std::vector<T> correct_out, correct_dsrc, correct_dbias;
  std::vector<uint8_t> correct_mask;

  platform::CUDAPlace place;
L
Leo Chen 已提交
61
  phi::GPUContext *ctx;
62

63
  FusedResidualDropoutBiasTester() {
64 65 66 67 68 69 70 71
    rows = 32;
    cols = 32;
    seed = 0;
    dropout_prob = 0.0;
    is_upscale_in_train = false;
    is_test = false;
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto device_ctx = pool.Get(place);
L
Leo Chen 已提交
72
    ctx = reinterpret_cast<phi::GPUContext *>(device_ctx);
73 74
  }

75 76 77
  FusedResidualDropoutBiasTester(int rows,
                                 int cols,
                                 uint64_t seed = 0,
78 79 80 81 82 83 84 85 86
                                 float dropout_prob = 0.0,
                                 bool is_upscale_in_train = false,
                                 bool is_test = false)
      : rows(rows),
        cols(cols),
        seed(seed),
        dropout_prob(dropout_prob),
        is_upscale_in_train(is_upscale_in_train),
        is_test(is_test) {
87 88
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto device_ctx = pool.Get(place);
L
Leo Chen 已提交
89
    ctx = reinterpret_cast<phi::GPUContext *>(device_ctx);
90 91 92 93 94 95 96 97 98 99
  }

  void SetUp() {
    const int n = rows * cols;
    correct_out.resize(n);
    correct_mask.resize(n);
    correct_dsrc.resize(n);
    correct_dbias.resize(cols);

    src_vec.resize(n);
100 101 102
    if (add_residual) {
      residual_vec.resize(n);
    }
103 104 105 106 107 108 109
    bias_vec.resize(cols);
    std::default_random_engine random(time(NULL));
    std::uniform_real_distribution<float> dis(0.0, 1.0);

    for (int i = 0; i < rows; i++) {
      for (int j = 0; j < cols; j++) {
        src_vec[i * cols + j] = static_cast<T>(dis(random));
110 111 112
        if (add_residual) {
          residual_vec[i * cols + j] = static_cast<T>(dis(random));
        }
113 114 115 116 117 118 119 120
        if (i == 0) {
          bias_vec[j] = dis(random);
        }
      }
    }

    framework::TensorFromVector<T>(src_vec, *ctx, &src);
    src.Resize({rows, cols});
121 122 123 124
    if (add_residual) {
      framework::TensorFromVector<T>(residual_vec, *ctx, &residual);
      residual.Resize({rows, cols});
    }
125 126 127 128 129
    if (has_bias) {
      framework::TensorFromVector<T>(bias_vec, *ctx, &bias);
      bias.Resize({cols});
    }

130 131 132
    out.mutable_data<T>({rows, cols}, place);
    mask.mutable_data<uint8_t>({rows, cols}, place);
    dsrc.mutable_data<T>({rows, cols}, place);
133

134 135
    if (has_bias) {
      dbias.mutable_data<T>({cols}, place);
136 137 138 139 140 141
    }
  }

  void BaseForward() {
    if (has_bias) {
      // add bias
142
      std::vector<T> bias_out(rows * cols);
143 144
      for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
145
          bias_out[i * cols + j] = src_vec[i * cols + j] + bias_vec[j];
146 147 148
        }
      }
      // call dropout
149 150 151 152 153 154 155 156 157
      Dropout<T>(bias_out,
                 src.dims(),
                 &correct_out,
                 &correct_mask,
                 *ctx,
                 seed,
                 dropout_prob,
                 is_upscale_in_train,
                 is_test);
158
    } else {
159 160 161 162 163 164 165 166 167
      Dropout<T>(src_vec,
                 src.dims(),
                 &correct_out,
                 &correct_mask,
                 *ctx,
                 seed,
                 dropout_prob,
                 is_upscale_in_train,
                 is_test);
168 169
    }
    ctx->Wait();
170
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
171 172 173 174 175 176 177
    if (add_residual) {
      // add residual
      for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
          int idx = i * cols + j;
          correct_out[idx] = residual_vec[idx] + correct_out[idx];
        }
178 179 180 181 182
      }
    }
  }

  void BaseBackward() {
183 184 185 186 187 188 189
    DropoutGrad<T>(&correct_dsrc,
                   src.dims(),
                   correct_out,
                   correct_mask,
                   *ctx,
                   dropout_prob,
                   is_upscale_in_train);
190 191
    // calc dbias
    memset(&correct_dbias[0], 0, cols * sizeof(T));
192 193
    if (has_bias) {
      ReduceSum<T>(correct_out, &correct_dbias, rows, cols);
194 195 196 197 198
    }
  }

  void FusedForward() {
    const int VecSize = MAX_CACHE_BYTES / sizeof(T);
199 200 201 202 203
    auto config =
        paddle::operators::Get1DBlocksAnd2DGrids(*ctx,
                                                 static_cast<uint64_t>(rows),
                                                 static_cast<uint64_t>(cols),
                                                 VecSize);
204

205 206 207 208 209
    const int increment = ((cols - 1) / (config.thread_per_block.x *
                                         config.block_per_grid.x * VecSize) +
                           1) *
                          VecSize;

210 211
    T *bias_ptr = has_bias ? bias.data<T>() : nullptr;
    T *residual_ptr = add_residual ? residual.data<T>() : nullptr;
212
    paddle::operators::LaunchResidualDropoutBias<T, uint8_t>(
213 214 215 216 217 218 219 220 221 222 223 224 225
        rows,
        cols,
        increment,
        seed,
        dropout_prob,
        is_test,
        is_upscale_in_train,
        src.data<T>(),
        residual_ptr,
        bias_ptr,
        mask.data<uint8_t>(),
        out.data<T>(),
        *ctx);
226
    ctx->Wait();
227
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
228 229 230 231 232 233 234
  }

  void FusedBackward() {
    if (is_test) {
      return;
    }

235
    T *bias_ptr = has_bias ? dbias.data<T>() : nullptr;
236
    paddle::operators::LaunchResidualDropoutBiasGrad<T, uint8_t>(
237 238 239 240 241 242 243 244 245
        out.data<T>(),
        mask.data<uint8_t>(),
        dropout_prob,
        is_upscale_in_train,
        rows,
        cols,
        dsrc.data<T>(),
        bias_ptr,
        *ctx);
246 247 248 249 250 251 252 253 254 255 256 257
  }

  void Run() {
    SetUp();
    BaseForward();
    FusedForward();
    BaseBackward();
    FusedBackward();
  }

  void CheckOut(const T diff) {
    const int n = rows * cols;
258 259 260
    std::vector<T> fused_out(n);
    std::vector<uint8_t> fused_mask(n);
    framework::TensorToVector(out, *ctx, &fused_out);
261
    if (!is_test && dropout_prob != 0.0f) {
262
      framework::TensorToVector<uint8_t>(mask, *ctx, &fused_mask);
263 264 265 266
    }
    ctx->Wait();

    for (int i = 0; i < n; i++) {
267
      EXPECT_LT(std::abs(fused_out[i] - correct_out[i]), diff);
268
      if (!is_test && dropout_prob != 0.0f) {
269 270
        EXPECT_EQ(fused_mask[i], correct_mask[i]);
      }
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    }
  }

  void CheckGrad(const T diff) {
    if (is_test) {
      return;
    }

    const int n = rows * cols;

    std::vector<T> _dsrc(n);
    framework::TensorToVector(dsrc, *ctx, &_dsrc);

    for (int i = 0; i < n; i++) {
      EXPECT_LT(std::abs(_dsrc[i] - correct_dsrc[i]), diff);
    }

    if (has_bias) {
      std::vector<T> _dbias(cols);
      framework::TensorToVector(dbias, *ctx, &_dbias);
      ctx->Wait();
      for (int i = 0; i < cols; i++) {
        EXPECT_LT(std::abs(_dbias[i] - correct_dbias[i]), diff);
      }
    }
  }
};

// test the shape and bias
template <typename T>
301
static void BaseTest() {
302
  const int rows = 16;
303 304 305 306 307 308
  T max_diff = static_cast<T>(0);
  if (std::is_same<T, paddle::platform::float16>::value) {
    max_diff = static_cast<T>(1e-1);
  } else {
    max_diff = static_cast<T>(1e-5);
  }
309 310
  for (auto cols : {16, 17}) {
    for (auto has_bias : {true, false}) {
311
      FusedResidualDropoutBiasTester<T> test(rows, cols);
312
      test.has_bias = has_bias;
313
      test.Run();
314 315
      test.CheckOut(max_diff);
      test.CheckGrad(max_diff);
316 317 318 319 320 321 322 323 324
    }
  }
}

TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest<float>(); }

TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest<double>(); }

TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) {
325
  BaseTest<platform::float16>();
326 327
}

328
TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) {
329 330
  const int rows = 16;
  const int cols = 16;
331
  for (auto is_upscale_in_train : {true, false}) {
332 333
    FusedResidualDropoutBiasTester<float> test(
        rows, cols, 0, 1.0, is_upscale_in_train, false);
334 335 336 337
    test.Run();
    test.CheckOut(static_cast<float>(1e-5));
    test.CheckGrad(static_cast<float>(1e-5));
  }
338 339
}

340
TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) {
341 342
  const int rows = 16;
  const int cols = 16;
343
  FusedResidualDropoutBiasTester<float> test(rows, cols, 0, 0.35, true, true);
344 345 346 347 348
  test.Run();
  test.CheckOut(static_cast<float>(1e-5));
  test.CheckGrad(static_cast<float>(1e-5));
}

349
TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) {
350 351
  const int rows = 16;
  const int cols = 16;
352 353
  FusedResidualDropoutBiasTester<float> test(
      rows, cols, 125, 0.0, false, false);
354 355 356 357 358
  test.Run();
  test.CheckOut(static_cast<float>(1e-5));
  test.CheckGrad(static_cast<float>(1e-5));
}

359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
TEST(FusedDropout, NoResidual) {
  const int rows = 16;
  const int cols = 16;
  for (float p : {0.0f, 0.5f, 1.0f}) {
    FusedResidualDropoutBiasTester<float> test(rows, cols, 0, p, false, false);
    test.add_residual = false;
    test.Run();
    // For a non 0 or 1 dropout_prob, just test whether it can run successly.
    if (CheckEqual(p, 0.0f) || CheckEqual(p, 1.0f)) {
      test.CheckOut(static_cast<float>(1e-5));
      test.CheckGrad(static_cast<float>(1e-5));
    }
  }
}

374
TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) {
375 376
  const int rows = 256;
  const int cols = 4096;
377
  FusedResidualDropoutBiasTester<float> test(rows, cols);
378 379 380 381
  test.Run();
  test.CheckOut(static_cast<float>(1e-5));
  test.CheckGrad(static_cast<float>(1e-3));
}
382 383 384 385 386 387 388 389 390 391 392

TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShapeFp16) {
  // Used to test that `cudaErrorLaunchOutOfResources` will not occur
  int rows = 1;
  int cols = 12288;
  if (std::getenv("_rows") != nullptr) {
    rows = atoi(std::getenv("_rows"));
  }
  if (std::getenv("_cols") != nullptr) {
    cols = atoi(std::getenv("_cols"));
  }
393 394
  FusedResidualDropoutBiasTester<platform::float16> test(
      rows, cols, 0, 0.0, true, true);
395 396 397 398
  test.Run();
  test.CheckOut(static_cast<platform::float16>(1e-1));
  test.CheckGrad(static_cast<platform::float16>(1e-1));
}