feed_forward_test.cu 20.8 KB
Newer Older
1
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <vector>

#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/fused/attn_feed_forward.h"
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/core/kernel_registry.h"
25
#include "paddle/phi/kernels/funcs/math_function.h"
26 27 28 29 30

namespace framework = paddle::framework;
namespace platform = paddle::platform;

USE_OP(matmul);
31
USE_OP_ITSELF(elementwise_add);
32

33 34 35
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(add_grad, GPU, ALL_LAYOUT);
36
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
37 38
#endif

39 40
// get paddle matmul op results as baseline
template <typename T>
41 42 43 44
void GetLinearOp(const std::vector<T> &x,
                 const std::vector<T> &y,
                 const framework::DDim &x_dim,
                 const framework::DDim &y_dim,
L
Leo Chen 已提交
45
                 const phi::GPUContext &ctx,
46 47 48 49
                 bool transpose_a,
                 bool transpose_b,
                 float alpha,
                 std::vector<T> *out) {
50 51
  framework::Scope scope;
  auto var_x = scope.Var("X");
52
  auto tensor_x = var_x->GetMutable<phi::DenseTensor>();
53
  auto var_y = scope.Var("Y");
54
  auto tensor_y = var_y->GetMutable<phi::DenseTensor>();
55
  auto var_out = scope.Var("Out");
56
  auto tensor_out = var_out->GetMutable<phi::DenseTensor>();
57 58 59 60 61 62 63 64

  tensor_x->Resize(x_dim);
  tensor_y->Resize(y_dim);
  tensor_out->Resize({x_dim[0], x_dim[1], y_dim[0]});

  auto x_ptr = tensor_x->mutable_data<T>(ctx.GetPlace());
  auto y_ptr = tensor_y->mutable_data<T>(ctx.GetPlace());
  auto z_ptr = tensor_out->mutable_data<T>(ctx.GetPlace());
65 66
  auto size_x = static_cast<size_t>(phi::product(x_dim));
  auto size_y = static_cast<size_t>(phi::product(y_dim));
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
  auto size_z = x_dim[0] * x_dim[1] * y_dim[0];
  cudaMemcpy(x_ptr, x.data(), size_x * sizeof(T), cudaMemcpyHostToDevice);
  cudaMemcpy(y_ptr, y.data(), size_y * sizeof(T), cudaMemcpyHostToDevice);

  framework::AttributeMap attrs;
  attrs.insert({"transpose_X", transpose_a});
  attrs.insert({"transpose_Y", transpose_b});
  attrs.insert({"alpha", alpha});

  auto op = framework::OpRegistry::CreateOp(
      "matmul", {{"X", {"X"}}, {"Y", {"Y"}}}, {{"Out", {"Out"}}}, attrs);
  op->Run(scope, ctx.GetPlace());

  cudaMemcpy(out->data(), z_ptr, size_z * sizeof(T), cudaMemcpyDeviceToHost);
  ctx.Wait();
}

// get paddle elementwise_add op results as baseline
template <typename T>
86 87 88 89
void GetElementwiseAddOp(const std::vector<T> &x,
                         const std::vector<T> &y,
                         const int bsz_seq,
                         const int output_size,
L
Leo Chen 已提交
90
                         const phi::GPUContext &ctx,
91 92 93
                         std::vector<T> *out) {
  framework::Scope scope;
  auto var_x = scope.Var("X");
94
  auto tensor_x = var_x->GetMutable<phi::DenseTensor>();
95
  auto var_y = scope.Var("Y");
96
  auto tensor_y = var_y->GetMutable<phi::DenseTensor>();
97
  auto var_out = scope.Var("Out");
98
  auto tensor_out = var_out->GetMutable<phi::DenseTensor>();
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

  tensor_x->Resize({bsz_seq, output_size});
  tensor_y->Resize({output_size});
  tensor_out->Resize({bsz_seq, output_size});

  auto x_ptr = tensor_x->mutable_data<T>(ctx.GetPlace());
  auto y_ptr = tensor_y->mutable_data<T>(ctx.GetPlace());
  auto z_ptr = tensor_out->mutable_data<T>(ctx.GetPlace());
  auto size_x = bsz_seq * output_size;
  auto size_y = output_size;
  auto size_z = bsz_seq * output_size;
  cudaMemcpy(x_ptr, x.data(), size_x * sizeof(T), cudaMemcpyHostToDevice);
  cudaMemcpy(y_ptr, y.data(), size_y * sizeof(T), cudaMemcpyHostToDevice);

  framework::AttributeMap attrs;
  auto op = framework::OpRegistry::CreateOp("elementwise_add",
                                            {{"X", {"X"}}, {"Y", {"Y"}}},
116 117
                                            {{"Out", {"Out"}}},
                                            attrs);
118 119 120 121 122 123 124
  op->Run(scope, ctx.GetPlace());
  cudaMemcpy(out->data(), z_ptr, size_z * sizeof(T), cudaMemcpyDeviceToHost);
  ctx.Wait();
}

