fused_gemm_epilogue_op.cu 26.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

16
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
17 18
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
19
#include "paddle/fluid/framework/scope_guard.h"
20 21 22 23 24 25 26 27 28 29 30 31
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
L
Leo Chen 已提交
32
    auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
33 34 35 36 37 38 39 40 41 42 43 44

    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* y = ctx.Input<Tensor>("Y");
    const Tensor* bias = ctx.Input<Tensor>("Bias");

    Tensor* out = ctx.Output<Tensor>("Out");
    Tensor* reserve_space = ctx.Output<Tensor>("ReserveSpace");

    bool trans_x = ctx.Attr<bool>("trans_x");
    bool trans_y = ctx.Attr<bool>("trans_y");

    std::string activation = ctx.Attr<std::string>("activation");
45 46
    VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y
             << " , activation = " << activation;
47 48
    bool enable_auxiliary = reserve_space == nullptr ? false : true;

49
    dev_ctx->Alloc<T>(out, out->numel() * sizeof(T));
50 51 52 53
    auto* out_data = out->data<T>();

    auto x_mat_dims =
        phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
54
    // (M * K) * (K * N)
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
    int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
    int64_t N = trans_y ? y->dims()[0] : y->dims()[1];

    cudaDataType_t mat_type = CUDA_R_32F;
    cudaDataType_t scale_type = CUDA_R_32F;
    cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
    if (std::is_same<T, paddle::platform::float16>::value) {
      mat_type = CUDA_R_16F;
    }
    if (std::is_same<T, double>::value) {
      mat_type = CUDA_R_64F;
      scale_type = CUDA_R_64F;
      compute_type = CUBLAS_COMPUTE_64F;
    }

    cublasLtMatmulDesc_t operation_desc = NULL;
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
        &operation_desc, compute_type, scale_type));
    cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
    cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
78 79 80
            operation_desc,
            CUBLASLT_MATMUL_DESC_TRANSB,
            &transx,
81 82 83
            sizeof(transx)));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
84 85 86
            operation_desc,
            CUBLASLT_MATMUL_DESC_TRANSA,
            &transy,
87 88 89 90 91 92
            sizeof(transy)));

    cublasLtEpilogue_t epiloque_func =
        get_epilogue_type_(activation, enable_auxiliary);
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
93 94 95
            operation_desc,
            CUBLASLT_MATMUL_DESC_EPILOGUE,
            &epiloque_func,
96 97 98 99
            sizeof(epiloque_func)));
    const T* bias_data = bias->data<T>();
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
100 101 102
            operation_desc,
            CUBLASLT_MATMUL_DESC_BIAS_POINTER,
            &bias_data,
103 104 105 106 107 108 109 110 111 112
            sizeof(bias_data)));

    if (enable_auxiliary && activation != "none") {
      size_t reserve_space_size = 0;
      if (activation == "relu") {
        // Count in bits.
        reserve_space_size = phi::product(out->dims()) / 8;
      } else {
        reserve_space_size = phi::product(out->dims()) * sizeof(T);
      }
113
      dev_ctx->Alloc(reserve_space, out->type(), reserve_space_size);
114 115 116 117
      void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());

      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
118 119 120 121
              operation_desc,
              CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
              &aux_data,
              sizeof(aux_data)));
122
      int64_t aux_ld = N;
123 124
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
125 126 127
              operation_desc,
              CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
              &aux_ld,
128
              sizeof(aux_ld)));
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    }

    cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
    if (trans_x)
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &x_desc, mat_type, M, K, M));
    else
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &x_desc, mat_type, K, M, K));
    if (trans_y)
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &y_desc, mat_type, K, N, K));
    else
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &y_desc, mat_type, N, K, N));
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
        &out_desc, mat_type, N, M, N));

    cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
S
sneaxiy 已提交
148 149 150
    // NOTE(zengjinle): I do not know whether the 4MB workspace size is
    // "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
    size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
151
    cudaStream_t stream = dev_ctx.stream();
