gpu_context.cc 31.8 KB
Newer Older
W
Wilber 已提交
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
W
Wilber 已提交
3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
15

16
#include "paddle/phi/backends/gpu/gpu_context.h"
17

W
Wilber 已提交
18
#include <algorithm>
W
Wilber 已提交
19 20 21 22 23 24
#include <array>
#include <functional>
#include <future>
#include <memory>
#include <mutex>

25 26
#include "glog/logging.h"
#include "paddle/phi/api/ext/exception.h"
27 28
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
W
Wilber 已提交
29
#include "paddle/phi/backends/gpu/gpu_resources.h"
30 31 32
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
L
Leo Chen 已提交
33
#include "paddle/phi/core/cuda_stream.h"
W
Wilber 已提交
34 35

#ifdef PADDLE_WITH_CUDA
36 37 38 39
#include "paddle/phi/backends/dynload/cublas.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/dynload/cusparse.h"
W
Wilber 已提交
40
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
41
#include "paddle/phi/backends/dynload/nccl.h"
W
Wilber 已提交
42 43 44 45
#endif  // !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#endif  // PADDLE_WITH_CUDA

#ifdef PADDLE_WITH_HIP
46 47
#include "paddle/phi/backends/dynload/miopen.h"
#include "paddle/phi/backends/dynload/rocblas.h"
W
Wilber 已提交
48
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
49
#include "paddle/phi/backends/dynload/rccl.h"
W
Wilber 已提交
50 51 52 53 54 55 56
#endif  // !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#endif  // PADDLE_WITH_HIP

// NOTE: The paddle framework should add WITH_EIGEN option to support compile
// without eigen.
#include "unsupported/Eigen/CXX11/Tensor"

57
// TODO(phi): remove fluid header.
W
Wilber 已提交
58 59
#include "paddle/fluid/platform/enforce.h"

60
namespace phi {
W
Wilber 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

namespace internal {

class EigenGpuStreamDevice : public Eigen::StreamInterface {
 public:
  EigenGpuStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenGpuStreamDevice() override {}

  void Reinitialize(gpuStream_t cuda_stream,
                    Allocator* allocator,
                    GPUPlace place) {
    stream_ = cuda_stream;
    place_ = place;
    allocator_ = allocator;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const gpuStream_t& stream() const override { return stream_; }

  const gpuDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
    auto buf = allocator_->Allocate(num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << " requested "
            << num_bytes;
    void* retv = buf->ptr();
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
    return retv;
  }

  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(
L
Leo Chen 已提交
121
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), stream()));
W
Wilber 已提交
122 123
#else
      PADDLE_ENFORCE_GPU_SUCCESS(
L
Leo Chen 已提交
124
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), stream()));
W
Wilber 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
#endif
    }
    return semaphore_;
  }

 private:
  GPUPlace place_;
  gpuStream_t stream_;                // not owned;
  Allocator* allocator_;              // not owned;
  const gpuDeviceProp* device_prop_;  // not owned;
  mutable void* scratch_;
  mutable unsigned int* semaphore_;
  mutable std::mutex mtx_;  // to protect allocations_
  mutable std::unordered_map<void*, Allocator::AllocationPtr> allocations_;
};

#ifdef PADDLE_WITH_HIP
static void StreamCallbackFunc(gpuStream_t stream,
                               gpuError_t status,
                               void* user_data)
#endif
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
    static void CUDART_CB StreamCallbackFunc(void* user_data)
#else
    static void CUDART_CB
    StreamCallbackFunc(cudaStream_t stream, cudaError_t status, void* user_data)
#endif
#endif
{
  std::unique_ptr<std::function<void()>> func(
      reinterpret_cast<std::function<void()>*>(user_data));
  (*func)();
}

}  // namespace internal

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
void DnnWorkspaceHandle::RunFuncSync(
    const std::function<void(void*)>& cudnn_func,
    size_t required_workspace_bytes,
    bool use_cached_allocation) {
  bool need_realloc = required_workspace_bytes > WorkspaceSize();
  if (need_realloc && !use_cached_allocation) {
    void* workspace_ptr = nullptr;
    size_t size = ((required_workspace_bytes + 255) >> 8) << 8;
    std::lock_guard<std::mutex> guard(*mtx_);
#ifdef PADDLE_WITH_HIP
    auto status = hipMalloc(&workspace_ptr, size);
#else
    auto status = cudaMalloc(&workspace_ptr, size);
#endif
    if (status == gpuSuccess) {
      cudnn_func(workspace_ptr);
      phi::backends::gpu::GpuStreamSync(stream_);
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(hipFree(workspace_ptr));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(cudaFree(workspace_ptr));
#endif
      return;
    }
  }

  RunFunc(cudnn_func, required_workspace_bytes);
  if (need_realloc) {
    // Release the workspace allocated in this running.
    ResetWorkspace();
  }
}

W
Wilber 已提交
195
void DnnWorkspaceHandle::ResetWorkspace() { allocation_ = nullptr; }
W
Wilber 已提交
196