// get paddle matmul_grad op results as baseline
template <typename T>
125 126
void GetLinearOpGrad(const std::vector<T> &x_vec,
                     const std::vector<T> &y_vec,
127
                     const std::vector<T> &dout_vec,
128 129
                     const framework::DDim &x_dim,
                     const framework::DDim &y_dim,
130
                     const framework::DDim &out_dim,
L
Leo Chen 已提交
131
                     const phi::GPUContext &ctx,
132 133 134 135
                     bool transpose_a,
                     bool transpose_b,
                     float alpha,
                     std::vector<T> *dinput_vec,
136 137 138
                     std::vector<T> *dweight_vec) {
  framework::Scope scope;
  auto var_x = scope.Var("X");
139
  auto tensor_x = var_x->GetMutable<phi::DenseTensor>();
140
  auto var_y = scope.Var("Y");
141
  auto tensor_y = var_y->GetMutable<phi::DenseTensor>();
142
  auto var_dout = scope.Var("DOut");
143
  auto tensor_dout = var_dout->GetMutable<phi::DenseTensor>();
144 145 146 147 148
  tensor_x->Resize(x_dim);
  tensor_y->Resize(y_dim);
  tensor_dout->Resize(out_dim);

  auto var_dx = scope.Var("DX");
149
  auto tensor_dx = var_dx->GetMutable<phi::DenseTensor>();
150
  auto var_dy = scope.Var("DY");
151
  auto tensor_dy = var_dy->GetMutable<phi::DenseTensor>();
152 153 154 155 156 157 158 159 160
  tensor_dx->Resize(x_dim);
  tensor_dy->Resize(y_dim);

  auto x_ptr = tensor_x->mutable_data<T>(ctx.GetPlace());
  auto y_ptr = tensor_y->mutable_data<T>(ctx.GetPlace());
  auto dout_ptr = tensor_dout->mutable_data<T>(ctx.GetPlace());
  auto dinput_ptr = tensor_dx->mutable_data<T>(ctx.GetPlace());
  auto dweight_ptr = tensor_dy->mutable_data<T>(ctx.GetPlace());

161 162
  auto size_x = static_cast<size_t>(phi::product(x_dim));
  auto size_y = static_cast<size_t>(phi::product(y_dim));
163 164 165
  auto size_z = x_dim[0] * x_dim[1] * y_dim[0];
  cudaMemcpy(x_ptr, x_vec.data(), size_x * sizeof(T), cudaMemcpyHostToDevice);
  cudaMemcpy(y_ptr, y_vec.data(), size_y * sizeof(T), cudaMemcpyHostToDevice);
166 167
  cudaMemcpy(
      dout_ptr, dout_vec.data(), size_z * sizeof(T), cudaMemcpyHostToDevice);
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186

  bool use_mkldnn = false;
  bool use_quantizer = false, force_fp32_output = false;
  std::string mkldnn_data_type = "float32";
  float Scale_x = 1.0, Scale_y = 1.0, Scale_out = 1.0;

  framework::AttributeMap attrs;
  attrs.insert({"transpose_X", transpose_a});
  attrs.insert({"transpose_Y", transpose_b});
  attrs.insert({"alpha", alpha});
  attrs.insert({"use_mkldnn", use_mkldnn});
  attrs.insert({"use_quantizer", use_quantizer});
  attrs.insert({"mkldnn_data_type", mkldnn_data_type});
  attrs.insert({"Scale_x", Scale_x});
  attrs.insert({"Scale_y", Scale_y});
  attrs.insert({"Scale_out", Scale_out});
  attrs.insert({"force_fp32_output", force_fp32_output});

  auto op = framework::OpRegistry::CreateOp(
187 188 189 190
      "matmul_grad",
      {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}},
      {{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}},
      attrs);