L
Leo Chen 已提交
152 153 154 155
    memory::allocation::AllocationPtr workspace = memory::Alloc(
        dev_ctx.GetPlace(),
        workspace_size,
        phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
156 157 158 159 160 161 162 163 164 165 166 167

    double alpha64 = 1.0, beta64 = 0.0;
    float alpha32 = 1.0f, beta32 = 0.0f;
    void *alpha = nullptr, *beta = nullptr;
    if (std::is_same<T, double>::value) {
      alpha = &alpha64;
      beta = &beta64;
    } else {
      alpha = &alpha32;
      beta = &beta32;
    }

168 169 170
    const auto* y_data = y->data<T>();
    const auto* x_data = x->data<T>();

171 172 173 174 175 176 177 178 179 180 181 182 183
    auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
                                                              operation_desc,
                                                              y_desc,
                                                              x_desc,
                                                              out_desc,
                                                              alpha,
                                                              beta,
                                                              y_data,
                                                              x_data,
                                                              out_data,
                                                              stream,
                                                              workspace->ptr(),
                                                              workspace_size);
184

185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmul(lt_handle,
                                          operation_desc,
                                          alpha,
                                          y_data,
                                          y_desc,
                                          x_data,
                                          x_desc,
                                          beta,
                                          out_data,
                                          out_desc,
                                          out_data,
                                          out_desc,
                                          algo,
                                          workspace->ptr(),
                                          workspace_size,
                                          stream));
202 203 204 205 206 207 208 209 210

    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescDestroy(operation_desc));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatrixLayoutDestroy(out_desc));
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
  }

 private:
  static cublasLtEpilogue_t get_epilogue_type_(const std::string& activation,
                                               bool enable_auxiliary) {
    if (activation == "relu") {
      return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS
                              : CUBLASLT_EPILOGUE_RELU_BIAS;
    } else if (activation == "gelu") {
      return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS
                              : CUBLASLT_EPILOGUE_GELU_BIAS;
    } else if (activation == "none") {
      return CUBLASLT_EPILOGUE_BIAS;
    } else {
      PADDLE_ENFORCE_EQ(
226 227
          true,
          false,
228 229 230 231 232 233 234 235 236
          platform::errors::InvalidArgument(
              "The activation attribute of fused_gemm_epilogue op should be"
              " one of {\"none\", \"relu\", \"gelu\"}. But received %s."
              "But received activation=%s.",
              activation));
    }
  }
};

237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };

template <bool TransX, bool TransY>
struct FusedGEMMGradTrait;

template <>
struct FusedGEMMGradTrait<false, false> {
  static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
  static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
  static constexpr auto kXGradATrans = false;
  static constexpr auto kXGradBTrans = true;

  static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
  static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
  static constexpr auto kYGradATrans = true;
  static constexpr auto kYGradBTrans = false;
};

template <>
struct FusedGEMMGradTrait<true, false> {
  static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
  static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
  static constexpr auto kXGradATrans = false;
  static constexpr auto kXGradBTrans = true;

  static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
  static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
  static constexpr auto kYGradATrans = false;
  static constexpr auto kYGradBTrans = false;
};

template <>
struct FusedGEMMGradTrait<false, true> {
  static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
  static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
  static constexpr auto kXGradATrans = false;
  static constexpr auto kXGradBTrans = false;

  static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
  static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
  static constexpr auto kYGradATrans = true;
  static constexpr auto kYGradBTrans = false;
};

template <>
struct FusedGEMMGradTrait<true, true> {
  static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
  static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
  static constexpr auto kXGradATrans = true;
  static constexpr auto kXGradBTrans = true;

  static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
  static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
  static constexpr auto kYGradATrans = true;
  static constexpr auto kYGradBTrans = true;
};

static constexpr auto BoolToCuBlasEnum(bool transpose) {
  return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
}

298 299 300 301
template <typename DeviceContext, typename T>
class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
302 303
    bool transpose_x = ctx.Attr<bool>("trans_x");
    bool transpose_y = ctx.Attr<bool>("trans_y");
304

