fused_feedforward_op.cu 26.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
19
#include "paddle/fluid/operators/matmul_v2_op.h"
20
#include "paddle/phi/api/include/tensor.h"
21
#include "paddle/phi/kernels/funcs/blas/blas.h"
22 23
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
24

25
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
26
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
27 28 29 30
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

31 32 33
namespace paddle {
namespace operators {

34
using Tensor = phi::DenseTensor;
35

36
template <typename T>
37
static void AllReduce(phi::DenseTensor& tensor,  // NOLINT
38
                      const int ring_id,
L
Leo Chen 已提交
39
                      const phi::GPUContext& ctx) {
40 41
  if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
42 43 44 45
  auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

  if (map->has(ring_id)) {
    paddle::distributed::ProcessGroup* pg = map->get(ring_id);
46
    auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg);
47 48
    paddle::distributed::AllreduceOptions opts;
    opts.reduce_op = distributed::ReduceOp::SUM;
49
    auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
50 51 52 53 54 55 56
    task->Wait();
  } else {
    auto dtype = platform::ToNCCLDataType(
        framework::TransToProtoVarType(tensor.dtype()));
    int64_t numel = tensor.numel();
    const void* sendbuff = tensor.data<T>();
    auto place = ctx.GetPlace();
57
    void* recvbuff = ctx.Alloc<T>(&tensor, tensor.numel() * sizeof(T));
58 59 60 61 62
    auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
    auto stream = ctx.stream();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
        sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
  }
63 64 65 66 67 68 69
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "PaddlePaddle should compile with NCCL or RCCL when used tensor model "
      "parallel op."));
#endif
}

70 71 72
template <typename DeviceContext, typename T>
class FusedFeedForwardKernel : public framework::OpKernel<T> {
 public:
L
Leo Chen 已提交
73
  void MatMul(const phi::GPUContext& ctx,
74 75 76
              const phi::DenseTensor& a,
              const phi::DenseTensor& b,
              phi::DenseTensor* c) const {
77
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
78 79
    auto a_2d = FoldInitDims(a);
    auto b_2d = FoldInitDims(b);
80 81
    auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, false);
    auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, false);
82 83 84 85
    T alpha = static_cast<T>(1.0);
    blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0));
  }

L
Leo Chen 已提交
86
  void FFN(const phi::GPUContext& ctx,
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
           const phi::DenseTensor& x,
           const phi::DenseTensor& linear1_weight,
           const phi::DenseTensor* linear1_bias,
           const phi::DenseTensor& linear2_weight,
           const phi::DenseTensor* linear2_bias,
           const phi::DenseTensor* ln1_scale,
           const phi::DenseTensor* ln1_bias,
           const phi::DenseTensor* ln2_scale,
           const phi::DenseTensor* ln2_bias,
           phi::DenseTensor* out,
           phi::DenseTensor* dropout1_mask,
           phi::DenseTensor* dropout2_mask,
           phi::DenseTensor* ln1_mean,
           phi::DenseTensor* ln1_variance,
           phi::DenseTensor* ln2_mean,
           phi::DenseTensor* ln2_variance,
           phi::DenseTensor* linear1_out,
           phi::DenseTensor* ln1_out,
           phi::DenseTensor* dropout1_out,
           phi::DenseTensor* dropout2_out,
107 108 109 110 111 112 113 114 115 116
           const int bsz_seq,
           const int d_model,
           const int dim_feedforward,
           const std::string& act_method,
           const bool pre_layer_norm,
           const float epsilon1,
           const float epsilon2,
           const bool add_residual,
           const int ring_id,
           const DropoutParam& dropout_param1,
117
           const DropoutParam& dropout_param2) const {
118 119 120 121 122 123 124 125
    FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
        bsz_seq, d_model, epsilon1);
    FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
        ctx, bsz_seq, dim_feedforward, dropout_param1);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
        ctx, bsz_seq, d_model, dropout_param2, epsilon2);

    using U = LayerNormParamType<T>;
