fused_gemm_epilogue_op.cu 27.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

16
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
17 18
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
19
#include "paddle/fluid/framework/scope_guard.h"
20
#include "paddle/fluid/platform/bfloat16.h"
21 22 23 24 25 26 27 28 29 30
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
L
Leo Chen 已提交
31
    auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
32

33 34 35
    const phi::DenseTensor* x = ctx.Input<phi::DenseTensor>("X");
    const phi::DenseTensor* y = ctx.Input<phi::DenseTensor>("Y");
    const phi::DenseTensor* bias = ctx.Input<phi::DenseTensor>("Bias");
36

37 38 39
    phi::DenseTensor* out = ctx.Output<phi::DenseTensor>("Out");
    phi::DenseTensor* reserve_space =
        ctx.Output<phi::DenseTensor>("ReserveSpace");
40 41 42 43 44

    bool trans_x = ctx.Attr<bool>("trans_x");
    bool trans_y = ctx.Attr<bool>("trans_y");

    std::string activation = ctx.Attr<std::string>("activation");
45 46
    VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y
             << " , activation = " << activation;
47 48
    bool enable_auxiliary = reserve_space == nullptr ? false : true;

49
    dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
50 51 52 53
    auto* out_data = out->data<T>();

    auto x_mat_dims =
        phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
54
    // (M * K) * (K * N)
55 56 57 58 59 60 61 62 63 64
    int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
    int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
    int64_t N = trans_y ? y->dims()[0] : y->dims()[1];

    cudaDataType_t mat_type = CUDA_R_32F;
    cudaDataType_t scale_type = CUDA_R_32F;
    cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
    if (std::is_same<T, paddle::platform::float16>::value) {
      mat_type = CUDA_R_16F;
    }
65 66 67
    if (std::is_same<T, platform::bfloat16>::value) {
      mat_type = CUDA_R_16BF;
    }
68 69 70 71 72 73 74 75 76 77 78 79 80
    if (std::is_same<T, double>::value) {
      mat_type = CUDA_R_64F;
      scale_type = CUDA_R_64F;
      compute_type = CUBLAS_COMPUTE_64F;
    }

    cublasLtMatmulDesc_t operation_desc = NULL;
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
        &operation_desc, compute_type, scale_type));
    cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
    cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
81 82 83
            operation_desc,
            CUBLASLT_MATMUL_DESC_TRANSB,
            &transx,
84 85 86
            sizeof(transx)));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
87 88 89
            operation_desc,
            CUBLASLT_MATMUL_DESC_TRANSA,
            &transy,
90 91 92 93 94 95
            sizeof(transy)));

    cublasLtEpilogue_t epiloque_func =
        get_epilogue_type_(activation, enable_auxiliary);
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
96 97 98
            operation_desc,
            CUBLASLT_MATMUL_DESC_EPILOGUE,
            &epiloque_func,
99 100 101 102
            sizeof(epiloque_func)));
    const T* bias_data = bias->data<T>();
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
103 104 105
            operation_desc,
            CUBLASLT_MATMUL_DESC_BIAS_POINTER,
            &bias_data,
106 107 108 109 110 111 112 113 114 115
            sizeof(bias_data)));

    if (enable_auxiliary && activation != "none") {
      size_t reserve_space_size = 0;
      if (activation == "relu") {
        // Count in bits.
        reserve_space_size = phi::product(out->dims()) / 8;
      } else {
        reserve_space_size = phi::product(out->dims()) * sizeof(T);
      }
116
      dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size);
117 118 119 120
      void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());

      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
121 122 123 124
              operation_desc,
              CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
              &aux_data,
              sizeof(aux_data)));
125
      int64_t aux_ld = N;
126 127
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
128 129 130
              operation_desc,
              CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
              &aux_ld,
131
              sizeof(aux_ld)));
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
    }

    cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
    if (trans_x)
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &x_desc, mat_type, M, K, M));
    else
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &x_desc, mat_type, K, M, K));
    if (trans_y)
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &y_desc, mat_type, K, N, K));
    else
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &y_desc, mat_type, N, K, N));
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
        &out_desc, mat_type, N, M, N));

    cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
S
sneaxiy 已提交
151 152 153
    // NOTE(zengjinle): I do not know whether the 4MB workspace size is
    // "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
    size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
154
    cudaStream_t stream = dev_ctx.stream();
