blaslt_impl.cu.h 19.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
18

19 20
#include "glog/logging.h"

21 22 23
#include <cuda_runtime_api.h>  // NOLINT
#include "cuda.h"              // NOLINT
#include "paddle/phi/backends/dynload/cublasLt.h"
24 25
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"

26
#include "paddle/phi/common/amp_type_traits.h"
27
#include "paddle/phi/common/memory_utils.h"
28
#include "paddle/phi/kernels/autotune/gpu_timer.h"
29 30
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#endif
31 32 33 34

namespace phi {
namespace funcs {

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
// Set this enum according to
// https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t
enum MatmulFusedType {
  kMatmul = CUBLASLT_EPILOGUE_DEFAULT,  // No special postprocessing.
  kMatmulBias = CUBLASLT_EPILOGUE_BIAS,
  kMatmulRelu = CUBLASLT_EPILOGUE_RELU,
  kMatmulBiasRelu =
      CUBLASLT_EPILOGUE_RELU_BIAS,  // Apply bias and then ReLU transform.
  kMatmulBiasGelu =
      CUBLASLT_EPILOGUE_GELU_BIAS,  // Apply Bias and then GELU transform.
  kMatmulBiasReluWithReservedData = CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
  kMatmulBiasGeluWithReservedData = CUBLASLT_EPILOGUE_GELU_AUX_BIAS
};

struct MatmulPlanner {
 public:
  const void* bias{nullptr};
  void* aux_data{nullptr};

  MatmulPlanner() {}
  MatmulPlanner(const std::vector<int64_t>& x_dims,
                const std::vector<int64_t>& y_dims,
                const bool trans_x,
                const bool trans_y,
                phi::DataType dtype,
                MatmulFusedType impl_type,
                const void* bias_data = nullptr,
                void* reserve_data = nullptr)
      : bias(bias_data), aux_data(reserve_data) {
    type = impl_type;
    key = phi::autotune::GenKey(x_dims,
                                y_dims,
                                static_cast<int64_t>(trans_x),
                                static_cast<int64_t>(trans_y),
                                static_cast<int64_t>(dtype));
  }

  MatmulFusedType ImplType() const { return type; }
  size_t GetKey() const { return key; }
  size_t GenSubKey(int idx) const { return phi::autotune::GenKey(key, idx); }

 private:
  MatmulFusedType type;
  size_t key;
};
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

template <typename T>
cublasComputeType_t GetCudaComputeType() {
  if (std::is_same<T, double>::value) {
    return CUBLAS_COMPUTE_64F;
  } else {
    return CUBLAS_COMPUTE_32F;
  }
}

struct MatmulDescriptor {
 public:
  cublasLtMatmulDesc_t op_desc{nullptr};
  cublasLtMatrixLayout_t x_desc{nullptr};
  cublasLtMatrixLayout_t y_desc{nullptr};
  cublasLtMatrixLayout_t out_desc{nullptr};
97
  cublasLtMatmulAlgo_t* algo{nullptr};
98
  bool is_cached{false};
99 100 101 102 103 104 105 106

  MatmulDescriptor() {}
  MatmulDescriptor(const MatmulDescriptor& obj) {
    algo = obj.algo;
    x_desc = obj.x_desc;
    y_desc = obj.y_desc;
    op_desc = obj.op_desc;
    out_desc = obj.out_desc;
107
    is_cached = obj.is_cached;
108 109
  }

Y
yuehuayingxueluo 已提交
110
  ~MatmulDescriptor() PADDLE_MAY_THROW {
111
    if (!is_cached) {
112 113 114 115
      PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc));
      PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(y_desc));
      PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(x_desc));
      PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(out_desc));
116 117 118 119 120 121 122 123 124
      delete algo;

      op_desc = nullptr;
      x_desc = nullptr;
      y_desc = nullptr;
      out_desc = nullptr;
      algo = nullptr;
    }
  }
125

126
  // x_desc, y_desc, op_desc are allocated in heap memory.