305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
    if (transpose_x) {
      if (transpose_y) {
        ComputeImpl<true, true>(ctx);
      } else {
        ComputeImpl<true, false>(ctx);
      }
    } else {
      if (transpose_y) {
        ComputeImpl<false, true>(ctx);
      } else {
        ComputeImpl<false, false>(ctx);
      }
    }
  }

 private:
  template <bool TransX, bool TransY>
  static void ComputeImpl(const framework::ExecutionContext& ctx) {
    using Trait = FusedGEMMGradTrait<TransX, TransY>;
L
Leo Chen 已提交
324
    auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
325 326 327 328 329 330 331 332 333 334 335
    const Tensor* dout = ctx.Input<Tensor>("DOut");
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* y = ctx.Input<Tensor>("Y");
    const Tensor* reserve_space = ctx.Input<Tensor>("ReserveSpace");

    Tensor* dx = ctx.Output<Tensor>("DX");
    Tensor* dy = ctx.Output<Tensor>("DY");
    Tensor* dbias = ctx.Output<Tensor>("DBias");

    std::string activation_grad = ctx.Attr<std::string>("activation_grad");

336 337 338 339 340 341 342 343 344 345
    VLOG(10) << "trans_x = " << TransX << " , trans_y = " << TransY
             << " , activation_grad = " << activation_grad;

    auto x_mat_dims =
        phi::flatten_to_2d(x->dims(), TransX ? 1 : x->dims().size() - 1);

    // (M * K) * (K * N)
    int64_t M = TransX ? x_mat_dims[1] : x_mat_dims[0];
    int64_t K = TransY ? y->dims()[1] : y->dims()[0];
    int64_t N = TransY ? y->dims()[0] : y->dims()[1];
346

347
    VLOG(10) << "M = " << M << " , K = " << K << " , N = " << N;
348 349 350 351 352 353 354 355 356 357 358 359 360 361

    cudaDataType_t mat_type = CUDA_R_32F;
    cudaDataType_t scale_type = CUDA_R_32F;
    cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
    if (std::is_same<T, paddle::platform::float16>::value) {
      mat_type = CUDA_R_16F;
    }
    if (std::is_same<T, double>::value) {
      mat_type = CUDA_R_64F;
      scale_type = CUDA_R_64F;
      compute_type = CUBLAS_COMPUTE_64F;
    }

    cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
S
sneaxiy 已提交
362 363 364
    // NOTE(zengjinle): I do not know whether the 4MB workspace size is
    // "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
    size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
365
    const cublasLtMatmulAlgo_t* algo = nullptr;
366 367 368 369 370 371 372 373 374 375 376 377 378
    cudaStream_t stream = dev_ctx.stream();

    double alpha64 = 1.0, beta64 = 0.0;
    float alpha32 = 1.0f, beta32 = 0.0f;
    void *alpha = nullptr, *beta = nullptr;
    if (std::is_same<T, double>::value) {
      alpha = &alpha64;
      beta = &beta64;
    } else {
      alpha = &alpha32;
      beta = &beta32;
    }

379 380 381 382 383 384 385 386
    cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
    cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr;
    cublasLtMatrixLayout_t y_desc = nullptr, y_trans_desc = nullptr;
    cublasLtMatrixLayout_t dx_desc = nullptr, dy_desc = nullptr;
    cublasLtMatmulDesc_t dx_operation_desc = nullptr,
                         dy_operation_desc = nullptr;

    DEFINE_PADDLE_SCOPE_GUARD([&] {
387 388 389 390 391 392 393 394
      auto descs = {dout_desc,
                    dout_trans_desc,
                    x_desc,
                    x_trans_desc,
                    y_desc,
                    y_trans_desc,
                    dx_desc,
                    dy_desc};
395 396 397 398 399 400
      for (auto desc : descs) {
        if (desc) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::cublasLtMatrixLayoutDestroy(desc));
        }
      }
401

402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
      if (dx_operation_desc) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
      }

      if (dy_operation_desc) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
      }
    });

    auto x_row = TransX ? K : M;
    auto x_col = TransX ? M : K;
    auto y_row = TransY ? N : K;
    auto y_col = TransY ? K : N;
    auto z_row = TransX ? N : M;
    auto z_col = TransX ? M : N;

    // dx = func(dout, y)