126
    const phi::DenseTensor* in = &x;
127 128 129 130 131 132 133 134 135 136 137 138 139

    const U* ln1_scale_ptr =
        ln1_scale == nullptr ? nullptr : ln1_scale->data<U>();
    const U* ln1_bias_ptr = ln1_bias == nullptr ? nullptr : ln1_bias->data<U>();
    const U* ln2_scale_ptr =
        ln2_scale == nullptr ? nullptr : ln2_scale->data<U>();
    const U* ln2_bias_ptr = ln2_bias == nullptr ? nullptr : ln2_bias->data<U>();
    const T* linear1_bias_ptr =
        linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
    const T* linear2_bias_ptr =
        linear2_bias == nullptr ? nullptr : linear2_bias->data<T>();

    if (pre_layer_norm) {
140 141 142 143 144 145 146
      pre_layernorm_helper.LayerNorm(ctx,
                                     x.data<T>(),
                                     ln1_scale_ptr,
                                     ln1_bias_ptr,
                                     ln1_out->data<T>(),
                                     ln1_mean->data<U>(),
                                     ln1_variance->data<U>());
147 148 149
      in = ln1_out;
    }
    MatMul(ctx, *in, linear1_weight, linear1_out);
150 151 152 153 154 155
    fused_act_dropout_helper.DropoutActBias(ctx,
                                            linear1_out->data<T>(),
                                            linear1_bias_ptr,
                                            act_method,
                                            dropout1_out->data<T>(),
                                            dropout1_mask->data<uint8_t>());
156
    phi::DenseTensor linear2_out;
157 158
    linear2_out.Resize({bsz_seq, d_model});
    ctx.Alloc<T>(&linear2_out, linear2_out.numel() * sizeof(T));
159
    MatMul(ctx, *dropout1_out, linear2_weight, &linear2_out);
160 161 162 163

    // tensor model parallel
    AllReduce<T>(linear2_out, ring_id, ctx);

164
    const T* residual_ptr = add_residual ? x.data<T>() : nullptr;