W
Wilber 已提交
197 198 199 200 201 202
void DnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) return;
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = allocator_->Allocate(required_workspace_bytes);
}
W
Wilber 已提交
203 204 205 206 207

struct GPUContext::Impl {
  void Init() {
    owned_ = true;
    backends::gpu::GPUDeviceGuard guard(place_.device);
W
Wilber 已提交
208 209 210 211 212 213 214 215
    phi::InitGpuProperties(place_,
                           &compute_capability_,
                           &runtime_version_,
                           &driver_version_,
                           &multi_process_,
                           &max_threads_per_mp_,
                           &max_threads_per_block_,
                           &max_grid_dim_size_);
L
Leo Chen 已提交
216 217
    stream_ = new CUDAStream(place_);
    InitEigenDevice();
W
Wilber 已提交
218 219 220 221 222
    InitDnnWorkspace();
  }

  void PartialInitWithoutAllocator() {
    owned_ = true;
L
Leo Chen 已提交
223
    stream_owned_ = true;
W
Wilber 已提交
224
    backends::gpu::GPUDeviceGuard guard(place_.device);
W
Wilber 已提交
225 226 227 228 229 230 231 232
    phi::InitGpuProperties(place_,
                           &compute_capability_,
                           &runtime_version_,
                           &driver_version_,
                           &multi_process_,
                           &max_threads_per_mp_,
                           &max_threads_per_block_,
                           &max_grid_dim_size_);
L
Leo Chen 已提交
233
    stream_ = new CUDAStream(place_);
W
Wilber 已提交
234 235 236 237
  }

  void PartialInitWithAllocator() {
    owned_ = true;
L
Leo Chen 已提交
238
    stream_owned_ = true;
W
Wilber 已提交
239 240 241 242 243 244 245 246
    backends::gpu::GPUDeviceGuard guard(place_.device);
    InitDnnWorkspace();
  }

  explicit Impl(const GPUPlace& place) : place_(place) {}

  ~Impl() {
    backends::gpu::GPUDeviceGuard guard(place_.device);
W
Wilber 已提交
247 248 249 250 251 252
    if (owned_) {
      DestoryInternalWorkspace();
      DestoryInternalEigenDevice();
      phi::DestroySparseHandle(sparse_handle_);
      phi::DestroySolverHandle(solver_handle_);
      phi::DestroyDnnHandle(dnn_handle_);
W
Wilber 已提交
253
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
W
Wilber 已提交
254
      if (nccl_comm_) {
L
LiYuRio 已提交
255 256 257 258 259 260 261
        // NOTE(liyurui): It is not recommend calling CUDA runtime API
        // in destructor. Since we can not ensure the release order of
        // static object, calling ncclCommDestroy in static object destructor
        // is a undefined behavior, CUDA driver may be already unloaded
        // from process.
        // If you really need to release the resource of nccl_comm,
        // try to get the nccl_comm out and use ncclCommDestroy outside.
W
Wilber 已提交
262
      }
W
Wilber 已提交
263
#endif
W
Wilber 已提交
264 265 266 267
      phi::DestroyBlasHandle(blas_handle_);
      phi::DestroyBlasHandle(blas_tensor_core_handle_);
      phi::DestroyBlasHandle(blas_tf32_tensor_core_handle_);
      phi::DestroyBlasLtHandle(blaslt_handle_);
L
Leo Chen 已提交
268 269 270
    }
    if (stream_owned_ && stream_) {
      delete stream_;
W
Wilber 已提交
271
    }
W
Wilber 已提交
272 273 274 275 276 277 278 279 280 281 282
  }

  const Place& GetPlace() const { return place_; }

  bool IsTensorCoreAvailable() const {
    return blas_tensor_core_handle_ != nullptr;
  }

  void InitDnnWorkspace() {
    PD_CHECK(allocator_ != nullptr,
             "the device allocator for gpu context is nullptr.");
L
Leo Chen 已提交
283
    workspace_ = new DnnWorkspaceHandle(allocator_, stream());
W
Wilber 已提交
284 285 286 287 288
  }

  void DestoryInternalWorkspace() {
    if (owned_ && workspace_ != nullptr) {
      delete workspace_;
289
      workspace_ = nullptr;
W
Wilber 已提交
290 291 292
    }
  }

W
Wilber 已提交
293 294 295 296 297 298 299 300
  // TODO(wilber): The return type is a pointer, to be modified later.
  // DnnWorkspaceHandle* GetDnnWorkspace() {
  //   PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr.");
  //   return workspace_;
  // }
  DnnWorkspaceHandle GetDnnWorkspace() {
    PD_CHECK(allocator_ != nullptr,
             "the device allocator for gpu context is nullptr.");
L
Leo Chen 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
    return DnnWorkspaceHandle(allocator_, stream());
  }

  void SetStream(gpuStream_t stream) {
    if (stream_ == nullptr) {
      auto s = Stream(reinterpret_cast<StreamId>(stream));
      stream_ = new CUDAStream(place_, s);
      stream_owned_ = true;
    }
    stream_->set_raw_stream(stream);
  }