421
    if (dx) {
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
      constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ);
      cublasLtMatrixLayout_t *dx_dout_desc, *dx_y_desc;

      if (TransX) {
        dx_dout_desc = &dout_trans_desc;
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(
                dx_dout_desc, mat_type, z_row, z_col, z_row));
      } else {
        dx_dout_desc = &dout_desc;
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(
                dx_dout_desc, mat_type, z_col, z_row, z_col));
      }

      dx_y_desc = &y_trans_desc;
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          dx_y_desc, mat_type, y_col, y_row, y_col));

      auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc);
      auto& b_desc = kXGradAIsDZ ? (*dx_y_desc) : (*dx_dout_desc);
      auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans);
      auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);

      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &dx_desc, mat_type, x_col, x_row, x_col));

449 450 451 452
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
          &dx_operation_desc, compute_type, scale_type));
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
453 454 455
              dx_operation_desc,
              CUBLASLT_MATMUL_DESC_TRANSB,
              &a_trans,
456
              sizeof(a_trans)));
457 458
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
459 460 461
              dx_operation_desc,
              CUBLASLT_MATMUL_DESC_TRANSA,
              &b_trans,
462 463
              sizeof(b_trans)));

464 465 466 467
      cublasLtEpilogue_t epiloque_func_for_dx =
          get_epilogue_type_(activation_grad);
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
468 469 470 471
              dx_operation_desc,
              CUBLASLT_MATMUL_DESC_EPILOGUE,
              &epiloque_func_for_dx,
              sizeof(epiloque_func_for_dx)));
472 473 474 475 476

      if (activation_grad != "none") {
        auto* aux_data = reserve_space->data<T>();
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
477 478 479 480
                dx_operation_desc,
                CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                &aux_data,
                sizeof(aux_data)));
481
        int64_t aux_ld = TransX ? M : K;
482 483
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
484 485 486 487
                dx_operation_desc,
                CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
                &aux_ld,
                sizeof(aux_ld)));
488 489
      }

L
Leo Chen 已提交
490 491 492 493
      auto dx_workspace = memory::Alloc(
          dev_ctx.GetPlace(),
          workspace_size,
          phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
494

495
      auto* dx_data = dev_ctx->Alloc<T>(dx, dx->numel() * sizeof(T));
496 497
      const auto* y_data = y->data<T>();
      const auto* dout_data = dout->data<T>();
498 499
      const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
      const auto* b_data = kXGradAIsDZ ? y_data : dout_data;
500

501 502 503 504 505 506 507 508 509 510 511 512 513 514
      auto algo =
          GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
                                                        dx_operation_desc,
                                                        b_desc,
                                                        a_desc,
                                                        dx_desc,
                                                        alpha,
                                                        beta,
                                                        b_data,
                                                        a_data,
                                                        dx_data,
                                                        stream,
                                                        dx_workspace->ptr(),
                                                        workspace_size);
515

516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmul(lt_handle,
                                            dx_operation_desc,
                                            alpha,
                                            b_data,
                                            b_desc,
                                            a_data,
                                            a_desc,
                                            beta,
                                            dx_data,
                                            dx_desc,
                                            dx_data,
                                            dx_desc,
                                            algo,
                                            dx_workspace->ptr(),
                                            workspace_size,
                                            stream));
533 534
    }

535
    // dy = func(dout, x)
536
    if (dy) {
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
      constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ);

      cublasLtMatrixLayout_t *dy_dout_desc = nullptr, *dy_x_desc = nullptr;
      if (TransX) {
        dy_dout_desc = &dout_trans_desc;
        if (dout_trans_desc == nullptr) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::cublasLtMatrixLayoutCreate(
                  dy_dout_desc, mat_type, z_row, z_col, z_row));
        }
      } else {
        dy_dout_desc = &dout_desc;
        if (dout_desc == nullptr) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::cublasLtMatrixLayoutCreate(
                  dy_dout_desc, mat_type, z_col, z_row, z_col));
        }
      }

      dy_x_desc = &x_trans_desc;
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          dy_x_desc, mat_type, x_col, x_row, x_col));

      auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc);
      auto& b_desc = kYGradAIsDZ ? (*dy_x_desc) : (*dy_dout_desc);
      auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans);
      auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);

      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &dy_desc, mat_type, y_col, y_row, y_col));

568 569
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
          &dy_operation_desc, compute_type, scale_type));
570

571 572
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
573 574 575
              dy_operation_desc,
              CUBLASLT_MATMUL_DESC_TRANSB,
              &a_trans,
