未验证 提交 6b78e00d 编写于 作者: W wangchaochaohu 提交者: GitHub

Cudnn convolution reconstruction (#18284)

* rewrite the conv_op using cudnn_conv_helper

* add workspace limit for v7 test=develop

* fix test=develop

* add half float test=develop

* fix test=develop

* fix test=develop

* revise code style test=develop

* fix test=develop
上级 157211c4
......@@ -14,11 +14,11 @@ limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h"
namespace paddle {
namespace operators {
......@@ -57,16 +57,57 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool has_got_workspace_size = true;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
algo_t algo;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
}
#endif
if (!exhaustive) {
#if CUDNN_VERSION >= 7001
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, &perf_count,
perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
workspace_size = GetWorkspaceSize(args, algo);
if (workspace_size > workspace_size_limit) {
has_got_workspace_size = false;
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
}
if (!has_got_workspace_size) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
args.handle, args.idesc.desc(), args.wdesc.desc(),
args.cdesc.desc(), args.odesc.desc(),
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit,
&algo));
}
#else
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
args.odesc.desc(), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
VLOG(3) << "choose algo " << algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
......@@ -128,15 +169,72 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
bool has_got_workspace_size = true;
algo_t algo;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
}
#endif
if (!exhaustive && !deterministic) {
#if CUDNN_VERSION >= 7001
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
args.handle, args.wdesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS,
&perf_count, perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
#if CUDNN_VERSION < 7500
int stride_dim = args.x->dims().size() - 2;
bool blacklist = std::any_of(args.s.begin(), args.s.begin() + stride_dim,
[=](int n) { return n != 1; });
if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
perf_results[best_algo_idx].algo) ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
static_cast<cudnnConvolutionBwdDataAlgo_t>(
perf_results[best_algo_idx].algo) ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
#endif
workspace_size = GetWorkspaceSize(args, algo);
if (workspace_size > workspace_size_limit) {
has_got_workspace_size = false;
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
}
if (!has_got_workspace_size) {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
args.handle, args.wdesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.idesc.desc(),
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
}
#else
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
args.handle, args.wdesc.desc(), args.idesc.desc(), args.cdesc.desc(),
args.odesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
args.handle, args.wdesc.desc(), args.odesc.desc(), args.cdesc.desc(),
args.idesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
} else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else {
......@@ -186,8 +284,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
size_t workspace_size = 0;
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle, args.wdesc.desc(), args.idesc.desc(),
args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size));
args.handle, args.wdesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size));
return workspace_size;
}
};
......@@ -203,17 +301,61 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
bool has_got_workspace_size = true;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
}
#endif
algo_t algo;
if (!exhaustive && !deterministic) {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS,
&perf_count, perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
workspace_size = GetWorkspaceSize(args, algo);
if (workspace_size > workspace_size_limit) {
has_got_workspace_size = false;
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
}
if (!has_got_workspace_size) {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(),
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
}
#else
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(),
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
} else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else {
......
......@@ -79,44 +79,42 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
if (exhaustive_search && FLAGS_cudnn_deterministic) {
PADDLE_THROW(
"Cann't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time.");
}
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
ConvArgs args{input, filter, output, strides, paddings, dilations};
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto dtype = platform::CudnnDataType<T>::type;
DataLayout layout = DataLayout::kNCHW;
if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
auto layout_format = GetCudnnTensorFormat(layout);
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
args.handle = handle;
args.cdesc.set(dtype, paddings, strides, dilations);
#if CUDNN_VERSION_MIN(7, 0, 1)
// cudnn 7 can support groups, no need to do it manually
// FIXME(typhoonzero): find a better way to disable groups
// rather than setting it to 1.
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups));
args.cdesc.desc(), groups));
groups = 1;
#endif
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()), groups);
args.idesc.set(*input, groups);
args.wdesc.set(*filter, layout_format, groups);
args.odesc.set(*output, groups);
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w;
......@@ -126,149 +124,27 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = filter->numel() / groups;
// ------------------- cudnn conv workspace ---------------------
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = 0;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::min(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
size_t workspace_size = 0; // final workspace to allocate.
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo{};
bool half_float = false;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Tensor core is supported since the volta GPU and
// is only enabled when input and filter data are float16
if (dev_ctx.GetComputeCapability() >= 70 &&
std::type_index(typeid(T)) ==
std::type_index(typeid(platform::float16))) {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
// Currently tensor core is only enabled using this algo
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
half_float = true;
VLOG(5) << "use cudnn_tensor_op_math";
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
}
#endif
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in
// conv_cudnn_helper.h
bool has_got_workspace_size = false;
if ((!exhaustive_search) && (!half_float)) {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
// get workspace size able to allocate
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
// NOTE(zjl): cudnnGetConvolutionForwardAlgorithm_v7 cannot limit
// workspace size. If the workspace size found by v7 exceeds the limit,
// we should fallback to non-v7 method to find another algorithm.
if (workspace_size_in_bytes > workspace_size_limit) {
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size_in_bytes << ") exceeds the limit("
<< workspace_size_limit << ")";
#endif
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#if CUDNN_VERSION >= 7001
} else {
has_got_workspace_size = true;
}
#endif
using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
algo = search::Find<T>(args, exhaustive_search, false, 0, ctx);
workspace_size = search::GetWorkspaceSize(args, algo);
VLOG(3) << "cuDNN forward algo " << algo;
} else if (exhaustive_search && (!half_float)) {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
algo = algo_cache.GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc,
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return fwd_perf_stat[0].algo;
});
VLOG(3) << "choose algo " << algo;
} else {
PADDLE_ENFORCE(half_float,
"cuDNN exhaustive search doesn't support half float.");
}
if (!has_got_workspace_size) {
// get workspace size able to allocate
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
}
// It is possible for float16 on Volta GPU to allocate more memory than
// the limit because the algo is overrided to use tensor core.
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit");
// Allocate on GPU memory
Tensor cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
void* cudnn_workspace_ptr =
static_cast<void*>(cudnn_workspace.data<int8_t>());
VLOG(2) << "Cudnn workspace size fwd: "
<< static_cast<double>(workspace_size_in_bytes) / (1 << 20)
<< " MB";
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_filter_desc, filter_data + i * group_offset_filter,
cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, args.idesc.desc(),
input_data + i * group_offset_in, args.wdesc.desc(),
filter_data + i * group_offset_filter, args.cdesc.desc(), algo,
workspace_ptr, workspace_size, &beta, args.odesc.desc(),
output_data + i * group_offset_out));
},
workspace_size);
}
}
};
......@@ -294,62 +170,30 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
if (exhaustive_search && FLAGS_cudnn_deterministic) {
bool deterministic = FLAGS_cudnn_deterministic;
if (exhaustive_search && deterministic) {
PADDLE_THROW(
"Can't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time.");
}
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_grad_desc;
ScopedFilterDescriptor filter_desc;
ScopedFilterDescriptor filter_grad_desc;
ScopedConvolutionDescriptor conv_desc;
T* filter_grad_data = nullptr;
T* input_grad_data = nullptr;
ConvArgs args1{input_grad, filter, output_grad,
strides, paddings, dilations};
ConvArgs args2{input, filter_grad, output_grad,
strides, paddings, dilations};
// conv_cudnn_helper.h
auto handle = dev_ctx.cudnn_handle();
auto dtype = platform::CudnnDataType<T>::type;
DataLayout layout = DataLayout::kNCHW;
if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
#if CUDNN_VERSION_MIN(7, 0, 1)
// cudnn 7 can support groups, no need to do it manually
// FIXME(typhoonzero): find a better way to disable groups
// rather than setting it to 1.
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups));
groups = 1;
#endif
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()), groups);
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Enable Tensor Core for cudnn backward
if (dev_ctx.GetComputeCapability() >= 70 &&
std::type_index(typeid(T)) ==
std::type_index(typeid(platform::float16))) {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math for backward";
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math for backward";
}
#endif
auto layout_tensor = GetCudnnTensorFormat(layout);
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
......@@ -361,263 +205,83 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = filter->numel() / groups;
// ------------------- cudnn backward algorithm ---------------------
cudnnConvolutionBwdDataAlgo_t data_algo{};
cudnnConvolutionBwdFilterAlgo_t filter_algo{};
size_t workspace_size_in_bytes = 0, tmp_size = 0;
size_t workspace_size_limit = 0;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::min(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
Tensor cudnn_workspace;
void* cudnn_workspace_ptr = nullptr;
if ((input_data || filter_data) && exhaustive_search) {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_limit)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
}
cudnnConvolutionBwdDataAlgo_t data_algo =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionBwdFilterAlgo_t filter_algo =
static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
size_t workspace_size = 0;
int iwo_groups, c_groups;
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in
// conv_cudnn_helper.h
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
auto handle = dev_ctx.cudnn_handle();
#if CUDNN_VERSION_MIN(7, 0, 1)
iwo_groups = 1;
c_groups = groups;
groups = 1;
#endif
bool has_got_bwd_data_ws_size = false;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>& data_algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>(
0);
data_algo = data_algo_cache.GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionBwdDataAlgoPerf_t,
kNUM_CUDNN_BWD_DATA_ALGS>
data_perf_stat;
CUDNN_ENFORCE(platform::dynload::
cudnnFindConvolutionBackwardDataAlgorithmEx(
handle, cudnn_filter_desc, filter_data,
cudnn_output_grad_desc, output_grad_data,
cudnn_conv_desc, cudnn_input_desc,
input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS,
&returned_algo_count, data_perf_stat.data(),
cudnn_workspace_ptr, workspace_size_limit));
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = data_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return data_perf_stat[0].algo;
});
VLOG(3) << "cuDNN backward data algo " << data_algo;
} else if (FLAGS_cudnn_deterministic) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
handle, cudnn_filter_desc,
// dyDesc: Handle to the previously initialized input
// differential
// tensor descriptor.
cudnn_output_grad_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor
// descriptor.
cudnn_input_desc, kNUM_CUDNN_BWD_DATA_ALGS, &perf_count,
perf_results.get()));
data_algo = (perf_results.get())[best_algo_idx].algo;
int stride_dim = input->dims().size() - 2;
bool blacklist =
std::any_of(strides.begin(), strides.begin() + stride_dim,
[=](int n) { return n != 1; });
if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
perf_results[best_algo_idx].algo) ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
static_cast<cudnnConvolutionBwdDataAlgo_t>(
perf_results[best_algo_idx].algo) ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
auto new_workspace_size = std::max(workspace_size_in_bytes, tmp_size);
if (new_workspace_size > workspace_size_limit) {
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< new_workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
#endif
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc,
// dyDesc: Handle to the previously initialized input
// differential
// tensor descriptor.
cudnn_output_grad_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor
// descriptor.
cudnn_input_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo));
#if CUDNN_VERSION >= 7001
} else {
workspace_size_in_bytes = new_workspace_size;
has_got_bwd_data_ws_size = true;
}
#endif
}
// ------------------- cudnn descriptors ---------------------
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
args1.handle = handle;
args1.idesc.set(*input_grad, iwo_groups);
args1.wdesc.set(*filter, layout_tensor, iwo_groups);
args1.odesc.set(*output_grad, iwo_groups);
args1.cdesc.set(dtype, paddings, strides, dilations, c_groups);
if (!has_got_bwd_data_ws_size) {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
}
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo =
search1::Find<T>(args1, exhaustive_search, deterministic, 0, ctx);
workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo));
}
bool has_got_bwd_filter_ws_size = false;
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>& f_algo_cache =
ctx.GetKernelConfig<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>(1);
filter_algo = f_algo_cache.GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
kNUM_CUDNN_BWD_FILTER_ALGS>
filter_perf_stat;
CUDNN_ENFORCE(
platform::dynload::
cudnnFindConvolutionBackwardFilterAlgorithmEx(
handle, cudnn_input_desc, input_data,
cudnn_output_grad_desc, output_grad_data,
cudnn_conv_desc, cudnn_filter_desc, filter_grad_data,
kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count,
filter_perf_stat.data(), cudnn_workspace_ptr,
workspace_size_limit));
return filter_perf_stat[0].algo;
});
VLOG(3) << "cuDNN backward filter algo " << filter_algo;
} else if (FLAGS_cudnn_deterministic) {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS,
&perf_count, perf_results.get()));
filter_algo = (perf_results.get())[best_algo_idx].algo;
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size));
auto new_workspace_size = std::max(workspace_size_in_bytes, tmp_size);
if (new_workspace_size > workspace_size_limit) {
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< new_workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
#endif
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo));
#if CUDNN_VERSION >= 7001
} else {
workspace_size_in_bytes = new_workspace_size;
has_got_bwd_filter_ws_size = true;
}
#endif
}
if (!has_got_bwd_filter_ws_size) {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
}
}
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit");
// ------------------- cudnn conv workspace ---------------------
if (!cudnn_workspace_ptr) {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
VLOG(2) << "Cudnn workspace size bwd: "
<< static_cast<double>(workspace_size_in_bytes) / (1 << 20)
<< " MB";
// ------------------- cudnn descriptors ---------------------
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
args2.handle = handle;
args2.idesc.set(*input, iwo_groups);
args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups);
args2.odesc.set(*output_grad, iwo_groups);
args2.cdesc.set(dtype, paddings, strides, dilations, c_groups);
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
search2::Find<T>(args2, exhaustive_search, deterministic, 1, ctx);
workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, filter_algo));
}
// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc,
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
cudnn_workspace_ptr, workspace_size_in_bytes, &beta,
cudnn_input_desc, input_grad_data + i * group_offset_in));
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, args1.wdesc.desc(),
filter_data + i * group_offset_filter, args1.odesc.desc(),
output_grad_data + i * group_offset_out, args1.cdesc.desc(),
data_algo, cudnn_workspace_ptr, workspace_size, &beta,
args1.idesc.desc(), input_grad_data + i * group_offset_in));
},
workspace_size);
}
}
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset filter_grad.
for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
cudnn_conv_desc, filter_algo, cudnn_workspace_ptr,
workspace_size_in_bytes, &beta, cudnn_filter_desc,
filter_grad_data + i * group_offset_filter));
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, args2.idesc.desc(),
input_data + i * group_offset_in, args2.odesc.desc(),
output_grad_data + i * group_offset_out, args2.cdesc.desc(),
filter_algo, cudnn_workspace_ptr, workspace_size, &beta,
args2.wdesc.desc(),
filter_grad_data + i * group_offset_filter));
},
workspace_size);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册