127 128 129 130 131 132
  template <typename T>
  void Create(const int M,
              const int N,
              const int K,
              const bool trans_x,
              const bool trans_y,
133
              phi::funcs::MatmulPlanner* planner,
134 135 136 137 138 139
              const int batch_size = 1,
              int64_t stride_x = 0,
              int64_t stride_y = 0,
              int64_t stride_out = 0) {
    using MT = typename phi::dtype::MPTypeTrait<T>::Type;

140 141
    cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
    cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    cublasComputeType_t compute_type = GetCudaComputeType<T>();

    // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
    // details about defaults; just need to set the transforms for A and B
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type));
    cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
    cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulDescSetAttribute(op_desc,
                                                CUBLASLT_MATMUL_DESC_TRANSB,
                                                &cublas_trans_x,
                                                sizeof(cublas_trans_x)));
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulDescSetAttribute(op_desc,
                                                CUBLASLT_MATMUL_DESC_TRANSA,
                                                &cublas_trans_y,
                                                sizeof(cublas_trans_y)));

    // Create matrix descriptors
    CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x);
    CreateMatrixLayout(&y_desc, mat_type, K, N, trans_y);
    CreateMatrixLayout(&out_desc, mat_type, M, N, false);

    // Config batch size and stride.
    if (batch_size > 1) {
      SetBatchAndStride(x_desc, batch_size, stride_x);
      SetBatchAndStride(y_desc, batch_size, stride_y);
      SetBatchAndStride(out_desc, batch_size, stride_out);
    }
172
    SetFusedEpilogueOpDescriptor(planner, N);
173 174
  }

175
  cublasLtMatmulAlgo_t* SetAlgo() {
176 177
    // while entering this function, the desc shall be cached.
    is_cached = true;
178 179 180 181
    algo = new cublasLtMatmulAlgo_t;
    return algo;
  }

182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
  template <typename T>
  void SetFusedEpiloguePtr(phi::funcs::MatmulPlanner* planner) {
    if (planner->bias != nullptr) {
      const T* bias_data = static_cast<const T*>(planner->bias);
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
          op_desc,
          CUBLASLT_MATMUL_DESC_BIAS_POINTER,
          &bias_data,
          sizeof(bias_data)));

      if (planner->aux_data != nullptr) {
        PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
            op_desc,
            CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
            &(planner->aux_data),
            sizeof(planner->aux_data)));
      }
    }
  }

  std::string GetDescResultString(std::string prefix,
                                  bool has_algo = true) const {
    std::ostringstream out;
    out << prefix << " \n";
#define GET_DESC_DATA_INFO(src)                      \
  do {                                               \
    out << #src << "= [";                            \
    int num = sizeof((*src)) / sizeof(src->data[0]); \
    for (int i = 0; i < num; ++i) {                  \
      out << src->data[i] << ", ";                   \
    }                                                \
    out << "]\n";                                    \
  } while (0);

    if (has_algo) {
      GET_DESC_DATA_INFO(&algo);
    }
    GET_DESC_DATA_INFO(x_desc);
    GET_DESC_DATA_INFO(y_desc);
    GET_DESC_DATA_INFO(out_desc);
    GET_DESC_DATA_INFO(op_desc);
    return out.str();
  }
225 226

 private:
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
  void CreateMatrixLayout(cublasLtMatrixLayout_t* desc,
                          cudaDataType type,
                          uint64_t rows,
                          uint64_t cols,
                          bool trans) {
    if (trans) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          dynload::cublasLtMatrixLayoutCreate(desc, type, rows, cols, rows));
    } else {
      PADDLE_ENFORCE_GPU_SUCCESS(
          dynload::cublasLtMatrixLayoutCreate(desc, type, cols, rows, cols));
    }
  }

  void SetBatchAndStride(cublasLtMatrixLayout_t desc,
                         int batch_size,
                         int64_t stride) {
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute(
        desc,
        CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
        &batch_size,
        sizeof(batch_size)));
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute(
        desc,
        CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
        &stride,
        sizeof(stride)));
  }
255

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
  void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner,
                                    int64_t lead_dim) {
    if (planner->bias) {
      auto fuse_type = static_cast<cublasLtEpilogue_t>(planner->ImplType());
      PADDLE_ENFORCE_GPU_SUCCESS(
          dynload::cublasLtMatmulDescSetAttribute(op_desc,
                                                  CUBLASLT_MATMUL_DESC_EPILOGUE,
                                                  &fuse_type,
                                                  sizeof(fuse_type)));
      if (planner->aux_data) {
        PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
            op_desc,
            CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
            &lead_dim,
            sizeof(lead_dim)));
      }
    }