576
              sizeof(a_trans)));
577 578
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
579 580 581
              dy_operation_desc,
              CUBLASLT_MATMUL_DESC_TRANSA,
              &b_trans,
582 583 584 585 586 587 588 589 590 591 592 593 594
              sizeof(b_trans)));

      cublasLtEpilogue_t epiloque_func_for_dy;
      if (dbias == nullptr) {
        epiloque_func_for_dy = CUBLASLT_EPILOGUE_DEFAULT;
      } else {
        if (TransY) {
          epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADB;
        } else {
          epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADA;
        }
      }

595 596
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
597 598 599 600
              dy_operation_desc,
              CUBLASLT_MATMUL_DESC_EPILOGUE,
              &epiloque_func_for_dy,
              sizeof(epiloque_func_for_dy)));
601 602

      if (dbias) {
603
        auto* dbias_data = dev_ctx->Alloc<T>(dbias, dbias->numel() * sizeof(T));
604 605
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
606 607 608 609
                dy_operation_desc,
                CUBLASLT_MATMUL_DESC_BIAS_POINTER,
                &dbias_data,
                sizeof(dbias_data)));
610 611
      }

L
Leo Chen 已提交
612 613 614 615
      auto dy_workspace = memory::Alloc(
          dev_ctx.GetPlace(),
          workspace_size,
          phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
616
      auto* dy_data = dev_ctx->Alloc<T>(dy, dy->numel() * sizeof(T));
617 618
      const auto* dout_data = dout->data<T>();
      const auto* x_data = x->data<T>();
619 620
      const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
      const auto* b_data = kYGradAIsDZ ? x_data : dout_data;
621

622 623 624 625 626 627 628 629 630 631 632 633 634 635
      auto algo =
          GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
                                                        dy_operation_desc,
                                                        b_desc,
                                                        a_desc,
                                                        dy_desc,
                                                        alpha,
                                                        beta,
                                                        b_data,
                                                        a_data,
                                                        dy_data,
                                                        stream,
                                                        dy_workspace->ptr(),
                                                        workspace_size);
636

637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmul(lt_handle,
                                            dy_operation_desc,
                                            alpha,
                                            b_data,
                                            b_desc,
                                            a_data,
                                            a_desc,
                                            beta,
                                            dy_data,
                                            dy_desc,
                                            dy_data,
                                            dy_desc,
                                            algo,
                                            dy_workspace->ptr(),
                                            workspace_size,
                                            stream));
654 655 656 657 658 659 660 661 662 663 664 665 666 667
    }
  }

 private:
  static cublasLtEpilogue_t get_epilogue_type_(
      const std::string& activation_grad) {
    if (activation_grad == "relu_grad") {
      return CUBLASLT_EPILOGUE_DRELU;
    } else if (activation_grad == "gelu_grad") {
      return CUBLASLT_EPILOGUE_DGELU;
    } else if (activation_grad == "none") {
      return CUBLASLT_EPILOGUE_DEFAULT;
    } else {
      PADDLE_ENFORCE_EQ(
668 669
          true,
          false,
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
          platform::errors::InvalidArgument(
              "The activation_grad attribute of fused_gemm_epilogue op should "
              "be"
              " one of {\"none\", \"relu\", \"gelu\"}. But received %s."
              "But received activation_grad=%s.",
              activation_grad));
    }
  }
};

}  // namespace operators
}  // namespace paddle

#if CUDA_VERSION >= 11060
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    fused_gemm_epilogue,
L
Leo Chen 已提交
687 688 689
    ops::FusedGemmEpilogueKernel<phi::GPUContext, float>,
    ops::FusedGemmEpilogueKernel<phi::GPUContext, double>,
    ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::float16>);
690 691 692

REGISTER_OP_CUDA_KERNEL(
    fused_gemm_epilogue_grad,
L
Leo Chen 已提交
693 694 695
    ops::FusedGemmEpilogueGradKernel<phi::GPUContext, float>,
    ops::FusedGemmEpilogueGradKernel<phi::GPUContext, double>,
    ops::FusedGemmEpilogueGradKernel<phi::GPUContext,
696 697
                                     paddle::platform::float16>);
#endif