blaslt_impl.cu.h 19.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
18 19 20
#include <cuda_runtime_api.h>  // NOLINT
#include "cuda.h"              // NOLINT
#include "paddle/phi/backends/dynload/cublasLt.h"
21
#include "paddle/phi/common/amp_type_traits.h"
22
#include "paddle/phi/common/memory_utils.h"
23
#include "paddle/phi/kernels/autotune/gpu_timer.h"
24 25
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#endif
26 27 28 29

namespace phi {
namespace funcs {

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
// Set this enum according to
// https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t
enum MatmulFusedType {
  kMatmul = CUBLASLT_EPILOGUE_DEFAULT,  // No special postprocessing.
  kMatmulBias = CUBLASLT_EPILOGUE_BIAS,
  kMatmulRelu = CUBLASLT_EPILOGUE_RELU,
  kMatmulBiasRelu =
      CUBLASLT_EPILOGUE_RELU_BIAS,  // Apply bias and then ReLU transform.
  kMatmulBiasGelu =
      CUBLASLT_EPILOGUE_GELU_BIAS,  // Apply Bias and then GELU transform.
  kMatmulBiasReluWithReservedData = CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
  kMatmulBiasGeluWithReservedData = CUBLASLT_EPILOGUE_GELU_AUX_BIAS
};

struct MatmulPlanner {
 public:
  const void* bias{nullptr};
  void* aux_data{nullptr};

  MatmulPlanner() {}
  MatmulPlanner(const std::vector<int64_t>& x_dims,
                const std::vector<int64_t>& y_dims,
                const bool trans_x,
                const bool trans_y,
                phi::DataType dtype,
                MatmulFusedType impl_type,
                const void* bias_data = nullptr,
                void* reserve_data = nullptr)
      : bias(bias_data), aux_data(reserve_data) {
    type = impl_type;
    key = phi::autotune::GenKey(x_dims,
                                y_dims,
                                static_cast<int64_t>(trans_x),
                                static_cast<int64_t>(trans_y),
                                static_cast<int64_t>(dtype));
  }

  MatmulFusedType ImplType() const { return type; }
  size_t GetKey() const { return key; }
  size_t GenSubKey(int idx) const { return phi::autotune::GenKey(key, idx); }

 private:
  MatmulFusedType type;
  size_t key;
};
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

template <typename T>
cublasComputeType_t GetCudaComputeType() {
  if (std::is_same<T, double>::value) {
    return CUBLAS_COMPUTE_64F;
  } else {
    return CUBLAS_COMPUTE_32F;
  }
}

struct MatmulDescriptor {
 public:
  cublasLtMatmulDesc_t op_desc{nullptr};
  cublasLtMatrixLayout_t x_desc{nullptr};
  cublasLtMatrixLayout_t y_desc{nullptr};
  cublasLtMatrixLayout_t out_desc{nullptr};
92
  cublasLtMatmulAlgo_t* algo{nullptr};
93
  bool is_cached{false};
94 95 96 97 98 99 100 101

  MatmulDescriptor() {}
  MatmulDescriptor(const MatmulDescriptor& obj) {
    algo = obj.algo;
    x_desc = obj.x_desc;
    y_desc = obj.y_desc;
    op_desc = obj.op_desc;
    out_desc = obj.out_desc;
102
    is_cached = obj.is_cached;
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
  }

  ~MatmulDescriptor() {
    if (!is_cached) {
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc));
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(y_desc));
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(x_desc));
      PADDLE_ENFORCE_GPU_SUCCESS(
          dynload::cublasLtMatrixLayoutDestroy(out_desc));
      delete algo;

      op_desc = nullptr;
      x_desc = nullptr;
      y_desc = nullptr;
      out_desc = nullptr;
      algo = nullptr;
    }
  }
121

122
  // x_desc, y_desc, op_desc are allocated in heap memory.