191 192
  op->Run(scope, ctx.GetPlace());

193 194 195
  cudaMemcpy(dinput_vec->data(),
             dinput_ptr,
             size_x * sizeof(T),
196
             cudaMemcpyDeviceToHost);
197 198 199
  cudaMemcpy(dweight_vec->data(),
             dweight_ptr,
             size_y * sizeof(T),
200 201 202 203 204 205
             cudaMemcpyDeviceToHost);
  ctx.Wait();
}

// get paddle elementwise_add_grad op results as baseline
template <typename T>
206 207
void GetElementwiseAddOpGrad(const std::vector<T> &dout_vec,
                             const int bsz_seq,
208
                             const int output_size,
L
Leo Chen 已提交
209
                             const phi::GPUContext &ctx,
210 211 212
                             std::vector<T> *dy_vec) {
  framework::Scope scope;
  auto var_x = scope.Var("X");
213
  auto tensor_x = var_x->GetMutable<phi::DenseTensor>();
214
  auto var_y = scope.Var("Y");
215
  auto tensor_y = var_y->GetMutable<phi::DenseTensor>();
216
  auto var_dout = scope.Var("DOut");
217
  auto tensor_dout = var_dout->GetMutable<phi::DenseTensor>();
218 219 220 221 222
  tensor_x->Resize({bsz_seq, output_size});
  tensor_y->Resize({output_size});
  tensor_dout->Resize({bsz_seq, output_size});

  auto var_dx = scope.Var("DX");
223
  auto tensor_dx = var_dx->GetMutable<phi::DenseTensor>();
224
  auto var_dy = scope.Var("DY");
225
  auto tensor_dy = var_dy->GetMutable<phi::DenseTensor>();
226 227 228 229 230 231
  tensor_dx->Resize({bsz_seq, output_size});
  tensor_dy->Resize({output_size});

  auto dout_ptr = tensor_dout->mutable_data<T>(ctx.GetPlace());
  auto tensor_dy_ptr = tensor_dy->mutable_data<T>(ctx.GetPlace());
  auto size_z = static_cast<size_t>(bsz_seq * output_size);
232 233
  cudaMemcpy(
      dout_ptr, dout_vec.data(), size_z * sizeof(T), cudaMemcpyHostToDevice);
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254

  int axis = -1;
  bool use_mkldnn = false, use_quantizer = false;
  std::string mkldnn_data_type = "float32";
  std::string x_data_format = "", y_data_format = "";
  float Scale_x = 1.0, Scale_y = 1.0, Scale_out = 1.0;

  framework::AttributeMap attrs;
  attrs.insert({"axis", axis});
  attrs.insert({"use_mkldnn", use_mkldnn});
  attrs.insert({"x_data_format", x_data_format});
  attrs.insert({"y_data_format", y_data_format});
  attrs.insert({"use_quantizer", use_quantizer});
  attrs.insert({"mkldnn_data_type", mkldnn_data_type});
  attrs.insert({"Scale_x", Scale_x});
  attrs.insert({"Scale_y", Scale_y});
  attrs.insert({"Scale_out", Scale_out});

  auto op = framework::OpRegistry::CreateOp(
      "elementwise_add_grad",
      {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}},
255 256
      {{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}},
      attrs);
257 258 259
  op->Run(scope, ctx.GetPlace());

  auto size_y = static_cast<size_t>(output_size);
260 261 262
  cudaMemcpy(dy_vec->data(),
             tensor_dy_ptr,
             size_y * sizeof(T),
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
             cudaMemcpyDeviceToHost);
  ctx.Wait();
}

template <typename T>
class TestFeedForward {
 public:
  TestFeedForward() {
    batch_size_ = 16;
    seq_len_ = 128;
    num_head_ = 16;
    dim_head_ = 64;
    dim_embed_ = 1024;
    has_bias_ = false;
  }

279 280 281 282 283 284
  TestFeedForward(int batch_size,
                  int seq_len,
                  int num_head,
                  int dim_head,
                  int dim_embed,
                  bool has_bias) {
285 286 287 288 289 290 291 292 293 294 295 296 297 298
    batch_size_ = batch_size;
    seq_len_ = seq_len;
    num_head_ = num_head;
    dim_head_ = dim_head;
    dim_embed_ = dim_embed;
    has_bias_ = has_bias;
  }