L
Leo Chen 已提交
155 156 157 158
    memory::allocation::AllocationPtr workspace = memory::Alloc(
        dev_ctx.GetPlace(),
        workspace_size,
        phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
159 160 161 162 163 164 165 166 167 168 169 170

    double alpha64 = 1.0, beta64 = 0.0;
    float alpha32 = 1.0f, beta32 = 0.0f;
    void *alpha = nullptr, *beta = nullptr;
    if (std::is_same<T, double>::value) {
      alpha = &alpha64;
      beta = &beta64;
    } else {
      alpha = &alpha32;
      beta = &beta32;
    }

171 172 173
    const auto* y_data = y->data<T>();
    const auto* x_data = x->data<T>();

174 175 176 177 178 179 180 181 182 183 184 185 186
    auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
                                                              operation_desc,
                                                              y_desc,
                                                              x_desc,
                                                              out_desc,
                                                              alpha,
                                                              beta,
                                                              y_data,
                                                              x_data,
                                                              out_data,
                                                              stream,
                                                              workspace->ptr(),
                                                              workspace_size);
187

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmul(lt_handle,
                                          operation_desc,
                                          alpha,
                                          y_data,
                                          y_desc,
                                          x_data,
                                          x_desc,
                                          beta,
                                          out_data,
                                          out_desc,
                                          out_data,
                                          out_desc,
                                          algo,
                                          workspace->ptr(),
                                          workspace_size,
                                          stream));
205 206 207 208 209 210 211 212 213

    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescDestroy(operation_desc));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatrixLayoutDestroy(out_desc));
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
  }

 private:
  static cublasLtEpilogue_t get_epilogue_type_(const std::string& activation,
                                               bool enable_auxiliary) {
    if (activation == "relu") {
      return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS
                              : CUBLASLT_EPILOGUE_RELU_BIAS;
    } else if (activation == "gelu") {
      return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS
                              : CUBLASLT_EPILOGUE_GELU_BIAS;
    } else if (activation == "none") {
      return CUBLASLT_EPILOGUE_BIAS;
    } else {
      PADDLE_ENFORCE_EQ(
229 230
          true,
          false,
231 232 233 234 235 236 237 238 239
          platform::errors::InvalidArgument(
              "The activation attribute of fused_gemm_epilogue op should be"
              " one of {\"none\", \"relu\", \"gelu\"}. But received %s."
              "But received activation=%s.",
              activation));
    }
  }
};

240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };

template <bool TransX, bool TransY>
struct FusedGEMMGradTrait;

template <>
struct FusedGEMMGradTrait<false, false> {
  static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
  static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
  static constexpr auto kXGradATrans = false;
  static constexpr auto kXGradBTrans = true;

  static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
  static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
  static constexpr auto kYGradATrans = true;
  static constexpr auto kYGradBTrans = false;
};

template <>
struct FusedGEMMGradTrait<true, false> {
  static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
  static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
  static constexpr auto kXGradATrans = false;
  static constexpr auto kXGradBTrans = true;

  static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
  static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
  static constexpr auto kYGradATrans = false;
  static constexpr auto kYGradBTrans = false;
};

template <>
struct FusedGEMMGradTrait<false, true> {
  static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
  static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
  static constexpr auto kXGradATrans = false;
  static constexpr auto kXGradBTrans = false;

  static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
  static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
  static constexpr auto kYGradATrans = true;
  static constexpr auto kYGradBTrans = false;
};

template <>
struct FusedGEMMGradTrait<true, true> {
  static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
  static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
  static constexpr auto kXGradATrans = true;
  static constexpr auto kXGradBTrans = true;

  static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
  static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
  static constexpr auto kYGradATrans = true;
  static constexpr auto kYGradBTrans = true;
};

static constexpr auto BoolToCuBlasEnum(bool transpose) {
  return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
}

301 302 303 304
template <typename DeviceContext, typename T>
class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
305 306
    bool transpose_x = ctx.Attr<bool>("trans_x");
    bool transpose_y = ctx.Attr<bool>("trans_y");
307

