未验证 提交 b9494058 编写于 作者: H Huihuang Zheng 提交者: GitHub

Use CudnnWorkspaceHandle in exhaustive search (#17082)

1. Use CudnnWorkspaceHandle in exhaustive search of conv_cudnn.
2. For Ops using CudnnWorkspaceHandle in exhaustive search, release their GPU memory after exhaustive search.

test=develop
上级 2192e7bb
......@@ -139,9 +139,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo;
auto handle = dev_ctx.cudnn_handle();
bool half_float = false;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Tensor core is supported since the volta GPU and
// is only enabled when input and filter data are float16
......@@ -160,9 +159,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
VLOG(5) << "NOT use cudnn_tensor_op_math";
}
#endif
Tensor cudnn_workspace;
void* cudnn_workspace_ptr = nullptr;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
if ((!exhaustive_search) && (!half_float)) {
......@@ -174,12 +173,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
} else if (exhaustive_search && (!half_float)) {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_limit)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
algo = algo_cache.GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
......@@ -187,13 +180,16 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc,
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace_ptr,
workspace_size_limit));
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc,
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
......@@ -219,14 +215,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
"workspace_size to be allocated exceeds the limit");
// Allocate on GPU memory
if (!cudnn_workspace_ptr) {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
}
Tensor cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
void* cudnn_workspace_ptr =
static_cast<void*>(cudnn_workspace.data<int8_t>());
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
DEFINE_int64(cudnn_exhaustive_search_times, -1,
"Exhaustive search times for cuDNN convolution, "
"defalut is 1, only search once.");
"defalut is -1, not exhaustive search");
namespace paddle {
namespace operators {
......@@ -132,7 +132,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
......
......@@ -163,6 +163,15 @@ class CudnnHolder {
cudnn_func(WorkspacePtr());
}
/*! \brief Reset workspace thus release the memory */
inline void ResetWorkspace() {
if (workspace_) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
workspace_ = nullptr;
}
}
inline void* WorkspacePtr() {
if (workspace_) {
return workspace_->ptr();
......@@ -207,6 +216,22 @@ class CudnnWorkspaceHandle {
required_workspace_len);
}
/*! \brief Thread which call RunFuncSync() would acquire the lock first
* before invoking cudnn function and release gpu memory after running
* the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
template <typename Callback>
inline void RunFuncSync(Callback&& cudnn_func,
size_t required_workspace_len) {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
}
holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
required_workspace_len);
holder_->ResetWorkspace();
}
CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册