未验证 提交 20f30dd6 编写于 作者: C ceci3 提交者: GitHub

add benchmark flag for conv_transpose (#22389)

上级 6d325a94
......@@ -272,6 +272,16 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
algo = search::Find<T>(args, exhaustive_search, false, 0, ctx);
workspace_size = search::GetWorkspaceSize(args, algo);
#if CUDNN_VERSION_MIN(7, 0, 1)
// when groups > 1, SearchAlgorithm find algo is CUDNN_CONVOLUTION_\
// FWD_ALGO_WINOGRAD_NONFUSED, but this kind of algorithm is unstable
// in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\
// FWD_ALGO_IMPLICIT_GEMM manually.
if (ctx.Attr<int>("groups") > 1) {
algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
}
#endif
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
......@@ -881,6 +891,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
#if CUDNN_VERSION_MIN(7, 0, 1)
iwo_group = 1;
c_group = groups;
groups = 1;
#endif
auto dtype = platform::CudnnDataType<T>::type;
......
......@@ -245,7 +245,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
if (algo == 0 && FLAGS_cudnn_deterministic) {
if (FLAGS_cudnn_deterministic) {
algo = static_cast<cudnnConvolutionBwdDataAlgo_t>(1);
}
......@@ -476,6 +476,10 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo));
if (FLAGS_cudnn_deterministic) {
data_algo = static_cast<cudnnConvolutionFwdAlgo_t>(1);
}
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
......@@ -492,6 +496,9 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo));
if (FLAGS_cudnn_deterministic) {
filter_algo = static_cast<cudnnConvolutionBwdFilterAlgo_t>(1);
}
// get workspace for backwards filter algorithm
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册