123 124 125 126 127 128
  template <typename T>
  void Create(const int M,
              const int N,
              const int K,
              const bool trans_x,
              const bool trans_y,
129
              phi::funcs::MatmulPlanner* planner,
130 131 132 133 134 135
              const int batch_size = 1,
              int64_t stride_x = 0,
              int64_t stride_y = 0,
              int64_t stride_out = 0) {
    using MT = typename phi::dtype::MPTypeTrait<T>::Type;

136 137
    cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
    cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
    cublasComputeType_t compute_type = GetCudaComputeType<T>();

    // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
    // details about defaults; just need to set the transforms for A and B
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type));
    cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
    cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulDescSetAttribute(op_desc,
                                                CUBLASLT_MATMUL_DESC_TRANSB,
                                                &cublas_trans_x,
                                                sizeof(cublas_trans_x)));
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulDescSetAttribute(op_desc,
                                                CUBLASLT_MATMUL_DESC_TRANSA,
                                                &cublas_trans_y,
                                                sizeof(cublas_trans_y)));

    // Create matrix descriptors
    CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x);
    CreateMatrixLayout(&y_desc, mat_type, K, N, trans_y);
    CreateMatrixLayout(&out_desc, mat_type, M, N, false);

    // Config batch size and stride.
    if (batch_size > 1) {
      SetBatchAndStride(x_desc, batch_size, stride_x);
      SetBatchAndStride(y_desc, batch_size, stride_y);
      SetBatchAndStride(out_desc, batch_size, stride_out);
    }
168
    SetFusedEpilogueOpDescriptor(planner, N);
169 170
  }

171
  cublasLtMatmulAlgo_t* SetAlgo() {
172 173
    // while entering this function, the desc shall be cached.
    is_cached = true;
174 175 176 177
    algo = new cublasLtMatmulAlgo_t;
    return algo;
  }

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
  template <typename T>
  void SetFusedEpiloguePtr(phi::funcs::MatmulPlanner* planner) {
    if (planner->bias != nullptr) {
      const T* bias_data = static_cast<const T*>(planner->bias);
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
          op_desc,
          CUBLASLT_MATMUL_DESC_BIAS_POINTER,
          &bias_data,
          sizeof(bias_data)));

      if (planner->aux_data != nullptr) {
        PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
            op_desc,
            CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
            &(planner->aux_data),
            sizeof(planner->aux_data)));
      }
    }
  }

  std::string GetDescResultString(std::string prefix,
                                  bool has_algo = true) const {
    std::ostringstream out;
    out << prefix << " \n";
#define GET_DESC_DATA_INFO(src)                      \
  do {                                               \
    out << #src << "= [";                            \
    int num = sizeof((*src)) / sizeof(src->data[0]); \
    for (int i = 0; i < num; ++i) {                  \
      out << src->data[i] << ", ";                   \
    }                                                \
    out << "]\n";                                    \
  } while (0);

    if (has_algo) {
      GET_DESC_DATA_INFO(&algo);
    }
    GET_DESC_DATA_INFO(x_desc);
    GET_DESC_DATA_INFO(y_desc);
    GET_DESC_DATA_INFO(out_desc);
    GET_DESC_DATA_INFO(op_desc);
    return out.str();
  }
221 222

 private:
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
  void CreateMatrixLayout(cublasLtMatrixLayout_t* desc,
                          cudaDataType type,
                          uint64_t rows,
                          uint64_t cols,
                          bool trans) {
    if (trans) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          dynload::cublasLtMatrixLayoutCreate(desc, type, rows, cols, rows));
    } else {
      PADDLE_ENFORCE_GPU_SUCCESS(
          dynload::cublasLtMatrixLayoutCreate(desc, type, cols, rows, cols));
    }
  }

  void SetBatchAndStride(cublasLtMatrixLayout_t desc,
                         int batch_size,
                         int64_t stride) {
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute(
        desc,
        CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
        &batch_size,
        sizeof(batch_size)));
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute(
        desc,
        CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
        &stride,
        sizeof(stride)));
  }
251

252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
  void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner,
                                    int64_t lead_dim) {
    if (planner->bias) {
      auto fuse_type = static_cast<cublasLtEpilogue_t>(planner->ImplType());
      PADDLE_ENFORCE_GPU_SUCCESS(
          dynload::cublasLtMatmulDescSetAttribute(op_desc,
                                                  CUBLASLT_MATMUL_DESC_EPILOGUE,
                                                  &fuse_type,
                                                  sizeof(fuse_type)));
      if (planner->aux_data) {
        PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
            op_desc,
            CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
            &lead_dim,
            sizeof(lead_dim)));
      }
    }
269
  }
