From db8c52da5e60117f86cb7581f62d22c98cbfb1eb Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 7 Nov 2018 10:25:05 +0800 Subject: [PATCH] Revert " Exhaustive search for cuDNN conv. (#14043)" This reverts commit ce7d9b079947e55f23d7653432732e498f723274. --- .../framework/ir/graph_pattern_detector.cc | 1 - .../fluid/inference/api/analysis_predictor.h | 2 - paddle/fluid/inference/api/helper.h | 3 +- paddle/fluid/inference/io.cc | 3 +- .../operators/add_position_encoding_op.h | 7 +- paddle/fluid/operators/conv_cudnn_op.cu.cc | 204 ++---------------- paddle/fluid/operators/conv_cudnn_op_cache.h | 90 -------- paddle/fluid/operators/conv_op.cc | 11 +- paddle/fluid/platform/device_context.cc | 5 +- paddle/fluid/platform/dynload/cudnn.h | 93 ++++---- python/paddle/fluid/__init__.py | 3 +- python/paddle/fluid/layers/nn.py | 17 +- .../fluid/tests/unittests/test_conv2d_op.py | 10 +- .../fluid/tests/unittests/test_conv3d_op.py | 6 - 14 files changed, 74 insertions(+), 381 deletions(-) delete mode 100644 paddle/fluid/operators/conv_cudnn_op_cache.h diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index fa713fe1dd..b20d701322 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include #include diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index a9f4cce6df..b7dc206733 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -13,8 +13,6 @@ // limitations under the License. #pragma once -#include -#include #include #include #include "paddle/fluid/framework/naive_executor.h" diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index af21c0095c..e46dc13269 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -16,14 +16,13 @@ #include #include -#include #include // NOLINT #include #include #include #include -#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/string/printf.h" +#include "paddle_inference_api.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 31f43bfdca..e246a06fd0 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -59,8 +59,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) { bool IsPersistable(const framework::VarDesc* var) { if (var->Persistable() && var->GetType() != framework::proto::VarType::FEED_MINIBATCH && - var->GetType() != framework::proto::VarType::FETCH_LIST && - var->GetType() != framework::proto::VarType::RAW) { + var->GetType() != framework::proto::VarType::FETCH_LIST) { return true; } return false; diff --git a/paddle/fluid/operators/add_position_encoding_op.h b/paddle/fluid/operators/add_position_encoding_op.h index 0b40d3de89..5f371235f1 100644 --- a/paddle/fluid/operators/add_position_encoding_op.h +++ b/paddle/fluid/operators/add_position_encoding_op.h @@ -66,10 +66,9 @@ class AddPositionEncodingKernel : public framework::OpKernel { x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i]; for (int j = 0; j < max_length; ++j) { for (int k = 0; k < half_size; ++k) { - const double val = - (half_size > 1) - ? j / pow(10000.0, static_cast(k) / (half_size - 1)) - : j / 10000.0; + const double val = (half_size > 1) + ? j / pow(10000.0, double(k) / (half_size - 1)) + : j / 10000.0; dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta; dst_ptr[half_size + k] = src_ptr[half_size + k] * alpha + cos(val) * beta; diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 1f4a95c5e7..c37032bf09 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -15,22 +15,15 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/profiler.h" DEFINE_bool(cudnn_deterministic, false, "Whether allow using an autotuning algorithm for convolution " "operator. The autotuning algorithm may be non-deterministic. If " "true, the algorithm is deterministic."); -DEFINE_uint64(conv_workspace_size_limit, 4096, - "cuDNN convolution workspace limit in MB unit."); -DEFINE_bool(cudnn_exhaustive_search, false, - "Whether enable exhaustive search for cuDNN convolution or " - "not, defalut is False."); namespace paddle { namespace operators { @@ -43,25 +36,13 @@ using DataLayout = platform::DataLayout; template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; -static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache"; -static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache"; -static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache"; - static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; -static constexpr size_t kNUM_CUDNN_FWD_ALGS = - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; -static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; -static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = - CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; - template class CUDNNConvOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); auto* input = ctx.Input("Input"); @@ -74,8 +55,6 @@ class CUDNNConvOpKernel : public framework::OpKernel { int groups = ctx.Attr("groups"); int64_t user_workspace_size = static_cast(ctx.Attr("workspace_size_MB")); - bool exhaustive_search = - FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); const T* input_data = input->data(); const T* filter_data = filter->data(); @@ -141,18 +120,19 @@ class CUDNNConvOpKernel : public framework::OpKernel { // ------------------- cudnn conv workspace --------------------- size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; - if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { - int64_t max_user_size = - std::max(static_cast(FLAGS_conv_workspace_size_limit), - user_workspace_size); - workspace_size_limit = max_user_size * 1024 * 1024; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; } - // ------------------- cudnn conv algorithm --------------------- cudnnConvolutionFwdAlgo_t algo; + auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - bool half_float = false; + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) // Tensor core is supported since the volta GPU and // is only enabled when input and filter data are float16 @@ -163,65 +143,12 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); // Currently tensor core is only enabled using this algo algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - half_float = true; } else { CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( cudnn_conv_desc, CUDNN_DEFAULT_MATH)); } #endif - auto x_dims = framework::vectorize(input->dims()); - auto f_dims = framework::vectorize(filter->dims()); - if ((!exhaustive_search) && (!half_float)) { - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); - VLOG(3) << "cuDNN forward algo " << algo; - } else if (exhaustive_search && (!half_float)) { - AlgorithmsCache* algo_cache = nullptr; - if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { - algo_cache = - ctx.scope() - .FindVar(kCUDNNFwdAlgoCache) - ->GetMutable>(); - } else { - algo_cache = - const_cast(ctx.scope()) - .Var(kCUDNNFwdAlgoCache) - ->GetMutable>(); - } - algo = algo_cache->GetAlgorithm( - x_dims, f_dims, strides, paddings, dilations, 0, [&]() { - int returned_algo_count; - std::array - fwd_perf_stat; - auto cudnn_find_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE( - platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( - handle, cudnn_input_desc, input_data, cudnn_filter_desc, - filter_data, cudnn_conv_desc, cudnn_output_desc, - output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, - fwd_perf_stat.data(), cudnn_workspace, - workspace_size_limit)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_find_func, - workspace_size_limit); - - VLOG(3) << "Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = fwd_perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time - << " " << stat.memory; - } - return fwd_perf_stat[0].algo; - }); - VLOG(3) << "choose algo " << algo; - } else { - PADDLE_ENFORCE(half_float, - "cuDNN exhaustive search doesn't support half float."); - } - // get workspace size able to allocate CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, @@ -251,7 +178,6 @@ template class CUDNNConvGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); auto input = ctx.Input("Input"); @@ -270,13 +196,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { int groups = ctx.Attr("groups"); int64_t user_workspace_size = static_cast(ctx.Attr("workspace_size_MB")); - bool exhaustive_search = - FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); - if (exhaustive_search && FLAGS_cudnn_deterministic) { - PADDLE_THROW( - "Cann't set exhaustive_search True and " - "FLAGS_cudnn_deterministic True at same time."); - } // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor input_desc; @@ -344,65 +263,14 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnnConvolutionBwdFilterAlgo_t filter_algo; size_t workspace_size_in_bytes = 0, tmp_size = 0; size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; - if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { - int64_t max_user_size = - std::max(static_cast(FLAGS_conv_workspace_size_limit), - user_workspace_size); - workspace_size_limit = max_user_size * 1024 * 1024; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; } - auto x_dims = framework::vectorize(input->dims()); - auto f_dims = framework::vectorize(filter->dims()); + auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); if (input_grad) { - T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - if (exhaustive_search) { - AlgorithmsCache* data_algo_cache; - if (ctx.scope().FindVar(kCUDNNBwdDataAlgoCache)) { - data_algo_cache = - ctx.scope() - .FindVar(kCUDNNBwdDataAlgoCache) - ->GetMutable< - AlgorithmsCache>(); - } else { - data_algo_cache = - const_cast(ctx.scope()) - .Var(kCUDNNBwdDataAlgoCache) - ->GetMutable< - AlgorithmsCache>(); - } - data_algo = data_algo_cache->GetAlgorithm( - x_dims, f_dims, strides, paddings, dilations, 0, [&]() { - int returned_algo_count; - std::array - data_perf_stat; - auto cudnn_find_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE( - platform::dynload:: - cudnnFindConvolutionBackwardDataAlgorithmEx( - handle, cudnn_filter_desc, filter_data, - cudnn_output_grad_desc, output_grad_data, - cudnn_conv_desc, cudnn_input_desc, input_grad_data, - kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, - data_perf_stat.data(), cudnn_workspace, - workspace_size_limit)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_find_func, - workspace_size_limit); - - VLOG(3) << "Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = data_perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time - << " " << stat.memory; - } - return data_perf_stat[0].algo; - }); - VLOG(3) << "cuDNN backward data algo " << data_algo; - } else if (FLAGS_cudnn_deterministic) { - data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - } else { + if (!FLAGS_cudnn_deterministic) { CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( handle, cudnn_filter_desc, @@ -415,7 +283,10 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnn_input_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &data_algo)); + } else { + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } + CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( handle, cudnn_filter_desc, cudnn_output_grad_desc, @@ -424,54 +295,17 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } if (filter_grad) { - T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - if (exhaustive_search) { - AlgorithmsCache* f_algo_cache; - if (ctx.scope().FindVar(kCUDNNBwdFilterAlgoCache)) { - f_algo_cache = - ctx.scope() - .FindVar(kCUDNNBwdFilterAlgoCache) - ->GetMutable< - AlgorithmsCache>(); - } else { - f_algo_cache = - const_cast(ctx.scope()) - .Var(kCUDNNBwdFilterAlgoCache) - ->GetMutable< - AlgorithmsCache>(); - } - filter_algo = f_algo_cache->GetAlgorithm( - x_dims, f_dims, strides, paddings, dilations, 0, [&]() { - int returned_algo_count; - std::array - filter_perf_stat; - auto cudnn_find_f_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE( - platform::dynload:: - cudnnFindConvolutionBackwardFilterAlgorithmEx( - handle, cudnn_input_desc, input_data, - cudnn_output_grad_desc, output_grad_data, - cudnn_conv_desc, cudnn_filter_desc, - filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS, - &returned_algo_count, filter_perf_stat.data(), - cudnn_workspace, workspace_size_limit)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_find_f_func, - workspace_size_limit); - return filter_perf_stat[0].algo; - }); - VLOG(3) << "cuDNN backward filter algo " << filter_algo; - } else if (FLAGS_cudnn_deterministic) { - filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - } else { + if (!FLAGS_cudnn_deterministic) { CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, cudnn_filter_desc, CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &filter_algo)); + } else { + filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } + CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, diff --git a/paddle/fluid/operators/conv_cudnn_op_cache.h b/paddle/fluid/operators/conv_cudnn_op_cache.h deleted file mode 100644 index 4b534321f7..0000000000 --- a/paddle/fluid/operators/conv_cudnn_op_cache.h +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -namespace paddle { -namespace operators { - -template -class AlgorithmsCache { - public: - // Caches the best algorithm for a given - // combination of tensor dimensions & compute data type. - TAlgorithm GetAlgorithm( - const std::vector& dims1, const std::vector& dims2, - const std::vector& strides, const std::vector& paddings, - const std::vector& dilations, - int algorithmFlags, // can set for different data type - std::function gen_func); - - private: - std::unordered_map hash_; - std::mutex mutex_; -}; - -template -TAlgorithm AlgorithmsCache::GetAlgorithm( - const std::vector& dims1, const std::vector& dims2, - const std::vector& strides, const std::vector& paddings, - const std::vector& dilations, int algorithmFlags, - std::function gen_func) { - std::lock_guard lock(mutex_); - int64_t seed = 0; - // Hash all of the inputs, use to try and look up a previously - // discovered algorithm, or fall back to generating a new one. - std::hash hashFn; - // do hash like boost - // https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x - for (const auto num : dims1) { - seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - - for (const auto num : dims2) { - seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1; - } - - for (const auto num : strides) { - seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + - (seed >> 2) + 2; - } - - for (const auto num : paddings) { - seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + - (seed >> 2) + 3; - } - - for (const auto num : dilations) { - seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + - (seed >> 2) + 4; - } - - seed ^= hashFn(static_cast(algorithmFlags)) + 0x9e3779b9 + - (seed << 6) + (seed >> 2) + 5; - - if (seed == 0) return gen_func(); - - if (hash_.find(seed) == hash_.end()) { - TAlgorithm value = gen_func(); - hash_[seed] = value; - } - return hash_[seed]; -} - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 7401f100d7..2cd9979bd3 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -189,11 +189,6 @@ void Conv2DOpMaker::Make() { "workspace size can increase performance but also requires " "better hardware. This size should be chosen carefully.") .SetDefault(4096); - AddAttr("exhaustive_search", - "(bool, default false) cuDNN has many algorithm to calculation " - "convolution, whether enable exhaustive search ", - "for cuDNN convolution or not, defalut is False.") - .SetDefault(false); AddComment(R"DOC( Convolution Operator. @@ -288,11 +283,7 @@ void Conv3DOpMaker::Make() { "workspace size can increase performance but also requires " "better hardware. This size should be chosen carefully.") .SetDefault(4096); - AddAttr("exhaustive_search", - "(bool, default false) cuDNN has many algorithm to calculation " - "convolution, whether enable exhaustive search ", - "for cuDNN convolution or not, defalut is False.") - .SetDefault(false); + AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index d62ef93383..ff49a1d57f 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -204,10 +204,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) << "." << (driver_version_ % 100) / 10 << ", Runtime Version: " << runtime_version_ / 1000 << "." << (runtime_version_ % 100) / 10; - size_t cudnn_dso_ver = dynload::cudnnGetVersion(); - LOG(INFO) << "device: " << place_.device - << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "." - << (cudnn_dso_ver % 100) / 10 << "."; + callback_manager_.reset(new StreamCallbackManager(stream_)); } diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index c26143d2f2..d3d754b6f5 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -65,54 +65,51 @@ extern void EnforceCUDNNLoaded(const char* fn_name); * include all needed cudnn functions in HPPL * different cudnn version has different interfaces **/ -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor); \ - __macro(cudnnSetTensor4dDescriptorEx); \ - __macro(cudnnSetTensorNdDescriptor); \ - __macro(cudnnGetTensorNdDescriptor); \ - __macro(cudnnGetConvolutionNdForwardOutputDim); \ - __macro(cudnnGetConvolutionForwardAlgorithm); \ - __macro(cudnnCreateTensorDescriptor); \ - __macro(cudnnDestroyTensorDescriptor); \ - __macro(cudnnCreateFilterDescriptor); \ - __macro(cudnnSetFilter4dDescriptor); \ - __macro(cudnnSetFilterNdDescriptor); \ - __macro(cudnnGetFilterNdDescriptor); \ - __macro(cudnnSetPooling2dDescriptor); \ - __macro(cudnnSetPoolingNdDescriptor); \ - __macro(cudnnGetPoolingNdDescriptor); \ - __macro(cudnnDestroyFilterDescriptor); \ - __macro(cudnnCreateConvolutionDescriptor); \ - __macro(cudnnCreatePoolingDescriptor); \ - __macro(cudnnDestroyPoolingDescriptor); \ - __macro(cudnnSetConvolution2dDescriptor); \ - __macro(cudnnDestroyConvolutionDescriptor); \ - __macro(cudnnSetConvolutionNdDescriptor); \ - __macro(cudnnGetConvolutionNdDescriptor); \ - __macro(cudnnDeriveBNTensorDescriptor); \ - __macro(cudnnCreateSpatialTransformerDescriptor); \ - __macro(cudnnSetSpatialTransformerNdDescriptor); \ - __macro(cudnnDestroySpatialTransformerDescriptor); \ - __macro(cudnnSpatialTfGridGeneratorForward); \ - __macro(cudnnSpatialTfGridGeneratorBackward); \ - __macro(cudnnSpatialTfSamplerForward); \ - __macro(cudnnSpatialTfSamplerBackward); \ - __macro(cudnnCreate); \ - __macro(cudnnDestroy); \ - __macro(cudnnSetStream); \ - __macro(cudnnActivationForward); \ - __macro(cudnnConvolutionForward); \ - __macro(cudnnConvolutionBackwardBias); \ - __macro(cudnnGetConvolutionForwardWorkspaceSize); \ - __macro(cudnnTransformTensor); \ - __macro(cudnnPoolingForward); \ - __macro(cudnnPoolingBackward); \ - __macro(cudnnSoftmaxBackward); \ - __macro(cudnnSoftmaxForward); \ - __macro(cudnnGetVersion); \ - __macro(cudnnFindConvolutionForwardAlgorithmEx); \ - __macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \ - __macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \ +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor); \ + __macro(cudnnSetTensor4dDescriptorEx); \ + __macro(cudnnSetTensorNdDescriptor); \ + __macro(cudnnGetTensorNdDescriptor); \ + __macro(cudnnGetConvolutionNdForwardOutputDim); \ + __macro(cudnnGetConvolutionForwardAlgorithm); \ + __macro(cudnnCreateTensorDescriptor); \ + __macro(cudnnDestroyTensorDescriptor); \ + __macro(cudnnCreateFilterDescriptor); \ + __macro(cudnnSetFilter4dDescriptor); \ + __macro(cudnnSetFilterNdDescriptor); \ + __macro(cudnnGetFilterNdDescriptor); \ + __macro(cudnnSetPooling2dDescriptor); \ + __macro(cudnnSetPoolingNdDescriptor); \ + __macro(cudnnGetPoolingNdDescriptor); \ + __macro(cudnnDestroyFilterDescriptor); \ + __macro(cudnnCreateConvolutionDescriptor); \ + __macro(cudnnCreatePoolingDescriptor); \ + __macro(cudnnDestroyPoolingDescriptor); \ + __macro(cudnnSetConvolution2dDescriptor); \ + __macro(cudnnDestroyConvolutionDescriptor); \ + __macro(cudnnSetConvolutionNdDescriptor); \ + __macro(cudnnGetConvolutionNdDescriptor); \ + __macro(cudnnDeriveBNTensorDescriptor); \ + __macro(cudnnCreateSpatialTransformerDescriptor); \ + __macro(cudnnSetSpatialTransformerNdDescriptor); \ + __macro(cudnnDestroySpatialTransformerDescriptor); \ + __macro(cudnnSpatialTfGridGeneratorForward); \ + __macro(cudnnSpatialTfGridGeneratorBackward); \ + __macro(cudnnSpatialTfSamplerForward); \ + __macro(cudnnSpatialTfSamplerBackward); \ + __macro(cudnnCreate); \ + __macro(cudnnDestroy); \ + __macro(cudnnSetStream); \ + __macro(cudnnActivationForward); \ + __macro(cudnnConvolutionForward); \ + __macro(cudnnConvolutionBackwardBias); \ + __macro(cudnnGetConvolutionForwardWorkspaceSize); \ + __macro(cudnnTransformTensor); \ + __macro(cudnnPoolingForward); \ + __macro(cudnnPoolingBackward); \ + __macro(cudnnSoftmaxBackward); \ + __macro(cudnnSoftmaxForward); \ + __macro(cudnnGetVersion); \ __macro(cudnnGetErrorString); CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 2670fe4b1b..737c8be814 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -127,8 +127,7 @@ def __bootstrap__(): if core.is_compiled_with_cuda(): read_env_flags += [ - 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', - 'conv_workspace_size_limit', 'cudnn_exhaustive_search' + 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic' ] core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 13a724ac2d..a87f123117 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,7 +27,6 @@ from .tensor import concat from . import utils from .. import unique_name from functools import reduce -from .. import core __all__ = [ 'fc', @@ -1665,20 +1664,6 @@ def conv2d(input, pre_bias = helper.create_variable_for_type_inference(dtype) - if use_cudnn: - helper.create_variable( - name="kCUDNNFwdAlgoCache", - persistable=True, - type=core.VarDesc.VarType.RAW) - helper.create_variable( - name="kCUDNNBwdDataAlgoCache", - persistable=True, - type=core.VarDesc.VarType.RAW) - helper.create_variable( - name="kCUDNNBwdFilterAlgoCache", - persistable=True, - type=core.VarDesc.VarType.RAW) - helper.append_op( type=l_type, inputs={ @@ -1692,7 +1677,7 @@ def conv2d(input, 'dilations': dilation, 'groups': groups, 'use_cudnn': use_cudnn, - 'use_mkldnn': False, + 'use_mkldnn': False }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index a8f8094426..2ecc2504a8 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -67,7 +67,6 @@ class TestConv2dOp(OpTest): def setUp(self): self.op_type = "conv2d" self.use_cudnn = False - self.exhaustive_search = False self.use_cuda = False self.use_mkldnn = False self.data_format = "AnyLayout" @@ -99,8 +98,7 @@ class TestConv2dOp(OpTest): 'dilations': self.dilations, 'use_cudnn': self.use_cudnn, 'use_mkldnn': self.use_mkldnn, - 'data_format': self.data_format, - 'exhaustive_search': self.exhaustive_search + 'data_format': self.data_format } self.outputs = {'Output': output} @@ -394,12 +392,6 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp): self.op_type = "depthwise_conv2d" -class TestCUDNNExhaustiveSearch(TestCUDNN): - def init_kernel_type(self): - self.use_cudnn = True - self.exhaustive_search = True - - # Please Don't remove the following code. # Currently, CI use cudnn V5.0 which not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation): diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py index 69c5ab7a4a..ddaf99fe06 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py @@ -335,12 +335,6 @@ class TestFP16WithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1): self.check_output_with_place(place, atol=2e-2) -class TestCUDNNExhaustiveSearch(TestCUDNN): - def init_kernel_type(self): - self.use_cudnn = True - self.exhaustive_search = True - - # FIXME(typhoonzero): find a way to determine if # using cudnn > 6 in python # class TestWithDilationCUDNN(TestWithDilation): -- GitLab