  ~TestFeedForward() { delete ctx_; }

  void SetUp() {
    bsz_seq_ = batch_size_ * seq_len_;
    output_size_ = 3 * num_head_ * dim_head_;
    input_size_ = dim_embed_;
L
Leo Chen 已提交
299
    ctx_ = new phi::GPUContext(place_);
W
Wilber 已提交
300 301 302 303 304 305 306 307 308 309 310
    ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                           .GetAllocator(place_, ctx_->stream())
                           .get());
    ctx_->SetHostAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CPUPlace())
            .get());
    ctx_->SetZeroAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetZeroAllocator(place_)
            .get());
W
wanghuancoder 已提交
311 312 313 314
    ctx_->SetPinnedAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CUDAPinnedPlace())
            .get());
W
Wilber 已提交
315
    ctx_->PartialInitWithAllocator();
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376

    size_src_ = bsz_seq_ * dim_embed_;         // src: [bs, seq_len, em_dim]
    size_weight_ = dim_embed_ * output_size_;  // weight: [output_size, em_dim]
    size_output_ =
        bsz_seq_ * output_size_;  // output: [bs, seq_len, output_size]
    size_bias_ = output_size_;

    base_out_vec_.resize(size_output_);
    base_bias_out_vec_.resize(size_output_);
    base_dinput_vec_.resize(size_src_);
    base_dweight_vec_.resize(size_weight_);
    base_dbias_vec_.resize(size_bias_);

    src_vec_.resize(size_src_);
    weight_vec_.resize(size_weight_);
    bias_vec_.resize(size_bias_);
    doutput_vec_.resize(size_output_);

    std::default_random_engine random(time(NULL));
    std::uniform_real_distribution<float> dis(0.0, 1.0);
    for (int i = 0; i < size_src_; i++) {
      src_vec_[i] = static_cast<T>(dis(random));
    }
    for (int i = 0; i < size_weight_; i++) {
      weight_vec_[i] = static_cast<T>(dis(random));
    }
    for (int i = 0; i < size_bias_; i++) {
      bias_vec_[i] = static_cast<T>(dis(random));
    }
    for (int i = 0; i < size_output_; i++) {
      doutput_vec_[i] = static_cast<T>(dis(random));
    }

    framework::TensorFromVector<T>(src_vec_, *ctx_, &src_);
    src_.Resize({batch_size_, seq_len_, dim_embed_});
    framework::TensorFromVector<T>(weight_vec_, *ctx_, &weight_);
    weight_.Resize({output_size_, dim_embed_});
    out_.Resize({batch_size_, seq_len_, output_size_});
    out_.mutable_data<T>(place_);
    if (has_bias_) {
      framework::TensorFromVector<T>(bias_vec_, *ctx_, &bias_);
      bias_.Resize({output_size_});
      bias_out_.Resize({batch_size_, seq_len_, output_size_});
      bias_out_.mutable_data<T>(place_);
    }
    framework::TensorFromVector<T>(doutput_vec_, *ctx_, &doutput_);
    doutput_.Resize({batch_size_, seq_len_, output_size_});

    dinput_.Resize({batch_size_, seq_len_, dim_embed_});
    dinput_.mutable_data<T>(place_);
    dweight_.Resize({output_size_, dim_embed_});
    dweight_.mutable_data<T>(place_);
    if (has_bias_) {
      dbias_.Resize({output_size_});
      dbias_.mutable_data<T>(place_);
    }
  }

  void BaselineForward() {
    bool transpose_a = false, transpose_b = true;
    float alpha = 1;
377 378 379 380 381 382 383 384 385
    GetLinearOp(src_vec_,
                weight_vec_,
                src_.dims(),
                weight_.dims(),
                *ctx_,
                transpose_a,
                transpose_b,
                alpha,
                &base_out_vec_);
386
    if (has_bias_) {
387 388 389 390 391 392
      GetElementwiseAddOp(base_out_vec_,
                          bias_vec_,
                          bsz_seq_,
                          output_size_,
                          *ctx_,
                          &base_bias_out_vec_);
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
    }
    ctx_->Wait();
  }

  // get forward results of feedforward.
  void FusedForward() {
    T *p_weight = weight_.data<T>();
    T *p_src = src_.data<T>();
    T *p_output = out_.data<T>();

    T *p_bias = nullptr;
    T *p_bias_output = nullptr;
    if (has_bias_) {
      p_bias = bias_.data<T>();
      p_bias_output = bias_out_.data<T>();
    }
    auto qkv_compute = paddle::operators::FeedForward<T>(
        *ctx_, bsz_seq_, output_size_, input_size_, has_bias_);
411 412
    qkv_compute.ComputeForward(
        p_weight, p_src, p_bias, p_output, p_bias_output);
413 414 415 416 417 418 419
    ctx_->Wait();
  }

  void BaselineBackward() {
    bool transpose_a = false, transpose_b = true;
    float alpha = 1;

420 421 422 423 424 425 426 427 428 429 430 431
    GetLinearOpGrad(src_vec_,
                    weight_vec_,
                    doutput_vec_,
                    src_.dims(),
                    weight_.dims(),
                    out_.dims(),
                    *ctx_,
                    transpose_a,
                    transpose_b,
                    alpha,
                    &base_dinput_vec_,
                    &base_dweight_vec_);
432
    if (has_bias_) {
433 434
      GetElementwiseAddOpGrad(
          doutput_vec_, bsz_seq_, output_size_, *ctx_, &base_dbias_vec_);
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
    }
    ctx_->Wait();
  }

  // get backward results of feedforward.
  void FusedBackward() {
    T *p_weight = weight_.data<T>();
    T *p_src = src_.data<T>();
    T *p_doutput = doutput_.data<T>();
    T *p_dinput = dinput_.data<T>();
    T *p_dweight = dweight_.data<T>();

    T *bias_ptr = nullptr;
    if (has_bias_) {
      bias_ptr = dbias_.data<T>();
    }
    auto qkv_compute = paddle::operators::FeedForward<T>(
        *ctx_, bsz_seq_, output_size_, input_size_, has_bias_);
453 454
    qkv_compute.ComputeBackward(
        p_src, p_weight, p_doutput, p_dinput, p_dweight, bias_ptr);
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
    ctx_->Wait();
  }

  void Run() {
    SetUp();
    BaselineForward();
    FusedForward();
    BaselineBackward();
    FusedBackward();
  }

  // check forward correctness between baseline and results of feedforward.
  void CheckOut(const T diff, bool is_relative_atol = false) {
    std::vector<T> out(size_output_);
    std::vector<T> bias_out(size_output_);
470
    paddle::framework::TensorToVector(out_, *ctx_, &out);
471
    if (has_bias_) {
472
      paddle::framework::TensorToVector(bias_out_, *ctx_, &bias_out);
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
    }
    ctx_->Wait();

    for (int i = 0; i < size_output_; i++) {
      if (is_relative_atol) {
        EXPECT_LT(std::abs((out[i] - base_out_vec_[i]) / base_out_vec_[i]),
                  diff);
      } else {
        EXPECT_LT(std::abs(out[i] - base_out_vec_[i]), diff);
      }
      if (has_bias_) {
        if (is_relative_atol) {
          EXPECT_LT(std::abs((bias_out[i] - base_bias_out_vec_[i]) /
                             base_bias_out_vec_[i]),
                    diff);
        } else {
          EXPECT_LT(std::abs(bias_out[i] - base_bias_out_vec_[i]), diff);
        }
      }
    }
  }

  // check backward correctness between baseline and results of feedforward.
  void CheckGrad(const T diff, bool is_relative_atol = false) {
    std::vector<T> h_dinput(size_src_);
498
    paddle::framework::TensorToVector(dinput_, *ctx_, &h_dinput);
499 500 501 502 503 504 505 506 507 508
    for (int i = 0; i < size_src_; i++) {
      if (is_relative_atol) {
        EXPECT_LT(
            std::abs((h_dinput[i] - base_dinput_vec_[i]) / base_dinput_vec_[i]),
            diff);
      } else {
        EXPECT_LT(std::abs(h_dinput[i] - base_dinput_vec_[i]), diff);
      }
    }
    std::vector<T> h_dweight(size_weight_);
509
    paddle::framework::TensorToVector(dweight_, *ctx_, &h_dweight);
510 511 512 513 514 515 516 517 518 519 520
    for (int i = 0; i < size_weight_; i++) {
      if (is_relative_atol) {
        EXPECT_LT(std::abs((h_dweight[i] - base_dweight_vec_[i]) /
                           base_dweight_vec_[i]),
                  diff);
      } else {
        EXPECT_LT(std::abs(h_dweight[i] - base_dweight_vec_[i]), diff);
      }
    }
    if (has_bias_) {
      std::vector<T> h_dbias(size_bias_);
521
      paddle::framework::TensorToVector(dbias_, *ctx_, &h_dbias);
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
      for (int i = 0; i < size_bias_; i++) {
        if (is_relative_atol) {
          EXPECT_LT(
              std::abs((h_dbias[i] - base_dbias_vec_[i]) / base_dbias_vec_[i]),
              diff);
        } else {
          EXPECT_LT(std::abs(h_dbias[i] - base_dbias_vec_[i]), diff);
        }
      }
    }
  }

 private:
  int batch_size_, seq_len_, num_head_, dim_head_, dim_embed_;
  int bsz_seq_, output_size_, input_size_;
  bool has_bias_;
  int size_src_, size_weight_, size_bias_, size_output_;

