未验证 提交 1280f294 编写于 作者: W Wilber 提交者: GitHub

add gpu resources. (#42723)

上级 757b5d31
......@@ -50,10 +50,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
endif()
if (WITH_ONNXRUNTIME)
cc_library(analysis_predictor SRCS analysis_predictor.cc onnxruntime_predictor.cc ${mkldnn_quantizer_src} DEPS ${inference_deps}
cc_library(analysis_predictor SRCS analysis_predictor.cc onnxruntime_predictor.cc resource_manager.cc infer_context.cc ${mkldnn_quantizer_src} DEPS ${inference_deps}
zero_copy_tensor ir_pass_manager op_compatible_info infer_io_utils onnxruntime paddle2onnx)
else (WITH_ONNXRUNTIME)
cc_library(analysis_predictor SRCS analysis_predictor.cc ${mkldnn_quantizer_src} DEPS ${inference_deps}
cc_library(analysis_predictor SRCS analysis_predictor.cc resource_manager.cc infer_context.cc ${mkldnn_quantizer_src} DEPS ${inference_deps}
zero_copy_tensor ir_pass_manager op_compatible_info infer_io_utils)
endif (WITH_ONNXRUNTIME)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/infer_context.h"
namespace paddle {} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/all_context.h"
namespace paddle {
class InferCPUContext : public phi::CPUContext {
public:
using phi::CPUContext::SetEigenDevice;
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
class InferGPUContext : public phi::GPUContext {
public:
using phi::GPUContext::SetStream;
using phi::GPUContext::SetEigenDevice;
using phi::GPUContext::SetBlasHandle;
using phi::GPUContext::SetBlasTensorCoreHandle;
using phi::GPUContext::SetBlasTF32Handle;
using phi::GPUContext::SetDnnHandle;
using phi::GPUContext::SetSolverHandle;
using phi::GPUContext::SetSparseHandle;
// using phi::GPUContext::SetDnnWorkspaceHandle;
using phi::GPUContext::SetComputeCapability;
using phi::GPUContext::SetMaxThreadsPerMultiProcessor;
using phi::GPUContext::SetMultiProcessors;
using phi::GPUContext::SetMaxThreadsPerBlock;
using phi::GPUContext::SetMaxGridDimSize;
using phi::GPUContext::SetDriverVersion;
using phi::GPUContext::SetRuntimeVersion;
};
#endif
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/resource_manager.h"
#include <unordered_map>
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/generator.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace internal {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
class EigenGpuStreamDevice : public Eigen::StreamInterface {
public:
EigenGpuStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
Eigen::initializeDeviceProp();
}
~EigenGpuStreamDevice() override {}
void Reinitialize(gpuStream_t cuda_stream, phi::Allocator* allocator,
GPUPlace place) {
stream_ = cuda_stream;
allocator_ = allocator;
device_prop_ = &Eigen::m_deviceProperties[place.device];
}
const gpuStream_t& stream() const override { return stream_; }
const gpuDeviceProp& deviceProperties() const override {
return *device_prop_;
}
void* allocate(size_t num_bytes) const override {
if (UNLIKELY(num_bytes == 0)) {
return nullptr;
}
auto buf = allocator_->Allocate(num_bytes);
VLOG(4) << "Eigen allocated at " << buf->ptr() << " requested "
<< num_bytes;
void* retv = buf->ptr();
{
std::lock_guard<std::mutex> lock(mtx_);
allocations_.emplace(retv, std::move(buf));
}
return retv;
}
void deallocate(void* buffer) const override {
if (LIKELY(buffer)) {
std::lock_guard<std::mutex> lock(mtx_);
allocations_.erase(buffer);
}
}
void* scratchpad() const override {
if (scratch_ == NULL) {
scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
}
return scratch_;
}
unsigned int* semaphore() const override {
if (semaphore_ == NULL) {
char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
semaphore_ = reinterpret_cast<unsigned int*>(scratch);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), stream_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), stream_));
#endif
}
return semaphore_;
}
private:
gpuStream_t stream_; // not owned;
phi::Allocator* allocator_; // not owned;
const gpuDeviceProp* device_prop_; // not owned;
mutable void* scratch_;
mutable unsigned int* semaphore_;
mutable std::mutex mtx_; // to protect allocations_
mutable std::unordered_map<void*, phi::Allocator::AllocationPtr> allocations_;
};
#endif
} // namespace internal
ResourceManager::ResourceManager(const phi::Place& place, void* stream)
: place_(place) {
InitCPUResource();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
InitGPUResource(stream);
#endif
}
ResourceManager::~ResourceManager() {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
DestroyGPUResource();
#endif
}
void ResourceManager::InitCPUResource() {
cpu_eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* ResourceManager::GetCpuEigenDevice() {
return cpu_eigen_device_.get();
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void ResourceManager::InitGPUResource(void* stream) {
if (stream == nullptr) {
owned_stream_ = true;
phi::InitStream(&stream_);
} else {
owned_stream_ = false;
stream_ = reinterpret_cast<gpuStream_t>(stream);
}
InitGpuProperties();
InitGpuEigenDevice();
InitDnnHanlde();
InitBlasHandle();
InitBlasLtHandle();
InitSolverHandle();
InitSparseHandle();
}
void ResourceManager::DestroyGPUResource() {
if (owned_stream_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
#endif
stream_ = nullptr;
}
DestroyDnnHandle();
DestroyBlasHandle();
DestroyBlasLtHandle();
DestroySolverHandle();
DestroySparseHandle();
}
void ResourceManager::InitGpuProperties() {
phi::backends::gpu::GPUDeviceGuard guard(place_.device);
phi::InitGpuProperties(place_, &compute_capability_, &runtime_version_,
&driver_version_, &multi_process_,
&max_threads_per_mp_, &max_threads_per_block_,
&max_grid_dim_size_);
}
void ResourceManager::InitGpuEigenDevice() {
auto* allocator = paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place_)
.get();
eigen_stream_.reset(new internal::EigenGpuStreamDevice());
eigen_stream_->Reinitialize(stream_, allocator, place_);
gpu_eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
void ResourceManager::InitDnnHanlde() {
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
}
void ResourceManager::DestroyDnnHandle() { phi::DestroyDnnHandle(dnn_handle_); }
void ResourceManager::InitBlasHandle() {
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
}
void ResourceManager::DestroyBlasHandle() {
phi::DestroyBlasHandle(blas_handle_);
phi::DestroyBlasHandle(blas_tensor_core_handle_);
phi::DestroyBlasHandle(blas_tf32_tensor_core_handle_);
}
void ResourceManager::InitBlasLtHandle() {
phi::InitBlasLtHandle(&blaslt_handle_);
}
void ResourceManager::DestroyBlasLtHandle() {
phi::DestroyBlasLtHandle(blaslt_handle_);
}
void ResourceManager::InitSolverHandle() {
phi::InitSolverHandle(&solver_handle_, stream_);
}
void ResourceManager::DestroySolverHandle() {
phi::DestroySolverHandle(solver_handle_);
}
void ResourceManager::InitSparseHandle() {
phi::InitSparseHandle(&sparse_handle_, stream_);
}
void ResourceManager::DestroySparseHandle() {
phi::DestroySparseHandle(sparse_handle_);
}
gpuStream_t ResourceManager::GetStream() const { return stream_; }
dnnHandle_t ResourceManager::GetDnnHandle() const { return dnn_handle_; }
blasHandle_t ResourceManager::GetBlasHandle() const { return blas_handle_; }
blasHandle_t ResourceManager::GetBlasTensorCoreHandle() const {
return blas_tensor_core_handle_;
}
blasHandle_t ResourceManager::GetBlasTF32Handle() const {
return blas_tf32_tensor_core_handle_;
}
blasLtHandle_t ResourceManager::GetBlasLtHandle() const {
return blaslt_handle_;
}
phi::solverHandle_t ResourceManager::GetSolverDnHandle() const {
return solver_handle_;
}
phi::sparseHandle_t ResourceManager::GetSparseHandle() const {
return sparse_handle_;
}
Eigen::GpuDevice* ResourceManager::GetGpuEigenDevice() const {
return gpu_eigen_device_.get();
}
int ResourceManager::GetGpuComputeCapability() const {
return compute_capability_;
}
int ResourceManager::GetGpuRuntimeVersion() const { return runtime_version_; }
int ResourceManager::GetGpuDriverVersion() const { return driver_version_; }
int ResourceManager::GetGPUMultiProcessors() const { return multi_process_; }
int ResourceManager::GetGpuMaxThreadsPerMp() const {
return max_threads_per_mp_;
}
int ResourceManager::GetGpuMaxThreadsPerBlock() const {
return max_threads_per_block_;
}
std::array<int, 3> ResourceManager::GetGpuMaxGridDimSize() const {
return max_grid_dim_size_;
}
#endif
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/cpu/forwards.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_resources.h"
#endif
namespace paddle {
namespace internal {
class EigenGpuStreamDevice;
} // namespace internal
class ResourceManager {
public:
explicit ResourceManager(const phi::Place& place, void* stream);
~ResourceManager();
public:
Eigen::DefaultDevice* GetCpuEigenDevice();
private:
void InitCPUResource();
private:
phi::Place place_;
std::unique_ptr<Eigen::DefaultDevice> cpu_eigen_device_;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
public:
gpuStream_t GetStream() const;
dnnHandle_t GetDnnHandle() const;
blasHandle_t GetBlasHandle() const;
blasHandle_t GetBlasTensorCoreHandle() const;
blasHandle_t GetBlasTF32Handle() const;
blasLtHandle_t GetBlasLtHandle() const;
phi::solverHandle_t GetSolverDnHandle() const;
phi::sparseHandle_t GetSparseHandle() const;
Eigen::GpuDevice* GetGpuEigenDevice() const;
int GetGpuComputeCapability() const;
int GetGpuRuntimeVersion() const;
int GetGpuDriverVersion() const;
int GetGPUMultiProcessors() const;
int GetGpuMaxThreadsPerMp() const;
int GetGpuMaxThreadsPerBlock() const;
std::array<int, 3> GetGpuMaxGridDimSize() const;
private:
void InitGPUResource(void* stream);
void DestroyGPUResource();
void InitGpuProperties();
void InitGpuEigenDevice();
void InitDnnHanlde();
void DestroyDnnHandle();
void InitBlasHandle();
void DestroyBlasHandle();
void InitBlasLtHandle();
void DestroyBlasLtHandle();
void InitSolverHandle();
void DestroySolverHandle();
void InitSparseHandle();
void DestroySparseHandle();
private:
int compute_capability_;
int runtime_version_;
int driver_version_;
int multi_process_;
int max_threads_per_mp_;
int max_threads_per_block_;
std::array<int, 3> max_grid_dim_size_;
bool owned_stream_{true};
gpuStream_t stream_;
std::unique_ptr<Eigen::GpuDevice> gpu_eigen_device_;
std::unique_ptr<internal::EigenGpuStreamDevice> eigen_stream_;
blasHandle_t blas_handle_{nullptr};
blasHandle_t blas_tensor_core_handle_{nullptr};
blasHandle_t blas_tf32_tensor_core_handle_{nullptr};
blasLtHandle_t blaslt_handle_{nullptr};
dnnHandle_t dnn_handle_{nullptr};
phi::solverHandle_t solver_handle_{nullptr};
phi::sparseHandle_t sparse_handle_{nullptr};
// DnnWorkspaceHandle
#endif
};
} // namespace paddle
......@@ -6,4 +6,5 @@ elseif(WITH_ROCM)
hip_library(phi_gpu_info SRCS gpu_info.cc DEPS phi_rocm_info gflags glog enforce phi_dynload_cuda)
endif()
cc_library(gpu_context SRCS gpu_context.cc DEPS phi_device_context phi_gpu_info eigen3)
cc_library(gpu_resources SRCS gpu_resources.cc DEPS phi_device_context phi_gpu_info)
cc_library(gpu_context SRCS gpu_context.cc DEPS phi_device_context phi_gpu_info eigen3 gpu_resources)
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
......@@ -202,27 +203,65 @@ struct GPUContext::Impl {
void Init() {
owned_ = true;
backends::gpu::GPUDeviceGuard guard(place_.device);
InitGpuProperties();
InitStream();
phi::InitGpuProperties(place_,
&compute_capability_,
&runtime_version_,
&driver_version_,
&multi_process_,
&max_threads_per_mp_,
&max_threads_per_block_,
&max_grid_dim_size_);
phi::InitStream(&stream_);
InitEigenDevice();
InitBlasHandle();
InitBlasLtHandle();
InitDNNHandle();
InitSolverHandle();
InitSparseHandle();
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
phi::InitBlasLtHandle(&blaslt_handle_);
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
phi::InitSolverHandle(&solver_handle_, stream_);
phi::InitSparseHandle(&sparse_handle_, stream_);
InitDnnWorkspace();
}
void PartialInitWithoutAllocator() {
owned_ = true;
backends::gpu::GPUDeviceGuard guard(place_.device);
InitGpuProperties();
InitStream();
InitBlasHandle();
InitBlasLtHandle();
InitDNNHandle();
InitSolverHandle();
InitSparseHandle();
phi::InitGpuProperties(place_,
&compute_capability_,
&runtime_version_,
&driver_version_,
&multi_process_,
&max_threads_per_mp_,
&max_threads_per_block_,
&max_grid_dim_size_);
phi::InitStream(&stream_);
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
phi::InitBlasLtHandle(&blaslt_handle_);
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
phi::InitSolverHandle(&solver_handle_, stream_);
phi::InitSparseHandle(&sparse_handle_, stream_);
}
void PartialInitWithAllocator() {
......@@ -238,19 +277,23 @@ struct GPUContext::Impl {
~Impl() {
backends::gpu::GPUDeviceGuard guard(place_.device);
DestoryInternalWorkspace();
DestoryInternalEigenDevice();
DestroyInternalSparseHandle();
DestroyInternalSolverHandle();
DestroyInternalDnnHandle();
if (owned_) {
DestoryInternalWorkspace();
DestoryInternalEigenDevice();
phi::DestroySparseHandle(sparse_handle_);
phi::DestroySolverHandle(solver_handle_);
phi::DestroyDnnHandle(dnn_handle_);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nccl_comm_) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
}
if (nccl_comm_) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
}
#endif
DestroyInternalBlasHandle();
DestroyInternalBlasLtHandle();
DestoryInternalStream();
phi::DestroyBlasHandle(blas_handle_);
phi::DestroyBlasHandle(blas_tensor_core_handle_);
phi::DestroyBlasHandle(blas_tf32_tensor_core_handle_);
phi::DestroyBlasLtHandle(blaslt_handle_);
phi::DestoryStream(stream_);
}
}
const Place& GetPlace() const { return place_; }
......@@ -259,73 +302,6 @@ struct GPUContext::Impl {
return blas_tensor_core_handle_ != nullptr;
}
void InitGpuProperties() {
backends::gpu::GPUDeviceGuard guard(place_.GetDeviceId());
compute_capability_ =
backends::gpu::GetGPUComputeCapability(place_.GetDeviceId());
multi_process_ = backends::gpu::GetGPUMultiProcessors(place_.GetDeviceId());
max_threads_per_mp_ =
backends::gpu::GetGPUMaxThreadsPerMultiProcessor(place_.GetDeviceId());
max_grid_dim_size_ =
backends::gpu::GetGpuMaxGridDimSize(place_.GetDeviceId());
max_threads_per_block_ =
backends::gpu::GetGPUMaxThreadsPerBlock(place_.GetDeviceId());
driver_version_ = backends::gpu::GetGPUDriverVersion(place_.GetDeviceId());
runtime_version_ =
backends::gpu::GetGPURuntimeVersion(place_.GetDeviceId());
// TODO(wilber): glog may be replaced in the future?
LOG_FIRST_N(WARNING, 1)
<< "Please NOTE: device: " << static_cast<int>(place_.device)
<< ", GPU Compute Capability: " << compute_capability_ / 10 << "."
<< compute_capability_ % 10
<< ", Driver API Version: " << driver_version_ / 1000 << "."
<< (driver_version_ % 100) / 10
<< ", Runtime API Version: " << runtime_version_ / 1000 << "."
<< (runtime_version_ % 100) / 10;
#ifdef PADDLE_WITH_HIP
size_t miopen_major, miopen_minor, miopen_patch;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch));
auto cudnn_dso_ver =
(miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
auto compile_miopen_version = MIOPEN_VERSION / 10;
if (cudnn_dso_ver < static_cast<size_t>(compile_miopen_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << static_cast<int>(place_.device)
<< ". The installed Paddle is compiled with MIOPEN "
<< compile_miopen_version / 100 << "." << compile_miopen_version % 100
<< ", but MIOPEN version in your machine is " << cudnn_dso_ver / 100
<< "." << cudnn_dso_ver % 100
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible MIOPEN "
"version.";
}
#else
size_t cudnn_dso_ver = dynload::cudnnGetVersion();
LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
<< ", cuDNN Version: " << cudnn_dso_ver / 1000
<< "." << (cudnn_dso_ver % 1000) / 100 << ".";
// Check CUDA/CUDNN version compatiblity
auto local_cuda_version =
(driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
auto compile_cuda_version =
(CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
if (local_cuda_version < compile_cuda_version) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << static_cast<int>(place_.device)
<< ". The installed Paddle is compiled with CUDA "
<< compile_cuda_version / 10 << "." << compile_cuda_version % 10
<< ", but CUDA runtime version in your machine is "
<< local_cuda_version / 10 << "." << local_cuda_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDA "
"version.";
}
#endif
}
void InitDnnWorkspace() {
PD_CHECK(allocator_ != nullptr,
"the device allocator for gpu context is nullptr.");
......@@ -350,27 +326,6 @@ struct GPUContext::Impl {
return DnnWorkspaceHandle(allocator_, stream_);
}
void InitStream() {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamCreateWithPriority(&stream_, hipStreamDefault, 0));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamCreateWithPriority(&stream_, cudaStreamDefault, 0));
#endif
}
void DestoryInternalStream() {
if (owned_ && stream_ != nullptr) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
#endif
}
stream_ = nullptr;
}
void SetStream(gpuStream_t stream) { stream_ = stream; }
gpuStream_t GetStream() const {
......@@ -400,55 +355,6 @@ struct GPUContext::Impl {
return eigen_device_;
}
void InitBlasHandle() {
#ifdef PADDLE_WITH_HIP
phi::dynload::rocblas_create_handle(&blas_handle_);
phi::dynload::rocblas_set_stream(blas_handle_, stream_);
#else // PADDLE_WITH_CUDA
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasCreate(&blas_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_handle_, stream_));
#if CUDA_VERSION >= 9000
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasCreate(&blas_tensor_core_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream_));
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#if CUDA_VERSION >= 11000
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasCreate(&blas_tf32_tensor_core_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream_));
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif // CUDA_VERSION >= 11000
#endif // CUDA_VERSION >= 9000
#endif // PADDLE_WITH_HIP
}
void DestroyInternalBlasHandle() {
#ifdef PADDLE_WITH_HIP
if (owned_ && blas_handle_ != nullptr) {
phi::dynload::rocblas_destroy_handle(blas_handle_);
blas_handle_ = nullptr;
}
#else
if (owned_ && blas_handle_ != nullptr) {
phi::dynload::cublasDestroy(blas_handle_);
blas_handle_ = nullptr;
}
if (owned_ && blas_tensor_core_handle_ != nullptr) {
phi::dynload::cublasDestroy(blas_tensor_core_handle_);
blas_tensor_core_handle_ = nullptr;
}
if (owned_ && blas_tf32_tensor_core_handle_ != nullptr) {
phi::dynload::cublasDestroy(blas_tf32_tensor_core_handle_);
blas_tf32_tensor_core_handle_ = nullptr;
}
#endif // PADDLE_WITH_HIP
}
blasHandle_t GetBlasHandle() const {
PD_CHECK(blas_handle_ != nullptr, "the gpu blas handle is nullptr.");
return blas_handle_;
......@@ -456,16 +362,12 @@ struct GPUContext::Impl {
void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; }
void InitBlasLtHandle() {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi::dynload::cublasLtCreate(&blaslt_handle_);
#endif
void SetBlasTensorCoreHandle(blasHandle_t handle) {
blas_tensor_core_handle_ = handle;
}
void DestroyInternalBlasLtHandle() {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi::dynload::cublasLtDestroy(blaslt_handle_);
#endif
void SetBlasTF32Handle(blasHandle_t handle) {
blas_tf32_tensor_core_handle_ = handle;
}
void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; }
......@@ -475,53 +377,6 @@ struct GPUContext::Impl {
return blaslt_handle_;
}
void InitDNNHandle() {
if (phi::dynload::HasCUDNN()) {
#ifdef PADDLE_WITH_HIP
size_t miopen_major, miopen_minor, miopen_patch;
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
&miopen_major, &miopen_minor, &miopen_patch));
auto local_miopen_version =
(miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
auto compile_miopen_version = MIOPEN_VERSION / 10;
if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with MIOPEN "
<< compile_miopen_version / 100 << "."
<< compile_miopen_version % 100
<< ", but MIOPEN version in your machine is "
<< local_miopen_version / 100 << "." << local_miopen_version % 100
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible MIOPEN "
"version.";
}
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(&dnn_handle_));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenSetStream(dnn_handle_, stream_));
#else
auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnCreate(&dnn_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cudnnSetStream(dnn_handle_, stream_));
#endif
} else {
dnn_handle_ = nullptr;
}
}
dnnHandle_t GetDnnHandle() {
PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr.");
return dnn_handle_;
......@@ -543,24 +398,6 @@ struct GPUContext::Impl {
void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; }
void InitSolverHandle() {
#ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnCreate(&solver_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cusolverDnSetStream(solver_handle_, stream_));
#endif
}
void DestroyInternalSolverHandle() {
#ifndef PADDLE_WITH_HIP
if (owned_ && solver_handle_ != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cusolverDnDestroy(solver_handle_));
solver_handle_ = nullptr;
}
#endif
}
solverHandle_t GetSolverHandle() const {
PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr.");
return solver_handle_;
......@@ -568,29 +405,6 @@ struct GPUContext::Impl {
void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; }
void InitSparseHandle() {
// ROCM is not yet supported
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 10010
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(&sparse_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
dynload::cusparseSetStream(sparse_handle_, stream_));
#endif
#endif
}
void DestroyInternalSparseHandle() {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if (owned_ && sparse_handle_ != nullptr) {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(sparse_handle_));
sparse_handle_ = nullptr;
}
#endif
#endif
}
sparseHandle_t GetSparseHandle() const {
PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr.");
return sparse_handle_;
......@@ -878,7 +692,10 @@ void GPUContext::Init() {
impl_->Init();
}
void GPUContext::SetStream(gpuStream_t stream) { impl_->SetStream(stream); }
void GPUContext::SetStream(gpuStream_t stream) {
impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator());
impl_->SetStream(stream);
}
void GPUContext::SetEigenDevice(Eigen::GpuDevice* device) {
impl_->SetEigenDevice(device);
......@@ -888,6 +705,14 @@ void GPUContext::SetBlasHandle(blasHandle_t blas) {
impl_->SetBlasHandle(blas);
}
void GPUContext::SetBlasTensorCoreHandle(blasHandle_t handle) {
impl_->SetBlasTensorCoreHandle(handle);
}
void GPUContext::SetBlasTF32Handle(blasHandle_t handle) {
impl_->SetBlasTF32Handle(handle);
}
void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) {
impl_->SetBlasLtHandle(blaslt);
}
......
......@@ -199,6 +199,10 @@ class PADDLE_API GPUContext : public DeviceContext {
void SetBlasHandle(blasHandle_t);
void SetBlasTensorCoreHandle(blasHandle_t);
void SetBlasTF32Handle(blasHandle_t);
void SetBlasLtHandle(blasLtHandle_t);
void SetDnnHandle(dnnHandle_t);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/backends/dynload/cublas.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/dynload/cusparse.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#include "paddle/phi/backends/dynload/nccl.h"
#endif // !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#endif // PADDLE_WITH_CUDA
#include "unsupported/Eigen/CXX11/Tensor"
// TODO(phi): remove fluid header.
#include "paddle/fluid/platform/enforce.h"
namespace phi {
void InitGpuProperties(Place place,
int* compute_capability,
int* runtime_version,
int* driver_version,
int* multi_process,
int* max_threads_per_mp,
int* max_threads_per_block,
std::array<int, 3>* max_grid_dim_size) {
backends::gpu::GPUDeviceGuard guard(place.GetDeviceId());
*compute_capability =
backends::gpu::GetGPUComputeCapability(place.GetDeviceId());
*multi_process = backends::gpu::GetGPUMultiProcessors(place.GetDeviceId());
*max_threads_per_mp =
backends::gpu::GetGPUMaxThreadsPerMultiProcessor(place.GetDeviceId());
*max_grid_dim_size = backends::gpu::GetGpuMaxGridDimSize(place.GetDeviceId());
*max_threads_per_block =
backends::gpu::GetGPUMaxThreadsPerBlock(place.GetDeviceId());
*driver_version = backends::gpu::GetGPUDriverVersion(place.GetDeviceId());
*runtime_version = backends::gpu::GetGPURuntimeVersion(place.GetDeviceId());
// TODO(wilber): glog may be replaced in the future?
LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: "
<< static_cast<int>(place.device)
<< ", GPU Compute Capability: "
<< *compute_capability / 10 << "."
<< *compute_capability % 10
<< ", Driver API Version: " << *driver_version / 1000
<< "." << (*driver_version % 100) / 10
<< ", Runtime API Version: "
<< *runtime_version / 1000 << "."
<< (*runtime_version % 100) / 10;
#ifdef PADDLE_WITH_HIP
size_t miopen_major, miopen_minor, miopen_patch;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch));
auto cudnn_dso_ver =
(miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
auto compile_miopen_version = MIOPEN_VERSION / 10;
if (cudnn_dso_ver < static_cast<size_t>(compile_miopen_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << static_cast<int>(place.device)
<< ". The installed Paddle is compiled with MIOPEN "
<< compile_miopen_version / 100 << "." << compile_miopen_version % 100
<< ", but MIOPEN version in your machine is " << cudnn_dso_ver / 100
<< "." << cudnn_dso_ver % 100
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible MIOPEN "
"version.";
}
#else
size_t cudnn_dso_ver = dynload::cudnnGetVersion();
LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place.device)
<< ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
<< (cudnn_dso_ver % 1000) / 100 << ".";
// Check CUDA/CUDNN version compatiblity
auto local_cuda_version =
(*driver_version / 1000) * 10 + (*driver_version % 100) / 10;
auto compile_cuda_version =
(CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
if (local_cuda_version < compile_cuda_version) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << static_cast<int>(place.device)
<< ". The installed Paddle is compiled with CUDA "
<< compile_cuda_version / 10 << "." << compile_cuda_version % 10
<< ", but CUDA runtime version in your machine is "
<< local_cuda_version / 10 << "." << local_cuda_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDA "
"version.";
}
#endif
}
void InitStream(gpuStream_t* stream) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamCreateWithPriority(stream, hipStreamDefault, 0));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamCreateWithPriority(stream, cudaStreamDefault, 0));
#endif
}
void DestoryStream(gpuStream_t stream) {
if (stream != nullptr) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#endif
}
stream = nullptr;
}
void InitBlasHandle(blasHandle_t* blas_handle, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
phi::dynload::rocblas_create_handle(blas_handle);
phi::dynload::rocblas_set_stream(*blas_handle, stream);
#else // PADDLE_WITH_CUDA
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasCreate(blas_handle));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(*blas_handle, stream));
#endif // PADDLE_WITH_HIP
}
void DestroyBlasHandle(blasHandle_t handle) {
#ifdef PADDLE_WITH_HIP
if (handle != nullptr) {
phi::dynload::rocblas_destroy_handle(handle);
handle = nullptr;
}
#else
if (handle != nullptr) {
phi::dynload::cublasDestroy(handle);
handle = nullptr;
}
#endif // PADDLE_WITH_HIP
}
void InitBlasLtHandle(blasLtHandle_t* blaslt_handle) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi::dynload::cublasLtCreate(blaslt_handle);
#endif
}
void DestroyBlasLtHandle(blasLtHandle_t handle) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (handle != nullptr) {
phi::dynload::cublasLtDestroy(handle);
handle = nullptr;
}
#endif
}
void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place) {
if (phi::dynload::HasCUDNN()) {
#ifdef PADDLE_WITH_HIP
size_t miopen_major, miopen_minor, miopen_patch;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch));
auto local_miopen_version =
(miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
auto compile_miopen_version = MIOPEN_VERSION / 10;
if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place.device
<< ". The installed Paddle is compiled with MIOPEN "
<< compile_miopen_version / 100 << "." << compile_miopen_version % 100
<< ", but MIOPEN version in your machine is "
<< local_miopen_version / 100 << "." << local_miopen_version % 100
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible MIOPEN "
"version.";
}
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(handle));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetStream(*handle, stream));
#else
auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnCreate(handle));
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnSetStream(*handle, stream));
#endif
} else {
*handle = nullptr;
}
}
void DestroyDnnHandle(dnnHandle_t handle) {
#ifdef PADDLE_WITH_HIP
if (handle != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenDestroy(handle));
handle = nullptr;
}
#else
if (handle != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDestroy(handle));
handle = nullptr;
}
#endif // PADDLE_WITH_HIP
}
void InitSolverHandle(solverHandle_t* handle, gpuStream_t stream) {
#ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnCreate(handle));
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnSetStream(*handle, stream));
#endif
}
void DestroySolverHandle(solverHandle_t solver_handle) {
#ifndef PADDLE_WITH_HIP
if (solver_handle != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDestroy(solver_handle));
solver_handle = nullptr;
}
#endif
}
void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream) {
// ROCM is not yet supported
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 10010
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(handle));
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseSetStream(*handle, stream));
#endif
#endif
}
void DestroySparseHandle(sparseHandle_t handle) {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if (handle != nullptr) {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(handle));
handle = nullptr;
}
#endif
#endif
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/common/place.h"
namespace phi {
void InitGpuProperties(Place place,
int* compute_capability,
int* runtime_version,
int* driver_version,
int* multi_process,
int* max_threads_per_mp,
int* max_threads_per_block,
std::array<int, 3>* max_grid_dim_size);
void InitStream(gpuStream_t* stream);
void DestoryStream(gpuStream_t stream);
void InitBlasHandle(blasHandle_t* blas_handle, gpuStream_t stream);
void DestroyBlasHandle(blasHandle_t handle);
void InitBlasLtHandle(blasLtHandle_t* blaslt_handle);
void DestroyBlasLtHandle(blasLtHandle_t handle);
void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place);
void DestroyDnnHandle(dnnHandle_t handle);
void InitSolverHandle(solverHandle_t* handle, gpuStream_t stream);
void DestroySolverHandle(solverHandle_t solver_handle);
void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream);
void DestroySparseHandle(sparseHandle_t handle);
// void InitDnnWorkspace();
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册