165
    if (!pre_layer_norm) {
166
      // TODO(Xreki): support post layer_norm case when add_residual is false.
167 168
      PADDLE_ENFORCE_EQ(add_residual,
                        true,
169 170 171 172
                        platform::errors::InvalidArgument(
                            "Attribute add_residual is expected to be true "
                            "when pre_layer_norm is false."));

173
      fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
174 175 176 177 178 179 180 181 182 183
          ctx,
          linear2_out.data<T>(),
          residual_ptr,
          linear2_bias_ptr,
          ln2_scale_ptr,
          ln2_bias_ptr,
          dropout2_out->data<T>(),
          dropout2_mask->data<uint8_t>(),
          out->data<T>(),
          ln2_mean->data<U>(),
184 185 186
          ln2_variance->data<U>());
    } else {
      fused_dropout_layernorm_helper.ResidualDropoutBias(
187 188 189 190 191 192
          ctx,
          linear2_out.data<T>(),
          residual_ptr,
          linear2_bias_ptr,
          out->data<T>(),
          dropout2_mask->data<uint8_t>());
193 194 195 196
    }
  }

  void Compute(const framework::ExecutionContext& context) const override {
197 198 199 200 201
    auto* x = context.Input<phi::DenseTensor>("X");
    auto* linear1_weight = context.Input<phi::DenseTensor>("Linear1Weight");
    auto* linear1_bias = context.Input<phi::DenseTensor>("Linear1Bias");
    auto* linear2_weight = context.Input<phi::DenseTensor>("Linear2Weight");
    auto* linear2_bias = context.Input<phi::DenseTensor>("Linear2Bias");
202
    const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
203
    auto& dev_ctx = context.template device_context<phi::GPUContext>();
204 205

    auto* ln1_scale =
206
        pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Scale") : nullptr;
207
    auto* ln1_bias =
208 209 210
        pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Bias") : nullptr;
    auto* ln2_scale =
        !pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Scale") : nullptr;
211
    auto* ln2_bias =
212
        !pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Bias") : nullptr;
213 214

    auto* ln1_mean =
215
        pre_layer_norm ? context.Output<phi::DenseTensor>("Ln1Mean") : nullptr;
216
    auto* ln1_variance = pre_layer_norm
217
                             ? context.Output<phi::DenseTensor>("Ln1Variance")
218
                             : nullptr;
219 220
    auto* ln2_mean =
        !pre_layer_norm ? context.Output<phi::DenseTensor>("Ln2Mean") : nullptr;
221
    auto* ln2_variance = !pre_layer_norm
222
                             ? context.Output<phi::DenseTensor>("Ln2Variance")
223
                             : nullptr;
224 225 226 227
    auto* out = context.Output<phi::DenseTensor>("Out");
    auto* dropout1_mask = context.Output<phi::DenseTensor>("Dropout1Mask");
    auto* dropout2_mask = context.Output<phi::DenseTensor>("Dropout2Mask");
    auto* linear1_out = context.Output<phi::DenseTensor>("Linear1Out");
228
    auto* ln1_out =
229 230 231
        pre_layer_norm ? context.Output<phi::DenseTensor>("Ln1Out") : nullptr;
    auto* dropout1_out = context.Output<phi::DenseTensor>("Dropout1Out");
    auto* dropout2_out = context.Output<phi::DenseTensor>("Dropout2Out");
232 233 234 235 236

    const std::string act_method = context.Attr<std::string>("act_method");

    const float epsilon1 = context.Attr<float>("ln1_epsilon");
    const float epsilon2 = context.Attr<float>("ln2_epsilon");
237
    const int ring_id = context.Attr<int>("ring_id");
238
    const bool add_residual = context.Attr<bool>("add_residual");
239 240 241 242 243

    DropoutParam dropout_param1(context, 1);
    DropoutParam dropout_param2(context, 2);

    using U = LayerNormParamType<T>;
244 245 246 247 248
    dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
    dev_ctx.Alloc<uint8_t>(dropout1_mask,
                           dropout1_mask->numel() * sizeof(uint8_t));
    dev_ctx.Alloc<uint8_t>(dropout2_mask,
                           dropout2_mask->numel() * sizeof(uint8_t));
249
    if (pre_layer_norm) {
250 251 252
      dev_ctx.Alloc<U>(ln1_mean, ln1_mean->numel() * sizeof(U));
      dev_ctx.Alloc<U>(ln1_variance, ln1_variance->numel() * sizeof(U));
      dev_ctx.Alloc<T>(ln1_out, ln1_out->numel() * sizeof(T));
253
    } else {
254 255
      dev_ctx.Alloc<U>(ln2_mean, ln2_mean->numel() * sizeof(U));
      dev_ctx.Alloc<U>(ln2_variance, ln2_variance->numel() * sizeof(U));
256 257
    }

258 259 260
    dev_ctx.Alloc<T>(linear1_out, linear1_out->numel() * sizeof(T));
    dev_ctx.Alloc<T>(dropout1_out, dropout1_out->numel() * sizeof(T));
    dev_ctx.Alloc<T>(dropout2_out, dropout2_out->numel() * sizeof(T));
261 262

    auto x_dim = x->dims();
263
    auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
264
        RowMatrixFromVector(x_dim), 0, false);
265 266 267 268 269 270

    auto dim = linear1_weight->dims();
    int d_model = dim[0];
    int dim_feedforward = dim[dim.size() - 1];
    int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;

