From 72d99c5dcd081f61371986d98907fcaa2c5fdaba Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 2 Mar 2021 10:27:15 +0800 Subject: [PATCH] [ROCM] update fluid operators for rocm (part4), test=develop (#31225) --- .../tensorrt/plugin/pool_op_plugin.cu | 8 +- paddle/fluid/operators/conv_cudnn_op.cu | 227 +++++++++++- paddle/fluid/operators/conv_cudnn_op_cache.h | 11 +- paddle/fluid/operators/conv_miopen_helper.h | 325 ++++++++++++++++++ paddle/fluid/operators/conv_op.cc | 12 +- .../operators/conv_transpose_cudnn_op.cu | 177 +++++++++- paddle/fluid/operators/conv_transpose_op.cc | 6 +- paddle/fluid/operators/math/CMakeLists.txt | 11 +- paddle/fluid/operators/math/concat_test.cc | 2 +- paddle/fluid/operators/math/pooling.cc | 34 +- paddle/fluid/operators/math/pooling.cu | 66 ++-- paddle/fluid/operators/math/pooling.h | 44 +-- paddle/fluid/operators/pool_cudnn_op.cu.cc | 165 ++++++++- paddle/fluid/operators/pool_op.cc | 7 +- paddle/fluid/operators/pool_op.h | 25 +- paddle/fluid/operators/spp_op.h | 10 +- 16 files changed, 1006 insertions(+), 124 deletions(-) create mode 100644 paddle/fluid/operators/conv_miopen_helper.h diff --git a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu index 1fa5b3228e1..154f61a2b7c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu @@ -59,14 +59,14 @@ int PoolPlugin::enqueue(int batchSize, const void *const *inputs, paddle::operators::math::MaxPool, float> pool2d_forward; pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, - paddings_, pool_process, true, adaptive_, odatas[0], stream); + paddings_, true, adaptive_, odatas[0], stream, pool_process); } else if (pool_type_ == PoolType::avg) { paddle::operators::math::AvgPool pool_process; paddle::operators::math::Pool2dDirectCUDAFunctor< paddle::operators::math::AvgPool, float> pool2d_forward; pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, - paddings_, pool_process, true, adaptive_, odatas[0], stream); + paddings_, true, adaptive_, odatas[0], stream, pool_process); } return cudaGetLastError() != cudaSuccess; @@ -224,14 +224,14 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, paddle::operators::math::MaxPool, float> pool2d_forward; pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings, - pool_process, true, adaptive_, output, stream); + true, adaptive_, output, stream, pool_process); } else if (pool_type_ == "avg") { paddle::operators::math::AvgPool pool_process; paddle::operators::math::Pool2dDirectCUDAFunctor< paddle::operators::math::AvgPool, float> pool2d_forward; pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings, - pool_process, true, adaptive_, output, stream); + true, adaptive_, output, stream, pool_process); } return cudaGetLastError() != cudaSuccess; diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index 5ef22b81869..110bb69a140 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -19,11 +19,13 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/memory/memory.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/operators/conv_miopen_helper.h" +#else #include "paddle/fluid/operators/conv_cudnn_helper.h" -#include "paddle/fluid/operators/conv_cudnn_op_cache.h" +#endif #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/math/padding.h" -#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" @@ -78,6 +80,10 @@ class CUDNNConvOpKernel : public framework::OpKernel { auto dtype = platform::CudnnDataType::type; +#ifdef PADDLE_WITH_HIP + // HIP MIOPEN ONLY SUPPORT NCHW format + auto compute_format = DataLayout::kNCHW; +#else // Tensor Core introduced from Volta GPUs supports more faster conv op // with FP16 in NHWC data format. const bool compute_in_nhwc = @@ -86,6 +92,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { // cudnn will convert NCHW to NHWC automatically on Tensor Core. auto compute_format = compute_in_nhwc && channel_last ? DataLayout::kNHWC : DataLayout::kNCHW; +#endif VLOG(3) << "Compute ConvOp with cuDNN:" << " data_format=" << data_format << " compute_format=" << (compute_format == DataLayout::kNHWC ? "NHWC" : "NCHW"); @@ -240,10 +247,16 @@ class CUDNNConvOpKernel : public framework::OpKernel { auto layout_format = GetCudnnTensorFormat(layout); args.handle = handle; + +#ifdef PADDLE_WITH_HIP + args.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), groups); +#else args.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn()); +#endif -#if CUDNN_VERSION_MIN(7, 0, 1) +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1) // cudnn 7 can support groups, no need to do it manually // FIXME(typhoonzero): find a better way to disable groups // rather than setting it to 1. @@ -275,14 +288,18 @@ class CUDNNConvOpKernel : public framework::OpKernel { int group_offset_filter = transformed_filter_channel.numel() / groups; // ------------------- cudnn conv workspace --------------------- size_t workspace_size = 0; // final workspace to allocate. - // ------------------- cudnn conv algorithm --------------------- +// ------------------- cudnn conv algorithm --------------------- +#ifdef PADDLE_WITH_HIP + miopenConvFwdAlgorithm_t algo{}; + using search = SearchAlgorithm; +#else cudnnConvolutionFwdAlgo_t algo{}; - using search = SearchAlgorithm; +#endif algo = search::Find(args, exhaustive_search, false, ctx); workspace_size = search::GetWorkspaceSize(args, algo); -#if CUDNN_VERSION_MIN(7, 0, 1) +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1) // when groups > 1, SearchAlgorithm find algo is CUDNN_CONVOLUTION_\ // FWD_ALGO_WINOGRAD_NONFUSED, but this kind of algorithm is unstable // in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\ @@ -296,10 +313,22 @@ class CUDNNConvOpKernel : public framework::OpKernel { ScalingParamType alpha = 1.0f; ScalingParamType beta = 0.0f; - // NOTE(zhiqiu): inplace addto is not supportted in double grad yet. - // ScalingParamType beta = ctx.Attr("use_addto") ? 1.0f : 0.0f; - // VLOG(4) << "Conv: use_addto = " << ctx.Attr("use_addto"); - +// NOTE(zhiqiu): inplace addto is not supportted in double grad yet. +// ScalingParamType beta = ctx.Attr("use_addto") ? 1.0f : 0.0f; +// VLOG(4) << "Conv: use_addto = " << ctx.Attr("use_addto"); + +#ifdef PADDLE_WITH_HIP + workspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionForward( + handle, &alpha, args.idesc.desc(), input_data, + args.wdesc.desc(), filter_data, args.cdesc.desc(), algo, + &beta, args.odesc.desc(), output_data, workspace_ptr, + workspace_size)); + }, + workspace_size); +#else for (int i = 0; i < groups; i++) { workspace_handle.RunFunc( [&](void* workspace_ptr) { @@ -313,6 +342,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { }, workspace_size); } +#endif if (channel_last && compute_format == DataLayout::kNCHW) { TransToChannelLast( @@ -361,10 +391,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); auto dtype = platform::CudnnDataType::type; + +#ifdef PADDLE_WITH_HIP + // HIP MIOPEN ONLY SUPPORT NCHW format + auto compute_format = DataLayout::kNCHW; +#else const bool compute_in_nhwc = dtype == CUDNN_DATA_HALF && IsVoltaOrLater(dev_ctx); auto compute_format = compute_in_nhwc && channel_last ? DataLayout::kNHWC : DataLayout::kNCHW; +#endif VLOG(3) << "Compute ConvGradOp with cuDNN:" << " data_format=" << data_format << " compute_format=" << (compute_format == DataLayout::kNHWC ? "NHWC" : "NCHW"); @@ -581,16 +617,23 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { int group_offset_in = i_c / groups * i_h * i_w * i_d; int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = transformed_filter_channel.numel() / groups; - // ------------------- cudnn backward algorithm --------------------- +// ------------------- cudnn backward algorithm --------------------- +#ifdef PADDLE_WITH_HIP + miopenConvBwdDataAlgorithm_t data_algo = + static_cast(0); + miopenConvBwdWeightsAlgorithm_t filter_algo = + static_cast(0); +#else cudnnConvolutionBwdDataAlgo_t data_algo = static_cast(0); cudnnConvolutionBwdFilterAlgo_t filter_algo = static_cast(0); +#endif size_t workspace_size = 0; int iwo_groups = groups; int c_groups = 1; -#if CUDNN_VERSION_MIN(7, 0, 1) +#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1) iwo_groups = 1; c_groups = groups; groups = 1; @@ -607,7 +650,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { args1.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_groups); +#ifdef PADDLE_WITH_HIP + using search1 = SearchAlgorithm; +#else using search1 = SearchAlgorithm; +#endif data_algo = search1::Find(args1, exhaustive_search, deterministic, ctx); workspace_size = @@ -624,8 +671,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { args2.odesc.set(transformed_output_grad_channel, layout_tensor); args2.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_groups); - +#ifdef PADDLE_WITH_HIP + using search2 = SearchAlgorithm; +#else using search2 = SearchAlgorithm; +#endif filter_algo = search2::Find(args2, exhaustive_search, deterministic, ctx); workspace_size = std::max(workspace_size, @@ -641,6 +691,20 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // When beta is 0, it is unnecessary to reset input_grad. // When beta is 1, the output cannot be reset since addt strategy used. for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardData( + handle, &alpha, args1.odesc.desc(), + output_grad_data + i * group_offset_out, + args1.wdesc.desc(), filter_data + i * group_offset_filter, + args1.cdesc.desc(), data_algo, &beta, args1.idesc.desc(), + transformed_input_grad_data + i * group_offset_in, + cudnn_workspace_ptr, workspace_size)); + }, + workspace_size); +#else workspace_handle.RunFunc( [&](void* cudnn_workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -653,6 +717,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { transformed_input_grad_data + i * group_offset_in)); }, workspace_size); +#endif } if (!is_sys_pad) { @@ -688,6 +753,21 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { if (filter_grad) { // Because beta is zero, it is unnecessary to reset filter_grad. for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardWeights( + handle, &alpha, args2.odesc.desc(), + output_grad_data + i * group_offset_out, + args2.idesc.desc(), input_data + i * group_offset_in, + args2.cdesc.desc(), filter_algo, &beta, + args2.wdesc.desc(), + filter_grad_data + i * group_offset_filter, + cudnn_workspace_ptr, workspace_size)); + }, + workspace_size); +#else workspace_handle.RunFunc( [&](void* cudnn_workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -700,6 +780,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { filter_grad_data + i * group_offset_filter)); }, workspace_size); +#endif } if (compute_format == DataLayout::kNHWC) { @@ -930,7 +1011,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { int iwo_group = groups; int c_group = 1; -#if CUDNN_VERSION_MIN(7, 0, 1) +#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1) iwo_group = 1; c_group = groups; groups = 1; @@ -960,6 +1041,16 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { &transformed_dX, ddW, &transformed_dO_channel, strides, padding_common, dilations, dtype}; +#ifdef PADDLE_WITH_HIP + miopenConvFwdAlgorithm_t fwd_algo1 = + static_cast(0); + miopenConvFwdAlgorithm_t fwd_algo2 = + static_cast(0); + miopenConvBwdDataAlgorithm_t data_algo = + static_cast(0); + miopenConvBwdWeightsAlgorithm_t filter_algo = + static_cast(0); +#else cudnnConvolutionFwdAlgo_t fwd_algo1 = static_cast(0); cudnnConvolutionFwdAlgo_t fwd_algo2 = @@ -968,6 +1059,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { static_cast(0); cudnnConvolutionBwdFilterAlgo_t filter_algo = static_cast(0); +#endif auto layout = GetCudnnTensorFormat(DataLayout::kNCHW); @@ -986,7 +1078,11 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args1.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_group); +#ifdef PADDLE_WITH_HIP + using search1 = SearchAlgorithm; +#else using search1 = SearchAlgorithm; +#endif fwd_algo1 = search1::Find(args1, exhaustive_search, false, ctx); workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); } @@ -1002,7 +1098,11 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args2.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_group); +#ifdef PADDLE_WITH_HIP + using search2 = SearchAlgorithm; +#else using search2 = SearchAlgorithm; +#endif fwd_algo2 = search2::Find(args2, exhaustive_search, false, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2)); @@ -1020,7 +1120,11 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args3.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_group); +#ifdef PADDLE_WITH_HIP + using search3 = SearchAlgorithm; +#else using search3 = SearchAlgorithm; +#endif filter_algo = search3::Find(args3, exhaustive_search, deterministic, ctx); workspace_size = std::max(workspace_size, @@ -1037,7 +1141,11 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args4.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_group); +#ifdef PADDLE_WITH_HIP + using search4 = SearchAlgorithm; +#else using search4 = SearchAlgorithm; +#endif data_algo = search4::Find(args4, exhaustive_search, deterministic, ctx); workspace_size = @@ -1063,13 +1171,26 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { // ScalingParamType beta = ctx.Attr("use_addto") ? 1.0f : // 0.0f; // VLOG(4) << "Conv_grad_grad: use_addto = " << ctx.Attr("use_addto"); - auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); if (ddO) { if (ddX) { ddx = transformed_ddX.data(); for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionForward( + handle, &alpha, args1.idesc.desc(), + ddx + i * group_offset_in, args1.wdesc.desc(), + w + i * group_offset_filter, args1.cdesc.desc(), + fwd_algo1, &beta, args1.odesc.desc(), + transformed_ddy_channel + i * group_offset_out, + workspace_ptr, workspace_size)); + }, + workspace_size); +#else wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1082,10 +1203,26 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { transformed_ddy_channel + i * group_offset_out)); }, workspace_size); +#endif } } if (ddW) { for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + // MIOPEN ONLY support beta to be 0.0f + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionForward( + handle, &alpha, args2.idesc.desc(), + x + i * group_offset_in, args2.wdesc.desc(), + ddw + i * group_offset_filter, args2.cdesc.desc(), + fwd_algo2, &beta, args2.odesc.desc(), + transformed_ddy_channel + i * group_offset_out, + workspace_ptr, workspace_size)); + }, + workspace_size); +#else wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1098,6 +1235,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { transformed_ddy_channel + i * group_offset_out)); }, workspace_size); +#endif } } if (channel_last) { @@ -1109,6 +1247,20 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { if (dW && ddX) { ddx = transformed_ddX.data(); for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardWeights( + handle, &alpha, args3.odesc.desc(), + transformed_dy_channel + i * group_offset_out, + args3.idesc.desc(), ddx + i * group_offset_in, + args3.cdesc.desc(), filter_algo, &beta, + args3.wdesc.desc(), dw + i * group_offset_filter, + workspace_ptr, workspace_size)); + }, + workspace_size); +#else wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1121,12 +1273,27 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { dw + i * group_offset_filter)); }, workspace_size); +#endif } } if (dX && ddW) { ddw = ddW->data(); for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardData( + handle, &alpha, args4.odesc.desc(), + transformed_dy_channel + i * group_offset_out, + args4.wdesc.desc(), ddw + i * group_offset_filter, + args4.cdesc.desc(), data_algo, &beta, args4.idesc.desc(), + transformed_dx + i * group_offset_in, workspace_ptr, + workspace_size)); + }, + workspace_size); +#else wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1139,6 +1306,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { transformed_dx + i * group_offset_in)); }, workspace_size); +#endif } if (!is_sys_pad) { @@ -1170,6 +1338,34 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { } // namespace paddle namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvOpKernel, + paddle::operators::CUDNNConvOpKernel); +REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel); +REGISTER_OP_KERNEL( + conv2d_grad_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel); + +REGISTER_OP_CUDA_KERNEL( + depthwise_conv2d_grad_grad, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel); + +REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvOpKernel, + paddle::operators::CUDNNConvOpKernel); +REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvGradOpKernel); +REGISTER_OP_KERNEL( + conv3d_grad_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel); +#else REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, paddle::operators::CUDNNConvOpKernel, @@ -1202,3 +1398,4 @@ REGISTER_OP_KERNEL( paddle::operators::CUDNNConvDoubleGradOpKernel, paddle::operators::CUDNNConvDoubleGradOpKernel, paddle::operators::CUDNNConvDoubleGradOpKernel); +#endif diff --git a/paddle/fluid/operators/conv_cudnn_op_cache.h b/paddle/fluid/operators/conv_cudnn_op_cache.h index de883580dc0..ddddb7f8641 100644 --- a/paddle/fluid/operators/conv_cudnn_op_cache.h +++ b/paddle/fluid/operators/conv_cudnn_op_cache.h @@ -18,7 +18,11 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/operator.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#else #include "paddle/fluid/platform/cudnn_helper.h" +#endif DECLARE_uint64(conv_workspace_size_limit); DECLARE_bool(cudnn_exhaustive_search); @@ -26,8 +30,11 @@ DECLARE_int64(cudnn_exhaustive_search_times); namespace paddle { namespace operators { - -#if CUDNN_VERSION_MIN(6, 0, 5) +#ifdef PADDLE_WITH_HIP +static constexpr size_t kNUM_CUDNN_FWD_ALGS = 1; +static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 1; +static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 1; +#elif CUDNN_VERSION_MIN(6, 0, 5) static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; diff --git a/paddle/fluid/operators/conv_miopen_helper.h b/paddle/fluid/operators/conv_miopen_helper.h new file mode 100644 index 00000000000..44ead95a355 --- /dev/null +++ b/paddle/fluid/operators/conv_miopen_helper.h @@ -0,0 +1,325 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/conv_search_cache.h" +#include "paddle/fluid/framework/operator_kernel_configs.h" +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" +#include "paddle/fluid/platform/miopen_desc.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = platform::DataLayout; +template +using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; +using framework::AlgorithmsCache; +static inline void GetNCDHW(const framework::DDim& dims, + const DataLayout& layout, int* N, int* C, int* D, + int* H, int* W) { + *N = dims[0]; + *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; + int i = layout == DataLayout::kNCHW ? 0 : 1; + if (dims.size() == 5) { + *D = dims[2 - i]; + *H = dims[3 - i]; + *W = dims[4 - i]; + } else { + *D = 1; + *H = dims[2 - i]; + *W = dims[3 - i]; + } +} + +template +static void RemovePaddingSlice(const framework::ExecutionContext& context, + const Tensor* input, Tensor* out, + const std::vector& starts, + const std::vector& axes) { + auto& place = + *context.template device_context().eigen_device(); + auto in_dims = input->dims(); + auto new_out_dims = out->dims(); + auto offsets = Eigen::array(); + auto extents = Eigen::array(); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = new_out_dims[i]; + } + + int start; + for (size_t i = 0; i < axes.size(); ++i) { + start = starts[i]; + if (start < 0) { + start = (start + in_dims[axes[i]]); + } + start = std::max(start, 0); + offsets[axes[i]] = start; + } + auto in_t = + framework::EigenTensor::From( + *input); + + auto out_t = + framework::EigenTensor::From( + *out, new_out_dims); + out_t.device(place) = in_t.slice(offsets, extents); +} + +template +std::ostream& operator<<(std::ostream& out, const std::vector& v) { + out << "["; + for (auto const& tmp : v) out << tmp << ","; + out << "]"; + return out; +} + +using framework::ConvSearchCache; + +struct ConvArgs { + miopenHandle_t handle; + platform::TensorDescriptor idesc, odesc; + platform::FilterDescriptor wdesc; + platform::ConvolutionDescriptor cdesc; + const framework::Tensor *x, *w, *o; + miopenDataType_t cudnn_dtype; + + // strides + std::vector s; + // paddings + std::vector p; + // dilations + std::vector d; + + ConvArgs(const framework::Tensor* x, const framework::Tensor* w, + const framework::Tensor* o, const std::vector s, + const std::vector p, const std::vector d, + miopenDataType_t dtype) + : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} +}; + +template +struct SearchAlgorithm {}; + +template <> +struct SearchAlgorithm { + using perf_t = miopenConvAlgoPerf_t; + using algo_t = miopenConvFwdAlgorithm_t; + + template + static algo_t Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, + const framework::ExecutionContext& ctx) { + auto dtype = platform::CudnnDataType::type; + bool has_got_workspace_size = true; + size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; + size_t workspace_size = 0; + algo_t algo; + + auto& dev_ctx = ctx.template device_context(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + + auto& temp = ctx.cuda_device_context(); + AlgorithmsCache& algo_cache = + *(framework::ConvSearchCache::Instance().GetForward()); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:" + << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" + << args.s << ", args.p" << args.p << ", args.d" << args.d; + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, + static_cast(args.cudnn_dtype), [&]() { + int returned_algo_count; + std::array perf_stat; + + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenFindConvolutionForwardAlgorithm( + args.handle, args.idesc.desc(), args.x->data(), + args.wdesc.desc(), args.w->data(), args.cdesc.desc(), + args.odesc.desc(), const_cast(args.o->data()), + kNUM_CUDNN_FWD_ALGS, &returned_algo_count, perf_stat.data(), + cudnn_workspace_ptr, workspace_size_limit, false)); + }; + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); + + VLOG(3) << "FwdAlgo Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = perf_stat[i]; + VLOG(3) << stat.fwd_algo; + } + return perf_stat[0].fwd_algo; + }); + VLOG(3) << "choose algo " << algo; + return algo; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + size_t workspace_size = 0; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionForwardGetWorkSpaceSize( + args.handle, args.wdesc.desc(), args.idesc.desc(), + args.cdesc.desc(), args.odesc.desc(), &workspace_size)); + return workspace_size; + } +}; + +template <> +struct SearchAlgorithm { + using perf_t = miopenConvAlgoPerf_t; + using algo_t = miopenConvBwdDataAlgorithm_t; + + template + static algo_t Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, + const framework::ExecutionContext& ctx) { + auto dtype = platform::CudnnDataType::type; + size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; + size_t workspace_size = 0; + bool has_got_workspace_size = true; + algo_t algo; + + auto& dev_ctx = ctx.template device_context(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + + AlgorithmsCache& algo_cache = + *(framework::ConvSearchCache::Instance().GetBackwardData()); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + VLOG(10) << "miopenConvolutionFwdAlgoPerf_t" + << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" + << args.s << ", args.p" << args.p << ", args.d" << args.d; + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, + static_cast(args.cudnn_dtype), [&]() { + int returned_algo_count; + std::array perf_stat; + + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenFindConvolutionBackwardDataAlgorithm( + args.handle, args.odesc.desc(), args.o->data(), + args.wdesc.desc(), args.w->data(), args.cdesc.desc(), + args.idesc.desc(), const_cast(args.x->data()), + kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, + perf_stat.data(), cudnn_workspace_ptr, workspace_size_limit, + false)); + }; + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); + + VLOG(3) << "BwdDataAlgo Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = perf_stat[i]; + VLOG(3) << stat.bwd_data_algo; + } + + return perf_stat[0].bwd_data_algo; + }); + VLOG(3) << "choose algo " << algo; + return algo; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + size_t workspace_size = 0; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardDataGetWorkSpaceSize( + args.handle, args.odesc.desc(), args.wdesc.desc(), + args.cdesc.desc(), args.idesc.desc(), &workspace_size)); + return workspace_size; + } +}; + +template <> +struct SearchAlgorithm { + using perf_t = miopenConvAlgoPerf_t; + using algo_t = miopenConvBwdWeightsAlgorithm_t; + + template + static algo_t Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, + const framework::ExecutionContext& ctx) { + auto dtype = platform::CudnnDataType::type; + size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; + size_t workspace_size = 0; + bool has_got_workspace_size = true; + algo_t algo; + + auto& dev_ctx = ctx.template device_context(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + AlgorithmsCache& algo_cache = + *(framework::ConvSearchCache::Instance().GetBackwardFilter()); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:" + << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" + << args.s << ", args.p" << args.p << ", args.d" << args.d; + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, + static_cast(args.cudnn_dtype), [&]() { + int returned_algo_count; + std::array perf_stat; + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload:: + miopenFindConvolutionBackwardWeightsAlgorithm( + args.handle, args.odesc.desc(), args.o->data(), + args.idesc.desc(), args.x->data(), args.cdesc.desc(), + args.wdesc.desc(), const_cast(args.w->data()), + kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, + perf_stat.data(), cudnn_workspace_ptr, + workspace_size_limit, false)); + }; + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); + + VLOG(3) << "BwdFilterAlgo Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = perf_stat[i]; + VLOG(3) << stat.bwd_weights_algo; + } + return perf_stat[0].bwd_weights_algo; + }); + VLOG(3) << "choose algo " << algo; + return algo; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + size_t workspace_size = 0; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardWeightsGetWorkSpaceSize( + args.handle, args.odesc.desc(), args.idesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), &workspace_size)); + return workspace_size; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index dd7bfbdaefe..f3dd0dcb46c 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -21,9 +21,13 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_helper.h" #endif + +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif + #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -149,7 +153,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( "AnyLayout"; // todo enable data layout when it's ready framework::DataLayout layout = framework::StringToDataLayout(data_format); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library = framework::LibraryType::kCUDNN; } @@ -559,7 +563,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } @@ -744,7 +748,7 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu b/paddle/fluid/operators/conv_transpose_cudnn_op.cu index edf00eb2ba9..376cefe5025 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu @@ -15,11 +15,14 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/operators/conv_miopen_helper.h" +#else #include "paddle/fluid/operators/conv_cudnn_helper.h" +#endif #include "paddle/fluid/operators/conv_transpose_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/padding.h" -#include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { namespace operators { @@ -212,7 +215,11 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { } size_t workspace_size = 0; +#ifdef PADDLE_WITH_HIP + miopenConvBwdDataAlgorithm_t algo{}; +#else cudnnConvolutionBwdDataAlgo_t algo{}; +#endif // ------------------- cudnn conv algorithm --------------------- auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); @@ -235,7 +242,12 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { args.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_groups); +#ifdef PADDLE_WITH_HIP + using search = SearchAlgorithm; +#else using search = SearchAlgorithm; +#endif + algo = search::Find(args, false, deterministic, ctx); workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args, algo)); @@ -250,6 +262,17 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { ScalingParamType beta = 0.0f; auto workspace_handle = dev_ctx.cudnn_workspace_handle(); for (int g = 0; g < groups; g++) { +#ifdef PADDLE_WITH_HIP + auto cudnn_func = [&](void* cudnn_workspace) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardData( + handle, &alpha, args.odesc.desc(), + input_data + input_offset * g, args.wdesc.desc(), + filter_data + filter_offset * g, args.cdesc.desc(), algo, &beta, + args.idesc.desc(), transformed_output_data + output_offset * g, + cudnn_workspace, workspace_size)); + }; +#else // PADDLE_WITH_HIP auto cudnn_func = [&](void* cudnn_workspace) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionBackwardData( @@ -259,6 +282,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { cudnn_workspace, workspace_size, &beta, args.idesc.desc(), transformed_output_data + output_offset * g)); }; +#endif // PADDLE_WITH_HIP workspace_handle.RunFunc(cudnn_func, workspace_size); } if (!is_sys_pad && strides.size() == 2U) { @@ -449,8 +473,14 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { padding_common, dilations, dtype}; + +#ifdef PADDLE_WITH_HIP + miopenConvFwdAlgorithm_t data_algo{}; + miopenConvBwdWeightsAlgorithm_t filter_algo{}; +#else cudnnConvolutionFwdAlgo_t data_algo{}; cudnnConvolutionBwdFilterAlgo_t filter_algo{}; +#endif auto layout_tensor = GetCudnnTensorFormat(layout); size_t workspace_size = 0; @@ -472,7 +502,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { args1.odesc.set(input_transpose, iwo_groups); args1.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_groups); +#ifdef PADDLE_WITH_HIP + using search1 = SearchAlgorithm; +#else using search1 = SearchAlgorithm; +#endif data_algo = search1::Find(args1, false, deterministic, ctx); workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); @@ -486,7 +520,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { args2.odesc.set(input_transpose, iwo_groups); args2.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_groups); +#ifdef PADDLE_WITH_HIP + using search2 = SearchAlgorithm; +#else using search2 = SearchAlgorithm; +#endif filter_algo = search2::Find(args2, false, deterministic, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo)); @@ -504,6 +542,18 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { if (input_grad) { // Because beta is zero, it is unnecessary to reset input_grad. for (int g = 0; g < groups; g++) { +#ifdef PADDLE_WITH_HIP + auto cudnn_func = [&](void* cudnn_workspace) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionForward( + handle, &alpha, args1.idesc.desc(), + output_grad_data + output_grad_offset * g, args1.wdesc.desc(), + filter_data + filter_offset * g, args1.cdesc.desc(), + data_algo, &beta, args1.odesc.desc(), + input_grad_data + input_offset * g, cudnn_workspace, + workspace_size)); + }; +#else // PADDLE_WITH_HIP auto cudnn_func = [&](void* cudnn_workspace) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionForward( @@ -513,6 +563,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { data_algo, cudnn_workspace, workspace_size, &beta, args1.odesc.desc(), input_grad_data + input_offset * g)); }; +#endif // PADDLE_WITH_HIP workspace_handle.RunFunc(cudnn_func, workspace_size); } @@ -540,6 +591,18 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { // Because beta is zero, it is unnecessary to reset filter_grad. // Gradient with respect to the filter for (int g = 0; g < groups; g++) { +#ifdef PADDLE_WITH_HIP + auto cudnn_func = [&](void* cudnn_workspace) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardWeights( + handle, &alpha, args2.odesc.desc(), + input_data + input_offset * g, args2.idesc.desc(), + output_grad_data + output_grad_offset * g, args2.cdesc.desc(), + filter_algo, &beta, args2.wdesc.desc(), + filter_grad_data + filter_offset * g, cudnn_workspace, + workspace_size)); + }; +#else // PADDLE_WITH_HIP auto cudnn_func = [&](void* cudnn_workspace) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionBackwardFilter( @@ -549,6 +612,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { filter_algo, cudnn_workspace, workspace_size, &beta, args2.wdesc.desc(), filter_grad_data + filter_offset * g)); }; +#endif // PADDLE_WITH_HIP workspace_handle.RunFunc(cudnn_func, workspace_size); } } @@ -840,7 +904,16 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { ConvArgs args4{ &transformed_dO, ddW, &transformed_dX_channel, strides, padding_common, dilations, dtype}; - +#ifdef PADDLE_WITH_HIP + miopenConvBwdDataAlgorithm_t bwd_algo1 = + static_cast(0); + miopenConvBwdDataAlgorithm_t bwd_algo2 = + static_cast(0); + miopenConvFwdAlgorithm_t data_algo = + static_cast(0); + miopenConvBwdWeightsAlgorithm_t filter_algo = + static_cast(0); +#else cudnnConvolutionBwdDataAlgo_t bwd_algo1 = static_cast(0); cudnnConvolutionBwdDataAlgo_t bwd_algo2 = @@ -849,6 +922,7 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { static_cast(0); cudnnConvolutionBwdFilterAlgo_t filter_algo = static_cast(0); +#endif auto layout = GetCudnnTensorFormat(platform::DataLayout::kNCHW); @@ -866,7 +940,11 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { args1.wdesc.set(*W, layout, iwo_group); args1.odesc.set(transformed_ddX, iwo_group); args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); +#ifdef PADDLE_WITH_HIP + using search1 = SearchAlgorithm; +#else using search1 = SearchAlgorithm; +#endif bwd_algo1 = search1::Find(args1, false, deterministic, ctx); workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1); } @@ -878,7 +956,11 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { args2.wdesc.set(*ddW, layout, iwo_group); args2.odesc.set(transformed_X, iwo_group); args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); +#ifdef PADDLE_WITH_HIP + using search2 = SearchAlgorithm; +#else using search2 = SearchAlgorithm; +#endif bwd_algo2 = search2::Find(args2, false, deterministic, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, bwd_algo2)); @@ -894,8 +976,11 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { args3.odesc.set(transformed_ddX_channel, iwo_group); args3.cdesc.set(dtype, padding_common, strides, dilations, c_group); - +#ifdef PADDLE_WITH_HIP + using search3 = SearchAlgorithm; +#else using search3 = SearchAlgorithm; +#endif filter_algo = search3::Find(args3, false, deterministic, ctx); workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); @@ -909,8 +994,11 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { args4.wdesc.set(*ddW, layout, iwo_group); args4.odesc.set(transformed_dX_channel, iwo_group); args4.cdesc.set(dtype, padding_common, strides, dilations, c_group); - +#ifdef PADDLE_WITH_HIP + using search4 = SearchAlgorithm; +#else using search4 = SearchAlgorithm; +#endif data_algo = search4::Find(args4, false, deterministic, ctx); workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); @@ -939,6 +1027,20 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { if (ddX) { ddx = transformed_ddX.data(); for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardData( + handle, &alpha, args1.odesc.desc(), + ddx + i * group_offset_in, args1.wdesc.desc(), + w + i * group_offset_filter, args1.cdesc.desc(), + bwd_algo1, &beta, args1.idesc.desc(), + transformed_ddy_channel + i * group_offset_out, + workspace_ptr, workspace_size)); + }, + workspace_size); +#else // PADDLE_WITH_HIP wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -951,10 +1053,25 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { transformed_ddy_channel + i * group_offset_out)); }, workspace_size); +#endif // PADDLE_WITH_HIP } } if (ddW) { for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardData( + handle, &alpha, args2.odesc.desc(), + x + i * group_offset_in, args2.wdesc.desc(), + ddw + i * group_offset_filter, args2.cdesc.desc(), + bwd_algo2, &alpha, args2.idesc.desc(), + transformed_ddy_channel + i * group_offset_out, + workspace_ptr, workspace_size)); + }, + workspace_size); +#else // PADDLE_WITH_HIP wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -967,6 +1084,7 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { transformed_ddy_channel + i * group_offset_out)); }, workspace_size); +#endif // PADDLE_WITH_HIP } } if ((!is_sys_pad) && (!channel_last)) { @@ -997,6 +1115,20 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { if (dW && ddX) { ddx = transformed_ddX_channel.data(); for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionBackwardWeights( + handle, &alpha, args3.odesc.desc(), + ddx + i * group_offset_in, args3.idesc.desc(), + transformed_dy_channel + i * group_offset_out, + args3.cdesc.desc(), filter_algo, &beta, + args3.wdesc.desc(), dw + i * group_offset_filter, + workspace_ptr, workspace_size)); + }, + workspace_size); +#else // PADDLE_WITH_HIP wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1009,12 +1141,27 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { dw + i * group_offset_filter)); }, workspace_size); +#endif // PADDLE_WITH_HIP } } if (dX && ddW) { ddw = ddW->data(); for (int i = 0; i < groups; i++) { +#ifdef PADDLE_WITH_HIP + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenConvolutionForward( + handle, &alpha, args4.idesc.desc(), + transformed_dy_channel + i * group_offset_out, + args4.wdesc.desc(), ddw + i * group_offset_filter, + args4.cdesc.desc(), data_algo, &beta, args4.odesc.desc(), + transformed_dx + i * group_offset_in, workspace_ptr, + workspace_size)); + }, + workspace_size); +#else // PADDLE_WITH_HIP wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -1027,6 +1174,7 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { transformed_dx + i * group_offset_in)); }, workspace_size); +#endif // PADDLE_WITH_HIP } if (channel_last) { TransToChannelLast( @@ -1042,6 +1190,26 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_KERNEL(conv2d_transpose, CUDNN, ::paddle::platform::CUDAPlace, + ops::CUDNNConvTransposeOpKernel, + ops::CUDNNConvTransposeOpKernel); +REGISTER_OP_KERNEL(conv2d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::CUDNNConvTransposeGradOpKernel, + ops::CUDNNConvTransposeGradOpKernel); +REGISTER_OP_KERNEL( + conv2d_transpose_grad_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvTransposeDoubleGradOpKernel, + paddle::operators::CUDNNConvTransposeDoubleGradOpKernel); + +REGISTER_OP_KERNEL(conv3d_transpose, CUDNN, ::paddle::platform::CUDAPlace, + ops::CUDNNConvTransposeOpKernel, + ops::CUDNNConvTransposeOpKernel); +REGISTER_OP_KERNEL(conv3d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::CUDNNConvTransposeGradOpKernel, + ops::CUDNNConvTransposeGradOpKernel); +#else REGISTER_OP_KERNEL(conv2d_transpose, CUDNN, ::paddle::platform::CUDAPlace, ops::CUDNNConvTransposeOpKernel, ops::CUDNNConvTransposeOpKernel, @@ -1064,3 +1232,4 @@ REGISTER_OP_KERNEL(conv3d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, ops::CUDNNConvTransposeGradOpKernel, ops::CUDNNConvTransposeGradOpKernel, ops::CUDNNConvTransposeGradOpKernel); +#endif diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index dc4b416a609..4ea936d5104 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -183,7 +183,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto& dev_ctx = ctx.template device_context(); use_cudnn &= dev_ctx.cudnn_handle() != nullptr; @@ -481,7 +481,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto& dev_ctx = ctx.template device_context(); use_cudnn &= dev_ctx.cudnn_handle() != nullptr; @@ -581,7 +581,7 @@ framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto& dev_ctx = ctx.template device_context(); use_cudnn &= dev_ctx.cudnn_handle() != nullptr; diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 2430e68225c..fdbc0c68525 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -28,15 +28,12 @@ function(math_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) list(APPEND cu_srcs ${TARGET}.cu.cc) endif() - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu) - list(APPEND hip_srcs ${TARGET}.hip.cu) - endif() list(LENGTH cc_srcs cc_srcs_len) if (WITH_GPU) nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) - elseif (WITH_ROCM_PLATFORM AND (${hip_srcs} MATCHES ".*\\.hip.cu$")) - hip_library_ops(${TARGET} SRCS ${cc_srcs} ${hip_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + elseif (WITH_ROCM) + hip_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) elseif(${cc_srcs_len} GREATER 0) cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) endif() @@ -89,6 +86,10 @@ if(WITH_GPU) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu.cc DEPS selected_rows_functor math_function) endif() +if(WITH_ROCM) + hip_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) + hip_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu.cc DEPS selected_rows_functor math_function) +endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) if(WITH_TESTING AND TEST im2col_test) diff --git a/paddle/fluid/operators/math/concat_test.cc b/paddle/fluid/operators/math/concat_test.cc index 094e2059c4d..011c85caf04 100644 --- a/paddle/fluid/operators/math/concat_test.cc +++ b/paddle/fluid/operators/math/concat_test.cc @@ -442,7 +442,7 @@ void TestConcatMain() { TEST(math, concat) { TestConcatMain(); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TestConcatMain(); #endif diff --git a/paddle/fluid/operators/math/pooling.cc b/paddle/fluid/operators/math/pooling.cc index 4df49a1b698..f2e5e955ec4 100644 --- a/paddle/fluid/operators/math/pooling.cc +++ b/paddle/fluid/operators/math/pooling.cc @@ -30,8 +30,9 @@ class Pool2dFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_process, - bool exclusive, bool adaptive, framework::Tensor* output) { + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* output, + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -104,8 +105,8 @@ class Pool2dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_process, - bool exclusive, bool adaptive, framework::Tensor* output) { + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* output, PoolProcess pool_process) { bool channel_last = (data_format == "NHWC"); const int batch_size = input.dims()[0]; @@ -249,8 +250,8 @@ class Pool2dGradFunctor { const platform::CPUDeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_grad_process, - bool exclusive, bool adaptive, framework::Tensor* input_grad) { + const std::vector& paddings, bool exclusive, bool adaptive, + framework::Tensor* input_grad, PoolProcess pool_grad_process) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -328,8 +329,8 @@ class Pool2dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, const std::string data_format, - PoolProcess pool_grad_process, bool exclusive, bool adaptive, - framework::Tensor* input_grad) { + bool exclusive, bool adaptive, framework::Tensor* input_grad, + PoolProcess pool_grad_process) { bool channel_last = (data_format == "NHWC"); const int batch_size = input.dims()[0]; @@ -678,8 +679,9 @@ class Pool3dFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_process, - bool exclusive, bool adaptive, framework::Tensor* output) { + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* output, + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -773,8 +775,8 @@ class Pool3dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_process, - bool exclusive, bool adaptive, framework::Tensor* output) { + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* output, PoolProcess pool_process) { bool channel_last = (data_format == "NDHWC"); const int batch_size = input.dims()[0]; @@ -970,8 +972,8 @@ class Pool3dGradFunctor { const platform::CPUDeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_grad_process, - bool exclusive, bool adaptive, framework::Tensor* input_grad) { + const std::vector& paddings, bool exclusive, bool adaptive, + framework::Tensor* input_grad, PoolProcess pool_grad_process) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -1071,8 +1073,8 @@ class Pool3dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, const std::string data_format, - PoolProcess pool_grad_process, bool exclusive, bool adaptive, - framework::Tensor* input_grad) { + bool exclusive, bool adaptive, framework::Tensor* input_grad, + PoolProcess pool_grad_process) { bool channel_last = (data_format == "NDHWC"); const int batch_size = input.dims()[0]; diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index b64dbb771a3..e51fb4204b8 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -237,8 +237,8 @@ void Pool2dDirectCUDAFunctor::operator()( const T* input, const std::vector& input_shape, const std::vector& output_shape, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - PoolProcess pool_compute, bool exclusive, bool adaptive, T* output, - cudaStream_t stream) { + bool exclusive, bool adaptive, T* output, gpuStream_t stream, + PoolProcess pool_compute) { const int batch_size = input_shape[0]; const int input_channels = input_shape[1]; const int input_height = input_shape[2]; @@ -277,8 +277,9 @@ class Pool2dFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_process, - bool exclusive, bool adaptive, framework::Tensor* output) { + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* output, + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -311,8 +312,8 @@ class Pool2dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_process, - bool exclusive, bool adaptive, framework::Tensor* output) { + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* output, PoolProcess pool_process) { bool channel_last = (data_format == "NHWC"); const int batch_size = input.dims()[0]; @@ -367,9 +368,9 @@ class Pool2dGradFunctor { const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_process, - bool exclusive, bool adaptive, - framework::Tensor* input_grad) { + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* input_grad, + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -399,13 +400,15 @@ class Pool2dGradFunctor { ksize_width, stride_height, stride_width, padding_height, padding_width, pool_process, exclusive, adaptive, input_grad_data); } - void operator()( - const platform::CUDADeviceContext& context, - const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, const std::vector& ksize, - const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_process, bool exclusive, - bool adaptive, framework::Tensor* input_grad) { + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& output, + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* input_grad, PoolProcess pool_process) { bool channel_last = (data_format == "NHWC"); const int batch_size = input.dims()[0]; @@ -881,8 +884,9 @@ class Pool3dFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_process, - bool exclusive, bool adaptive, framework::Tensor* output) { + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* output, + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -922,8 +926,8 @@ class Pool3dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_process, - bool exclusive, bool adaptive, framework::Tensor* output) { + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* output, PoolProcess pool_process) { bool channel_last = (data_format == "NDHWC"); const int batch_size = input.dims()[0]; @@ -988,9 +992,9 @@ class Pool3dGradFunctor { const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_process, - bool exclusive, bool adaptive, - framework::Tensor* input_grad) { + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* input_grad, + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -1028,13 +1032,15 @@ class Pool3dGradFunctor { stride_height, stride_width, padding_depth, padding_height, padding_width, pool_process, exclusive, adaptive, input_grad_data); } - void operator()( - const platform::CUDADeviceContext& context, - const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, const std::vector& ksize, - const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_process, bool exclusive, - bool adaptive, framework::Tensor* input_grad) { + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& output, + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* input_grad, PoolProcess pool_process) { bool channel_last = (data_format == "NDHWC"); const int batch_size = input.dims()[0]; diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index 5a6ae224789..21d588cc01f 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -97,7 +97,7 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { * This is different from average pooling. So we rewrite the max_pool_grad: * MaxPool2dGradFunctor, MaxPool3dGradFunctor. */ -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class Pool2dDirectCUDAFunctor { public: @@ -105,9 +105,9 @@ class Pool2dDirectCUDAFunctor { const std::vector& output_shape, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, bool adaptive, T* output, - cudaStream_t stream); + const std::vector& paddings, bool exclusive, + bool adaptive, T* output, gpuStream_t stream, + PoolProcess pool_compute); }; #endif @@ -117,16 +117,17 @@ class Pool2dFunctor { void operator()(const DeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, bool adaptive, framework::Tensor* output); + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* output, + PoolProcess pool_compute); // overload operator() to support argument data_format void operator()(const DeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_compute, - bool exclusive, bool adaptive, framework::Tensor* output); + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* output, PoolProcess pool_compute); }; template @@ -137,8 +138,9 @@ class Pool2dGradFunctor { const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, bool adaptive, framework::Tensor* input_grad); + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* input_grad, + PoolProcess pool_compute); // overload operator() to support argument data_format void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, @@ -146,8 +148,8 @@ class Pool2dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_compute, - bool exclusive, bool adaptive, framework::Tensor* input_grad); + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* input_grad, PoolProcess pool_compute); }; template @@ -176,15 +178,16 @@ class Pool3dFunctor { void operator()(const DeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, bool adaptive, framework::Tensor* output); + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* output, + PoolProcess pool_compute); // overload operator() to support argument data_format void operator()(const DeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_compute, - bool exclusive, bool adaptive, framework::Tensor* output); + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* output, PoolProcess pool_compute); }; template @@ -195,8 +198,9 @@ class Pool3dGradFunctor { const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, bool adaptive, framework::Tensor* input_grad); + const std::vector& paddings, bool exclusive, + bool adaptive, framework::Tensor* input_grad, + PoolProcess pool_compute); // overload operator() to support argument data_format void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, @@ -204,8 +208,8 @@ class Pool3dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const std::string data_format, PoolProcess pool_compute, - bool exclusive, bool adaptive, framework::Tensor* input_grad); + const std::string data_format, bool exclusive, bool adaptive, + framework::Tensor* input_grad, PoolProcess pool_compute); }; template diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 3dc184facc7..8ceb22d8cc4 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -16,7 +16,12 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/pool_op.h" +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif namespace paddle { namespace operators { @@ -122,7 +127,32 @@ class PoolCUDNNOpKernel : public framework::OpKernel { out_dims_vec[3] = output->dims()[2]; out_dims_vec[4] = output->dims()[3]; transformed_output.Resize(framework::make_ddim(out_dims_vec)); +#ifdef PADDLE_WITH_HIP + // MIOPEN not support NHWC data layout + } else if (data_format == str_NHWC) { + layout = DataLayout::kNCHW; + auto &dev_ctx = + ctx.template device_context(); + std::vector axis{0, 3, 1, 2}; + + transformed_input.Resize(input->dims()); + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[3]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + transformed_input.Resize(framework::make_ddim(in_dims_vec)); + transformed_input.mutable_data(ctx.GetPlace(), input->type()); + math::Transpose trans; + trans(dev_ctx, *input, &transformed_input, axis); + + transformed_output.Resize(output->dims()); + auto out_dims_vec = framework::vectorize(output->dims()); + out_dims_vec[1] = output->dims()[3]; + out_dims_vec[2] = output->dims()[1]; + out_dims_vec[3] = output->dims()[2]; + transformed_output.Resize(framework::make_ddim(out_dims_vec)); +#endif } else { layout = getLayoutFromStr(data_format); transformed_input = *input; @@ -138,11 +168,17 @@ class PoolCUDNNOpKernel : public framework::OpKernel { ScopedTensorDescriptor output_desc; ScopedPoolingDescriptor pool_desc; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize(transformed_input.dims())); + miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize(transformed_output.dims())); +#else cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize(transformed_input.dims())); cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( layout, framework::vectorize(transformed_output.dims())); - +#endif PoolingMode pooling_mode; if (pooling_type == "max") { pooling_mode = PoolingMode::kMaximum; @@ -151,17 +187,36 @@ class PoolCUDNNOpKernel : public framework::OpKernel { : PoolingMode::kAverageInclusive; } +#ifdef PADDLE_WITH_HIP + miopenPoolingDescriptor_t cudnn_pool_desc = + pool_desc.descriptor(pooling_mode, ksize, paddings, strides); +#else cudnnPoolingDescriptor_t cudnn_pool_desc = pool_desc.descriptor(pooling_mode, ksize, paddings, strides); +#endif // ------------------- cudnn pool algorithm --------------------- auto handle = ctx.cuda_device_context().cudnn_handle(); ScalingParamType alpha = 1.0f, beta = 0.0f; +#ifdef PADDLE_WITH_HIP + char *pool_workspace; + size_t pool_worksize = 0; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenPoolingGetWorkSpaceSizeV2( + cudnn_pool_desc, cudnn_output_desc, &pool_worksize)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&pool_workspace, pool_worksize)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenPoolingForward( + handle, cudnn_pool_desc, &alpha, cudnn_input_desc, + tranformed_input_data, &beta, cudnn_output_desc, tranformed_output_data, + false, pool_workspace, pool_worksize)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipFree(pool_workspace)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnPoolingForward( handle, cudnn_pool_desc, &alpha, cudnn_input_desc, tranformed_input_data, &beta, cudnn_output_desc, tranformed_output_data)); +#endif // add if (data_format == str_NDHWC) { auto &dev_ctx = @@ -170,6 +225,16 @@ class PoolCUDNNOpKernel : public framework::OpKernel { math::Transpose trans5_v2; trans5_v2(dev_ctx, transformed_output, output, axis); } +#ifdef PADDLE_WITH_HIP + // MIOPEN not support NHWC data layout + if (data_format == str_NHWC) { + auto &dev_ctx = + ctx.template device_context(); + std::vector axis{0, 2, 3, 1}; + math::Transpose trans; + trans(dev_ctx, transformed_output, output, axis); + } +#endif } }; @@ -272,6 +337,49 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { // input grad transformed_input_grad.Resize(framework::make_ddim(in_dims_vec)); +#ifdef PADDLE_WITH_HIP + // MIOPEN not support NHWC data layout + } else if (data_format == str_NHWC) { + layout = DataLayout::kNCHW; + auto &dev_ctx = + ctx.template device_context(); + std::vector axis{0, 3, 1, 2}; + + // input + transformed_input.Resize(input->dims()); + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[3]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + transformed_input.Resize(framework::make_ddim(in_dims_vec)); + transformed_input.mutable_data(ctx.GetPlace(), input->type()); + + math::Transpose trans4; + trans4(dev_ctx, *input, &transformed_input, axis); + + // output + transformed_output.Resize(output->dims()); + auto out_dims_vec = framework::vectorize(output->dims()); + out_dims_vec[1] = output->dims()[3]; + out_dims_vec[2] = output->dims()[1]; + out_dims_vec[3] = output->dims()[2]; + transformed_output.Resize(framework::make_ddim(out_dims_vec)); + + transformed_output.mutable_data(ctx.GetPlace(), output->type()); + + math::Transpose trans4_v2; + trans4_v2(dev_ctx, *output, &transformed_output, axis); + + // output grad + transformed_output_grad.Resize(framework::make_ddim(out_dims_vec)); + transformed_output_grad.mutable_data(ctx.GetPlace(), output_grad->type()); + + math::Transpose trans4_v3; + trans4_v3(dev_ctx, *output_grad, &transformed_output_grad, axis); + + // input grad + transformed_input_grad.Resize(framework::make_ddim(in_dims_vec)); +#endif } else { layout = getLayoutFromStr(data_format); transformed_input = *input; @@ -289,11 +397,17 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { ScopedTensorDescriptor output_desc; ScopedPoolingDescriptor pool_desc; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize(transformed_input.dims())); + miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize(transformed_output.dims())); +#else cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize(transformed_input.dims())); cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( layout, framework::vectorize(transformed_output.dims())); - +#endif PoolingMode pooling_mode; if (pooling_type == "max") { if (FLAGS_cudnn_deterministic) { @@ -306,8 +420,13 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { : PoolingMode::kAverageInclusive; } +#ifdef PADDLE_WITH_HIP + miopenPoolingDescriptor_t cudnn_pool_desc = + pool_desc.descriptor(pooling_mode, ksize, paddings, strides); +#else cudnnPoolingDescriptor_t cudnn_pool_desc = pool_desc.descriptor(pooling_mode, ksize, paddings, strides); +#endif // ------------------- cudnn pool algorithm --------------------- auto handle = ctx.cuda_device_context().cudnn_handle(); @@ -315,11 +434,25 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { if (input_grad) { T *input_grad_data = transformed_input_grad.mutable_data( transformed_input_grad.dims(), ctx.GetPlace()); - // Because beta is zero, it is unnecessary to reset input_grad. +// Because beta is zero, it is unnecessary to reset input_grad. +#ifdef PADDLE_WITH_HIP + char *pool_workspace; + size_t pool_worksize = 0; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenPoolingGetWorkSpaceSizeV2( + cudnn_pool_desc, cudnn_output_desc, &pool_worksize)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&pool_workspace, pool_worksize)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenPoolingBackward( + handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, + cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data, + &beta, cudnn_input_desc, input_grad_data, pool_workspace)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipFree(pool_workspace)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnPoolingBackward( handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data, &beta, cudnn_input_desc, input_grad_data)); +#endif if (data_format == str_NDHWC) { auto &dev_ctx = @@ -328,6 +461,16 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { math::Transpose trans5_v4; trans5_v4(dev_ctx, transformed_input_grad, input_grad, axis); } +#ifdef PADDLE_WITH_HIP + // MIOPEN not support NHWC data layout + if (data_format == str_NHWC) { + auto &dev_ctx = + ctx.template device_context(); + std::vector axis{0, 2, 3, 1}; + math::Transpose trans4_v4; + trans4_v4(dev_ctx, transformed_input_grad, input_grad, axis); + } +#endif } } }; @@ -338,6 +481,21 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, + ops::PoolCUDNNOpKernel, + ops::PoolCUDNNOpKernel); +REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); + +REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, + ops::PoolCUDNNOpKernel, + ops::PoolCUDNNOpKernel); +REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, + ops::PoolCUDNNGradOpKernel); +#else REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, ops::PoolCUDNNOpKernel, @@ -354,3 +512,4 @@ REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, ops::PoolCUDNNGradOpKernel); +#endif diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 2d4ef64cc89..feb47a73ee4 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -18,6 +18,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -180,7 +183,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } @@ -235,7 +238,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 6b0dbd2d83a..4bb0e1d582e 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -205,7 +205,7 @@ class PoolKernel : public framework::OpKernel { pool2d_forward; paddle::operators::math::MaxPool pool_process; pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format, - pool_process, true, false, out); + true, false, out, pool_process); } else if (pooling_type == "avg") { std::vector reduce_dim; @@ -213,7 +213,12 @@ class PoolKernel : public framework::OpKernel { if (reduce_num > 0 && adaptive) { // for adaptive_avg_pool2d && output_size == 1 -#ifdef __NVCC__ +#ifdef __HIPCC__ + auto stream = dev_ctx.stream(); + TensorReduce>( + *in_x, out, reduce_dim, static_cast(0), hipcub::Sum(), + DivideFunctor(reduce_num), stream); +#elif defined(__NVCC__) auto stream = dev_ctx.stream(); TensorReduce>( *in_x, out, reduce_dim, static_cast(0), cub::Sum(), @@ -224,7 +229,7 @@ class PoolKernel : public framework::OpKernel { pool2d_forward; paddle::operators::math::AvgPool pool_process; pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, - data_format, pool_process, exclusive, adaptive, out); + data_format, exclusive, adaptive, out, pool_process); #endif } else { // avgpool_2d or adaptive_avg_pool2d && output_size != 1 paddle::operators::math::Pool2dFunctor< @@ -232,7 +237,7 @@ class PoolKernel : public framework::OpKernel { pool2d_forward; paddle::operators::math::AvgPool pool_process; pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, - data_format, pool_process, exclusive, adaptive, out); + data_format, exclusive, adaptive, out, pool_process); } } } break; @@ -243,7 +248,7 @@ class PoolKernel : public framework::OpKernel { pool3d_forward; paddle::operators::math::MaxPool pool_process; pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format, - pool_process, true, false, out); + true, false, out, pool_process); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dFunctor< @@ -251,7 +256,7 @@ class PoolKernel : public framework::OpKernel { pool3d_forward; paddle::operators::math::AvgPool pool_process; pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format, - pool_process, exclusive, adaptive, out); + exclusive, adaptive, out, pool_process); } } break; default: { @@ -324,8 +329,8 @@ class PoolGradKernel : public framework::OpKernel { pool2d_backward; paddle::operators::math::AvgPoolGrad pool_process; pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, - paddings, data_format, pool_process, exclusive, - adaptive, in_x_grad); + paddings, data_format, exclusive, adaptive, + in_x_grad, pool_process); } } break; case 3: { @@ -340,8 +345,8 @@ class PoolGradKernel : public framework::OpKernel { pool3d_backward; paddle::operators::math::AvgPoolGrad pool_process; pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, - paddings, data_format, pool_process, exclusive, - adaptive, in_x_grad); + paddings, data_format, exclusive, adaptive, + in_x_grad, pool_process); } } break; default: { diff --git a/paddle/fluid/operators/spp_op.h b/paddle/fluid/operators/spp_op.h index 3c2d51ec911..6f78b885734 100644 --- a/paddle/fluid/operators/spp_op.h +++ b/paddle/fluid/operators/spp_op.h @@ -56,14 +56,14 @@ class SppKernel : public framework::OpKernel { math::Pool2dFunctor, T> pool_forward; math::MaxPool max_process; pool_forward(context.template device_context(), *in_x, - kernel_size, strides, paddings, max_process, true, false, - &out_level); + kernel_size, strides, paddings, true, false, &out_level, + max_process); } else if (pooling_type == "avg") { math::Pool2dFunctor, T> pool_forward; math::AvgPool avg_process; pool_forward(context.template device_context(), *in_x, - kernel_size, strides, paddings, avg_process, true, false, - &out_level); + kernel_size, strides, paddings, true, false, &out_level, + avg_process); } // flatten pooling output shape int output_flatten_w = in_x->dims()[1] * bins * bins; @@ -156,7 +156,7 @@ class SppGradKernel : public framework::OpKernel { math::AvgPoolGrad avg_process; pool_backward(context.template device_context(), *in_x, *&out_level, *&outgrad_level, kernel_size, strides, - paddings, avg_process, true, false, in_x_grad); + paddings, true, false, in_x_grad, avg_process); } } } -- GitLab