540 541
  phi::DenseTensor src_, weight_, bias_, out_, bias_out_;
  phi::DenseTensor dinput_, dweight_, dbias_, doutput_;
542 543 544 545 546 547 548 549
  std::vector<T> src_vec_, weight_vec_, bias_vec_, out_vec_, bias_out_vec_;
  std::vector<T> dinput_vec_, dweight_vec_, dbias_vec_, doutput_vec_;

  // results of baseline.
  std::vector<T> base_out_vec_, base_bias_out_vec_;
  std::vector<T> base_dinput_vec_, base_dweight_vec_, base_dbias_vec_;

  platform::CUDAPlace place_;
L
Leo Chen 已提交
550
  phi::GPUContext *ctx_;
551 552 553 554 555 556 557 558 559 560
};

// test for fp32, fp16, fp32+bias and fp16+bias
TEST(FeedForward, GPUFeedforwardBertLargeSizeFp32) {
  int batch_size = 16;
  int seq_len = 128;
  int num_head = 16;
  int dim_head = 64;
  int dim_embed = 1024;
  bool has_bias = false;
561 562
  TestFeedForward<float> test(
      batch_size, seq_len, num_head, dim_head, dim_embed, has_bias);
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588
  test.Run();
  test.CheckOut(static_cast<float>(1e-5));
  test.CheckGrad(static_cast<float>(1e-5));
}