271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
    FFN(context.cuda_device_context(),
        *x,
        *linear1_weight,
        linear1_bias,
        *linear2_weight,
        linear2_bias,
        ln1_scale,
        ln1_bias,
        ln2_scale,
        ln2_bias,
        out,
        dropout1_mask,
        dropout2_mask,
        ln1_mean,
        ln1_variance,
        ln2_mean,
        ln2_variance,
        linear1_out,
        ln1_out,
        dropout1_out,
        dropout2_out,
        bsz_seq,
        d_model,
        dim_feedforward,
        act_method,
        pre_layer_norm,
        epsilon1,
        epsilon2,
        add_residual,
        ring_id,
        dropout_param1,
        dropout_param2);
303 304 305
  }
};

306 307 308
template <typename DeviceContext, typename T>
class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
 public:
L
Leo Chen 已提交
309
  void MatMulGrad(const phi::GPUContext& ctx,
310 311 312 313 314
                  const phi::DenseTensor& d_out,
                  const phi::DenseTensor& a,
                  const phi::DenseTensor& b,
                  phi::DenseTensor* d_a,
                  phi::DenseTensor* d_b) const {
315
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
316 317
    auto a_2d = FoldInitDims(a);
    auto b_2d = FoldInitDims(b);
318 319
    auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, true);
    auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, true);
320
    auto mat_dim_dout =
321
        phi::funcs::CreateMatrixDescriptor(d_out.dims(), 0, false);
322 323 324 325 326
    T alpha = static_cast<T>(1.0);
    blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0));
    blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0));
  }

L
Leo Chen 已提交
327
  void FFNGrad(const phi::GPUContext& ctx,
328 329 330 331 332 333 334
               const phi::DenseTensor& d_out,
               const phi::DenseTensor& x,
               const phi::DenseTensor& dropout1_mask,
               const phi::DenseTensor& dropout2_mask,
               const phi::DenseTensor& linear1_out,
               const phi::DenseTensor* ln1_out,
               const phi::DenseTensor& dropout1_out,
335
               const phi::DenseTensor* dropout2_out,
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
               const phi::DenseTensor& linear1_weight,
               const phi::DenseTensor* linear1_bias,
               const phi::DenseTensor& linear2_weight,
               const phi::DenseTensor* ln1_gamma,
               const phi::DenseTensor* ln1_beta,
               const phi::DenseTensor* ln1_mean,
               const phi::DenseTensor* ln1_variance,
               const phi::DenseTensor* ln2_gamma,
               const phi::DenseTensor* ln2_beta,
               const phi::DenseTensor* ln2_mean,
               const phi::DenseTensor* ln2_variance,
               phi::DenseTensor* d_x,
               phi::DenseTensor* d_linear1_weight,
               phi::DenseTensor* d_linear1_bias,
               phi::DenseTensor* d_linear2_weight,
               phi::DenseTensor* d_linear2_bias,
               phi::DenseTensor* d_ln1_gamma,
               phi::DenseTensor* d_ln1_beta,
               phi::DenseTensor* d_ln2_gamma,
               phi::DenseTensor* d_ln2_beta,
356 357 358 359 360 361 362 363 364 365 366
               const int bsz_seq,
               const int d_model,
               const int dim_feedforward,
               const DropoutParam& dropout_param1,
               const DropoutParam& dropout_param2,
               const std::string& act_method,
               const bool pre_layer_norm,
               const float epsilon1,
               const float epsilon2,
               const bool add_residual,
               const int ring_id) const {
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
    FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
        bsz_seq, d_model, epsilon1);
    FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
        ctx, bsz_seq, dim_feedforward, dropout_param1);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
        ctx, bsz_seq, d_model, dropout_param2, epsilon2);

    using U = LayerNormParamType<T>;
    const U* ln1_gamma_ptr =
        ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
    const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>();
    const U* ln2_gamma_ptr =
        ln2_gamma == nullptr ? nullptr : ln2_gamma->data<U>();
    const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data<U>();
    const T* linear1_bias_ptr =
        linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
    T* d_linear1_bias_ptr =
        d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data<T>();
    T* d_linear2_bias_ptr =
        d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data<T>();
    U* d_ln1_gamma_ptr =
        d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data<U>();
    U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data<U>();
    U* d_ln2_gamma_ptr =
        d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data<U>();
    U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data<U>();