270
};
271

272 273
template <typename T>
struct DescriptorSetter {
274
  MatmulDescriptor desc;
275 276
  size_t sub_key{std::numeric_limits<size_t>::min()};

277
  DescriptorSetter(phi::funcs::MatmulPlanner* planner,
278 279 280 281 282 283 284 285 286
                   const int M,
                   const int N,
                   const int K,
                   const bool trans_x,
                   const bool trans_y,
                   const int batch_size = 1,
                   int64_t stride_x = 0,
                   int64_t stride_y = 0,
                   int64_t stride_out = 0) {
287 288
    if (planner != nullptr) {
      sub_key = planner->GenSubKey(static_cast<size_t>(planner->ImplType()));
289
    }
290

291 292
    auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
    if (mamtul_cache.FindSubKey(sub_key)) {
293 294 295 296
      desc = *(
          reinterpret_cast<MatmulDescriptor*>(mamtul_cache.GetSubKey(sub_key)));
      desc.SetFusedEpiloguePtr<T>(planner);
      VLOG(6) << desc.GetDescResultString("[Heap MatmulDescriptor] ");
297
    } else {
298 299 300 301 302 303 304 305 306 307 308 309 310 311
      desc.Create<T>(M,
                     N,
                     K,
                     trans_x,
                     trans_y,
                     planner,
                     batch_size,
                     stride_x,
                     stride_y,
                     stride_out);
      if (planner != nullptr) {
        desc.SetFusedEpiloguePtr<T>(planner);
      }
      VLOG(6) << desc.GetDescResultString("[Stack MatmulDescriptor] ", false);
312
    }
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
  }
};

template <typename T>
struct MatmulWithCublasLt {
 public:
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

