未验证 提交 3bad26ec 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

convfusion_cache (#45902)

上级 d6b5d91c
...@@ -55,7 +55,8 @@ class ConvSearchCache { ...@@ -55,7 +55,8 @@ class ConvSearchCache {
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>* GetBackwardFilter() { AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>* GetBackwardFilter() {
return &backward_filter_cache_; return &backward_filter_cache_;
} }
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() { AlgorithmsCache<SearchFuseResult<cudnnConvolutionFwdAlgo_t>>*
GetConvFusion() {
return &fusion_forward_cache_; return &fusion_forward_cache_;
} }
#endif #endif
...@@ -75,7 +76,8 @@ class ConvSearchCache { ...@@ -75,7 +76,8 @@ class ConvSearchCache {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_; AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t> backward_data_cache_; AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t> backward_data_cache_;
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t> backward_filter_cache_; AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t> backward_filter_cache_;
AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_; AlgorithmsCache<SearchFuseResult<cudnnConvolutionFwdAlgo_t>>
fusion_forward_cache_;
#endif #endif
}; };
......
...@@ -24,6 +24,16 @@ limitations under the License. */ ...@@ -24,6 +24,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename AlgoT>
struct SearchFuseResult {
SearchFuseResult() {}
explicit SearchFuseResult(AlgoT a) : algo(a) {}
AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f;
size_t workspace_size = 0;
};
// thread-safe. // thread-safe.
template <typename TAlgorithm> template <typename TAlgorithm>
class AlgorithmsCache { class AlgorithmsCache {
......
...@@ -35,6 +35,7 @@ using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; ...@@ -35,6 +35,7 @@ using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using framework::AlgorithmsCache; using framework::AlgorithmsCache;
using framework::ConvSearchCache; using framework::ConvSearchCache;
using framework::SearchFuseResult;
template <typename T> template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType; using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
...@@ -348,34 +349,35 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -348,34 +349,35 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
&perf_count, &perf_count,
perf_results.get())); perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo; algo = (perf_results.get())[best_algo_idx].algo;
#else
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, handle,
cudnn_input_desc, cudnn_input_desc,
cudnn_filter_desc, cudnn_filter_desc,
cudnn_conv_desc, cudnn_conv_desc,
cudnn_output_desc, cudnn_output_desc,
algo, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
&workspace_size_in_bytes)); workspace_size_limit,
if (workspace_size_in_bytes > workspace_size_limit) &algo));
workspace_size_limit = workspace_size_in_bytes; #endif
#else
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardAlgorithm( platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, handle,
cudnn_input_desc, cudnn_input_desc,
cudnn_filter_desc, cudnn_filter_desc,
cudnn_conv_desc, cudnn_conv_desc,
cudnn_output_desc, cudnn_output_desc,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, algo,
workspace_size_limit, &workspace_size_in_bytes));
&algo)); if (workspace_size_in_bytes > workspace_size_limit)
workspace_size_limit = workspace_size_in_bytes;
VLOG(3) << "cuDNN forward algo " << algo; VLOG(3) << "cuDNN forward algo " << algo;
#endif
} else { } else {
std::function<cudnnConvolutionFwdAlgo_t()> search_func = std::function<SearchFuseResult<cudnnConvolutionFwdAlgo_t>()> search_func =
[&]() -> cudnnConvolutionFwdAlgo_t { [&]() -> SearchFuseResult<cudnnConvolutionFwdAlgo_t> {
int returned_algo_count; int returned_algo_count;
SearchFuseResult<cudnnConvolutionFwdAlgo_t> fwd_result;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS> std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat; fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) { auto cudnn_find_func = [&](void* cudnn_workspace) {
...@@ -402,11 +404,34 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -402,11 +404,34 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time << " " VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time << " "
<< stat.memory; << stat.memory;
} }
return fwd_perf_stat[0].algo;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle,
cudnn_input_desc,
cudnn_filter_desc,
cudnn_conv_desc,
cudnn_output_desc,
fwd_perf_stat[0].algo,
&workspace_size_in_bytes));
// PADDLE_ENFORCE_LE(
// workspace_size_in_bytes,
// workspace_size_limit,
// platform::errors::InvalidArgument(
// "The actual workspace size to be allocated for cuDNN is
// expected " "to be less than the limit. But received: the
// actual workspace " "size = %d, limit = %d.",
// workspace_size_in_bytes,
// workspace_size_limit));
fwd_result.algo = fwd_perf_stat[0].algo;
fwd_result.workspace_size = workspace_size_in_bytes;
return fwd_result;
}; };
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache = AlgorithmsCache<SearchFuseResult<cudnnConvolutionFwdAlgo_t>>& algo_cache =
*(framework::ConvSearchCache::Instance().GetConvFusion()); *(framework::ConvSearchCache::Instance().GetConvFusion());
int search_times = ctx.Attr<int>("search_times"); int search_times = ctx.Attr<int>("search_times");
SearchFuseResult<cudnnConvolutionFwdAlgo_t> algo_result;
search_times = std::max( search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times); static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
// TODO(dangqingqing): Unify this if-else. // TODO(dangqingqing): Unify this if-else.
...@@ -414,10 +439,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -414,10 +439,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// The searched algo will be cached by `search_times` times for // The searched algo will be cached by `search_times` times for
// different input dimension. For other dimensions, select the algo // different input dimension. For other dimensions, select the algo
// of closest area. // of closest area.
algo = algo_cache.GetAlgorithm( algo_result = algo_cache.GetAlgorithm(
x_dims[2] * x_dims[3], search_times, 0, search_func); x_dims[2] * x_dims[3], search_times, 0, search_func);
algo = algo_result.algo;
workspace_size_in_bytes = algo_result.workspace_size;
} else { } else {
algo = algo_cache.GetAlgorithm(x_dims, algo_result = algo_cache.GetAlgorithm(x_dims,
f_dims, f_dims,
strides, strides,
paddings, paddings,
...@@ -425,28 +452,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -425,28 +452,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
0, 0,
dtype, dtype,
search_func); search_func);
algo = algo_result.algo;
workspace_size_in_bytes = algo_result.workspace_size;
} }
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
} }
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle,
cudnn_input_desc,
cudnn_filter_desc,
cudnn_conv_desc,
cudnn_output_desc,
algo,
&workspace_size_in_bytes));
// PADDLE_ENFORCE_LE(
// workspace_size_in_bytes,
// workspace_size_limit,
// platform::errors::InvalidArgument(
// "The actual workspace size to be allocated for cuDNN is expected
// " "to be less than the limit. But received: the actual workspace
// " "size = %d, limit = %d.", workspace_size_in_bytes,
// workspace_size_limit));
if ((activation == "identity") && (!residual)) { if ((activation == "identity") && (!residual)) {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
// enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册