394
    phi::DenseTensor d_linear2_out, d_dropout2_out, d_residual;
395 396 397 398
    d_linear2_out.Resize({bsz_seq, d_model});
    ctx.Alloc<T>(&d_linear2_out, d_linear2_out.numel() * sizeof(T));
    d_dropout2_out.Resize({bsz_seq, d_model});
    ctx.Alloc<T>(&d_dropout2_out, d_dropout2_out.numel() * sizeof(T));
399

400 401
    T* d_residual_ptr = nullptr;
    if (add_residual) {
402 403 404
      d_residual.Resize(d_x->dims());
      d_residual_ptr =
          ctx.Alloc<T>(&d_residual, d_residual.numel() * sizeof(T));
405
    }
406 407
    if (pre_layer_norm) {
      fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
408 409 410 411 412 413
          ctx,
          d_out.data<T>(),
          dropout2_mask.data<uint8_t>(),
          d_linear2_out.data<T>(),
          d_residual_ptr,
          d_linear2_bias_ptr);
414 415
    } else {
      fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
416 417
          ctx,
          d_out.data<T>(),
418
          dropout2_out->data<T>(),
419 420 421 422 423 424 425 426 427
          dropout2_mask.data<uint8_t>(),
          ln2_gamma_ptr,
          ln2_mean->data<U>(),
          ln2_variance->data<U>(),
          d_dropout2_out.data<T>(),
          d_ln2_gamma_ptr,
          d_ln2_beta_ptr,
          d_linear2_out.data<T>(),
          d_linear2_bias_ptr,
428
          d_residual_ptr);
429 430
    }

431
    phi::DenseTensor d_dropout1_out;
432 433
    d_dropout1_out.Resize({bsz_seq, dim_feedforward});
    ctx.Alloc<T>(&d_dropout1_out, d_dropout1_out.numel() * sizeof(T));
434 435 436 437 438 439
    MatMulGrad(ctx,
               d_linear2_out,
               dropout1_out,
               linear2_weight,
               &d_dropout1_out,
               d_linear2_weight);
440

441
    phi::DenseTensor d_linear1_out;
442 443
    d_linear1_out.Resize({bsz_seq, dim_feedforward});
    ctx.Alloc<T>(&d_linear1_out, d_linear1_out.numel() * sizeof(T));
444 445 446 447 448 449 450 451
    fused_act_dropout_helper.DropoutActBiasGrad(ctx,
                                                d_dropout1_out.data<T>(),
                                                linear1_out.data<T>(),
                                                linear1_bias_ptr,
                                                dropout1_mask.data<uint8_t>(),
                                                d_linear1_out.data<T>(),
                                                d_linear1_bias_ptr,
                                                act_method);
452 453

    if (pre_layer_norm) {
454
      phi::DenseTensor d_ln1_out;
455 456
      d_ln1_out.Resize({bsz_seq, d_model});
      ctx.Alloc<T>(&d_ln1_out, d_ln1_out.numel() * sizeof(T));
457 458 459 460 461
      MatMulGrad(ctx,
                 d_linear1_out,
                 *ln1_out,
                 linear1_weight,
                 &d_ln1_out,
462
                 d_linear1_weight);
463 464
      // tensor model parallel
      AllReduce<T>(d_ln1_out, ring_id, ctx);
465 466 467 468 469 470 471 472 473
      pre_layernorm_helper.LayerNormGrad(ctx,
                                         d_ln1_out.data<T>(),
                                         x.data<T>(),
                                         ln1_gamma_ptr,
                                         ln1_mean->data<U>(),
                                         ln1_variance->data<U>(),
                                         d_x->data<T>(),
                                         d_ln1_gamma_ptr,
                                         d_ln1_beta_ptr);
474 475
    } else {
      MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
476 477
      // tensor model parallel
      AllReduce<T>(*d_x, ring_id, ctx);
478
    }