308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
    if (transpose_x) {
      if (transpose_y) {
        ComputeImpl<true, true>(ctx);
      } else {
        ComputeImpl<true, false>(ctx);
      }
    } else {
      if (transpose_y) {
        ComputeImpl<false, true>(ctx);
      } else {
        ComputeImpl<false, false>(ctx);
      }
    }
  }

 private:
  template <bool TransX, bool TransY>
  static void ComputeImpl(const framework::ExecutionContext& ctx) {
    using Trait = FusedGEMMGradTrait<TransX, TransY>;
L
Leo Chen 已提交
327
    auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
328 329 330 331 332 333 334 335 336
    const phi::DenseTensor* dout = ctx.Input<phi::DenseTensor>("DOut");
    const phi::DenseTensor* x = ctx.Input<phi::DenseTensor>("X");
    const phi::DenseTensor* y = ctx.Input<phi::DenseTensor>("Y");
    const phi::DenseTensor* reserve_space =
        ctx.Input<phi::DenseTensor>("ReserveSpace");

    phi::DenseTensor* dx = ctx.Output<phi::DenseTensor>("DX");
    phi::DenseTensor* dy = ctx.Output<phi::DenseTensor>("DY");
    phi::DenseTensor* dbias = ctx.Output<phi::DenseTensor>("DBias");
337 338 339

    std::string activation_grad = ctx.Attr<std::string>("activation_grad");

340 341 342 343 344 345 346 347 348 349
    VLOG(10) << "trans_x = " << TransX << " , trans_y = " << TransY
             << " , activation_grad = " << activation_grad;

    auto x_mat_dims =
        phi::flatten_to_2d(x->dims(), TransX ? 1 : x->dims().size() - 1);

    // (M * K) * (K * N)
    int64_t M = TransX ? x_mat_dims[1] : x_mat_dims[0];
    int64_t K = TransY ? y->dims()[1] : y->dims()[0];
    int64_t N = TransY ? y->dims()[0] : y->dims()[1];
350

351
    VLOG(10) << "M = " << M << " , K = " << K << " , N = " << N;
352 353 354 355 356 357 358

    cudaDataType_t mat_type = CUDA_R_32F;
    cudaDataType_t scale_type = CUDA_R_32F;
    cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
    if (std::is_same<T, paddle::platform::float16>::value) {
      mat_type = CUDA_R_16F;
    }
359 360 361
    if (std::is_same<T, platform::bfloat16>::value) {
      mat_type = CUDA_R_16BF;
    }
362 363 364 365 366 367 368
    if (std::is_same<T, double>::value) {
      mat_type = CUDA_R_64F;
      scale_type = CUDA_R_64F;
      compute_type = CUBLAS_COMPUTE_64F;
    }

    cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
S
sneaxiy 已提交
369 370 371
    // NOTE(zengjinle): I do not know whether the 4MB workspace size is
    // "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
    size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
372
    const cublasLtMatmulAlgo_t* algo = nullptr;
373 374 375 376 377 378 379 380 381 382 383 384 385
    cudaStream_t stream = dev_ctx.stream();

    double alpha64 = 1.0, beta64 = 0.0;
    float alpha32 = 1.0f, beta32 = 0.0f;
    void *alpha = nullptr, *beta = nullptr;
    if (std::is_same<T, double>::value) {
      alpha = &alpha64;
      beta = &beta64;
    } else {
      alpha = &alpha32;
      beta = &beta32;
    }

386 387 388 389 390 391 392 393
    cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
    cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr;
    cublasLtMatrixLayout_t y_desc = nullptr, y_trans_desc = nullptr;
    cublasLtMatrixLayout_t dx_desc = nullptr, dy_desc = nullptr;
    cublasLtMatmulDesc_t dx_operation_desc = nullptr,
                         dy_operation_desc = nullptr;

    DEFINE_PADDLE_SCOPE_GUARD([&] {
394 395 396 397 398 399 400 401
      auto descs = {dout_desc,
                    dout_trans_desc,
                    x_desc,
                    x_trans_desc,
                    y_desc,
                    y_trans_desc,
                    dx_desc,
                    dy_desc};
402 403 404 405 406 407
      for (auto desc : descs) {
        if (desc) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::cublasLtMatrixLayoutDestroy(desc));
        }
      }
408

409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
      if (dx_operation_desc) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
      }

      if (dy_operation_desc) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
      }
    });

    auto x_row = TransX ? K : M;
    auto x_col = TransX ? M : K;
    auto y_row = TransY ? N : K;
    auto y_col = TransY ? K : N;
    auto z_row = TransX ? N : M;
    auto z_col = TransX ? M : N;

    // dx = func(dout, y)