273
  }
274
};
275

276 277
template <typename T>
struct DescriptorSetter {
278
  MatmulDescriptor desc;
279 280
  size_t sub_key{std::numeric_limits<size_t>::min()};

281
  DescriptorSetter(phi::funcs::MatmulPlanner* planner,
282 283 284 285 286 287 288 289 290
                   const int M,
                   const int N,
                   const int K,
                   const bool trans_x,
                   const bool trans_y,
                   const int batch_size = 1,
                   int64_t stride_x = 0,
                   int64_t stride_y = 0,
                   int64_t stride_out = 0) {
291 292
    if (planner != nullptr) {
      sub_key = planner->GenSubKey(static_cast<size_t>(planner->ImplType()));
293
    }
294

295 296
    auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
    if (mamtul_cache.FindSubKey(sub_key)) {
297 298 299 300
      desc = *(
          reinterpret_cast<MatmulDescriptor*>(mamtul_cache.GetSubKey(sub_key)));
      desc.SetFusedEpiloguePtr<T>(planner);
      VLOG(6) << desc.GetDescResultString("[Heap MatmulDescriptor] ");
301
    } else {
302 303 304 305 306 307 308 309 310 311 312 313 314 315
      desc.Create<T>(M,
                     N,
                     K,
                     trans_x,
                     trans_y,
                     planner,
                     batch_size,
                     stride_x,
                     stride_y,
                     stride_out);
      if (planner != nullptr) {
        desc.SetFusedEpiloguePtr<T>(planner);
      }
      VLOG(6) << desc.GetDescResultString("[Stack MatmulDescriptor] ", false);
316
    }
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
  }
};

template <typename T>
struct MatmulWithCublasLt {
 public:
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