479 480 481

    if (add_residual) {
      // gradient accumulation
482 483
      std::vector<const phi::DenseTensor*> ins = {&d_residual, d_x};
      std::vector<phi::DenseTensor*> outs = {d_x};
484 485
      phi::funcs::ElementwiseKernel<T>(
          ctx, ins, &outs, phi::funcs::AddFunctor<T>());
486
    }
487 488 489 490
  }

  void Compute(const framework::ExecutionContext& context) const override {
    using U = LayerNormParamType<T>;
491
    auto& dev_ctx = context.template device_context<phi::GPUContext>();
492
    auto d_out =
493 494
        *context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    auto x = *context.Input<phi::DenseTensor>("X");
495
    const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
496 497 498
    auto dropout1_mask = *context.Input<phi::DenseTensor>("Dropout1Mask");
    auto dropout2_mask = *context.Input<phi::DenseTensor>("Dropout2Mask");
    auto linear1_out = *context.Input<phi::DenseTensor>("Linear1Out");
499
    auto* ln1_out =
500 501
        pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Out") : nullptr;
    auto dropout1_out = *context.Input<phi::DenseTensor>("Dropout1Out");
502
    auto* dropout2_out = context.Input<phi::DenseTensor>("Dropout2Out");
503 504 505
    auto linear1_weight = *context.Input<phi::DenseTensor>("Linear1Weight");
    auto* linear1_bias = context.Input<phi::DenseTensor>("Linear1Bias");
    auto linear2_weight = *context.Input<phi::DenseTensor>("Linear2Weight");
506
    auto* ln1_mean =
507
        pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Mean") : nullptr;
508
    auto* ln1_variance = pre_layer_norm
509
                             ? context.Input<phi::DenseTensor>("Ln1Variance")
510 511
                             : nullptr;
    auto* ln1_scale =
512
        pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Scale") : nullptr;
513
    auto* ln1_bias =
514
        pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Bias") : nullptr;
515
    auto* ln2_mean =
516
        !pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Mean") : nullptr;
517
    auto* ln2_variance = !pre_layer_norm
518
                             ? context.Input<phi::DenseTensor>("Ln2Variance")
519
                             : nullptr;
520 521
    auto* ln2_scale =
        !pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Scale") : nullptr;
522
    auto* ln2_bias =
523
        !pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Bias") : nullptr;
524

525 526
    auto* d_x = context.Output<phi::DenseTensor>(framework::GradVarName("X"));
    auto* d_ln1_scale = pre_layer_norm ? context.Output<phi::DenseTensor>(
527 528
                                             framework::GradVarName("Ln1Scale"))
                                       : nullptr;
529
    auto* d_ln1_bias = pre_layer_norm ? context.Output<phi::DenseTensor>(
530 531 532 533
                                            framework::GradVarName("Ln1Bias"))
                                      : nullptr;
    auto* d_ln2_scale = pre_layer_norm
                            ? nullptr
534
                            : context.Output<phi::DenseTensor>(
535 536
                                  framework::GradVarName("Ln2Scale"));
    auto* d_ln2_bias = pre_layer_norm ? nullptr
537
                                      : context.Output<phi::DenseTensor>(
538
                                            framework::GradVarName("Ln2Bias"));
539
    auto* d_linear1_weight = context.Output<phi::DenseTensor>(
540
        framework::GradVarName("Linear1Weight"));
541 542 543
    auto* d_linear1_bias =
        context.Output<phi::DenseTensor>(framework::GradVarName("Linear1Bias"));
    auto* d_linear2_weight = context.Output<phi::DenseTensor>(
544
        framework::GradVarName("Linear2Weight"));
545 546
    auto* d_linear2_bias =
        context.Output<phi::DenseTensor>(framework::GradVarName("Linear2Bias"));
547 548 549

    const float epsilon1 = context.Attr<float>("ln1_epsilon");
    const float epsilon2 = context.Attr<float>("ln2_epsilon");
550
    const bool add_residual = context.Attr<bool>("add_residual");
551
    const int ring_id = context.Attr<int>("ring_id");
552 553 554 555
    const std::string act_method = context.Attr<std::string>("act_method");
    DropoutParam dropout_param1(context, 1);
    DropoutParam dropout_param2(context, 2);

556
    dev_ctx.Alloc<T>(d_x, d_x->numel() * sizeof(T));
557
    if (d_ln1_scale) {
558
      dev_ctx.Alloc<U>(d_ln1_scale, d_ln1_scale->numel() * sizeof(U));
559 560
    }
    if (d_ln1_bias) {
561
      dev_ctx.Alloc<U>(d_ln1_bias, d_ln1_bias->numel() * sizeof(U));
562 563
    }
    if (d_ln2_scale) {
564
      dev_ctx.Alloc<U>(d_ln2_scale, d_ln2_scale->numel() * sizeof(U));
565 566
    }
    if (d_ln2_bias) {
567
      dev_ctx.Alloc<U>(d_ln2_bias, d_ln2_bias->numel() * sizeof(U));
568 569
    }
    if (d_linear1_bias) {
570
      dev_ctx.Alloc<T>(d_linear1_bias, d_linear1_bias->numel() * sizeof(T));
571 572
    }
    if (d_linear2_bias) {
573
      dev_ctx.Alloc<T>(d_linear2_bias, d_linear2_bias->numel() * sizeof(T));
574
    }
575 576
    dev_ctx.Alloc<T>(d_linear1_weight, d_linear1_weight->numel() * sizeof(T));
    dev_ctx.Alloc<T>(d_linear2_weight, d_linear2_weight->numel() * sizeof(T));
577 578

    auto x_dim = x.dims();
579
    auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
580
        RowMatrixFromVector(x_dim), 0, false);
