未验证 提交 1892a441 编写于 作者: W Wilber 提交者: GitHub

inference multi stream support handle lazy init. (#44563)

* multi stream support handle lazy init.

* support eigen lazy init

* update

* fix ci problem
上级 65ad58b6
......@@ -415,14 +415,16 @@ void AnalysisPredictor::InitDeviceContexts() {
gpu_context->SetHostGenerator(framework::DefaultCPUGenerator().get());
gpu_context->SetStream(gpu_resource->GetStream());
gpu_context->SetBlasHandle(gpu_resource->GetBlasHandle());
gpu_context->SetBlasHandle(gpu_resource->GetBlasHandleCreator());
gpu_context->SetBlasTensorCoreHandle(
gpu_resource->GetBlasTensorCoreHandle());
gpu_context->SetBlasTF32Handle(gpu_resource->GetBlasTF32Handle());
gpu_context->SetDnnHandle(gpu_resource->GetDnnHandle());
gpu_context->SetSolverHandle(gpu_resource->GetSolverDnHandle());
gpu_context->SetSparseHandle(gpu_resource->GetSparseHandle());
gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDevice());
gpu_resource->GetBlasTensorCoreHandleCreator());
gpu_context->SetBlasTF32Handle(
gpu_resource->GetBlasTF32TensorCoreHandleCreator());
gpu_context->SetDnnHandle(gpu_resource->GetDnnHandleCreator());
gpu_context->SetSolverHandle(
gpu_resource->GetSolverDnHandleCreator());
gpu_context->SetSparseHandle(gpu_resource->GetSparseHandleCreator());
gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDeviceCreator());
gpu_context->SetComputeCapability(
gpu_resource->GetGpuComputeCapability());
gpu_context->SetMaxThreadsPerBlock(
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/api/resource_manager.h"
#include <functional>
#include <memory>
#include <mutex>
#include <unordered_map>
......@@ -150,12 +151,6 @@ void GPUContextResource::InitGPUResource(void* stream) {
}
InitGpuProperties();
InitGpuEigenDevice();
InitDnnHanlde();
InitBlasHandle();
InitBlasLtHandle();
InitSolverHandle();
InitSparseHandle();
}
void GPUContextResource::DestroyGPUResource() {
......@@ -203,22 +198,6 @@ void GPUContextResource::DestroyDnnHandle() {
phi::DestroyDnnHandle(dnn_handle_);
}
void GPUContextResource::InitBlasHandle() {
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
}
void GPUContextResource::DestroyBlasHandle() {
phi::DestroyBlasHandle(blas_handle_);
phi::DestroyBlasHandle(blas_tensor_core_handle_);
......@@ -255,32 +234,106 @@ gpuStream_t GPUContextResource::GetStream() const { return stream_; }
dnnHandle_t GPUContextResource::GetDnnHandle() const { return dnn_handle_; }
std::function<phi::dnnHandle_t()> GPUContextResource::GetDnnHandleCreator() {
return [&]() -> phi::dnnHandle_t {
InitDnnHanlde();
return dnn_handle_;
};
}
blasHandle_t GPUContextResource::GetBlasHandle() const { return blas_handle_; }
std::function<phi::blasHandle_t()> GPUContextResource::GetBlasHandleCreator() {
return [&]() -> phi::blasHandle_t {
phi::InitBlasHandle(&blas_handle_, stream_);
return blas_handle_;
};
}
blasHandle_t GPUContextResource::GetBlasTensorCoreHandle() const {
return blas_tensor_core_handle_;
}
std::function<phi::blasHandle_t()>
GPUContextResource::GetBlasTensorCoreHandleCreator() {
return [&]() -> phi::blasHandle_t {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#endif
return blas_tensor_core_handle_;
};
}
blasHandle_t GPUContextResource::GetBlasTF32Handle() const {
return blas_tf32_tensor_core_handle_;
}
std::function<phi::blasHandle_t()>
GPUContextResource::GetBlasTF32TensorCoreHandleCreator() {
return [&]() -> phi::blasHandle_t {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
return blas_tf32_tensor_core_handle_;
};
}
blasLtHandle_t GPUContextResource::GetBlasLtHandle() const {
return blaslt_handle_;
}
std::function<phi::blasLtHandle_t()>
GPUContextResource::GetBlasLtHandleCreator() {
return [&]() {
InitBlasLtHandle();
return blaslt_handle_;
};
}
phi::solverHandle_t GPUContextResource::GetSolverDnHandle() const {
return solver_handle_;
}
std::function<phi::solverHandle_t()>
GPUContextResource::GetSolverDnHandleCreator() {
return [&]() {
InitSolverHandle();
return solver_handle_;
};
}
phi::sparseHandle_t GPUContextResource::GetSparseHandle() const {
return sparse_handle_;
}
std::function<phi::sparseHandle_t()>
GPUContextResource::GetSparseHandleCreator() {
return [&]() {
InitSparseHandle();
return sparse_handle_;
};
}
Eigen::GpuDevice* GPUContextResource::GetGpuEigenDevice() const {
return gpu_eigen_device_.get();
}
std::function<Eigen::GpuDevice*()>
GPUContextResource::GetGpuEigenDeviceCreator() {
return [&]() {
InitGpuEigenDevice();
return gpu_eigen_device_.get();
};
}
int GPUContextResource::GetGpuComputeCapability() const {
return compute_capability_;
}
......@@ -311,15 +364,19 @@ void GPUContextResource::ReBindStream(gpuStream_t stream) {
}
void GPUContextResource::ReBindDnnHandle(gpuStream_t stream) const {
if (dnn_handle_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenSetStream(dnn_handle_, stream));
#else
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnSetStream(dnn_handle_, stream));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cudnnSetStream(dnn_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindBlasHandle(gpuStream_t stream) const {
if (blas_handle_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_set_stream(blas_handle_, stream));
......@@ -327,9 +384,11 @@ void GPUContextResource::ReBindBlasHandle(gpuStream_t stream) const {
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindBlasTensorCoreHandle(gpuStream_t stream) const {
if (blas_tensor_core_handle_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_set_stream(blas_tensor_core_handle_, stream));
......@@ -337,26 +396,32 @@ void GPUContextResource::ReBindBlasTensorCoreHandle(gpuStream_t stream) const {
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindBlasTF32Handle(gpuStream_t stream) const {
if (blas_tf32_tensor_core_handle_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_set_stream(blas_tf32_tensor_core_handle_, stream));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::rocblas_set_stream(
blas_tf32_tensor_core_handle_, stream));
#else
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindSolverDnHandle(gpuStream_t stream) const {
if (solver_handle_) {
#ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cusolverDnSetStream(solver_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindSparseHandle(gpuStream_t stream) const {
if (sparse_handle_) {
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 11000
......@@ -364,14 +429,17 @@ void GPUContextResource::ReBindSparseHandle(gpuStream_t stream) const {
phi::dynload::cusparseSetStream(sparse_handle_, stream));
#endif
#endif
}
}
void GPUContextResource::ReBindEigenDevice(gpuStream_t stream,
GPUPlace place) const {
if (eigen_stream_) {
auto* allocator = paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place_)
.get();
eigen_stream_->Reinitialize(stream, allocator, place);
}
}
#endif
......
......@@ -55,6 +55,15 @@ class GPUContextResource {
~GPUContextResource();
phi::Place Place() const;
std::function<phi::dnnHandle_t()> GetDnnHandleCreator();
std::function<phi::blasHandle_t()> GetBlasHandleCreator();
std::function<phi::blasHandle_t()> GetBlasTensorCoreHandleCreator();
std::function<phi::blasHandle_t()> GetBlasTF32TensorCoreHandleCreator();
std::function<phi::blasLtHandle_t()> GetBlasLtHandleCreator();
std::function<phi::solverHandle_t()> GetSolverDnHandleCreator();
std::function<phi::sparseHandle_t()> GetSparseHandleCreator();
std::function<Eigen::GpuDevice*()> GetGpuEigenDeviceCreator();
gpuStream_t GetStream() const;
dnnHandle_t GetDnnHandle() const;
blasHandle_t GetBlasHandle() const;
......@@ -89,7 +98,6 @@ class GPUContextResource {
void InitGpuEigenDevice();
void InitDnnHanlde();
void DestroyDnnHandle();
void InitBlasHandle();
void DestroyBlasHandle();
void InitBlasLtHandle();
void DestroyBlasLtHandle();
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
......
......@@ -213,7 +213,6 @@ struct GPUContext::Impl {
&max_threads_per_block_,
&max_grid_dim_size_);
phi::InitStream(&stream_);
InitEigenDevice();
InitDnnWorkspace();
}
......@@ -234,7 +233,6 @@ struct GPUContext::Impl {
void PartialInitWithAllocator() {
owned_ = true;
backends::gpu::GPUDeviceGuard guard(place_.device);
InitEigenDevice();
InitDnnWorkspace();
}
......@@ -317,27 +315,49 @@ struct GPUContext::Impl {
void SetEigenDevice(Eigen::GpuDevice* device) { eigen_device_ = device; }
Eigen::GpuDevice* eigen_device() const {
void SetEigenDevice(std::function<Eigen::GpuDevice*()>&& creator) {
eigen_device_creator_ = std::move(creator);
}
Eigen::GpuDevice* eigen_device() {
std::call_once(flag_eigen_device_, [&]() {
if (!eigen_device_) {
if (!eigen_device_creator_)
InitEigenDevice();
else
eigen_device_ = eigen_device_creator_();
}
});
PD_CHECK(eigen_device_ != nullptr, "the gpu eigen_device is nullptr.");
return eigen_device_;
}
blasHandle_t GetBlasHandle() {
std::call_once(flag_blas_, [=]() {
std::call_once(flag_blas_, [&]() {
if (!blas_handle_) {
if (!blas_handle_creator_)
phi::InitBlasHandle(&blas_handle_, stream_);
else
blas_handle_ = blas_handle_creator_();
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
if (!blas_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
else
blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
if (!blas_tf32_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
else
blas_tf32_tensor_core_handle_ =
blas_tf32_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
......@@ -350,27 +370,53 @@ struct GPUContext::Impl {
void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; }
void SetBlasHandle(std::function<blasHandle_t()>&& handle_creator) {
blas_handle_creator_ = std::move(handle_creator);
}
void SetBlasTensorCoreHandle(blasHandle_t handle) {
blas_tensor_core_handle_ = handle;
}
void SetBlasTensorCoreHandle(std::function<blasHandle_t()>&& handle_creator) {
blas_tensor_core_handle_creator_ = std::move(handle_creator);
}
void SetBlasTF32Handle(blasHandle_t handle) {
blas_tf32_tensor_core_handle_ = handle;
}
void SetBlasTF32Handle(std::function<blasHandle_t()>&& handle_creator) {
blas_tf32_tensor_core_handle_creator_ = std::move(handle_creator);
}
void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; }
void SetBlasLtHandle(std::function<blasLtHandle_t()>&& handle_creator) {
blaslt_handle_creator_ = std::move(handle_creator);
}
blasLtHandle_t GetBlasLtHandle() {
std::call_once(flag_blaslt_, [=]() {
if (!blaslt_handle_) phi::InitBlasLtHandle(&blaslt_handle_);
std::call_once(flag_blaslt_, [&]() {
if (!blaslt_handle_) {
if (!blaslt_handle_creator_)
phi::InitBlasLtHandle(&blaslt_handle_);
else
blaslt_handle_ = blaslt_handle_creator_();
}
});
PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr.");
return blaslt_handle_;
}
dnnHandle_t GetDnnHandle() {
std::call_once(flag_dnn_, [=]() {
if (!dnn_handle_) phi::InitDnnHandle(&dnn_handle_, stream_, place_);
std::call_once(flag_dnn_, [&]() {
if (!dnn_handle_) {
if (!dnn_handle_creator_)
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
else
dnn_handle_ = dnn_handle_creator_();
}
});
PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr.");
return dnn_handle_;
......@@ -392,9 +438,18 @@ struct GPUContext::Impl {
void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; }
void SetDnnHandle(std::function<dnnHandle_t()>&& handle_creator) {
dnn_handle_creator_ = std::move(handle_creator);
}
solverHandle_t GetSolverHandle() {
std::call_once(flag_slover_, [=]() {
if (!solver_handle_) phi::InitSolverHandle(&solver_handle_, stream_);
std::call_once(flag_slover_, [&]() {
if (!solver_handle_) {
if (!solver_handle_creator_)
phi::InitSolverHandle(&solver_handle_, stream_);
else
solver_handle_ = solver_handle_creator_();
}
});
PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr.");
return solver_handle_;
......@@ -402,9 +457,18 @@ struct GPUContext::Impl {
void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; }
void SetSolverHandle(std::function<solverHandle_t()>&& handle_creator) {
solver_handle_creator_ = std::move(handle_creator);
}
sparseHandle_t GetSparseHandle() {
std::call_once(flag_sparse_, [=]() {
if (!sparse_handle_) phi::InitSparseHandle(&sparse_handle_, stream_);
std::call_once(flag_sparse_, [&]() {
if (!sparse_handle_) {
if (!sparse_handle_creator_)
phi::InitSparseHandle(&sparse_handle_, stream_);
else
sparse_handle_ = sparse_handle_creator_();
}
});
PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr.");
return sparse_handle_;
......@@ -412,6 +476,10 @@ struct GPUContext::Impl {
void SetSparseHandle(sparseHandle_t handle) { sparse_handle_ = handle; }
void SetSparseHandle(std::function<sparseHandle_t()>&& handle_creator) {
sparse_handle_creator_ = std::move(handle_creator);
}
void Wait() const {
#ifdef PADDLE_WITH_HIP
hipError_t e_sync = hipSuccess;
......@@ -461,21 +529,31 @@ struct GPUContext::Impl {
}
inline void CublasCall(const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_cublas_, [=]() {
std::call_once(flag_cublas_, [&]() {
if (!blas_handle_) {
if (!blas_handle_creator_)
phi::InitBlasHandle(&blas_handle_, stream_);
else
blas_handle_ = blas_handle_creator_();
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
if (!blas_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
else
blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
if (!blas_tf32_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
else
blas_tf32_tensor_core_handle_ =
blas_tf32_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
......@@ -493,19 +571,31 @@ struct GPUContext::Impl {
inline void TensorCoreCublasCallIfAvailable(
const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_tensorcore_cublas_, [=]() {
if (!blas_handle_) phi::InitBlasHandle(&blas_handle_, stream_);
std::call_once(flag_tensorcore_cublas_, [&]() {
if (!blas_handle_) {
if (!blas_handle_creator_)
phi::InitBlasHandle(&blas_handle_, stream_);
else
blas_handle_ = blas_handle_creator_();
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
if (!blas_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
else
blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
if (!blas_tf32_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
else
blas_tf32_tensor_core_handle_ =
blas_tf32_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
......@@ -523,9 +613,12 @@ struct GPUContext::Impl {
inline void CusparseCall(
const std::function<void(sparseHandle_t)>& callback) {
std::call_once(flag_sparse_, [=]() {
std::call_once(flag_sparse_, [&]() {
if (!sparse_handle_) {
if (!sparse_handle_creator_)
phi::InitSparseHandle(&sparse_handle_, stream_);
else
sparse_handle_ = sparse_handle_creator_();
}
});
std::lock_guard<std::mutex> guard(sparse_mtx_);
......@@ -597,13 +690,21 @@ struct GPUContext::Impl {
gpuStream_t stream_{nullptr};
Eigen::GpuDevice* eigen_device_{nullptr};
std::function<Eigen::GpuDevice*()> eigen_device_creator_{nullptr};
blasHandle_t blas_handle_{nullptr};
std::function<blasHandle_t()> blas_handle_creator_{nullptr};
blasHandle_t blas_tensor_core_handle_{nullptr};
std::function<blasHandle_t()> blas_tensor_core_handle_creator_{nullptr};
blasHandle_t blas_tf32_tensor_core_handle_{nullptr};
std::function<blasHandle_t()> blas_tf32_tensor_core_handle_creator_{nullptr};
blasLtHandle_t blaslt_handle_{nullptr};
std::function<blasLtHandle_t()> blaslt_handle_creator_{nullptr};
dnnHandle_t dnn_handle_{nullptr};
std::function<dnnHandle_t()> dnn_handle_creator_{nullptr};
solverHandle_t solver_handle_{nullptr};
std::function<solverHandle_t()> solver_handle_creator_{nullptr};
sparseHandle_t sparse_handle_{nullptr};
std::function<sparseHandle_t()> sparse_handle_creator_{nullptr};
DnnWorkspaceHandle* workspace_{nullptr};
std::once_flag flag_sparse_;
......@@ -613,6 +714,7 @@ struct GPUContext::Impl {
std::once_flag flag_slover_;
std::once_flag flag_cublas_;
std::once_flag flag_tensorcore_cublas_;
std::once_flag flag_eigen_device_;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// NCCL communicator (single process version) for NCCL collective operations.
......@@ -752,34 +854,66 @@ void GPUContext::SetEigenDevice(Eigen::GpuDevice* device) {
impl_->SetEigenDevice(device);
}
void GPUContext::SetEigenDevice(std::function<Eigen::GpuDevice*()>&& creator) {
impl_->SetEigenDevice(std::move(creator));
}
void GPUContext::SetBlasHandle(blasHandle_t blas) {
impl_->SetBlasHandle(blas);
}
void GPUContext::SetBlasHandle(std::function<blasHandle_t()>&& func) {
impl_->SetBlasHandle(std::move(func));
}
void GPUContext::SetBlasTensorCoreHandle(blasHandle_t handle) {
impl_->SetBlasTensorCoreHandle(handle);
}
void GPUContext::SetBlasTensorCoreHandle(std::function<blasHandle_t()>&& func) {
impl_->SetBlasTensorCoreHandle(std::move(func));
}
void GPUContext::SetBlasTF32Handle(blasHandle_t handle) {
impl_->SetBlasTF32Handle(handle);
}
void GPUContext::SetBlasTF32Handle(std::function<blasHandle_t()>&& func) {
impl_->SetBlasTF32Handle(std::move(func));
}
void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) {
impl_->SetBlasLtHandle(blaslt);
}
void GPUContext::SetBlasLtHandle(std::function<blasLtHandle_t()>&& func) {
impl_->SetBlasLtHandle(std::move(func));
}
void GPUContext::SetDnnHandle(dnnHandle_t handle) {
impl_->SetDnnHandle(handle);
}
void GPUContext::SetDnnHandle(std::function<dnnHandle_t()>&& func) {
impl_->SetDnnHandle(std::move(func));
}
void GPUContext::SetSolverHandle(solverHandle_t handle) {
impl_->SetSolverHandle(handle);
}
void GPUContext::SetSolverHandle(std::function<solverHandle_t()>&& func) {
impl_->SetSolverHandle(std::move(func));
}
void GPUContext::SetSparseHandle(sparseHandle_t handle) {
impl_->SetSparseHandle(handle);
}
void GPUContext::SetSparseHandle(std::function<sparseHandle_t()>&& func) {
impl_->SetSparseHandle(std::move(func));
}
void GPUContext::SetDnnWorkspaceHandle(DnnWorkspaceHandle* handle) {
impl_->workspace_ = handle;
}
......
......@@ -197,20 +197,28 @@ class PADDLE_API GPUContext : public DeviceContext {
void SetStream(gpuStream_t);
void SetEigenDevice(Eigen::GpuDevice*);
void SetEigenDevice(std::function<Eigen::GpuDevice*()>&&);
void SetBlasHandle(blasHandle_t);
void SetBlasHandle(std::function<blasHandle_t()>&&);
void SetBlasTensorCoreHandle(blasHandle_t);
void SetBlasTensorCoreHandle(std::function<blasHandle_t()>&&);
void SetBlasTF32Handle(blasHandle_t);
void SetBlasTF32Handle(std::function<blasHandle_t()>&&);
void SetBlasLtHandle(blasLtHandle_t);
void SetBlasLtHandle(std::function<blasLtHandle_t()>&&);
void SetDnnHandle(dnnHandle_t);
void SetDnnHandle(std::function<dnnHandle_t()>&&);
void SetSolverHandle(solverHandle_t);
void SetSolverHandle(std::function<solverHandle_t()>&&);
void SetSparseHandle(sparseHandle_t);
void SetSparseHandle(std::function<sparseHandle_t()>&&);
void SetDnnWorkspaceHandle(DnnWorkspaceHandle*);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册