  static void Run(const phi::GPUContext& ctx,
                  const T* x_data,
                  const T* y_data,
                  T* out_data,
                  const int M,
                  const int N,
                  const int K,
                  const bool trans_x,
                  const bool trans_y,
330 331
                  phi::funcs::MatmulPlanner* planner = nullptr) {
    auto setter = DescriptorSetter<T>(planner, M, N, K, trans_x, trans_y);
332
    RunImpl(
333
        ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
334 335
  }

336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
  static void RunWithBatch(const phi::GPUContext& ctx,
                           const T* x_data,
                           const T* y_data,
                           T* out_data,
                           const int M,
                           const int N,
                           const int K,
                           bool trans_x,
                           bool trans_y,
                           int batch_size,
                           int64_t stride_x,
                           int64_t stride_y,
                           int64_t stride_out,
                           phi::funcs::MatmulPlanner* planner = nullptr) {
    auto setter = DescriptorSetter<T>(planner,
351 352 353 354 355 356 357 358 359 360
                                      M,
                                      N,
                                      K,
                                      trans_x,
                                      trans_y,
                                      batch_size,
                                      stride_x,
                                      stride_y,
                                      stride_out);
    RunImpl(
361
        ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
362 363
  }

364 365 366 367 368 369 370 371 372 373 374
  static void RunWithBatch(const phi::GPUContext& ctx,
                           const T** x_data,
                           const T** y_data,
                           T** out_data,
                           const int M,
                           const int N,
                           const int K,
                           bool trans_x,
                           bool trans_y,
                           int batch_size,
                           phi::funcs::MatmulPlanner* planner = nullptr) {
375 376 377 378 379 380 381 382 383 384
    for (int i = 0; i < batch_size; ++i) {
      Run(ctx,
          x_data[i],
          y_data[i],
          out_data[i],
          M,
          N,
          K,
          trans_x,
          trans_y,
385
          planner);
386 387 388 389 390 391
    }
  }

 private:
  static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx,
                                                    size_t workspace_size) {
392
    return phi::memory_utils::Alloc(
393 394 395 396 397 398
        ctx.GetPlace(),
        workspace_size,
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
  }

  static void RunImpl(const phi::GPUContext& ctx,
399
                      MatmulDescriptor* desc,
400
                      const size_t sub_key,
401 402 403
                      const T* x_ptr,
                      const T* y_ptr,
                      T* out_ptr,
404
                      phi::funcs::MatmulPlanner* planner) {
405 406 407 408 409 410 411
    MT alpha = static_cast<MT>(1);
    MT beta = static_cast<MT>(0);

    cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle();
    size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
    phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size);

412
    if (planner != nullptr) {
413
      if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() &&
414
          (!desc->is_cached)) {
415 416
        SearchBestAlgo(ctx,
                       cublaslt_handle,
417
                       desc,
418 419 420 421 422 423
                       static_cast<void*>(&alpha),
                       static_cast<void*>(&beta),
                       y_ptr,
                       x_ptr,
                       out_ptr,
                       workspace->ptr(),
424 425
                       workspace_size);
        MatmulDescriptor* best_desc = new MatmulDescriptor(*desc);
426 427 428 429
        VLOG(6) << best_desc->GetDescResultString(
            "[Searched MatmulDescriptor] ");

        auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
430
        cache.SetSubKey(sub_key, reinterpret_cast<void*>(best_desc));
431 432 433
      }
    }

434
    VLOG(6) << desc->GetDescResultString("[Impl MatmulDescriptor] ");
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmul(cublaslt_handle,
                                desc->op_desc,
                                static_cast<void*>(&alpha),
                                y_ptr,
                                desc->y_desc,
                                x_ptr,
                                desc->x_desc,
                                static_cast<void*>(&beta),
                                out_ptr,
                                desc->out_desc,
                                out_ptr,
                                desc->out_desc,
                                desc->algo,
                                workspace->ptr(),
                                workspace_size,
                                ctx.stream()));
452 453 454 455
  }

  static void SearchBestAlgo(const phi::GPUContext& ctx,
                             const cublasLtHandle_t& lt_handle,
456
                             MatmulDescriptor* desc,
457 458 459 460 461 462
                             const void* alpha,
                             const void* beta,
                             const void* y_data,
                             const void* x_data,
                             void* out_data,
                             void* workspace_ptr,
463 464
                             size_t workspace_size) {
    cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo();
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
    const auto& stream = ctx.stream();
    int returned_results = 0;
    constexpr int requested_algo_count = 10;
    cublasLtMatmulPreference_t preference;
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulPreferenceCreate(&preference));
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute(
        preference,
        CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
        &workspace_size,
        sizeof(workspace_size)));
    std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results(
        requested_algo_count);
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle,
480 481 482 483 484
                                                desc->op_desc,
                                                desc->y_desc,
                                                desc->x_desc,
                                                desc->out_desc,
                                                desc->out_desc,
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
                                                preference,
                                                requested_algo_count,
                                                heuristic_results.data(),
                                                &returned_results));
    PADDLE_ENFORCE_GT(returned_results,
                      0,
                      phi::errors::Unavailable("No GEMM algorithm avaliable."));
    phi::GpuTimer timer;
    int best_algo_idx = -1;
    constexpr int repeats = 6;
    float min_time_cost = std::numeric_limits<float>::max();
    for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) {
      ctx.Wait();
      float cur_time = 0.f;
      for (int i = 0; i < repeats; ++i) {
        timer.Start(stream);
        PADDLE_ENFORCE_GPU_SUCCESS(
            dynload::cublasLtMatmul(lt_handle,
503
                                    desc->op_desc,
504 505
                                    alpha,
                                    y_data,
506
                                    desc->y_desc,
507
                                    x_data,
508
                                    desc->x_desc,
509 510
                                    beta,
                                    out_data,
511
                                    desc->out_desc,
512
                                    out_data,
513
                                    desc->out_desc,
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
                                    &(heuristic_results[algo_idx].algo),
                                    workspace_ptr,
                                    workspace_size,
                                    stream));
        timer.Stop(stream);
        auto time = timer.ElapsedTime();
        if (i > 0) {
          cur_time += time;
        }
      }
      float time_cnt = (cur_time / (repeats - 1));
      VLOG(4) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]"
              << "is : " << time_cnt << " s";

      if (cur_time < min_time_cost) {
        best_algo_idx = algo_idx;
        min_time_cost = cur_time;
      }
    }
    VLOG(4) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx;
    *best_algo = heuristic_results[best_algo_idx].algo;
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulPreferenceDestroy(preference));
  }
};
539 540 541 542
#else
// A void structure just for successfully complile.
struct MatmulPlanner {};
#endif  // (PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
543 544 545

}  // namespace funcs
}  // namespace phi