  void SetCUDAStream(CUDAStream* stream, bool clear = true) {
    if (clear && stream_owned_ && stream_) {
      delete stream_;
    }
    stream_owned_ = false;
    stream_ = stream;
    // TODO(phi): reset related handles?
W
Wilber 已提交
320 321
  }

L
Leo Chen 已提交
322 323 324 325 326
  gpuStream_t stream() const {
    auto s = stream_->raw_stream();
    PD_CHECK(s != nullptr, "the gpu stream is nullptr.");
    return s;
  }
W
Wilber 已提交
327

L
Leo Chen 已提交
328
  CUDAStream* cuda_stream() const {
W
Wilber 已提交
329 330 331 332 333 334 335 336
    PD_CHECK(stream_ != nullptr, "the gpu stream is nullptr.");
    return stream_;
  }

  void InitEigenDevice() {
    PD_CHECK(allocator_ != nullptr,
             "the allocator for eigen device is nullptr.");
    eigen_stream_.reset(new internal::EigenGpuStreamDevice());
L
Leo Chen 已提交
337
    eigen_stream_->Reinitialize(stream(), allocator_, place_);
W
Wilber 已提交
338 339 340 341 342 343 344 345 346 347 348 349
    eigen_device_ = new Eigen::GpuDevice(eigen_stream_.get());
  }

  void DestoryInternalEigenDevice() {
    if (owned_ && eigen_device_ != nullptr) {
      delete eigen_device_;
      eigen_device_ = nullptr;
    }
  }

  void SetEigenDevice(Eigen::GpuDevice* device) { eigen_device_ = device; }

350 351 352 353 354 355 356 357 358 359 360 361 362
  void SetEigenDevice(std::function<Eigen::GpuDevice*()>&& creator) {
    eigen_device_creator_ = std::move(creator);
  }

  Eigen::GpuDevice* eigen_device() {
    std::call_once(flag_eigen_device_, [&]() {
      if (!eigen_device_) {
        if (!eigen_device_creator_)
          InitEigenDevice();
        else
          eigen_device_ = eigen_device_creator_();
      }
    });
W
Wilber 已提交
363 364 365 366
    PD_CHECK(eigen_device_ != nullptr, "the gpu eigen_device is nullptr.");
    return eigen_device_;
  }

X
xiaoxiaohehe001 已提交
367
  blasHandle_t GetBlasHandle() {
368
    std::call_once(flag_blas_, [&]() {
X
xiaoxiaohehe001 已提交
369
      if (!blas_handle_) {
L
Leo Chen 已提交
370 371 372
        if (!blas_handle_creator_) {
          phi::InitBlasHandle(&blas_handle_, stream());
        } else {
373
          blas_handle_ = blas_handle_creator_();
L
Leo Chen 已提交
374
        }
X
xiaoxiaohehe001 已提交
375 376 377 378
      }
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
      if (!blas_tensor_core_handle_) {
L
Leo Chen 已提交
379 380 381
        if (!blas_tensor_core_handle_creator_) {
          phi::InitBlasHandle(&blas_tensor_core_handle_, stream());
        } else {
382
          blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
L
Leo Chen 已提交
383
        }
X
xiaoxiaohehe001 已提交
384 385 386 387 388 389
        PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
            blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
      }
#endif
#if CUDA_VERSION >= 11000
      if (!blas_tf32_tensor_core_handle_) {
L
Leo Chen 已提交
390 391 392
        if (!blas_tf32_tensor_core_handle_creator_) {
          phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream());
        } else {
393 394
          blas_tf32_tensor_core_handle_ =
              blas_tf32_tensor_core_handle_creator_();
L
Leo Chen 已提交
395
        }
X
xiaoxiaohehe001 已提交
396 397 398 399 400 401
        PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
            blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
      }
#endif
#endif
    });
W
Wilber 已提交
402 403 404 405 406 407
    PD_CHECK(blas_handle_ != nullptr, "the gpu blas handle is nullptr.");
    return blas_handle_;
  }

  void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; }

408 409 410 411
  void SetBlasHandle(std::function<blasHandle_t()>&& handle_creator) {
    blas_handle_creator_ = std::move(handle_creator);
  }

W
Wilber 已提交
412 413
  void SetBlasTensorCoreHandle(blasHandle_t handle) {
    blas_tensor_core_handle_ = handle;
414 415
  }

416 417 418 419
  void SetBlasTensorCoreHandle(std::function<blasHandle_t()>&& handle_creator) {
    blas_tensor_core_handle_creator_ = std::move(handle_creator);
  }

W
Wilber 已提交
420 421
  void SetBlasTF32Handle(blasHandle_t handle) {
    blas_tf32_tensor_core_handle_ = handle;
422 423
  }

424 425 426 427
  void SetBlasTF32Handle(std::function<blasHandle_t()>&& handle_creator) {
    blas_tf32_tensor_core_handle_creator_ = std::move(handle_creator);
  }

428 429
  void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; }

430 431 432 433
  void SetBlasLtHandle(std::function<blasLtHandle_t()>&& handle_creator) {
    blaslt_handle_creator_ = std::move(handle_creator);
  }

