From 3d5aa9d10a70b7e68b3cded9b2720f662c952016 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Fri, 12 Mar 2021 13:55:14 +0800 Subject: [PATCH] [ROCM] fix conv2d and conv3d op, test=develop (#31553) --- paddle/fluid/operators/conv_cudnn_op.cu | 215 ++++++++-------- paddle/fluid/operators/conv_miopen_helper.h | 231 ++++++++---------- .../operators/conv_transpose_cudnn_op.cu | 40 ++- paddle/fluid/platform/miopen_desc.h | 25 +- .../fluid/tests/unittests/test_conv2d_op.py | 15 +- .../fluid/tests/unittests/test_conv3d_op.py | 14 ++ .../unittests/test_sync_batch_norm_op.py | 7 +- 7 files changed, 298 insertions(+), 249 deletions(-) diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index 110bb69a14..39e9d37ddc 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -249,6 +249,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { args.handle = handle; #ifdef PADDLE_WITH_HIP + // MIOPEN need to set groups in cdesc in miopen_desc.h args.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), groups); #else @@ -264,6 +265,10 @@ class CUDNNConvOpKernel : public framework::OpKernel { platform::dynload::cudnnSetConvolutionGroupCount(args.cdesc.desc(), groups)); groups = 1; +#endif +#ifdef PADDLE_WITH_HIP + // MIOPEN do not set groups in wdesc after set groups in cdesc + groups = 1; #endif args.idesc.set(transformed_input, layout_format); args.wdesc.set(transformed_filter_channel, layout_format, groups); @@ -292,12 +297,14 @@ class CUDNNConvOpKernel : public framework::OpKernel { #ifdef PADDLE_WITH_HIP miopenConvFwdAlgorithm_t algo{}; using search = SearchAlgorithm; + workspace_size = search::GetWorkspaceSize(args); + algo = search::Find(args, exhaustive_search, false, workspace_size, ctx); #else cudnnConvolutionFwdAlgo_t algo{}; using search = SearchAlgorithm; -#endif algo = search::Find(args, exhaustive_search, false, ctx); workspace_size = search::GetWorkspaceSize(args, algo); +#endif #if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1) // when groups > 1, SearchAlgorithm find algo is CUDNN_CONVOLUTION_\ @@ -652,13 +659,17 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { #ifdef PADDLE_WITH_HIP using search1 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search1::GetWorkspaceSize(args1)); + data_algo = search1::Find(args1, exhaustive_search, deterministic, + workspace_size, ctx); #else using search1 = SearchAlgorithm; -#endif data_algo = search1::Find(args1, exhaustive_search, deterministic, ctx); workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); +#endif } if (filter_grad) { @@ -673,13 +684,17 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { platform::AllowTF32Cudnn(), c_groups); #ifdef PADDLE_WITH_HIP using search2 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search2::GetWorkspaceSize(args2)); + filter_algo = search2::Find(args2, exhaustive_search, deterministic, + workspace_size, ctx); #else using search2 = SearchAlgorithm; -#endif filter_algo = search2::Find(args2, exhaustive_search, deterministic, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo)); +#endif } // ------------------- cudnn conv backward data --------------------- @@ -688,23 +703,22 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { VLOG(4) << "Conv_grad: use_addto = " << ctx.Attr("use_addto"); if (input_grad) { - // When beta is 0, it is unnecessary to reset input_grad. - // When beta is 1, the output cannot be reset since addt strategy used. - for (int i = 0; i < groups; i++) { +// When beta is 0, it is unnecessary to reset input_grad. +// When beta is 1, the output cannot be reset since addt strategy used. #ifdef PADDLE_WITH_HIP - workspace_handle.RunFunc( - [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenConvolutionBackwardData( - handle, &alpha, args1.odesc.desc(), - output_grad_data + i * group_offset_out, - args1.wdesc.desc(), filter_data + i * group_offset_filter, - args1.cdesc.desc(), data_algo, &beta, args1.idesc.desc(), - transformed_input_grad_data + i * group_offset_in, - cudnn_workspace_ptr, workspace_size)); - }, - workspace_size); + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardData( + handle, &alpha, args1.odesc.desc(), output_grad_data, + args1.wdesc.desc(), filter_data, args1.cdesc.desc(), + data_algo, &beta, args1.idesc.desc(), + transformed_input_grad_data, cudnn_workspace_ptr, + workspace_size)); + }, + workspace_size); #else + for (int i = 0; i < groups; i++) { workspace_handle.RunFunc( [&](void* cudnn_workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -717,9 +731,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { transformed_input_grad_data + i * group_offset_in)); }, workspace_size); -#endif } - +#endif if (!is_sys_pad) { std::vector starts(transformed_input_channel.dims().size(), 0); std::vector axes(transformed_input_channel.dims().size(), 0); @@ -751,23 +764,20 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { ScalingParamType beta_filter = 0.0f; // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { - // Because beta is zero, it is unnecessary to reset filter_grad. - for (int i = 0; i < groups; i++) { +// Because beta is zero, it is unnecessary to reset filter_grad. #ifdef PADDLE_WITH_HIP - workspace_handle.RunFunc( - [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenConvolutionBackwardWeights( - handle, &alpha, args2.odesc.desc(), - output_grad_data + i * group_offset_out, - args2.idesc.desc(), input_data + i * group_offset_in, - args2.cdesc.desc(), filter_algo, &beta, - args2.wdesc.desc(), - filter_grad_data + i * group_offset_filter, - cudnn_workspace_ptr, workspace_size)); - }, - workspace_size); + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardWeights( + handle, &alpha, args2.odesc.desc(), output_grad_data, + args2.idesc.desc(), input_data, args2.cdesc.desc(), + filter_algo, &beta, args2.wdesc.desc(), filter_grad_data, + cudnn_workspace_ptr, workspace_size)); + }, + workspace_size); #else + for (int i = 0; i < groups; i++) { workspace_handle.RunFunc( [&](void* cudnn_workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -780,8 +790,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { filter_grad_data + i * group_offset_filter)); }, workspace_size); -#endif } +#endif if (compute_format == DataLayout::kNHWC) { TransToChannelFirst( @@ -1080,32 +1090,37 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { #ifdef PADDLE_WITH_HIP using search1 = SearchAlgorithm; + workspace_size = search1::GetWorkspaceSize(args1); + fwd_algo1 = search1::Find(args1, exhaustive_search, false, + workspace_size, ctx); #else using search1 = SearchAlgorithm; -#endif fwd_algo1 = search1::Find(args1, exhaustive_search, false, ctx); workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); +#endif } if (ddW) { ddw = ddW->data(); args2.handle = handle; args2.idesc.set(transformed_X, iwo_group); - args2.wdesc.set(*ddW, layout, iwo_group); - args2.odesc.set(transformed_ddO_channel, iwo_group); args2.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_group); #ifdef PADDLE_WITH_HIP using search2 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search2::GetWorkspaceSize(args2)); + fwd_algo2 = search2::Find(args2, exhaustive_search, false, + workspace_size, ctx); #else using search2 = SearchAlgorithm; -#endif fwd_algo2 = search2::Find(args2, exhaustive_search, false, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2)); +#endif } } @@ -1114,21 +1129,23 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args3.handle = handle; args3.idesc.set(transformed_ddX, iwo_group); args3.wdesc.set(*dW, layout, iwo_group); - args3.odesc.set(transformed_dO_channel, iwo_group); - args3.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_group); #ifdef PADDLE_WITH_HIP using search3 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search3::GetWorkspaceSize(args3)); + filter_algo = search3::Find(args3, exhaustive_search, deterministic, + workspace_size, ctx); #else using search3 = SearchAlgorithm; -#endif filter_algo = search3::Find(args3, exhaustive_search, deterministic, ctx); workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); +#endif } if (ddW && dX) { @@ -1143,13 +1160,17 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { #ifdef PADDLE_WITH_HIP using search4 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search4::GetWorkspaceSize(args4)); + data_algo = search4::Find(args4, exhaustive_search, deterministic, + workspace_size, ctx); #else using search4 = SearchAlgorithm; -#endif data_algo = search4::Find(args4, exhaustive_search, deterministic, ctx); workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); +#endif } int i_n, i_c, i_d, i_h, i_w; @@ -1176,21 +1197,19 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { if (ddO) { if (ddX) { ddx = transformed_ddX.data(); - for (int i = 0; i < groups; i++) { #ifdef PADDLE_WITH_HIP - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenConvolutionForward( - handle, &alpha, args1.idesc.desc(), - ddx + i * group_offset_in, args1.wdesc.desc(), - w + i * group_offset_filter, args1.cdesc.desc(), - fwd_algo1, &beta, args1.odesc.desc(), - transformed_ddy_channel + i * group_offset_out, - workspace_ptr, workspace_size)); - }, - workspace_size); + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionForward( + handle, &alpha, args1.idesc.desc(), ddx, + args1.wdesc.desc(), w, args1.cdesc.desc(), fwd_algo1, + &beta, args1.odesc.desc(), transformed_ddy_channel, + workspace_ptr, workspace_size)); + }, + workspace_size); #else + for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1203,26 +1222,24 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { transformed_ddy_channel + i * group_offset_out)); }, workspace_size); -#endif } +#endif } if (ddW) { - for (int i = 0; i < groups; i++) { #ifdef PADDLE_WITH_HIP - // MIOPEN ONLY support beta to be 0.0f - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenConvolutionForward( - handle, &alpha, args2.idesc.desc(), - x + i * group_offset_in, args2.wdesc.desc(), - ddw + i * group_offset_filter, args2.cdesc.desc(), - fwd_algo2, &beta, args2.odesc.desc(), - transformed_ddy_channel + i * group_offset_out, - workspace_ptr, workspace_size)); - }, - workspace_size); + // MIOPEN ONLY support beta to be 0.0f + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionForward( + handle, &alpha, args2.idesc.desc(), x, args2.wdesc.desc(), + ddw, args2.cdesc.desc(), fwd_algo2, &beta, + args2.odesc.desc(), transformed_ddy_channel, + workspace_ptr, workspace_size)); + }, + workspace_size); #else + for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1235,8 +1252,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { transformed_ddy_channel + i * group_offset_out)); }, workspace_size); -#endif } +#endif } if (channel_last) { TransToChannelLast( @@ -1246,21 +1263,19 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { T* transformed_dy_channel = transformed_dO_channel.data(); if (dW && ddX) { ddx = transformed_ddX.data(); - for (int i = 0; i < groups; i++) { #ifdef PADDLE_WITH_HIP - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenConvolutionBackwardWeights( - handle, &alpha, args3.odesc.desc(), - transformed_dy_channel + i * group_offset_out, - args3.idesc.desc(), ddx + i * group_offset_in, - args3.cdesc.desc(), filter_algo, &beta, - args3.wdesc.desc(), dw + i * group_offset_filter, - workspace_ptr, workspace_size)); - }, - workspace_size); + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardWeights( + handle, &alpha, args3.odesc.desc(), transformed_dy_channel, + args3.idesc.desc(), ddx, args3.cdesc.desc(), filter_algo, + &beta, args3.wdesc.desc(), dw, workspace_ptr, + workspace_size)); + }, + workspace_size); #else + for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1273,27 +1288,25 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { dw + i * group_offset_filter)); }, workspace_size); -#endif } +#endif } if (dX && ddW) { ddw = ddW->data(); - for (int i = 0; i < groups; i++) { #ifdef PADDLE_WITH_HIP - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenConvolutionBackwardData( - handle, &alpha, args4.odesc.desc(), - transformed_dy_channel + i * group_offset_out, - args4.wdesc.desc(), ddw + i * group_offset_filter, - args4.cdesc.desc(), data_algo, &beta, args4.idesc.desc(), - transformed_dx + i * group_offset_in, workspace_ptr, - workspace_size)); - }, - workspace_size); + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardData( + handle, &alpha, args4.odesc.desc(), transformed_dy_channel, + args4.wdesc.desc(), ddw, args4.cdesc.desc(), data_algo, + &beta, args4.idesc.desc(), transformed_dx, workspace_ptr, + workspace_size)); + }, + workspace_size); #else + for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1306,8 +1319,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { transformed_dx + i * group_offset_in)); }, workspace_size); -#endif } +#endif if (!is_sys_pad) { // reverse padded input diff --git a/paddle/fluid/operators/conv_miopen_helper.h b/paddle/fluid/operators/conv_miopen_helper.h index 44ead95a35..3ab27e1ec4 100644 --- a/paddle/fluid/operators/conv_miopen_helper.h +++ b/paddle/fluid/operators/conv_miopen_helper.h @@ -127,57 +127,52 @@ struct SearchAlgorithm { template static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, + bool deterministic, size_t workspace_size, const framework::ExecutionContext& ctx) { - auto dtype = platform::CudnnDataType::type; - bool has_got_workspace_size = true; - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; - size_t workspace_size = 0; algo_t algo; auto& dev_ctx = ctx.template device_context(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto& temp = ctx.cuda_device_context(); - AlgorithmsCache& algo_cache = - *(framework::ConvSearchCache::Instance().GetForward()); - - auto x_dims = framework::vectorize(args.x->dims()); - auto w_dims = framework::vectorize(args.w->dims()); - - VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; - - algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { - int returned_algo_count; - std::array perf_stat; - - auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenFindConvolutionForwardAlgorithm( - args.handle, args.idesc.desc(), args.x->data(), - args.wdesc.desc(), args.w->data(), args.cdesc.desc(), - args.odesc.desc(), const_cast(args.o->data()), - kNUM_CUDNN_FWD_ALGS, &returned_algo_count, perf_stat.data(), - cudnn_workspace_ptr, workspace_size_limit, false)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); - - VLOG(3) << "FwdAlgo Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = perf_stat[i]; - VLOG(3) << stat.fwd_algo; - } - return perf_stat[0].fwd_algo; - }); + int find_count; + miopenConvAlgoPerf_t find_result; + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenFindConvolutionForwardAlgorithm( + args.handle, args.idesc.desc(), args.x->data(), + args.wdesc.desc(), args.w->data(), args.cdesc.desc(), + args.odesc.desc(), const_cast(args.o->data()), + kNUM_CUDNN_FWD_ALGS, &find_count, &find_result, + cudnn_workspace_ptr, workspace_size, false)); + }; + + if (!exhaustive_search && !deterministic) { + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); + algo = find_result.fwd_algo; + } else { + auto& temp = ctx.cuda_device_context(); + AlgorithmsCache& algo_cache = + *(framework::ConvSearchCache::Instance().GetForward()); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:" + << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" + << args.s << ", args.p" << args.p << ", args.d" << args.d; + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, + static_cast(args.cudnn_dtype), [&]() { + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); + return find_result.fwd_algo; + }); + } VLOG(3) << "choose algo " << algo; return algo; } - static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + static size_t GetWorkspaceSize(const ConvArgs& args) { size_t workspace_size = 0; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::miopenConvolutionForwardGetWorkSpaceSize( @@ -194,58 +189,51 @@ struct SearchAlgorithm { template static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, + bool deterministic, size_t workspace_size, const framework::ExecutionContext& ctx) { - auto dtype = platform::CudnnDataType::type; - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; - size_t workspace_size = 0; - bool has_got_workspace_size = true; algo_t algo; auto& dev_ctx = ctx.template device_context(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - AlgorithmsCache& algo_cache = - *(framework::ConvSearchCache::Instance().GetBackwardData()); - - auto x_dims = framework::vectorize(args.x->dims()); - auto w_dims = framework::vectorize(args.w->dims()); - - VLOG(10) << "miopenConvolutionFwdAlgoPerf_t" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; - - algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { - int returned_algo_count; - std::array perf_stat; - - auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenFindConvolutionBackwardDataAlgorithm( - args.handle, args.odesc.desc(), args.o->data(), - args.wdesc.desc(), args.w->data(), args.cdesc.desc(), - args.idesc.desc(), const_cast(args.x->data()), - kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, - perf_stat.data(), cudnn_workspace_ptr, workspace_size_limit, - false)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); - - VLOG(3) << "BwdDataAlgo Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = perf_stat[i]; - VLOG(3) << stat.bwd_data_algo; - } - - return perf_stat[0].bwd_data_algo; - }); + int find_count; + miopenConvAlgoPerf_t find_result; + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenFindConvolutionBackwardDataAlgorithm( + args.handle, args.odesc.desc(), args.o->data(), + args.wdesc.desc(), args.w->data(), args.cdesc.desc(), + args.idesc.desc(), const_cast(args.x->data()), + kNUM_CUDNN_BWD_DATA_ALGS, &find_count, &find_result, + cudnn_workspace_ptr, workspace_size, false)); + }; + + if (!exhaustive_search && !deterministic) { + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); + algo = find_result.bwd_data_algo; + } else { + AlgorithmsCache& algo_cache = + *(framework::ConvSearchCache::Instance().GetBackwardData()); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + VLOG(10) << "miopenConvolutionFwdAlgoPerf_t" + << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" + << args.s << ", args.p" << args.p << ", args.d" << args.d; + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, + static_cast(args.cudnn_dtype), [&]() { + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); + return find_result.bwd_data_algo; + }); + } VLOG(3) << "choose algo " << algo; return algo; } - static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + static size_t GetWorkspaceSize(const ConvArgs& args) { size_t workspace_size = 0; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::miopenConvolutionBackwardDataGetWorkSpaceSize( @@ -262,56 +250,51 @@ struct SearchAlgorithm { template static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, + bool deterministic, size_t workspace_size, const framework::ExecutionContext& ctx) { - auto dtype = platform::CudnnDataType::type; - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; - size_t workspace_size = 0; - bool has_got_workspace_size = true; algo_t algo; auto& dev_ctx = ctx.template device_context(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - AlgorithmsCache& algo_cache = - *(framework::ConvSearchCache::Instance().GetBackwardFilter()); - - auto x_dims = framework::vectorize(args.x->dims()); - auto w_dims = framework::vectorize(args.w->dims()); - - VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; - - algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { - int returned_algo_count; - std::array perf_stat; - auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload:: - miopenFindConvolutionBackwardWeightsAlgorithm( - args.handle, args.odesc.desc(), args.o->data(), - args.idesc.desc(), args.x->data(), args.cdesc.desc(), - args.wdesc.desc(), const_cast(args.w->data()), - kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, - perf_stat.data(), cudnn_workspace_ptr, - workspace_size_limit, false)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); - - VLOG(3) << "BwdFilterAlgo Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = perf_stat[i]; - VLOG(3) << stat.bwd_weights_algo; - } - return perf_stat[0].bwd_weights_algo; - }); + + int find_count; + miopenConvAlgoPerf_t find_result; + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenFindConvolutionBackwardWeightsAlgorithm( + args.handle, args.odesc.desc(), args.o->data(), + args.idesc.desc(), args.x->data(), args.cdesc.desc(), + args.wdesc.desc(), const_cast(args.w->data()), + kNUM_CUDNN_BWD_FILTER_ALGS, &find_count, &find_result, + cudnn_workspace_ptr, workspace_size, false)); + }; + + if (!exhaustive_search && !deterministic) { + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); + algo = find_result.bwd_weights_algo; + } else { + AlgorithmsCache& algo_cache = + *(framework::ConvSearchCache::Instance().GetBackwardFilter()); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:" + << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" + << args.s << ", args.p" << args.p << ", args.d" << args.d; + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, + static_cast(args.cudnn_dtype), [&]() { + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); + return find_result.bwd_weights_algo; + }); + } VLOG(3) << "choose algo " << algo; return algo; } - static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + static size_t GetWorkspaceSize(const ConvArgs& args) { size_t workspace_size = 0; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::miopenConvolutionBackwardWeightsGetWorkSpaceSize( diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu b/paddle/fluid/operators/conv_transpose_cudnn_op.cu index 376cefe502..5781dd18b7 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu @@ -244,13 +244,14 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { #ifdef PADDLE_WITH_HIP using search = SearchAlgorithm; + workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args)); + algo = search::Find(args, false, deterministic, workspace_size, ctx); #else using search = SearchAlgorithm; -#endif - algo = search::Find(args, false, deterministic, ctx); workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args, algo)); +#endif // ------------------- cudnn conv transpose forward --------------------- int input_offset = @@ -504,12 +505,16 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { platform::AllowTF32Cudnn(), c_groups); #ifdef PADDLE_WITH_HIP using search1 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search1::GetWorkspaceSize(args1)); + data_algo = + search1::Find(args1, false, deterministic, workspace_size, ctx); #else using search1 = SearchAlgorithm; -#endif data_algo = search1::Find(args1, false, deterministic, ctx); workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); +#endif } if (filter_grad) { @@ -522,12 +527,16 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { platform::AllowTF32Cudnn(), c_groups); #ifdef PADDLE_WITH_HIP using search2 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search2::GetWorkspaceSize(args2)); + filter_algo = + search2::Find(args2, false, deterministic, workspace_size, ctx); #else using search2 = SearchAlgorithm; -#endif filter_algo = search2::Find(args2, false, deterministic, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo)); +#endif } // ------------------- cudnn conv backward data --------------------- @@ -942,11 +951,14 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); #ifdef PADDLE_WITH_HIP using search1 = SearchAlgorithm; + workspace_size = search1::GetWorkspaceSize(args1); + bwd_algo1 = + search1::Find(args1, false, deterministic, workspace_size, ctx); #else using search1 = SearchAlgorithm; -#endif bwd_algo1 = search1::Find(args1, false, deterministic, ctx); workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1); +#endif } if (ddW) { @@ -958,12 +970,16 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); #ifdef PADDLE_WITH_HIP using search2 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search2::GetWorkspaceSize(args2)); + bwd_algo2 = + search2::Find(args2, false, deterministic, workspace_size, ctx); #else using search2 = SearchAlgorithm; -#endif bwd_algo2 = search2::Find(args2, false, deterministic, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, bwd_algo2)); +#endif } } @@ -978,12 +994,16 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { args3.cdesc.set(dtype, padding_common, strides, dilations, c_group); #ifdef PADDLE_WITH_HIP using search3 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search3::GetWorkspaceSize(args3)); + filter_algo = + search3::Find(args3, false, deterministic, workspace_size, ctx); #else using search3 = SearchAlgorithm; -#endif filter_algo = search3::Find(args3, false, deterministic, ctx); workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); +#endif } if (ddW && dX) { @@ -996,12 +1016,16 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { args4.cdesc.set(dtype, padding_common, strides, dilations, c_group); #ifdef PADDLE_WITH_HIP using search4 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search4::GetWorkspaceSize(args4)); + data_algo = + search4::Find(args4, false, deterministic, workspace_size, ctx); #else using search4 = SearchAlgorithm; -#endif data_algo = search4::Find(args4, false, deterministic, ctx); workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); +#endif } int i_n, i_c, i_d, i_h, i_w; diff --git a/paddle/fluid/platform/miopen_desc.h b/paddle/fluid/platform/miopen_desc.h index 7de713559a..c82e61ceb1 100644 --- a/paddle/fluid/platform/miopen_desc.h +++ b/paddle/fluid/platform/miopen_desc.h @@ -199,19 +199,24 @@ class FilterDescriptor { void set(const Tensor& tensor, const miopenTensorFormat_t format, const int groups = 1) { - auto dims = framework::vectorize(tensor.dims()); - std::vector transformed_dims; PADDLE_ENFORCE_EQ(format, MIOPEN_TENSOR_NCHW, platform::errors::InvalidArgument( "format should ONLY be NCHW in MIOPEN.")); - transformed_dims = dims; - // if (groups > 1) { - // transformed_dims[1] = transformed_dims[1] / groups; - // } - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSet4dTensorDescriptor( - (miopenTensorDescriptor_t)desc_.get(), ToCudnnDataType(tensor.type()), - transformed_dims[0], transformed_dims[1], transformed_dims[2], - transformed_dims[3])); + auto dims = framework::vectorize(tensor.dims()); + std::vector strides(dims.size()); + strides[dims.size() - 1] = 1; + for (int i = dims.size() - 2; i >= 0; i--) { + strides[i] = dims[i + 1] * strides[i + 1]; + } + std::vector dims_with_group(dims.begin(), dims.end()); + if (groups > 1) { + dims_with_group[1] = dims_with_group[1] / groups; + } + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor( + (miopenTensorDescriptor_t)(desc_.get()), ToCudnnDataType(tensor.type()), + static_cast(dims_with_group.size()), + const_cast(dims_with_group.data()), + const_cast(strides.data()))); } private: diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 9992efee1b..29c35d28d4 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -128,6 +128,8 @@ def create_test_cudnn_class(parent): class TestCUDNNCase(parent): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm( + ) else np.float64 cls_name = "{0}_{1}".format(parent.__name__, "CUDNN") TestCUDNNCase.__name__ = cls_name @@ -185,6 +187,8 @@ def create_test_cudnn_channel_last_class(parent): class TestCudnnChannelLastCase(parent): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm( + ) else np.float64 def init_data_format(self): self.data_format = "NHWC" @@ -264,6 +268,8 @@ def create_test_cudnn_padding_SAME_class(parent): class TestCUDNNPaddingSMAECase(parent): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm( + ) else np.float64 def init_paddings(self): self.pad = [1, 1] @@ -280,6 +286,8 @@ def create_test_cudnn_padding_VALID_class(parent): class TestCUDNNPaddingVALIDCase(parent): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm( + ) else np.float64 def init_paddings(self): self.pad = [1, 1] @@ -299,8 +307,7 @@ class TestConv2DOp(OpTest): self.use_mkldnn = False self.fuse_relu_before_depthwise_conv = False self.data_format = "AnyLayout" - # explicilty use float32 for ROCm, as MIOpen does not yet support float64 - self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 + self.dtype = np.float64 self.init_kernel_type() self.init_group() self.init_dilation() @@ -693,6 +700,7 @@ class TestCUDNNExhaustiveSearch(TestConv2DOp): def init_kernel_type(self): self.use_cudnn = True self.exhaustive_search = True + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 class TestConv2DOpError(unittest.TestCase): @@ -734,8 +742,7 @@ class TestConv2DOp_v2(OpTest): self.use_cuda = False self.use_mkldnn = False self.fuse_relu_before_depthwise_conv = False - # explicilty use float32 for ROCm, as MIOpen does not yet support float64 - self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 + self.dtype = np.float64 self.init_kernel_type() self.init_group() self.init_dilation() diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py index 1636019a62..59d1f3216e 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py @@ -135,6 +135,8 @@ def create_test_cudnn_class(parent): class TestCUDNNCase(parent): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm( + ) else np.float64 cls_name = "{0}_{1}".format(parent.__name__, "CUDNN") TestCUDNNCase.__name__ = cls_name @@ -169,6 +171,8 @@ def create_test_cudnn_padding_SAME_class(parent): class TestCUDNNPaddingSMAECase(parent): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm( + ) else np.float64 def init_paddings(self): self.pad = [1, 1, 1] @@ -185,6 +189,8 @@ def create_test_cudnn_padding_VALID_class(parent): class TestCUDNNPaddingVALIDCase(parent): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm( + ) else np.float64 def init_paddings(self): self.pad = [1, 1, 1] @@ -215,6 +221,8 @@ def create_test_cudnn_channel_last_class(parent): class TestCudnnChannelLastCase(parent): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm( + ) else np.float64 def init_data_format(self): self.data_format = "NDHWC" @@ -410,6 +418,7 @@ class TestWithDilation(TestConv3DOp): class TestCUDNN(TestConv3DOp): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -431,6 +440,7 @@ class TestFP16CUDNN(TestConv3DOp): class TestWithGroup1CUDNN(TestWithGroup1): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -452,6 +462,7 @@ class TestFP16WithGroup1CUDNN(TestWithGroup1): class TestWithGroup2CUDNN(TestWithGroup2): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -473,6 +484,7 @@ class TestFP16WithGroup2CUDNN(TestWithGroup2): class TestWith1x1CUDNN(TestWith1x1): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -494,6 +506,7 @@ class TestFP16With1x1CUDNN(TestWith1x1): class TestWithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1): def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -514,6 +527,7 @@ class TestCUDNNExhaustiveSearch(TestCUDNN): def init_kernel_type(self): self.use_cudnn = True self.exhaustive_search = True + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 # ---- test asymmetric padding ---- diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index 4649323b5b..13aa7d3d37 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -50,7 +50,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): def setUp(self): """Setup.""" #self.dtype = np.float32 - self.dtype = np.float64 + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.N = 8 self.C = 16 self.H = 32 @@ -92,7 +92,10 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): moving_variance_name='bn_moving_variance', data_layout=layout, is_test=only_forward) - bn = fluid.layers.cast(bn, 'float64') + if core.is_compiled_with_rocm(): + bn = fluid.layers.cast(bn, 'float32') + else: + bn = fluid.layers.cast(bn, 'float64') sigmoid = fluid.layers.sigmoid(bn) out = fluid.layers.reduce_sum(sigmoid) if not sync_bn: -- GitLab