未验证 提交 abe20923 编写于 作者: Q qingqing01 提交者: GitHub

Exhaustive search for cuDNN conv. (#14286)

* exhaustive search for cuDNN conv.
* Refine code and add unit testing.
* Fix model load in fluid/inference and unit testing in conv2d
* Follow comments.
* Fix compiling test=develop
上级 f215534e
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <array>
#include <string>
#include <vector>
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/naive_executor.h"
......
......@@ -16,7 +16,6 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle_inference_api.h"
namespace paddle {
......
......@@ -16,13 +16,14 @@
#include <glog/logging.h>
#include <sys/time.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <numeric>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/string/printf.h"
#include "paddle_inference_api.h"
namespace paddle {
namespace inference {
......
......@@ -59,7 +59,8 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
bool IsPersistable(const framework::VarDesc* var) {
if (var->Persistable() &&
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != framework::proto::VarType::FETCH_LIST) {
var->GetType() != framework::proto::VarType::FETCH_LIST &&
var->GetType() != framework::proto::VarType::RAW) {
return true;
}
return false;
......
......@@ -134,7 +134,7 @@ class TensorRTEngine : public EngineBase {
std::unordered_map<std::string /*name*/, std::unique_ptr<framework::Tensor>>
weight_map;
// TODO: (NHZLX)
// TODO(NHZLX)
// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
......
......@@ -66,9 +66,10 @@ class AddPositionEncodingKernel : public framework::OpKernel<T> {
x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i];
for (int j = 0; j < max_length; ++j) {
for (int k = 0; k < half_size; ++k) {
const double val = (half_size > 1)
? j / pow(10000.0, double(k) / (half_size - 1))
: j / 10000.0;
const double val =
(half_size > 1)
? j / pow(10000.0, static_cast<double>(k) / (half_size - 1))
: j / 10000.0;
dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta;
dst_ptr[half_size + k] =
src_ptr[half_size + k] * alpha + cos(val) * beta;
......
......@@ -15,15 +15,22 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_bool(cudnn_deterministic, false,
"Whether allow using an autotuning algorithm for convolution "
"operator. The autotuning algorithm may be non-deterministic. If "
"true, the algorithm is deterministic.");
DEFINE_uint64(conv_workspace_size_limit, 4096,
"cuDNN convolution workspace limit in MB unit.");
DEFINE_bool(cudnn_exhaustive_search, false,
"Whether enable exhaustive search for cuDNN convolution or "
"not, defalut is False.");
namespace paddle {
namespace operators {
......@@ -36,13 +43,25 @@ using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache";
static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache";
static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache";
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024;
static constexpr size_t kNUM_CUDNN_FWD_ALGS =
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS =
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS =
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto* input = ctx.Input<Tensor>("Input");
......@@ -55,6 +74,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
......@@ -120,19 +141,19 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace ---------------------
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
bool half_float = false;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Tensor core is supported since the volta GPU and
// is only enabled when input and filter data are float16
......@@ -143,6 +164,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
// Currently tensor core is only enabled using this algo
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
half_float = true;
VLOG(5) << "use cudnn_tensor_op_math";
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
......@@ -151,6 +173,57 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
}
#endif
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
if ((!exhaustive_search) && (!half_float)) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo;
} else if (exhaustive_search && (!half_float)) {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr;
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
algo_cache =
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} else {
algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
}
algo = algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc,
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return fwd_perf_stat[0].algo;
});
VLOG(3) << "choose algo " << algo;
} else {
PADDLE_ENFORCE(half_float,
"cuDNN exhaustive search doesn't support half float.");
}
// get workspace size able to allocate
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
......@@ -162,7 +235,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
......@@ -180,6 +252,7 @@ template <typename T>
class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto input = ctx.Input<Tensor>("Input");
......@@ -198,6 +271,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
if (exhaustive_search && FLAGS_cudnn_deterministic) {
PADDLE_THROW(
"Cann't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time.");
}
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
......@@ -265,14 +345,66 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnnConvolutionBwdFilterAlgo_t filter_algo;
size_t workspace_size_in_bytes = 0, tmp_size = 0;
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
if (!FLAGS_cudnn_deterministic) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* data_algo_cache;
if (ctx.scope().FindVar(kCUDNNBwdDataAlgoCache)) {
data_algo_cache =
ctx.scope()
.FindVar(kCUDNNBwdDataAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
} else {
data_algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNBwdDataAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
}
data_algo = data_algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionBwdDataAlgoPerf_t,
kNUM_CUDNN_BWD_DATA_ALGS>
data_perf_stat;
auto cudnn_find_bd_data_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::
cudnnFindConvolutionBackwardDataAlgorithmEx(
handle, cudnn_filter_desc, filter_data,
cudnn_output_grad_desc, output_grad_data,
cudnn_conv_desc, cudnn_input_desc, input_grad_data,
kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
data_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_bd_data_func,
workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = data_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return data_perf_stat[0].algo;
});
VLOG(3) << "cuDNN backward data algo " << data_algo;
} else if (FLAGS_cudnn_deterministic) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc,
......@@ -285,10 +417,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_input_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo));
} else {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_output_grad_desc,
......@@ -297,17 +426,54 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
}
if (filter_grad) {
if (!FLAGS_cudnn_deterministic) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>* f_algo_cache;
if (ctx.scope().FindVar(kCUDNNBwdFilterAlgoCache)) {
f_algo_cache =
ctx.scope()
.FindVar(kCUDNNBwdFilterAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
} else {
f_algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNBwdFilterAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
}
filter_algo = f_algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
kNUM_CUDNN_BWD_FILTER_ALGS>
filter_perf_stat;
auto cudnn_find_bd_f_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::
cudnnFindConvolutionBackwardFilterAlgorithmEx(
handle, cudnn_input_desc, input_data,
cudnn_output_grad_desc, output_grad_data,
cudnn_conv_desc, cudnn_filter_desc,
filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS,
&returned_algo_count, filter_perf_stat.data(),
cudnn_workspace, workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_bd_f_func,
workspace_size_limit);
return filter_perf_stat[0].algo;
});
VLOG(3) << "cuDNN backward filter algo " << filter_algo;
} else if (FLAGS_cudnn_deterministic) {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo));
} else {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
}
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
......@@ -317,7 +483,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace operators {
template <typename TAlgorithm>
class AlgorithmsCache {
public:
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations,
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);
private:
std::unordered_map<int64_t, TAlgorithm> hash_;
std::mutex mutex_;
};
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
std::lock_guard<std::mutex> lock(mutex_);
int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously
// discovered algorithm, or fall back to generating a new one.
std::hash<int64_t> hashFn;
// do hash like boost
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
for (const auto num : dims1) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
for (const auto num : dims2) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
}
for (const auto num : strides) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 2;
}
for (const auto num : paddings) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 3;
}
for (const auto num : dilations) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 4;
}
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5;
if (seed == 0) return gen_func();
if (hash_.find(seed) == hash_.end()) {
TAlgorithm value = gen_func();
hash_[seed] = value;
}
return hash_[seed];
}
} // namespace operators
} // namespace paddle
......@@ -189,6 +189,11 @@ void Conv2DOpMaker::Make() {
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
AddAttr<bool>("exhaustive_search",
"(bool, default false) cuDNN has many algorithm to calculation "
"convolution, whether enable exhaustive search ",
"for cuDNN convolution or not, defalut is False.")
.SetDefault(false);
AddComment(R"DOC(
Convolution Operator.
......@@ -283,7 +288,11 @@ void Conv3DOpMaker::Make() {
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
AddAttr<bool>("exhaustive_search",
"(bool, default false) cuDNN has many algorithm to calculation "
"convolution, whether enable exhaustive search ",
"for cuDNN convolution or not, defalut is False.")
.SetDefault(false);
AddComment(R"DOC(
Convolution3D Operator.
......
......@@ -34,7 +34,7 @@ namespace operators {
using FluidDT = framework::proto::VarType_Type;
using TRT_DT = nvinfer1::DataType;
namespace {
namespace { // NOLINT
TRT_DT FluidDataType2TRT(FluidDT type) {
switch (type) {
......
......@@ -204,7 +204,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
<< "." << (driver_version_ % 100) / 10
<< ", Runtime Version: " << runtime_version_ / 1000
<< "." << (runtime_version_ % 100) / 10;
size_t cudnn_dso_ver = dynload::cudnnGetVersion();
LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
<< ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
<< (cudnn_dso_ver % 100) / 10 << ".";
callback_manager_.reset(new StreamCallbackManager(stream_));
}
......
......@@ -65,51 +65,54 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
* include all needed cudnn functions in HPPL
* different cudnn version has different interfaces
**/
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
__macro(cudnnFindConvolutionForwardAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \
__macro(cudnnGetErrorString);
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
......
......@@ -126,7 +126,8 @@ def __bootstrap__():
if core.is_compiled_with_cuda():
read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic'
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'conv_workspace_size_limit', 'cudnn_exhaustive_search'
]
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
......
......@@ -27,6 +27,7 @@ from .tensor import concat
from . import utils
from .. import unique_name
from functools import reduce
from .. import core
__all__ = [
'fc',
......@@ -1666,6 +1667,20 @@ def conv2d(input,
pre_bias = helper.create_variable_for_type_inference(dtype)
if use_cudnn:
helper.create_variable(
name="kCUDNNFwdAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
helper.create_variable(
name="kCUDNNBwdDataAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
helper.create_variable(
name="kCUDNNBwdFilterAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
helper.append_op(
type=l_type,
inputs={
......@@ -1679,7 +1694,7 @@ def conv2d(input,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False
'use_mkldnn': False,
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......
......@@ -67,6 +67,7 @@ class TestConv2dOp(OpTest):
def setUp(self):
self.op_type = "conv2d"
self.use_cudnn = False
self.exhaustive_search = False
self.use_cuda = False
self.use_mkldnn = False
self.data_format = "AnyLayout"
......@@ -98,7 +99,8 @@ class TestConv2dOp(OpTest):
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
'data_format': self.data_format,
'exhaustive_search': self.exhaustive_search
}
self.outputs = {'Output': output}
......@@ -361,6 +363,12 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp):
self.op_type = "depthwise_conv2d"
class TestCUDNNExhaustiveSearch(TestConv2dOp):
def init_kernel_type(self):
self.use_cudnn = True
self.exhaustive_search = True
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
......
......@@ -335,6 +335,12 @@ class TestFP16WithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNExhaustiveSearch(TestCUDNN):
def init_kernel_type(self):
self.use_cudnn = True
self.exhaustive_search = True
# FIXME(typhoonzero): find a way to determine if
# using cudnn > 6 in python
# class TestWithDilationCUDNN(TestWithDilation):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册