X
xiaoxiaohehe001 已提交
434
  blasLtHandle_t GetBlasLtHandle() {
435 436 437 438 439 440 441
    std::call_once(flag_blaslt_, [&]() {
      if (!blaslt_handle_) {
        if (!blaslt_handle_creator_)
          phi::InitBlasLtHandle(&blaslt_handle_);
        else
          blaslt_handle_ = blaslt_handle_creator_();
      }
X
xiaoxiaohehe001 已提交
442
    });
443 444 445 446
    PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr.");
    return blaslt_handle_;
  }

W
Wilber 已提交
447
  dnnHandle_t GetDnnHandle() {
448 449
    std::call_once(flag_dnn_, [&]() {
      if (!dnn_handle_) {
L
Leo Chen 已提交
450 451 452
        if (!dnn_handle_creator_) {
          phi::InitDnnHandle(&dnn_handle_, stream(), place_);
        } else {
453
          dnn_handle_ = dnn_handle_creator_();
L
Leo Chen 已提交
454
        }
455
      }
X
xiaoxiaohehe001 已提交
456
    });
W
Wilber 已提交
457 458 459 460 461 462 463
    PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr.");
    return dnn_handle_;
  }

  void DestroyInternalDnnHandle() {
#ifdef PADDLE_WITH_HIP
    if (owned_ && dnn_handle_ != nullptr) {
464
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenDestroy(dnn_handle_));
W
Wilber 已提交
465 466 467 468
      dnn_handle_ = nullptr;
    }
#else
    if (owned_ && dnn_handle_ != nullptr) {
469
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDestroy(dnn_handle_));
W
Wilber 已提交
470 471 472 473 474 475 476
      dnn_handle_ = nullptr;
    }
#endif  // PADDLE_WITH_HIP
  }

  void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; }

477 478 479 480
  void SetDnnHandle(std::function<dnnHandle_t()>&& handle_creator) {
    dnn_handle_creator_ = std::move(handle_creator);
  }

X
xiaoxiaohehe001 已提交
481
  solverHandle_t GetSolverHandle() {
482 483
    std::call_once(flag_slover_, [&]() {
      if (!solver_handle_) {
L
Leo Chen 已提交
484 485 486
        if (!solver_handle_creator_) {
          phi::InitSolverHandle(&solver_handle_, stream());
        } else {
487
          solver_handle_ = solver_handle_creator_();
L
Leo Chen 已提交
488
        }
489
      }
X
xiaoxiaohehe001 已提交
490
    });
W
Wilber 已提交
491 492 493 494 495 496
    PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr.");
    return solver_handle_;
  }

  void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; }

497 498 499 500
  void SetSolverHandle(std::function<solverHandle_t()>&& handle_creator) {
    solver_handle_creator_ = std::move(handle_creator);
  }

501
  sparseHandle_t GetSparseHandle() {
502 503
    std::call_once(flag_sparse_, [&]() {
      if (!sparse_handle_) {
L
Leo Chen 已提交
504 505 506
        if (!sparse_handle_creator_) {
          phi::InitSparseHandle(&sparse_handle_, stream());
        } else {
507
          sparse_handle_ = sparse_handle_creator_();
L
Leo Chen 已提交
508
        }
509
      }
510
    });
W
Wilber 已提交
511 512 513 514 515 516
    PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr.");
    return sparse_handle_;
  }

  void SetSparseHandle(sparseHandle_t handle) { sparse_handle_ = handle; }

517 518 519 520
  void SetSparseHandle(std::function<sparseHandle_t()>&& handle_creator) {
    sparse_handle_creator_ = std::move(handle_creator);
  }

W
Wilber 已提交
521 522 523 524
  void Wait() const {
#ifdef PADDLE_WITH_HIP
    hipError_t e_sync = hipSuccess;
#if !defined(_WIN32)
L
Leo Chen 已提交
525
    e_sync = hipStreamSynchronize(stream());
W
Wilber 已提交
526
#else
L
Leo Chen 已提交
527
    while (e_sync = hipStreamQuery(stream())) {
W
Wilber 已提交
528 529 530 531 532 533 534
      if (e_sync == hipErrorNotReady) continue;
      break;
    }
#endif  // !defined(_WIN32)
#else   // PADDLE_WITH_HIP
    cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
L
Leo Chen 已提交
535
    e_sync = cudaStreamSynchronize(stream());
W
Wilber 已提交
536
#else
L
Leo Chen 已提交
537
    while (e_sync = cudaStreamQuery(stream())) {
W
Wilber 已提交
538 539 540 541 542 543 544 545 546 547 548
      if (e_sync == cudaErrorNotReady) continue;
      break;
    }
#endif  // !defined(_WIN32)
#endif  // PADDLE_WITH_HIP

    PADDLE_ENFORCE_GPU_SUCCESS(e_sync);
  }

  void WaitEvent(gpuEvent_t ev) const {
#ifdef PADDLE_WITH_HIP
L
Leo Chen 已提交
549
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(stream(), ev, 0));
W
Wilber 已提交
550
#else
L
Leo Chen 已提交
551
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(stream(), ev, 0));
W
Wilber 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
#endif
  }

  ncclComm_t GetNcclComm() const {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
    // PD_CHECK(nccl_comm_ != nullptr, "the gpu nccl_comm is nullptr.");
    return nccl_comm_;
#endif
    return nullptr;
  }

  void SetNcclComm(ncclComm_t comm) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
    nccl_comm_ = comm;