581 582 583 584 585 586

    auto linear1_weight_dim = linear1_weight.dims();
    int d_model = linear1_weight_dim[0];
    int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
    int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;

587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
    FFNGrad(context.cuda_device_context(),
            d_out,
            x,
            dropout1_mask,
            dropout2_mask,
            linear1_out,
            ln1_out,
            dropout1_out,
            dropout2_out,
            linear1_weight,
            linear1_bias,
            linear2_weight,
            ln1_scale,
            ln1_bias,
            ln1_mean,
            ln1_variance,
            ln2_scale,
            ln2_bias,
            ln2_mean,
            ln2_variance,
            d_x,
            d_linear1_weight,
            d_linear1_bias,
            d_linear2_weight,
            d_linear2_bias,
            d_ln1_scale,
            d_ln1_bias,
            d_ln2_scale,
            d_ln2_bias,
            bsz_seq,
            d_model,
            dim_feedforward,
            dropout_param1,
            dropout_param2,
            act_method,
            pre_layer_norm,
            epsilon1,
            epsilon2,
            add_residual,
626
            ring_id);
627 628
  }
};
629 630 631 632 633 634
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    fused_feedforward,
L
Leo Chen 已提交
635 636 637
    ops::FusedFeedForwardKernel<phi::GPUContext, float>,
    ops::FusedFeedForwardKernel<phi::GPUContext, double>,
    ops::FusedFeedForwardKernel<phi::GPUContext, paddle::platform::float16>);
638 639
REGISTER_OP_CUDA_KERNEL(
    fused_feedforward_grad,
L
Leo Chen 已提交
640 641 642
    ops::FusedFeedForwardGradKernel<phi::GPUContext, float>,
    ops::FusedFeedForwardGradKernel<phi::GPUContext, double>,
    ops::FusedFeedForwardGradKernel<phi::GPUContext,
643
                                    paddle::platform::float16>);