  static void Run(const phi::GPUContext& ctx,
                  const T* x_data,
                  const T* y_data,
                  T* out_data,
                  const int M,
                  const int N,
                  const int K,
                  const bool trans_x,
                  const bool trans_y,
334 335
                  phi::funcs::MatmulPlanner* planner = nullptr) {
    auto setter = DescriptorSetter<T>(planner, M, N, K, trans_x, trans_y);
336
    RunImpl(
337
        ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
338 339
  }

340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
  static void RunWithBatch(const phi::GPUContext& ctx,
                           const T* x_data,
                           const T* y_data,
                           T* out_data,
                           const int M,
                           const int N,
                           const int K,
                           bool trans_x,
                           bool trans_y,
                           int batch_size,
                           int64_t stride_x,
                           int64_t stride_y,
                           int64_t stride_out,
                           phi::funcs::MatmulPlanner* planner = nullptr) {
    auto setter = DescriptorSetter<T>(planner,
355 356 357 358 359 360 361 362 363 364
                                      M,
                                      N,
                                      K,
                                      trans_x,
                                      trans_y,
                                      batch_size,
                                      stride_x,
                                      stride_y,
                                      stride_out);
    RunImpl(
365
        ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
366 367
  }

368 369 370 371 372 373 374 375 376 377 378
  static void RunWithBatch(const phi::GPUContext& ctx,
                           const T** x_data,
                           const T** y_data,
                           T** out_data,
                           const int M,
                           const int N,
                           const int K,
                           bool trans_x,
                           bool trans_y,
                           int batch_size,
                           phi::funcs::MatmulPlanner* planner = nullptr) {
379 380 381 382 383 384 385 386 387 388
    for (int i = 0; i < batch_size; ++i) {
      Run(ctx,
          x_data[i],
          y_data[i],
          out_data[i],
          M,
          N,
          K,
          trans_x,
          trans_y,
389
          planner);
390 391 392 393 394 395
    }
  }

 private:
  static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx,
                                                    size_t workspace_size) {
396
    return phi::memory_utils::Alloc(
397 398 399 400 401 402
        ctx.GetPlace(),
        workspace_size,
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
  }

  static void RunImpl(const phi::GPUContext& ctx,
403
                      MatmulDescriptor* desc,
404
                      const size_t sub_key,
405 406 407
                      const T* x_ptr,
                      const T* y_ptr,
                      T* out_ptr,
408
                      phi::funcs::MatmulPlanner* planner) {
409 410 411 412 413 414 415
    MT alpha = static_cast<MT>(1);
    MT beta = static_cast<MT>(0);

    cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle();
    size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
    phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size);

416
    if (planner != nullptr) {
417
      if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() &&
418
          (!desc->is_cached)) {
419 420
        SearchBestAlgo(ctx,
                       cublaslt_handle,
421
                       desc,
422 423 424 425 426 427
                       static_cast<void*>(&alpha),
                       static_cast<void*>(&beta),
                       y_ptr,
                       x_ptr,
                       out_ptr,
                       workspace->ptr(),
428 429
                       workspace_size);
        MatmulDescriptor* best_desc = new MatmulDescriptor(*desc);
430 431 432 433
        VLOG(6) << best_desc->GetDescResultString(
            "[Searched MatmulDescriptor] ");

        auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
434
        cache.SetSubKey(sub_key, reinterpret_cast<void*>(best_desc));
435 436 437
      }
    }

438
    VLOG(6) << desc->GetDescResultString("[Impl MatmulDescriptor] ");
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmul(cublaslt_handle,
                                desc->op_desc,
                                static_cast<void*>(&alpha),
                                y_ptr,
                                desc->y_desc,
                                x_ptr,
                                desc->x_desc,
                                static_cast<void*>(&beta),
                                out_ptr,
                                desc->out_desc,
                                out_ptr,
                                desc->out_desc,
                                desc->algo,
                                workspace->ptr(),
                                workspace_size,
                                ctx.stream()));
456 457 458 459
  }

  static void SearchBestAlgo(const phi::GPUContext& ctx,
                             const cublasLtHandle_t& lt_handle,
460
                             MatmulDescriptor* desc,
461 462 463 464 465 466
                             const void* alpha,
                             const void* beta,
                             const void* y_data,
                             const void* x_data,
                             void* out_data,
                             void* workspace_ptr,
467 468
                             size_t workspace_size) {
    cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo();
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
    const auto& stream = ctx.stream();
    int returned_results = 0;
    constexpr int requested_algo_count = 10;
    cublasLtMatmulPreference_t preference;
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulPreferenceCreate(&preference));
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute(
        preference,
        CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
        &workspace_size,
        sizeof(workspace_size)));
    std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results(
        requested_algo_count);
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle,
484 485 486 487 488
                                                desc->op_desc,
                                                desc->y_desc,
                                                desc->x_desc,
                                                desc->out_desc,
                                                desc->out_desc,
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
                                                preference,
                                                requested_algo_count,
                                                heuristic_results.data(),
                                                &returned_results));
    PADDLE_ENFORCE_GT(returned_results,
                      0,
                      phi::errors::Unavailable("No GEMM algorithm avaliable."));
    phi::GpuTimer timer;
    int best_algo_idx = -1;
    constexpr int repeats = 6;
    float min_time_cost = std::numeric_limits<float>::max();
    for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) {
      ctx.Wait();
      float cur_time = 0.f;
      for (int i = 0; i < repeats; ++i) {
        timer.Start(stream);
        PADDLE_ENFORCE_GPU_SUCCESS(
            dynload::cublasLtMatmul(lt_handle,
507
                                    desc->op_desc,
508 509
                                    alpha,
                                    y_data,
510
                                    desc->y_desc,
511
                                    x_data,
512
                                    desc->x_desc,
513 514
                                    beta,
                                    out_data,
515
                                    desc->out_desc,
516
                                    out_data,
517
                                    desc->out_desc,
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
                                    &(heuristic_results[algo_idx].algo),
                                    workspace_ptr,
                                    workspace_size,
                                    stream));
        timer.Stop(stream);
        auto time = timer.ElapsedTime();
        if (i > 0) {
          cur_time += time;
        }
      }
      float time_cnt = (cur_time / (repeats - 1));
      VLOG(4) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]"
              << "is : " << time_cnt << " s";

      if (cur_time < min_time_cost) {
        best_algo_idx = algo_idx;
        min_time_cost = cur_time;
      }
    }
    VLOG(4) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx;
    *best_algo = heuristic_results[best_algo_idx].algo;
    PADDLE_ENFORCE_GPU_SUCCESS(
        dynload::cublasLtMatmulPreferenceDestroy(preference));
  }
};
543 544 545 546
#else
// A void structure just for successfully complile.
struct MatmulPlanner {};
#endif  // (PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
547 548 549

}  // namespace funcs
}  // namespace phi