#endif
  }

X
xiaoxiaohehe001 已提交
569
  inline void CublasCall(const std::function<void(blasHandle_t)>& callback) {
570
    std::call_once(flag_cublas_, [&]() {
X
xiaoxiaohehe001 已提交
571
      if (!blas_handle_) {
L
Leo Chen 已提交
572 573 574
        if (!blas_handle_creator_) {
          phi::InitBlasHandle(&blas_handle_, stream());
        } else {
575
          blas_handle_ = blas_handle_creator_();
L
Leo Chen 已提交
576
        }
X
xiaoxiaohehe001 已提交
577 578 579 580
      }
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
      if (!blas_tensor_core_handle_) {
L
Leo Chen 已提交
581 582 583
        if (!blas_tensor_core_handle_creator_) {
          phi::InitBlasHandle(&blas_tensor_core_handle_, stream());
        } else {
584
          blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
L
Leo Chen 已提交
585
        }
X
xiaoxiaohehe001 已提交
586 587 588 589 590 591
        PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
            blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
      }
#endif
#if CUDA_VERSION >= 11000
      if (!blas_tf32_tensor_core_handle_) {
L
Leo Chen 已提交
592 593 594
        if (!blas_tf32_tensor_core_handle_creator_) {
          phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream());
        } else {
595 596
          blas_tf32_tensor_core_handle_ =
              blas_tf32_tensor_core_handle_creator_();
L
Leo Chen 已提交
597
        }
X
xiaoxiaohehe001 已提交
598 599 600 601 602 603
        PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
            blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
      }
#endif
#endif
    });
W
Wilber 已提交
604 605 606 607 608 609 610 611 612 613
    if (blas_tf32_tensor_core_handle_ != nullptr) {
      std::lock_guard<std::mutex> guard(blas_tf32_mtx_);
      callback(blas_tf32_tensor_core_handle_);
    } else {
      std::lock_guard<std::mutex> guard(blas_mtx_);
      callback(blas_handle_);
    }
  }

  inline void TensorCoreCublasCallIfAvailable(
X
xiaoxiaohehe001 已提交
614
      const std::function<void(blasHandle_t)>& callback) {
615 616
    std::call_once(flag_tensorcore_cublas_, [&]() {
      if (!blas_handle_) {
L
Leo Chen 已提交
617 618 619
        if (!blas_handle_creator_) {
          phi::InitBlasHandle(&blas_handle_, stream());
        } else {
620
          blas_handle_ = blas_handle_creator_();
L
Leo Chen 已提交
621
        }
622
      }
X
xiaoxiaohehe001 已提交
623 624 625
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
      if (!blas_tensor_core_handle_) {
L
Leo Chen 已提交
626 627 628
        if (!blas_tensor_core_handle_creator_) {
          phi::InitBlasHandle(&blas_tensor_core_handle_, stream());
        } else {
629
          blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
L
Leo Chen 已提交
630
        }
X
xiaoxiaohehe001 已提交
631 632 633 634 635 636
        PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
            blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
      }
#endif
#if CUDA_VERSION >= 11000
      if (!blas_tf32_tensor_core_handle_) {
L
Leo Chen 已提交
637 638 639
        if (!blas_tf32_tensor_core_handle_creator_) {
          phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream());
        } else {
640 641
          blas_tf32_tensor_core_handle_ =
              blas_tf32_tensor_core_handle_creator_();
L
Leo Chen 已提交
642
        }
X
xiaoxiaohehe001 已提交
643 644 645 646 647 648
        PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
            blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
      }
#endif
#endif
    });
W
Wilber 已提交
649 650 651 652 653 654 655 656 657 658
    if (blas_tensor_core_handle_ != nullptr) {
      std::lock_guard<std::mutex> guard(blas_tensor_core_mtx_);
      callback(blas_tensor_core_handle_);
    } else {
      std::lock_guard<std::mutex> guard(blas_mtx_);
      callback(blas_handle_);
    }
  }

  inline void CusparseCall(
659
      const std::function<void(sparseHandle_t)>& callback) {
660
    std::call_once(flag_sparse_, [&]() {
661
      if (!sparse_handle_) {
L
Leo Chen 已提交
662 663 664
        if (!sparse_handle_creator_) {
          phi::InitSparseHandle(&sparse_handle_, stream());
        } else {
665
          sparse_handle_ = sparse_handle_creator_();
L
Leo Chen 已提交
666
        }
667 668
      }
    });
W
Wilber 已提交
669 670 671 672 673 674 675 676 677 678 679
    std::lock_guard<std::mutex> guard(sparse_mtx_);
    callback(sparse_handle_);
  }

  void RecordEvent(gpuEvent_t ev, const std::function<void()>& callback) const {
    callback();
    RecordEvent(ev);
  }

  void RecordEvent(gpuEvent_t ev) const {
#ifdef PADDLE_WITH_HIP
L
Leo Chen 已提交
680
    PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(ev, stream()));
