From ce7d9b079947e55f23d7653432732e498f723274 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 7 Nov 2018 09:56:06 +0800 Subject: [PATCH] Exhaustive search for cuDNN conv. (#14043) * exhaustive search for cuDNN conv. * Refine code and add unit testing. * Clean code * Fix model load in fluid/inference and unit testing in conv2d * Follow comments. --- .../framework/ir/graph_pattern_detector.cc | 1 + .../fluid/inference/api/analysis_predictor.h | 2 + paddle/fluid/inference/api/helper.h | 3 +- paddle/fluid/inference/io.cc | 3 +- .../operators/add_position_encoding_op.h | 7 +- paddle/fluid/operators/conv_cudnn_op.cu.cc | 204 ++++++++++++++++-- paddle/fluid/operators/conv_cudnn_op_cache.h | 90 ++++++++ paddle/fluid/operators/conv_op.cc | 11 +- paddle/fluid/platform/device_context.cc | 5 +- paddle/fluid/platform/dynload/cudnn.h | 93 ++++---- python/paddle/fluid/__init__.py | 3 +- python/paddle/fluid/layers/nn.py | 17 +- .../fluid/tests/unittests/test_conv2d_op.py | 10 +- .../fluid/tests/unittests/test_conv3d_op.py | 6 + 14 files changed, 381 insertions(+), 74 deletions(-) create mode 100644 paddle/fluid/operators/conv_cudnn_op_cache.h diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index b20d7013225..fa713fe1dd5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index b7dc2067332..a9f4cce6dfa 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include +#include #include #include #include "paddle/fluid/framework/naive_executor.h" diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index e46dc132695..af21c0095c2 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -16,13 +16,14 @@ #include #include +#include #include // NOLINT #include #include #include #include +#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/string/printf.h" -#include "paddle_inference_api.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index e246a06fd07..31f43bfdcaa 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -59,7 +59,8 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) { bool IsPersistable(const framework::VarDesc* var) { if (var->Persistable() && var->GetType() != framework::proto::VarType::FEED_MINIBATCH && - var->GetType() != framework::proto::VarType::FETCH_LIST) { + var->GetType() != framework::proto::VarType::FETCH_LIST && + var->GetType() != framework::proto::VarType::RAW) { return true; } return false; diff --git a/paddle/fluid/operators/add_position_encoding_op.h b/paddle/fluid/operators/add_position_encoding_op.h index 5f371235f16..0b40d3de890 100644 --- a/paddle/fluid/operators/add_position_encoding_op.h +++ b/paddle/fluid/operators/add_position_encoding_op.h @@ -66,9 +66,10 @@ class AddPositionEncodingKernel : public framework::OpKernel { x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i]; for (int j = 0; j < max_length; ++j) { for (int k = 0; k < half_size; ++k) { - const double val = (half_size > 1) - ? j / pow(10000.0, double(k) / (half_size - 1)) - : j / 10000.0; + const double val = + (half_size > 1) + ? j / pow(10000.0, static_cast(k) / (half_size - 1)) + : j / 10000.0; dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta; dst_ptr[half_size + k] = src_ptr[half_size + k] * alpha + cos(val) * beta; diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index c37032bf090..1f4a95c5e7e 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -15,15 +15,22 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/profiler.h" DEFINE_bool(cudnn_deterministic, false, "Whether allow using an autotuning algorithm for convolution " "operator. The autotuning algorithm may be non-deterministic. If " "true, the algorithm is deterministic."); +DEFINE_uint64(conv_workspace_size_limit, 4096, + "cuDNN convolution workspace limit in MB unit."); +DEFINE_bool(cudnn_exhaustive_search, false, + "Whether enable exhaustive search for cuDNN convolution or " + "not, defalut is False."); namespace paddle { namespace operators { @@ -36,13 +43,25 @@ using DataLayout = platform::DataLayout; template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; +static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache"; +static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache"; +static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache"; + static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; +static constexpr size_t kNUM_CUDNN_FWD_ALGS = + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; +static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; +static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = + CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + template class CUDNNConvOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); auto* input = ctx.Input("Input"); @@ -55,6 +74,8 @@ class CUDNNConvOpKernel : public framework::OpKernel { int groups = ctx.Attr("groups"); int64_t user_workspace_size = static_cast(ctx.Attr("workspace_size_MB")); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); const T* input_data = input->data(); const T* filter_data = filter->data(); @@ -120,19 +141,18 @@ class CUDNNConvOpKernel : public framework::OpKernel { // ------------------- cudnn conv workspace --------------------- size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; - if (user_workspace_size > 0) { - workspace_size_limit = user_workspace_size * 1024 * 1024; + if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { + int64_t max_user_size = + std::max(static_cast(FLAGS_conv_workspace_size_limit), + user_workspace_size); + workspace_size_limit = max_user_size * 1024 * 1024; } + // ------------------- cudnn conv algorithm --------------------- cudnnConvolutionFwdAlgo_t algo; - auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); - + bool half_float = false; #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) // Tensor core is supported since the volta GPU and // is only enabled when input and filter data are float16 @@ -143,12 +163,65 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); // Currently tensor core is only enabled using this algo algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + half_float = true; } else { CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( cudnn_conv_desc, CUDNN_DEFAULT_MATH)); } #endif + auto x_dims = framework::vectorize(input->dims()); + auto f_dims = framework::vectorize(filter->dims()); + if ((!exhaustive_search) && (!half_float)) { + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + VLOG(3) << "cuDNN forward algo " << algo; + } else if (exhaustive_search && (!half_float)) { + AlgorithmsCache* algo_cache = nullptr; + if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { + algo_cache = + ctx.scope() + .FindVar(kCUDNNFwdAlgoCache) + ->GetMutable>(); + } else { + algo_cache = + const_cast(ctx.scope()) + .Var(kCUDNNFwdAlgoCache) + ->GetMutable>(); + } + algo = algo_cache->GetAlgorithm( + x_dims, f_dims, strides, paddings, dilations, 0, [&]() { + int returned_algo_count; + std::array + fwd_perf_stat; + auto cudnn_find_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( + handle, cudnn_input_desc, input_data, cudnn_filter_desc, + filter_data, cudnn_conv_desc, cudnn_output_desc, + output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, + fwd_perf_stat.data(), cudnn_workspace, + workspace_size_limit)); + }; + dev_ctx.RunCudnnFuncWithWorkspace(cudnn_find_func, + workspace_size_limit); + + VLOG(3) << "Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = fwd_perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time + << " " << stat.memory; + } + return fwd_perf_stat[0].algo; + }); + VLOG(3) << "choose algo " << algo; + } else { + PADDLE_ENFORCE(half_float, + "cuDNN exhaustive search doesn't support half float."); + } + // get workspace size able to allocate CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, @@ -178,6 +251,7 @@ template class CUDNNConvGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); auto input = ctx.Input("Input"); @@ -196,6 +270,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { int groups = ctx.Attr("groups"); int64_t user_workspace_size = static_cast(ctx.Attr("workspace_size_MB")); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); + if (exhaustive_search && FLAGS_cudnn_deterministic) { + PADDLE_THROW( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time."); + } // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor input_desc; @@ -263,14 +344,65 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnnConvolutionBwdFilterAlgo_t filter_algo; size_t workspace_size_in_bytes = 0, tmp_size = 0; size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; - if (user_workspace_size > 0) { - workspace_size_limit = user_workspace_size * 1024 * 1024; + if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { + int64_t max_user_size = + std::max(static_cast(FLAGS_conv_workspace_size_limit), + user_workspace_size); + workspace_size_limit = max_user_size * 1024 * 1024; } - auto& dev_ctx = ctx.template device_context(); + auto x_dims = framework::vectorize(input->dims()); + auto f_dims = framework::vectorize(filter->dims()); auto handle = dev_ctx.cudnn_handle(); if (input_grad) { - if (!FLAGS_cudnn_deterministic) { + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + if (exhaustive_search) { + AlgorithmsCache* data_algo_cache; + if (ctx.scope().FindVar(kCUDNNBwdDataAlgoCache)) { + data_algo_cache = + ctx.scope() + .FindVar(kCUDNNBwdDataAlgoCache) + ->GetMutable< + AlgorithmsCache>(); + } else { + data_algo_cache = + const_cast(ctx.scope()) + .Var(kCUDNNBwdDataAlgoCache) + ->GetMutable< + AlgorithmsCache>(); + } + data_algo = data_algo_cache->GetAlgorithm( + x_dims, f_dims, strides, paddings, dilations, 0, [&]() { + int returned_algo_count; + std::array + data_perf_stat; + auto cudnn_find_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardDataAlgorithmEx( + handle, cudnn_filter_desc, filter_data, + cudnn_output_grad_desc, output_grad_data, + cudnn_conv_desc, cudnn_input_desc, input_grad_data, + kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, + data_perf_stat.data(), cudnn_workspace, + workspace_size_limit)); + }; + dev_ctx.RunCudnnFuncWithWorkspace(cudnn_find_func, + workspace_size_limit); + + VLOG(3) << "Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = data_perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time + << " " << stat.memory; + } + return data_perf_stat[0].algo; + }); + VLOG(3) << "cuDNN backward data algo " << data_algo; + } else if (FLAGS_cudnn_deterministic) { + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } else { CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( handle, cudnn_filter_desc, @@ -283,10 +415,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnn_input_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &data_algo)); - } else { - data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } - CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( handle, cudnn_filter_desc, cudnn_output_grad_desc, @@ -295,17 +424,54 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } if (filter_grad) { - if (!FLAGS_cudnn_deterministic) { + T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + if (exhaustive_search) { + AlgorithmsCache* f_algo_cache; + if (ctx.scope().FindVar(kCUDNNBwdFilterAlgoCache)) { + f_algo_cache = + ctx.scope() + .FindVar(kCUDNNBwdFilterAlgoCache) + ->GetMutable< + AlgorithmsCache>(); + } else { + f_algo_cache = + const_cast(ctx.scope()) + .Var(kCUDNNBwdFilterAlgoCache) + ->GetMutable< + AlgorithmsCache>(); + } + filter_algo = f_algo_cache->GetAlgorithm( + x_dims, f_dims, strides, paddings, dilations, 0, [&]() { + int returned_algo_count; + std::array + filter_perf_stat; + auto cudnn_find_f_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardFilterAlgorithmEx( + handle, cudnn_input_desc, input_data, + cudnn_output_grad_desc, output_grad_data, + cudnn_conv_desc, cudnn_filter_desc, + filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS, + &returned_algo_count, filter_perf_stat.data(), + cudnn_workspace, workspace_size_limit)); + }; + dev_ctx.RunCudnnFuncWithWorkspace(cudnn_find_f_func, + workspace_size_limit); + return filter_perf_stat[0].algo; + }); + VLOG(3) << "cuDNN backward filter algo " << filter_algo; + } else if (FLAGS_cudnn_deterministic) { + filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } else { CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, cudnn_filter_desc, CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &filter_algo)); - } else { - filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } - CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, diff --git a/paddle/fluid/operators/conv_cudnn_op_cache.h b/paddle/fluid/operators/conv_cudnn_op_cache.h new file mode 100644 index 00000000000..4b534321f74 --- /dev/null +++ b/paddle/fluid/operators/conv_cudnn_op_cache.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace operators { + +template +class AlgorithmsCache { + public: + // Caches the best algorithm for a given + // combination of tensor dimensions & compute data type. + TAlgorithm GetAlgorithm( + const std::vector& dims1, const std::vector& dims2, + const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, + int algorithmFlags, // can set for different data type + std::function gen_func); + + private: + std::unordered_map hash_; + std::mutex mutex_; +}; + +template +TAlgorithm AlgorithmsCache::GetAlgorithm( + const std::vector& dims1, const std::vector& dims2, + const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, int algorithmFlags, + std::function gen_func) { + std::lock_guard lock(mutex_); + int64_t seed = 0; + // Hash all of the inputs, use to try and look up a previously + // discovered algorithm, or fall back to generating a new one. + std::hash hashFn; + // do hash like boost + // https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x + for (const auto num : dims1) { + seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + + for (const auto num : dims2) { + seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1; + } + + for (const auto num : strides) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 2; + } + + for (const auto num : paddings) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 3; + } + + for (const auto num : dilations) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 4; + } + + seed ^= hashFn(static_cast(algorithmFlags)) + 0x9e3779b9 + + (seed << 6) + (seed >> 2) + 5; + + if (seed == 0) return gen_func(); + + if (hash_.find(seed) == hash_.end()) { + TAlgorithm value = gen_func(); + hash_[seed] = value; + } + return hash_[seed]; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 2cd9979bd34..7401f100d72 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -189,6 +189,11 @@ void Conv2DOpMaker::Make() { "workspace size can increase performance but also requires " "better hardware. This size should be chosen carefully.") .SetDefault(4096); + AddAttr("exhaustive_search", + "(bool, default false) cuDNN has many algorithm to calculation " + "convolution, whether enable exhaustive search ", + "for cuDNN convolution or not, defalut is False.") + .SetDefault(false); AddComment(R"DOC( Convolution Operator. @@ -283,7 +288,11 @@ void Conv3DOpMaker::Make() { "workspace size can increase performance but also requires " "better hardware. This size should be chosen carefully.") .SetDefault(4096); - + AddAttr("exhaustive_search", + "(bool, default false) cuDNN has many algorithm to calculation " + "convolution, whether enable exhaustive search ", + "for cuDNN convolution or not, defalut is False.") + .SetDefault(false); AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index ff49a1d57fd..d62ef933833 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -204,7 +204,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) << "." << (driver_version_ % 100) / 10 << ", Runtime Version: " << runtime_version_ / 1000 << "." << (runtime_version_ % 100) / 10; - + size_t cudnn_dso_ver = dynload::cudnnGetVersion(); + LOG(INFO) << "device: " << place_.device + << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "." + << (cudnn_dso_ver % 100) / 10 << "."; callback_manager_.reset(new StreamCallbackManager(stream_)); } diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index d3d754b6f58..c26143d2f27 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -65,51 +65,54 @@ extern void EnforceCUDNNLoaded(const char* fn_name); * include all needed cudnn functions in HPPL * different cudnn version has different interfaces **/ -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor); \ - __macro(cudnnSetTensor4dDescriptorEx); \ - __macro(cudnnSetTensorNdDescriptor); \ - __macro(cudnnGetTensorNdDescriptor); \ - __macro(cudnnGetConvolutionNdForwardOutputDim); \ - __macro(cudnnGetConvolutionForwardAlgorithm); \ - __macro(cudnnCreateTensorDescriptor); \ - __macro(cudnnDestroyTensorDescriptor); \ - __macro(cudnnCreateFilterDescriptor); \ - __macro(cudnnSetFilter4dDescriptor); \ - __macro(cudnnSetFilterNdDescriptor); \ - __macro(cudnnGetFilterNdDescriptor); \ - __macro(cudnnSetPooling2dDescriptor); \ - __macro(cudnnSetPoolingNdDescriptor); \ - __macro(cudnnGetPoolingNdDescriptor); \ - __macro(cudnnDestroyFilterDescriptor); \ - __macro(cudnnCreateConvolutionDescriptor); \ - __macro(cudnnCreatePoolingDescriptor); \ - __macro(cudnnDestroyPoolingDescriptor); \ - __macro(cudnnSetConvolution2dDescriptor); \ - __macro(cudnnDestroyConvolutionDescriptor); \ - __macro(cudnnSetConvolutionNdDescriptor); \ - __macro(cudnnGetConvolutionNdDescriptor); \ - __macro(cudnnDeriveBNTensorDescriptor); \ - __macro(cudnnCreateSpatialTransformerDescriptor); \ - __macro(cudnnSetSpatialTransformerNdDescriptor); \ - __macro(cudnnDestroySpatialTransformerDescriptor); \ - __macro(cudnnSpatialTfGridGeneratorForward); \ - __macro(cudnnSpatialTfGridGeneratorBackward); \ - __macro(cudnnSpatialTfSamplerForward); \ - __macro(cudnnSpatialTfSamplerBackward); \ - __macro(cudnnCreate); \ - __macro(cudnnDestroy); \ - __macro(cudnnSetStream); \ - __macro(cudnnActivationForward); \ - __macro(cudnnConvolutionForward); \ - __macro(cudnnConvolutionBackwardBias); \ - __macro(cudnnGetConvolutionForwardWorkspaceSize); \ - __macro(cudnnTransformTensor); \ - __macro(cudnnPoolingForward); \ - __macro(cudnnPoolingBackward); \ - __macro(cudnnSoftmaxBackward); \ - __macro(cudnnSoftmaxForward); \ - __macro(cudnnGetVersion); \ +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor); \ + __macro(cudnnSetTensor4dDescriptorEx); \ + __macro(cudnnSetTensorNdDescriptor); \ + __macro(cudnnGetTensorNdDescriptor); \ + __macro(cudnnGetConvolutionNdForwardOutputDim); \ + __macro(cudnnGetConvolutionForwardAlgorithm); \ + __macro(cudnnCreateTensorDescriptor); \ + __macro(cudnnDestroyTensorDescriptor); \ + __macro(cudnnCreateFilterDescriptor); \ + __macro(cudnnSetFilter4dDescriptor); \ + __macro(cudnnSetFilterNdDescriptor); \ + __macro(cudnnGetFilterNdDescriptor); \ + __macro(cudnnSetPooling2dDescriptor); \ + __macro(cudnnSetPoolingNdDescriptor); \ + __macro(cudnnGetPoolingNdDescriptor); \ + __macro(cudnnDestroyFilterDescriptor); \ + __macro(cudnnCreateConvolutionDescriptor); \ + __macro(cudnnCreatePoolingDescriptor); \ + __macro(cudnnDestroyPoolingDescriptor); \ + __macro(cudnnSetConvolution2dDescriptor); \ + __macro(cudnnDestroyConvolutionDescriptor); \ + __macro(cudnnSetConvolutionNdDescriptor); \ + __macro(cudnnGetConvolutionNdDescriptor); \ + __macro(cudnnDeriveBNTensorDescriptor); \ + __macro(cudnnCreateSpatialTransformerDescriptor); \ + __macro(cudnnSetSpatialTransformerNdDescriptor); \ + __macro(cudnnDestroySpatialTransformerDescriptor); \ + __macro(cudnnSpatialTfGridGeneratorForward); \ + __macro(cudnnSpatialTfGridGeneratorBackward); \ + __macro(cudnnSpatialTfSamplerForward); \ + __macro(cudnnSpatialTfSamplerBackward); \ + __macro(cudnnCreate); \ + __macro(cudnnDestroy); \ + __macro(cudnnSetStream); \ + __macro(cudnnActivationForward); \ + __macro(cudnnConvolutionForward); \ + __macro(cudnnConvolutionBackwardBias); \ + __macro(cudnnGetConvolutionForwardWorkspaceSize); \ + __macro(cudnnTransformTensor); \ + __macro(cudnnPoolingForward); \ + __macro(cudnnPoolingBackward); \ + __macro(cudnnSoftmaxBackward); \ + __macro(cudnnSoftmaxForward); \ + __macro(cudnnGetVersion); \ + __macro(cudnnFindConvolutionForwardAlgorithmEx); \ + __macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \ + __macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \ __macro(cudnnGetErrorString); CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 737c8be8147..2670fe4b1ba 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -127,7 +127,8 @@ def __bootstrap__(): if core.is_compiled_with_cuda(): read_env_flags += [ - 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic' + 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', + 'conv_workspace_size_limit', 'cudnn_exhaustive_search' ] core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a87f1231174..13a724ac2d9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,6 +27,7 @@ from .tensor import concat from . import utils from .. import unique_name from functools import reduce +from .. import core __all__ = [ 'fc', @@ -1664,6 +1665,20 @@ def conv2d(input, pre_bias = helper.create_variable_for_type_inference(dtype) + if use_cudnn: + helper.create_variable( + name="kCUDNNFwdAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + helper.create_variable( + name="kCUDNNBwdDataAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + helper.create_variable( + name="kCUDNNBwdFilterAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + helper.append_op( type=l_type, inputs={ @@ -1677,7 +1692,7 @@ def conv2d(input, 'dilations': dilation, 'groups': groups, 'use_cudnn': use_cudnn, - 'use_mkldnn': False + 'use_mkldnn': False, }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 2ecc2504a8c..a8f80944265 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -67,6 +67,7 @@ class TestConv2dOp(OpTest): def setUp(self): self.op_type = "conv2d" self.use_cudnn = False + self.exhaustive_search = False self.use_cuda = False self.use_mkldnn = False self.data_format = "AnyLayout" @@ -98,7 +99,8 @@ class TestConv2dOp(OpTest): 'dilations': self.dilations, 'use_cudnn': self.use_cudnn, 'use_mkldnn': self.use_mkldnn, - 'data_format': self.data_format + 'data_format': self.data_format, + 'exhaustive_search': self.exhaustive_search } self.outputs = {'Output': output} @@ -392,6 +394,12 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp): self.op_type = "depthwise_conv2d" +class TestCUDNNExhaustiveSearch(TestCUDNN): + def init_kernel_type(self): + self.use_cudnn = True + self.exhaustive_search = True + + # Please Don't remove the following code. # Currently, CI use cudnn V5.0 which not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation): diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py index ddaf99fe061..69c5ab7a4a4 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py @@ -335,6 +335,12 @@ class TestFP16WithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1): self.check_output_with_place(place, atol=2e-2) +class TestCUDNNExhaustiveSearch(TestCUDNN): + def init_kernel_type(self): + self.use_cudnn = True + self.exhaustive_search = True + + # FIXME(typhoonzero): find a way to determine if # using cudnn > 6 in python # class TestWithDilationCUDNN(TestWithDilation): -- GitLab