428
    if (dx) {
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
      constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ);
      cublasLtMatrixLayout_t *dx_dout_desc, *dx_y_desc;

      if (TransX) {
        dx_dout_desc = &dout_trans_desc;
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(
                dx_dout_desc, mat_type, z_row, z_col, z_row));
      } else {
        dx_dout_desc = &dout_desc;
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(
                dx_dout_desc, mat_type, z_col, z_row, z_col));
      }

      dx_y_desc = &y_trans_desc;
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          dx_y_desc, mat_type, y_col, y_row, y_col));

      auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc);
      auto& b_desc = kXGradAIsDZ ? (*dx_y_desc) : (*dx_dout_desc);
      auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans);
      auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);

      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &dx_desc, mat_type, x_col, x_row, x_col));

456 457 458 459
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
          &dx_operation_desc, compute_type, scale_type));
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
460 461 462
              dx_operation_desc,
              CUBLASLT_MATMUL_DESC_TRANSB,
              &a_trans,
463
              sizeof(a_trans)));
464 465
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
466 467 468
              dx_operation_desc,
              CUBLASLT_MATMUL_DESC_TRANSA,
              &b_trans,
469 470
              sizeof(b_trans)));

471 472 473 474
      cublasLtEpilogue_t epiloque_func_for_dx =
          get_epilogue_type_(activation_grad);
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
475 476 477 478
              dx_operation_desc,
              CUBLASLT_MATMUL_DESC_EPILOGUE,
              &epiloque_func_for_dx,
              sizeof(epiloque_func_for_dx)));
479 480 481 482 483

      if (activation_grad != "none") {
        auto* aux_data = reserve_space->data<T>();
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
484 485 486 487
                dx_operation_desc,
                CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                &aux_data,
                sizeof(aux_data)));
488
        int64_t aux_ld = TransX ? M : K;
489 490
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
491 492 493 494
                dx_operation_desc,
                CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
                &aux_ld,
                sizeof(aux_ld)));
495 496
      }

L
Leo Chen 已提交
497 498 499 500
      auto dx_workspace = memory::Alloc(
          dev_ctx.GetPlace(),
          workspace_size,
          phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
501

502
      auto* dx_data = dev_ctx.Alloc<T>(dx, dx->numel() * sizeof(T));
503 504
      const auto* y_data = y->data<T>();
      const auto* dout_data = dout->data<T>();
505 506
      const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
      const auto* b_data = kXGradAIsDZ ? y_data : dout_data;
507

508 509 510 511 512 513 514 515 516 517 518 519 520 521
      auto algo =
          GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
                                                        dx_operation_desc,
                                                        b_desc,
                                                        a_desc,
                                                        dx_desc,
                                                        alpha,
                                                        beta,
                                                        b_data,
                                                        a_data,
                                                        dx_data,
                                                        stream,
                                                        dx_workspace->ptr(),
                                                        workspace_size);
522

523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmul(lt_handle,
                                            dx_operation_desc,
                                            alpha,
                                            b_data,
                                            b_desc,
                                            a_data,
                                            a_desc,
                                            beta,
                                            dx_data,
                                            dx_desc,
                                            dx_data,
                                            dx_desc,
                                            algo,
                                            dx_workspace->ptr(),
                                            workspace_size,
                                            stream));
540 541
    }

542
    // dy = func(dout, x)
543
    if (dy) {
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
      constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ);

      cublasLtMatrixLayout_t *dy_dout_desc = nullptr, *dy_x_desc = nullptr;
      if (TransX) {
        dy_dout_desc = &dout_trans_desc;
        if (dout_trans_desc == nullptr) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::cublasLtMatrixLayoutCreate(
                  dy_dout_desc, mat_type, z_row, z_col, z_row));
        }
      } else {
        dy_dout_desc = &dout_desc;
        if (dout_desc == nullptr) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::cublasLtMatrixLayoutCreate(
                  dy_dout_desc, mat_type, z_col, z_row, z_col));
        }
      }

      dy_x_desc = &x_trans_desc;
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          dy_x_desc, mat_type, x_col, x_row, x_col));

      auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc);
      auto& b_desc = kYGradAIsDZ ? (*dy_x_desc) : (*dy_dout_desc);
      auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans);
      auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);

      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &dy_desc, mat_type, y_col, y_row, y_col));

575 576
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
          &dy_operation_desc, compute_type, scale_type));
577