W
Wilber 已提交
681
#else
L
Leo Chen 已提交
682
    PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(ev, stream()));
W
Wilber 已提交
683 684 685 686
#endif
  }

  void AddStreamCallback(const std::function<void()>& callback) const {
L
Leo Chen 已提交
687
    // NOTE(zhiqiu): better use threadpool here, otherwise "std::async" may
688
    // launch too many threads and result in thread oversubscription.
L
Leo Chen 已提交
689 690
    auto* callback_func = new std::function<void()>(std::move(callback));
    auto* func = new std::function<void()>([this, callback_func] {
W
Wilber 已提交
691
      std::lock_guard<std::mutex> lock(stream_call_back_mtx_);
L
Leo Chen 已提交
692 693 694 695 696
      VLOG(4) << "Stream callback";
      last_future_ = std::async(std::launch::async, [callback_func]() {
        std::unique_ptr<std::function<void()>> releaser(callback_func);
        (*callback_func)();
      });
W
Wilber 已提交
697 698 699 700
    });

#ifdef PADDLE_WITH_HIP
    PADDLE_ENFORCE_GPU_SUCCESS(
L
Leo Chen 已提交
701
        hipStreamAddCallback(stream(), internal::StreamCallbackFunc, func, 0));
W
Wilber 已提交
702 703 704 705
#endif
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
    PADDLE_ENFORCE_GPU_SUCCESS(
L
Leo Chen 已提交
706
        cudaLaunchHostFunc(stream(), internal::StreamCallbackFunc, func));
W
Wilber 已提交
707 708
#else
    PADDLE_ENFORCE_GPU_SUCCESS(
L
Leo Chen 已提交
709
        cudaStreamAddCallback(stream(), internal::StreamCallbackFunc, func, 0));
W
Wilber 已提交
710 711 712 713 714 715
#endif
#endif
  }

  void WaitStreamCallback() const {
#if defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_CUDA)
L
Leo Chen 已提交
716
    phi::backends::gpu::GpuStreamSync(stream());
W
Wilber 已提交
717 718 719 720 721 722 723 724 725
#endif
    {
      std::lock_guard<std::mutex> lock(stream_call_back_mtx_);
      if (last_future_.valid()) {
        last_future_.wait();
      }
    }
  }

726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
  bool HasDnnAttr(const std::string& attr_name) const {
    return dnn_attrs_.count(attr_name) != 0UL;
  }

  const Attribute& GetDnnAttr(const std::string& attr_name) const {
    auto iter = dnn_attrs_.find(attr_name);
    PADDLE_ENFORCE_NE(
        iter,
        dnn_attrs_.end(),
        phi::errors::NotFound("Attribute `%s` is not found in OneDNNContext."));
    return iter->second;
  }

  void SetDnnAttr(const std::string& attr_name, Attribute attr) {
    dnn_attrs_[attr_name] = attr;
  }

L
Leo Chen 已提交
743 744
  // use one flag for all handles?
  // they should be accessed consistently
W
Wilber 已提交
745
  bool owned_{false};
L
Leo Chen 已提交
746
  bool stream_owned_{false};
W
Wilber 已提交
747 748 749 750 751 752 753 754 755
  Place place_;
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
  int max_threads_per_block_;
  std::array<int, 3> max_grid_dim_size_;

L
Leo Chen 已提交
756
  CUDAStream* stream_{nullptr};
W
Wilber 已提交
757
  Eigen::GpuDevice* eigen_device_{nullptr};
758
  std::function<Eigen::GpuDevice*()> eigen_device_creator_{nullptr};
W
Wilber 已提交
759
  blasHandle_t blas_handle_{nullptr};
760
  std::function<blasHandle_t()> blas_handle_creator_{nullptr};
W
Wilber 已提交
761
  blasHandle_t blas_tensor_core_handle_{nullptr};
762
  std::function<blasHandle_t()> blas_tensor_core_handle_creator_{nullptr};
W
Wilber 已提交
763
  blasHandle_t blas_tf32_tensor_core_handle_{nullptr};
764
  std::function<blasHandle_t()> blas_tf32_tensor_core_handle_creator_{nullptr};
765
  blasLtHandle_t blaslt_handle_{nullptr};
766
  std::function<blasLtHandle_t()> blaslt_handle_creator_{nullptr};
W
Wilber 已提交
767
  dnnHandle_t dnn_handle_{nullptr};
768
  std::function<dnnHandle_t()> dnn_handle_creator_{nullptr};
W
Wilber 已提交
769
  solverHandle_t solver_handle_{nullptr};
770
  std::function<solverHandle_t()> solver_handle_creator_{nullptr};
W
Wilber 已提交
771
  sparseHandle_t sparse_handle_{nullptr};
772
  std::function<sparseHandle_t()> sparse_handle_creator_{nullptr};