TEST(FeedForward, GPUFeedforwardBertLargeSizeFp16) {
  int batch_size = 16;
  int seq_len = 128;
  int num_head = 16;
  int dim_head = 64;
  int dim_embed = 1024;
  bool has_bias = false;
  TestFeedForward<paddle::platform::float16> test(
      batch_size, seq_len, num_head, dim_head, dim_embed, has_bias);
  test.Run();
  test.CheckOut(static_cast<paddle::platform::float16>(1e-5));
  test.CheckGrad(static_cast<paddle::platform::float16>(1e-5));
}

TEST(FeedForward, GPUFeedforwardBertLargeSizeFp32Bias) {
  int batch_size = 16;
  int seq_len = 128;
  int num_head = 16;
  int dim_head = 64;
  int dim_embed = 1024;
  bool has_bias = true;
589 590
  TestFeedForward<float> test(
      batch_size, seq_len, num_head, dim_head, dim_embed, has_bias);
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
  test.Run();
  test.CheckOut(static_cast<float>(1e-5));
  test.CheckGrad(static_cast<float>(1e-3));
}

TEST(FeedForward, GPUFeedforwardBertLargeSizeFp16Bias) {
  int batch_size = 16;
  int seq_len = 128;
  int num_head = 16;
  int dim_head = 64;
  int dim_embed = 1024;
  bool has_bias = true;
  TestFeedForward<paddle::platform::float16> test(
      batch_size, seq_len, num_head, dim_head, dim_embed, has_bias);
  test.Run();
  test.CheckOut(static_cast<paddle::platform::float16>(1e-2));
  test.CheckGrad(static_cast<paddle::platform::float16>(1e-2), true);
}