578 579
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
580 581 582
              dy_operation_desc,
              CUBLASLT_MATMUL_DESC_TRANSB,
              &a_trans,
583
              sizeof(a_trans)));
584 585
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
586 587 588
              dy_operation_desc,
              CUBLASLT_MATMUL_DESC_TRANSA,
              &b_trans,
589 590 591 592 593 594 595 596 597 598 599 600 601
              sizeof(b_trans)));

      cublasLtEpilogue_t epiloque_func_for_dy;
      if (dbias == nullptr) {
        epiloque_func_for_dy = CUBLASLT_EPILOGUE_DEFAULT;
      } else {
        if (TransY) {
          epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADB;
        } else {
          epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADA;
        }
      }

602 603
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
604 605 606 607
              dy_operation_desc,
              CUBLASLT_MATMUL_DESC_EPILOGUE,
              &epiloque_func_for_dy,
              sizeof(epiloque_func_for_dy)));
608 609

      if (dbias) {
610
        auto* dbias_data = dev_ctx.Alloc<T>(dbias, dbias->numel() * sizeof(T));
611 612
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
613 614 615 616
                dy_operation_desc,
                CUBLASLT_MATMUL_DESC_BIAS_POINTER,
                &dbias_data,
                sizeof(dbias_data)));
617 618
      }

L
Leo Chen 已提交
619 620 621 622
      auto dy_workspace = memory::Alloc(
          dev_ctx.GetPlace(),
          workspace_size,
          phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
623
      auto* dy_data = dev_ctx.Alloc<T>(dy, dy->numel() * sizeof(T));
624 625
      const auto* dout_data = dout->data<T>();
      const auto* x_data = x->data<T>();
626 627
      const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
      const auto* b_data = kYGradAIsDZ ? x_data : dout_data;
628

629 630 631 632 633 634 635 636 637 638 639 640 641 642
      auto algo =
          GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
                                                        dy_operation_desc,
                                                        b_desc,
                                                        a_desc,
                                                        dy_desc,
                                                        alpha,
                                                        beta,
                                                        b_data,
                                                        a_data,
                                                        dy_data,
                                                        stream,
                                                        dy_workspace->ptr(),
                                                        workspace_size);
643

644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmul(lt_handle,
                                            dy_operation_desc,
                                            alpha,
                                            b_data,
                                            b_desc,
                                            a_data,
                                            a_desc,
                                            beta,
                                            dy_data,
                                            dy_desc,
                                            dy_data,
                                            dy_desc,
                                            algo,
                                            dy_workspace->ptr(),
                                            workspace_size,
                                            stream));
661 662 663 664 665 666 667 668 669 670 671 672 673 674
    }
  }

 private:
  static cublasLtEpilogue_t get_epilogue_type_(
      const std::string& activation_grad) {
    if (activation_grad == "relu_grad") {
      return CUBLASLT_EPILOGUE_DRELU;
    } else if (activation_grad == "gelu_grad") {
      return CUBLASLT_EPILOGUE_DGELU;
    } else if (activation_grad == "none") {
      return CUBLASLT_EPILOGUE_DEFAULT;
    } else {
      PADDLE_ENFORCE_EQ(
675 676
          true,
          false,
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
          platform::errors::InvalidArgument(
              "The activation_grad attribute of fused_gemm_epilogue op should "
              "be"
              " one of {\"none\", \"relu\", \"gelu\"}. But received %s."
              "But received activation_grad=%s.",
              activation_grad));
    }
  }
};

}  // namespace operators
}  // namespace paddle

#if CUDA_VERSION >= 11060
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    fused_gemm_epilogue,
L
Leo Chen 已提交
694 695
    ops::FusedGemmEpilogueKernel<phi::GPUContext, float>,
    ops::FusedGemmEpilogueKernel<phi::GPUContext, double>,
696 697
    ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::float16>,
    ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::bfloat16>);
698 699 700

REGISTER_OP_CUDA_KERNEL(
    fused_gemm_epilogue_grad,
L
Leo Chen 已提交
701 702 703
    ops::FusedGemmEpilogueGradKernel<phi::GPUContext, float>,
    ops::FusedGemmEpilogueGradKernel<phi::GPUContext, double>,
    ops::FusedGemmEpilogueGradKernel<phi::GPUContext,
704 705
                                     paddle::platform::float16>,
    ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::bfloat16>);
706
#endif