W
Wilber 已提交
773 774
  DnnWorkspaceHandle* workspace_{nullptr};

775
  std::once_flag flag_sparse_;
X
xiaoxiaohehe001 已提交
776 777 778 779 780 781
  std::once_flag flag_blas_;
  std::once_flag flag_blaslt_;
  std::once_flag flag_dnn_;
  std::once_flag flag_slover_;
  std::once_flag flag_cublas_;
  std::once_flag flag_tensorcore_cublas_;
782
  std::once_flag flag_eigen_device_;
X
xiaoxiaohehe001 已提交
783

W
Wilber 已提交
784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.

  // NOTE: Distributed communicator, distributed framework manages its
  // resources.
  ncclComm_t nccl_comm_{nullptr};
#endif

  mutable std::mutex blas_mtx_;
  mutable std::mutex blas_tensor_core_mtx_;
  mutable std::mutex blas_tf32_mtx_;
  mutable std::mutex sparse_mtx_;
  mutable std::mutex stream_call_back_mtx_;
  mutable std::future<void> last_future_;

  Allocator* allocator_{nullptr};  // external resource.
  // A internal resouce to initinalize eigen_device.
  std::unique_ptr<internal::EigenGpuStreamDevice> eigen_stream_{nullptr};
806 807 808 809 810

  // Holds some attributes only used by the gpudnn kernel calculation
  // Because DeviceContext is a global singleton, you need to ensure thread
  // safety, use the thread_local variable
  static thread_local AttributeMap dnn_attrs_;
W
Wilber 已提交
811 812
};

813 814
thread_local AttributeMap GPUContext::Impl::dnn_attrs_ = {};

W
Wilber 已提交
815 816 817 818
GPUContext::GPUContext(GPUContext&&) = default;

GPUContext& GPUContext::operator=(GPUContext&&) = default;

L
Leo Chen 已提交
819 820 821 822 823 824
GPUContext::GPUContext(const GPUPlace& place, bool init)
    : DeviceContext(), impl_(std::make_unique<Impl>(place)) {
  if (init) {
    impl_->PartialInitWithoutAllocator();
  }
}
W
Wilber 已提交
825 826 827 828 829

GPUContext::~GPUContext() = default;

const Place& GPUContext::GetPlace() const { return impl_->GetPlace(); }

L
Leo Chen 已提交
830 831 832
gpuStream_t GPUContext::stream() const { return impl_->stream(); }

CUDAStream* GPUContext::cuda_stream() const { return impl_->cuda_stream(); }
W
Wilber 已提交
833 834 835 836 837 838 839

dnnHandle_t GPUContext::cudnn_handle() const { return impl_->GetDnnHandle(); }

blasHandle_t GPUContext::cublas_handle() const {
  return impl_->GetBlasHandle();
}

840 841 842 843
blasLtHandle_t GPUContext::cublaslt_handle() const {
  return impl_->GetBlasLtHandle();
}

W
Wilber 已提交
844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881
solverHandle_t GPUContext::cusolver_dn_handle() const {
  return impl_->GetSolverHandle();
}

sparseHandle_t GPUContext::cusparse_handle() const {
  return impl_->GetSparseHandle();
}

void GPUContext::Wait() const { impl_->Wait(); }

void GPUContext::WaitEvent(gpuEvent_t ev) const { impl_->WaitEvent(ev); }

bool GPUContext::tensor_core_available() const {
  return impl_->IsTensorCoreAvailable();
}

int GPUContext::GetComputeCapability() const {
  return impl_->compute_capability_;
}

int GPUContext::GetMaxPhysicalThreadCount() const {
  return impl_->multi_process_ * impl_->max_threads_per_mp_;
}

int GPUContext::GetSMCount() const { return impl_->multi_process_; }

int GPUContext::GetMaxThreadsPerBlock() const {
  return impl_->max_threads_per_block_;
}

std::array<int, 3> GPUContext::GetCUDAMaxGridDimSize() const {
  return impl_->max_grid_dim_size_;
}

Eigen::GpuDevice* GPUContext::eigen_device() const {
  return impl_->eigen_device();
}

W
Wilber 已提交
882
DnnWorkspaceHandle GPUContext::cudnn_workspace_handle() const {
W
Wilber 已提交
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923
  return impl_->GetDnnWorkspace();
}

void GPUContext::CublasCall(
    const std::function<void(blasHandle_t)>& callback) const {
  impl_->CublasCall(callback);
}

void GPUContext::TensorCoreCublasCallIfAvailable(
    const std::function<void(blasHandle_t)>& callback) const {
  impl_->TensorCoreCublasCallIfAvailable(callback);
}

void GPUContext::CusparseCall(
    const std::function<void(sparseHandle_t)>& callback) const {
  impl_->CusparseCall(callback);
}

void GPUContext::RecordEvent(gpuEvent_t ev,
                             const std::function<void()>& callback) const {
  impl_->RecordEvent(ev, callback);
}

void GPUContext::RecordEvent(gpuEvent_t ev) const { impl_->RecordEvent(ev); }

void GPUContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  impl_->AddStreamCallback(callback);
}

void GPUContext::WaitStreamCallback() const { impl_->WaitStreamCallback(); }

ncclComm_t GPUContext::nccl_comm() const { return impl_->GetNcclComm(); }

void GPUContext::set_nccl_comm(ncclComm_t comm) { impl_->SetNcclComm(comm); }

void GPUContext::Init() {
  impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator());
  impl_->Init();
}

W
Wilber 已提交
924 925 926 927
void GPUContext::SetStream(gpuStream_t stream) {
  impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator());
  impl_->SetStream(stream);
}
W
Wilber 已提交
928

L
Leo Chen 已提交
929 930 931 932 933
void GPUContext::SetCUDAStream(CUDAStream* stream, bool clear) {
  impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator());
  impl_->SetCUDAStream(stream, clear);
}

W
Wilber 已提交
934 935 936 937
void GPUContext::SetEigenDevice(Eigen::GpuDevice* device) {
  impl_->SetEigenDevice(device);
}

938 939 940 941
void GPUContext::SetEigenDevice(std::function<Eigen::GpuDevice*()>&& creator) {
  impl_->SetEigenDevice(std::move(creator));
}

W
Wilber 已提交
942 943 944 945
void GPUContext::SetBlasHandle(blasHandle_t blas) {
  impl_->SetBlasHandle(blas);
}

946 947 948 949
void GPUContext::SetBlasHandle(std::function<blasHandle_t()>&& func) {
  impl_->SetBlasHandle(std::move(func));
}

W
Wilber 已提交
950 951 952 953
void GPUContext::SetBlasTensorCoreHandle(blasHandle_t handle) {
  impl_->SetBlasTensorCoreHandle(handle);
}

954 955 956 957
void GPUContext::SetBlasTensorCoreHandle(std::function<blasHandle_t()>&& func) {
  impl_->SetBlasTensorCoreHandle(std::move(func));
}

W
Wilber 已提交
958 959 960 961
void GPUContext::SetBlasTF32Handle(blasHandle_t handle) {
  impl_->SetBlasTF32Handle(handle);
}

962 963 964 965
void GPUContext::SetBlasTF32Handle(std::function<blasHandle_t()>&& func) {
  impl_->SetBlasTF32Handle(std::move(func));
}

966 967 968 969
void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) {
  impl_->SetBlasLtHandle(blaslt);
}

970 971 972 973
void GPUContext::SetBlasLtHandle(std::function<blasLtHandle_t()>&& func) {
  impl_->SetBlasLtHandle(std::move(func));
}

W
Wilber 已提交
974 975 976 977
void GPUContext::SetDnnHandle(dnnHandle_t handle) {
  impl_->SetDnnHandle(handle);
}

978 979 980 981
void GPUContext::SetDnnHandle(std::function<dnnHandle_t()>&& func) {
  impl_->SetDnnHandle(std::move(func));
}

W
Wilber 已提交
982 983 984 985
void GPUContext::SetSolverHandle(solverHandle_t handle) {
  impl_->SetSolverHandle(handle);
}

986 987 988 989
void GPUContext::SetSolverHandle(std::function<solverHandle_t()>&& func) {
  impl_->SetSolverHandle(std::move(func));
}

W
Wilber 已提交
990 991 992 993
void GPUContext::SetSparseHandle(sparseHandle_t handle) {
  impl_->SetSparseHandle(handle);
}

994 995 996 997
void GPUContext::SetSparseHandle(std::function<sparseHandle_t()>&& func) {
  impl_->SetSparseHandle(std::move(func));
}

W
Wilber 已提交
998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
void GPUContext::SetDnnWorkspaceHandle(DnnWorkspaceHandle* handle) {
  impl_->workspace_ = handle;
}

void GPUContext::PartialInitWithoutAllocator() {
  impl_->PartialInitWithoutAllocator();
}

void GPUContext::PartialInitWithAllocator() {
  impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator());
  impl_->PartialInitWithAllocator();
}

void GPUContext::SetComputeCapability(int val) {
  impl_->compute_capability_ = val;
}

void GPUContext::SetMaxThreadsPerMultiProcessor(int val) {
  impl_->max_threads_per_mp_ = val;
}

void GPUContext::SetMultiProcessors(int val) { impl_->multi_process_ = val; }

void GPUContext::SetMaxThreadsPerBlock(int val) {
  impl_->max_threads_per_block_ = val;
}

void GPUContext::SetMaxGridDimSize(const std::array<int, 3>& val) {
  impl_->max_grid_dim_size_ = val;
}

void GPUContext::SetDriverVersion(int val) { impl_->driver_version_ = val; }

void GPUContext::SetRuntimeVersion(int val) { impl_->runtime_version_ = val; }

1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044
bool GPUContext::HasDnnAttr(const std::string& attr_name) const {
  return impl_->HasDnnAttr(attr_name);
}

const Attribute& GPUContext::GetDnnAttr(const std::string& attr_name) const {
  return impl_->GetDnnAttr(attr_name);
}

void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
  return impl_->SetDnnAttr(attr_name, std::move(attr));
}

1045
}  // namespace phi