From e312a1ff6ed897743f29a64a61ad03b6275ceed7 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Wed, 3 Mar 2021 15:02:43 +0800 Subject: [PATCH] [ROCM] update fluid operators for rocm (part9), test=develop (#31338) --- paddle/fluid/operators/p_norm_op.cu | 6 + paddle/fluid/operators/prroi_pool_op.cu | 27 +-- paddle/fluid/operators/prroi_pool_op.h | 131 ++++++++----- paddle/fluid/operators/pull_box_sparse_op.h | 6 +- paddle/fluid/operators/random_crop_op.h | 4 +- paddle/fluid/operators/rank_attention.cu.h | 6 +- paddle/fluid/operators/rank_attention_op.cu | 1 - paddle/fluid/operators/reshape_op.cc | 2 +- paddle/fluid/operators/rnn_op.cu.cc | 179 ++++++++++++++++-- paddle/fluid/operators/seed_op.cu | 1 - paddle/fluid/operators/segment_pool_op.h | 8 +- paddle/fluid/operators/select_op_helper.h | 2 +- paddle/fluid/operators/shuffle_batch_op.h | 2 +- .../sigmoid_cross_entropy_with_logits_op.cu | 6 + paddle/fluid/operators/softmax_cudnn_op.cu | 43 ++++- paddle/fluid/operators/softmax_op.cc | 8 +- .../fluid/operators/split_selected_rows_op.h | 2 +- paddle/fluid/operators/strided_memcpy.h | 2 +- paddle/fluid/operators/strided_memcpy_test.cc | 2 +- 19 files changed, 334 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index ba0d46f4c73..918f0bb1e49 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -13,7 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/operators/p_norm_op.h" namespace paddle { diff --git a/paddle/fluid/operators/prroi_pool_op.cu b/paddle/fluid/operators/prroi_pool_op.cu index b85352ae650..a21f565dae7 100644 --- a/paddle/fluid/operators/prroi_pool_op.cu +++ b/paddle/fluid/operators/prroi_pool_op.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/prroi_pool_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -29,22 +28,6 @@ static inline int NumBlocks(const int N) { kNumMaximumNumBlocks); } -template -DEVICE void PrRoIPoolingDistributeDiffCUDA(T* diff, const T top_diff, - const int h, const int w, - const int height, const int width, - const T coeff) { - bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); - if (!overflow) { - paddle::platform::CudaAtomicAdd(diff + h * width + w, top_diff * coeff); - } -} - -template -DEVICE void GPUAccumulateRois(T* offset, T data) { - paddle::platform::CudaAtomicAdd(offset, data); -} - template __global__ void GPUPRROIPoolForward( const int nthreads, const T* input_data, const T* input_rois, @@ -170,25 +153,23 @@ __global__ void GPUPRROIPoolBackward( for (int w_iter = s_w; w_iter < e_w; ++w_iter) { for (int h_iter = s_h; h_iter < e_h; ++h_iter) { - PrRoIPoolingMatDistributeDiff( + PrRoIPoolingMatDistributeDiff( offset_input_grad_data, sum_out, h_iter, w_iter, h_iter + 1, w_iter + 1, max(win_start_h, static_cast(h_iter)), max(win_start_w, static_cast(w_iter)), min(win_end_h, static_cast(h_iter) + static_cast(1.0)), min(win_end_w, static_cast(w_iter) + static_cast(1.0)), - height, width, PrRoIPoolingDistributeDiffCUDA); + height, width); } } const T* offset_out_data = out_data + i; const T* offset_in_data = in_data + input_offset; - PrRoIPoolingCoorBackward( + PrRoIPoolingCoorBackward( s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale, offset_in_data, offset_out_data, offset_input_roi_grad_data, - offset_output_grad_data, GPUAccumulateRois, - [](const T x, const T y) { return max(x, y); }, - [](const T x, const T y) { return min(x, y); }); + offset_output_grad_data); } } diff --git a/paddle/fluid/operators/prroi_pool_op.h b/paddle/fluid/operators/prroi_pool_op.h index 11ecff88452..f9e2b78d5d3 100644 --- a/paddle/fluid/operators/prroi_pool_op.h +++ b/paddle/fluid/operators/prroi_pool_op.h @@ -16,6 +16,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/fluid/platform/cuda_primitives.h" +#endif namespace paddle { namespace operators { @@ -73,6 +76,17 @@ inline HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data, return sum_out; } +#if defined(__NVCC__) || defined(__HIPCC__) +template +DEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, const int h, + const int w, const int height, + const int width, const T coeff) { + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + if (!overflow) { + paddle::platform::CudaAtomicAdd(diff + h * width + w, top_diff * coeff); + } +} +#else template inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, const int h, const int w, @@ -84,12 +98,15 @@ inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, *(diff + h * width + w) += top_diff * coeff; } } +#endif -template -HOSTDEVICE void PrRoIPoolingMatDistributeDiff( - T* diff, const T top_diff, const int s_h, const int s_w, const int e_h, - const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0, - const int w0, Functor functor) { +template +HOSTDEVICE void PrRoIPoolingMatDistributeDiff(T* diff, const T top_diff, + const int s_h, const int s_w, + const int e_h, const int e_w, + const T y0, const T x0, + const T y1, const T x1, + const int h0, const int w0) { T alpha, beta, lim_alpha, lim_beta, tmp; alpha = x0 - static_cast(s_w); @@ -99,14 +116,14 @@ HOSTDEVICE void PrRoIPoolingMatDistributeDiff( tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); - functor(diff, top_diff, s_h, s_w, h0, w0, tmp); + PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp); alpha = static_cast(e_w) - x1; lim_alpha = static_cast(e_w) - x0; tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); - functor(diff, top_diff, s_h, e_w, h0, w0, tmp); + PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp); alpha = x0 - static_cast(s_w); beta = static_cast(e_h) - y1; @@ -115,20 +132,47 @@ HOSTDEVICE void PrRoIPoolingMatDistributeDiff( tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); - functor(diff, top_diff, e_h, s_w, h0, w0, tmp); + PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp); alpha = static_cast(e_w) - x1; lim_alpha = static_cast(e_w) - x0; tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); - functor(diff, top_diff, e_h, e_w, h0, w0, tmp); + PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp); } +#if defined(__NVCC__) || defined(__HIPCC__) +template +DEVICE void AccumulateRois(T* offset, T data) { + paddle::platform::CudaAtomicAdd(offset, data); +} +#else template -inline HOSTDEVICE void CPUAccumulateRois(T* offset, T data) { +inline HOSTDEVICE void AccumulateRois(T* offset, T data) { *offset += data; } +#endif + +#if defined(__NVCC__) || defined(__HIPCC__) +template +DEVICE T MaxFunctor(const T x, const T y) { + return max(x, y); +} +template +DEVICE T MinFunctor(const T x, const T y) { + return min(x, y); +} +#else +template +inline HOSTDEVICE T MaxFunctor(const T x, const T y) { + return std::max(x, y); +} +template +inline HOSTDEVICE T MinFunctor(const T x, const T y) { + return std::max(x, y); +} +#endif template inline HOSTDEVICE static T PrRoIPoolingGetCoeff(T dh, T dw) { @@ -172,15 +216,13 @@ inline HOSTDEVICE T PrRoIPoolingSingleCoorIntegral(T s, T t, T c1, T c2) { (t - 0.5f * t * t - s + 0.5f * s * s) * c1; } -template +template inline HOSTDEVICE void PrRoIPoolingCoorBackward( int s_w, int e_w, int s_h, int e_h, int width, int height, T win_start_w, T win_start_h, T win_end_w, T win_end_h, int pw, int ph, const int pooled_width, const int pooled_height, T win_size, const float spatial_scale, const T* this_bottom_data, - const T* this_top_data, T* this_data_grad, const T* this_out_grad, - Functor functor, MaxFunctor maxFunctor, MinFunctor minFunctor) { + const T* this_top_data, T* this_data_grad, const T* this_out_grad) { T g_x1_y = 0.f; T g_x2_y = 0.f; T g_x_y1 = 0.f; @@ -188,16 +230,16 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward( for (int h_iter = s_h; h_iter < e_h; ++h_iter) { g_x1_y += PrRoIPoolingSingleCoorIntegral( - maxFunctor(win_start_h, static_cast(h_iter)) - h_iter, - minFunctor(win_end_h, static_cast(h_iter + 1)) - h_iter, + MaxFunctor(win_start_h, static_cast(h_iter)) - h_iter, + MinFunctor(win_end_h, static_cast(h_iter + 1)) - h_iter, PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_start_w, height, width), PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_start_w, height, width)); g_x2_y += PrRoIPoolingSingleCoorIntegral( - maxFunctor(win_start_h, static_cast(h_iter)) - h_iter, - minFunctor(win_end_h, static_cast(h_iter + 1)) - h_iter, + MaxFunctor(win_start_h, static_cast(h_iter)) - h_iter, + MinFunctor(win_end_h, static_cast(h_iter + 1)) - h_iter, PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_end_w, height, width), PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_end_w, @@ -206,16 +248,16 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward( for (int w_iter = s_w; w_iter < e_w; ++w_iter) { g_x_y1 += PrRoIPoolingSingleCoorIntegral( - maxFunctor(win_start_w, static_cast(w_iter)) - w_iter, - minFunctor(win_end_w, static_cast(w_iter + 1)) - w_iter, + MaxFunctor(win_start_w, static_cast(w_iter)) - w_iter, + MinFunctor(win_end_w, static_cast(w_iter + 1)) - w_iter, PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter, height, width), PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter + 1, height, width)); g_x_y2 += PrRoIPoolingSingleCoorIntegral( - maxFunctor(win_start_w, static_cast(w_iter)) - w_iter, - minFunctor(win_end_w, static_cast(w_iter + 1)) - w_iter, + MaxFunctor(win_start_w, static_cast(w_iter)) - w_iter, + MinFunctor(win_end_w, static_cast(w_iter + 1)) - w_iter, PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter, height, width), PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter + 1, @@ -232,22 +274,24 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward( partial_y1 = partial_y1 / win_size * spatial_scale; partial_y2 = partial_y2 / win_size * spatial_scale; - functor(this_data_grad + 0, - (partial_x1 * (1.0 - static_cast(pw) / pooled_width) + - partial_x2 * (1.0 - static_cast(pw + 1) / pooled_width)) * - (*this_out_grad)); - functor(this_data_grad + 1, - (partial_y1 * (1.0 - static_cast(ph) / pooled_height) + - partial_y2 * (1.0 - static_cast(ph + 1) / pooled_height)) * - (*this_out_grad)); - functor(this_data_grad + 2, - (partial_x2 * static_cast(pw + 1) / pooled_width + - partial_x1 * static_cast(pw) / pooled_width) * - (*this_out_grad)); - functor(this_data_grad + 3, - (partial_y2 * static_cast(ph + 1) / pooled_height + - partial_y1 * static_cast(ph) / pooled_height) * - (*this_out_grad)); + AccumulateRois( + this_data_grad + 0, + (partial_x1 * (1.0 - static_cast(pw) / pooled_width) + + partial_x2 * (1.0 - static_cast(pw + 1) / pooled_width)) * + (*this_out_grad)); + AccumulateRois( + this_data_grad + 1, + (partial_y1 * (1.0 - static_cast(ph) / pooled_height) + + partial_y2 * (1.0 - static_cast(ph + 1) / pooled_height)) * + (*this_out_grad)); + AccumulateRois(this_data_grad + 2, + (partial_x2 * static_cast(pw + 1) / pooled_width + + partial_x1 * static_cast(pw) / pooled_width) * + (*this_out_grad)); + AccumulateRois(this_data_grad + 3, + (partial_y2 * static_cast(ph + 1) / pooled_height + + partial_y1 * static_cast(ph) / pooled_height) * + (*this_out_grad)); } template @@ -516,7 +560,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { for (int w_iter = s_w; w_iter < e_w; ++w_iter) { for (int h_iter = s_h; h_iter < e_h; ++h_iter) { - PrRoIPoolingMatDistributeDiff( + PrRoIPoolingMatDistributeDiff( offset_input_grad_data, sum_out, h_iter, w_iter, h_iter + 1, w_iter + 1, std::max(win_start_h, static_cast(h_iter)), std::max(win_start_w, static_cast(w_iter)), @@ -524,19 +568,16 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { static_cast(h_iter) + static_cast(1.0)), std::min(win_end_w, static_cast(w_iter) + static_cast(1.0)), - height, width, PrRoIPoolingDistributeDiff); + height, width); } } const T* offset_in_data = in_data + input_offset; - PrRoIPoolingCoorBackward( + PrRoIPoolingCoorBackward( s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale, offset_in_data, offset_out_data, - offset_input_roi_grad_data, offset_output_grad_data, - CPUAccumulateRois, - [](const T x, const T y) { return std::max(x, y); }, - [](const T x, const T y) { return std::min(x, y); }); + offset_input_roi_grad_data, offset_output_grad_data); } } } diff --git a/paddle/fluid/operators/pull_box_sparse_op.h b/paddle/fluid/operators/pull_box_sparse_op.h index 48e42c32324..48903012b59 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.h +++ b/paddle/fluid/operators/pull_box_sparse_op.h @@ -47,7 +47,8 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths, hidden_size, 0); #endif -#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ + (defined PADDLE_WITH_PSLIB) auto hidden_size = ctx.Attr("size"); auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance(); gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths, @@ -90,7 +91,8 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values, slot_lengths, hidden_size, 0, batch_size); #endif -#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ + (defined PADDLE_WITH_PSLIB) auto hidden_size = ctx.Attr("size"); auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance(); gpu_ps_ptr->PushSparseGrad(ctx.GetPlace(), 0, all_keys, all_grad_values, diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h index 62edb298d1a..ee111a0ec7c 100644 --- a/paddle/fluid/operators/random_crop_op.h +++ b/paddle/fluid/operators/random_crop_op.h @@ -18,7 +18,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include #endif @@ -36,7 +36,7 @@ struct Random { using UniformIntDist = std::uniform_int_distribution; }; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template <> struct Random { using Engine = thrust::minstd_rand; diff --git a/paddle/fluid/operators/rank_attention.cu.h b/paddle/fluid/operators/rank_attention.cu.h index 27fe67e73cd..8ec138c8824 100644 --- a/paddle/fluid/operators/rank_attention.cu.h +++ b/paddle/fluid/operators/rank_attention.cu.h @@ -50,7 +50,7 @@ __global__ void expand_input_by_rank_kernel( } template -void expand_rank_attention_input(cudaStream_t stream, const T* input, +void expand_rank_attention_input(gpuStream_t stream, const T* input, int input_row, int input_col, T* output, int output_row, int output_col, const int* rank_offset, int rank_offset_row, @@ -93,7 +93,7 @@ __global__ void expand_rank_attention_param_kernel( } template -void expand_rank_attention_param(cudaStream_t stream, const T* input, +void expand_rank_attention_param(gpuStream_t stream, const T* input, int input_row, int input_col, const int* rank_offset, int rank_offset_row, int rank_offset_col, const T* param, @@ -133,7 +133,7 @@ __global__ void merge_param_gradient_kernel( } template -void merge_rank_attention_param_grad(cudaStream_t stream, T* expanded_grad, +void merge_rank_attention_param_grad(gpuStream_t stream, T* expanded_grad, int expanded_grad_row, int expanded_grad_col, T* param_grad, int param_grad_row, int param_grad_col, diff --git a/paddle/fluid/operators/rank_attention_op.cu b/paddle/fluid/operators/rank_attention_op.cu index 6c242e156a5..aaa4eec7c1b 100644 --- a/paddle/fluid/operators/rank_attention_op.cu +++ b/paddle/fluid/operators/rank_attention_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 0e11771d87c..94efa70e467 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -654,7 +654,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR( ops::ReshapeDoubleGradKernel, paddle::platform::complex128, ops::ReshapeDoubleGradKernel); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, uint8_t, ops::ReshapeKernel, int64_t, diff --git a/paddle/fluid/operators/rnn_op.cu.cc b/paddle/fluid/operators/rnn_op.cu.cc index 91d7d0f6783..ccf619a074a 100644 --- a/paddle/fluid/operators/rnn_op.cu.cc +++ b/paddle/fluid/operators/rnn_op.cu.cc @@ -16,7 +16,12 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/utils.h" +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif namespace paddle { namespace operators { @@ -28,7 +33,11 @@ class RNNDescriptors { public: RNNDescriptors(int seq_length, int batch_size, int input_size, int hidden_size, int num_layers, float dropout_prob, int seed, +#ifdef PADDLE_WITH_HIP + int weight_numel, miopenRNNMode_t mode, bool is_bidirec, +#else int weight_numel, cudnnRNNMode_t mode, bool is_bidirec, +#endif bool is_test) : seq_length_(seq_length), batch_size_(batch_size), @@ -40,15 +49,23 @@ class RNNDescriptors { weight_numel_(weight_numel), mode_(mode), is_bidirec_(is_bidirec), - is_test_(is_test) {} + is_test_(is_test) { + } template +#ifdef PADDLE_WITH_HIP + void Create(const miopenHandle_t &handle, const platform::Place &place, +#else void Create(const cudnnHandle_t &handle, const platform::Place &place, +#endif const std::vector &sequence_length, size_t *workspace_size, size_t *reserve_size, framework::Tensor *dropout_state) { int numDirections = is_bidirec_ ? 2 : 1; +#ifdef PADDLE_WITH_HIP + miopenDataType_t cudnn_type = platform::CudnnDataType::type; +#else cudnnDataType_t cudnn_type = platform::CudnnDataType::type; - +#endif // ------------------- cudnn x, y descriptors --------------------- std::vector dims_x = {batch_size_, input_size_, 1}; std::vector strides_x = {input_size_, 1, 1}; @@ -59,7 +76,7 @@ class RNNDescriptors { y_descs_.emplace_back(y_desc_.descriptor(dims_y, strides_y)); } -#if CUDNN_VERSION >= 7201 +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { x_seq_desc_.descriptor(seq_length_, batch_size_, input_size_, true, sequence_length); @@ -82,17 +99,29 @@ class RNNDescriptors { size_t state_size; bool is_initialized = dropout_state->IsInitialized(); if (!is_test_ && !is_initialized) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDropoutGetStatesSize(handle, &state_size)); + dropout_state->mutable_data({static_cast(state_size)}, + place); +#else PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); dropout_state->mutable_data({static_cast(state_size)}, place); +#endif } dropout_desc_.descriptor(handle, place, is_initialized, dropout_prob_, is_test_ ? nullptr : dropout_state, seed_, state_size); // ------------------- cudnn rnn descriptors --------------------- -#if CUDNN_VERSION >= 6000 +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetRNNDescriptor( + rnn_desc_.desc(), hidden_size_, num_layers_, miopenRNNlinear, + is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection, mode_, + miopenRNNNoBias, miopenRNNdefault, cudnn_type)); +#elif CUDNN_VERSION >= 6000 PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6( handle, rnn_desc_.desc(), hidden_size_, num_layers_, dropout_desc_.desc(), CUDNN_LINEAR_INPUT, @@ -106,7 +135,7 @@ class RNNDescriptors { cudnn_type)); #endif -#if CUDNN_VERSION >= 7201 +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode( rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); @@ -115,8 +144,13 @@ class RNNDescriptors { // ------------------- cudnn weights_size --------------------- size_t weights_size_; +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenGetRNNParamsSize( + handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize( handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type)); +#endif PADDLE_ENFORCE_EQ( weights_size_, sizeof(T) * weight_numel_, platform::errors::InvalidArgument( @@ -126,7 +160,16 @@ class RNNDescriptors { int dim_tmp = weights_size_ / sizeof(T); std::vector dim_w = {dim_tmp, 1, 1}; weight_desc_.descriptor(layout, dim_w); - // ------------------- cudnn workspace, reserve size --------------------- +// ------------------- cudnn workspace, reserve size --------------------- +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenGetRNNWorkspaceSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), + workspace_size)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenGetRNNTrainingReserveSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), + reserve_size)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize( handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), workspace_size)); @@ -134,7 +177,19 @@ class RNNDescriptors { platform::dynload::cudnnGetRNNTrainingReserveSize( handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), reserve_size)); +#endif } +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t *x_descs() { return x_descs_.data(); } + miopenTensorDescriptor_t *y_descs() { return y_descs_.data(); } + miopenTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); } + miopenTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); } + miopenTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); } + miopenTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); } + miopenRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); } + miopenDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); } + miopenTensorDescriptor_t weight_desc() { return weight_desc_.desc(); } +#else cudnnTensorDescriptor_t *x_descs() { return x_descs_.data(); } cudnnTensorDescriptor_t *y_descs() { return y_descs_.data(); } #if CUDNN_VERSION >= 7201 @@ -148,6 +203,7 @@ class RNNDescriptors { cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); } cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); } cudnnFilterDescriptor_t weight_desc() { return weight_desc_.desc(); } +#endif private: int seq_length_; @@ -158,15 +214,24 @@ class RNNDescriptors { float dropout_prob_; int seed_; int weight_numel_; +#ifdef PADDLE_WITH_HIP + miopenRNNMode_t mode_; +#else cudnnRNNMode_t mode_; +#endif bool is_bidirec_; bool is_test_; +#ifdef PADDLE_WITH_HIP + std::vector x_descs_; + std::vector y_descs_; +#else std::vector x_descs_; std::vector y_descs_; +#endif platform::ScopedTensorDescriptor x_desc_; platform::ScopedTensorDescriptor y_desc_; -#if CUDNN_VERSION >= 7201 +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 platform::ScopedRNNTensorDescriptor x_seq_desc_; platform::ScopedRNNTensorDescriptor y_seq_desc_; #endif @@ -193,7 +258,7 @@ bool is_continuous(const Type &weight_list) { } template -void weight_to_tensor(const platform::Place &place, cudaStream_t stream, +void weight_to_tensor(const platform::Place &place, gpuStream_t stream, const std::vector &weight_list, Tensor *weight) { auto weight_data = weight->data(); @@ -211,7 +276,7 @@ void weight_to_tensor(const platform::Place &place, cudaStream_t stream, } template -void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream, +void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream, std::vector *weight_grad, const std::vector &weight_input, const Tensor *weight) { @@ -247,6 +312,17 @@ class RNNCudnnKernel : public framework::OpKernel { int hidden_size = ctx.Attr("hidden_size"); int num_layers = ctx.Attr("num_layers"); auto mode = ctx.Attr("mode"); +#ifdef PADDLE_WITH_HIP + miopenRNNMode_t rnn_mode = miopenLSTM; + if (mode == "LSTM") + rnn_mode = miopenLSTM; + else if (mode == "GRU") + rnn_mode = miopenGRU; + else if (mode == "RNN_RELU") + rnn_mode = miopenRNNRELU; + else if (mode == "RNN_TANH") + rnn_mode = miopenRNNTANH; +#else cudnnRNNMode_t rnn_mode = CUDNN_LSTM; if (mode == "LSTM") rnn_mode = CUDNN_LSTM; @@ -256,6 +332,7 @@ class RNNCudnnKernel : public framework::OpKernel { rnn_mode = CUDNN_RNN_RELU; else if (mode == "RNN_TANH") rnn_mode = CUDNN_RNN_TANH; +#endif else PADDLE_THROW(platform::errors::InvalidArgument( "rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: " @@ -285,7 +362,11 @@ class RNNCudnnKernel : public framework::OpKernel { T *out_data = out->mutable_data(ctx.GetPlace()); T *last_h_data = state[0]->mutable_data(ctx.GetPlace()); T *last_c_data = nullptr; +#ifdef PADDLE_WITH_HIP + if (rnn_mode == miopenLSTM) { +#else if (rnn_mode == CUDNN_LSTM) { +#endif init_c_data = pre_state[1]->data(); last_c_data = state[1]->mutable_data(ctx.GetPlace()); } @@ -362,8 +443,17 @@ class RNNCudnnKernel : public framework::OpKernel { &workspace_data_, workspace_size); } else { if (!has_seq_length) { - // for train - // This interface is used when the input/output is unpadded. +// for train +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNForwardTraining( + handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.weight_desc(), w_data, rnn.y_descs(), out_data, + rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, + workspace_data_.data(), workspace_size, reserve_data, + reserve_size)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining( handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data, rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, @@ -371,8 +461,9 @@ class RNNCudnnKernel : public framework::OpKernel { rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, workspace_data_.data(), workspace_size, reserve_data, reserve_size)); +#endif } else { -#if CUDNN_VERSION >= 7201 +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 // for train // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS( @@ -394,23 +485,36 @@ class RNNCudnnKernel : public framework::OpKernel { } } +#ifdef PADDLE_WITH_HIP + void RNNInferece(const bool &has_seq_length, const miopenHandle_t &handle, +#else void RNNInferece(const bool &has_seq_length, const cudnnHandle_t &handle, +#endif const int &seq_length, RNNDescriptors *rnn, const T *x_data, const T *init_h_data, const T *init_c_data, const T *w_data, T *out_data, T *last_h_data, T *last_c_data, framework::Tensor *workspace_data, const size_t &workspace_size) const { if (!has_seq_length) { - // for inference - // This interface is used when the input/output is unpadded. +// for inference +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNForwardInference( + handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data, + rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data, + rnn->weight_desc(), w_data, rnn->y_descs(), out_data, + rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data, + workspace_data->data(), workspace_size)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data, rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data, rnn->weight_desc(), w_data, rnn->y_descs(), out_data, rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data, workspace_data->data(), workspace_size)); +#endif } else { -#if CUDNN_VERSION >= 7201 +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 // for inference // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx( @@ -457,6 +561,17 @@ class RNNGradCudnnKernel : public framework::OpKernel { int hidden_size = ctx.Attr("hidden_size"); int num_layers = ctx.Attr("num_layers"); auto mode = ctx.Attr("mode"); +#ifdef PADDLE_WITH_HIP + miopenRNNMode_t rnn_mode = miopenLSTM; + if (mode == "LSTM") + rnn_mode = miopenLSTM; + else if (mode == "GRU") + rnn_mode = miopenGRU; + else if (mode == "RNN_RELU") + rnn_mode = miopenRNNRELU; + else if (mode == "RNN_TANH") + rnn_mode = miopenRNNTANH; +#else cudnnRNNMode_t rnn_mode = CUDNN_LSTM; if (mode == "LSTM") rnn_mode = CUDNN_LSTM; @@ -466,6 +581,7 @@ class RNNGradCudnnKernel : public framework::OpKernel { rnn_mode = CUDNN_RNN_RELU; else if (mode == "RNN_TANH") rnn_mode = CUDNN_RNN_TANH; +#endif else PADDLE_THROW(platform::errors::InvalidArgument( "rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: " @@ -532,7 +648,11 @@ class RNNGradCudnnKernel : public framework::OpKernel { ? pre_state_grad[0]->mutable_data(ctx.GetPlace()) : nullptr; T *init_c_grad_data = nullptr; +#ifdef PADDLE_WITH_HIP + if (rnn_mode == miopenLSTM) { +#else if (rnn_mode == CUDNN_LSTM) { +#endif init_c_data = pre_state[1]->data(); // last_c_data = state[1]->data(); last_c_grad_data = state_grad[1]->data(); @@ -579,6 +699,17 @@ class RNNGradCudnnKernel : public framework::OpKernel { if (!has_seq_length) { if (in_grad) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNBackwardData( + handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data, + rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data, + rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data, + rnn.init_c_desc(), init_c_grad_data, + workspace_data_.data(), workspace_size, + const_cast(reserve_data), reserve_size)); +#else // This interface is used when the input/output is unpadded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData( handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data, @@ -589,17 +720,27 @@ class RNNGradCudnnKernel : public framework::OpKernel { rnn.init_c_desc(), init_c_grad_data, workspace_data_.data(), workspace_size, const_cast(reserve_data), reserve_size)); +#endif } if (!weight_grad_list.empty()) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNBackwardWeights( + handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data(), + rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data(), + rnn.weight_desc(), weight_grad_data, + workspace_data_.data(), workspace_size, + const_cast(reserve_data), reserve_size)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights( handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data(), rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data(), workspace_data_.data(), workspace_size, rnn.weight_desc(), weight_grad_data, const_cast(reserve_data), reserve_size)); +#endif } } else { -#if CUDNN_VERSION >= 7201 +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 // for train // This interface is used when the input/output is padded. if (in_grad) { @@ -638,7 +779,13 @@ class RNNGradCudnnKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_CUDA_KERNEL(rnn, ops::RNNCudnnKernel); +REGISTER_OP_CUDA_KERNEL(rnn_grad, ops::RNNGradCudnnKernel); +#else REGISTER_OP_CUDA_KERNEL(rnn, ops::RNNCudnnKernel, ops::RNNCudnnKernel); REGISTER_OP_CUDA_KERNEL(rnn_grad, ops::RNNGradCudnnKernel, ops::RNNGradCudnnKernel); +#endif diff --git a/paddle/fluid/operators/seed_op.cu b/paddle/fluid/operators/seed_op.cu index 8070f01e9b5..c84407ba52d 100644 --- a/paddle/fluid/operators/seed_op.cu +++ b/paddle/fluid/operators/seed_op.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "paddle/fluid/operators/seed_op.h" namespace paddle { diff --git a/paddle/fluid/operators/segment_pool_op.h b/paddle/fluid/operators/segment_pool_op.h index 23b0c31608d..5f9635c8ae1 100644 --- a/paddle/fluid/operators/segment_pool_op.h +++ b/paddle/fluid/operators/segment_pool_op.h @@ -63,7 +63,7 @@ void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) { auto& dev_ctx = context.template device_context(); set_zero(dev_ctx, output, static_cast(0)); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (!cpu_place) { Tensor length; length.mutable_data(framework::make_ddim({1}), @@ -71,9 +71,15 @@ void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) { IndexT* length_data = length.data(); const IndexT* segment_ids = segment->data(); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT), + hipMemcpyDeviceToHost)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT), cudaMemcpyDeviceToHost)); +#endif IndexT length_host = length_data[0]; length_host++; diff --git a/paddle/fluid/operators/select_op_helper.h b/paddle/fluid/operators/select_op_helper.h index 5df4f8c4a54..32284302176 100644 --- a/paddle/fluid/operators/select_op_helper.h +++ b/paddle/fluid/operators/select_op_helper.h @@ -37,7 +37,7 @@ inline int GetBranchNumber(const framework::LoDTensor &mask) { } // when platform::is_gpu_place(mask.place()) is ture std::unique_ptr cpu_mask{new framework::LoDTensor()}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get()); #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/shuffle_batch_op.h b/paddle/fluid/operators/shuffle_batch_op.h index ac8e3f0538f..f05af3f249c 100644 --- a/paddle/fluid/operators/shuffle_batch_op.h +++ b/paddle/fluid/operators/shuffle_batch_op.h @@ -33,7 +33,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template using Vector = framework::Vector; #else diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index cdcd51904e8..b9300f1b23b 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -11,7 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu b/paddle/fluid/operators/softmax_cudnn_op.cu index ac7963dd8ad..b62d71bdbc4 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu +++ b/paddle/fluid/operators/softmax_cudnn_op.cu @@ -16,7 +16,11 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/platform/cuda_device_function.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#else #include "paddle/fluid/platform/cudnn_helper.h" +#endif #include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { @@ -388,18 +392,30 @@ class SoftmaxCUDNNKernel : public framework::OpKernel { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; DataLayout layout = DataLayout::kNCHW; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); +#else cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); +#endif auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); + +#ifdef PADDLE_WITH_HIP + auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE + : MIOPEN_SOFTMAX_MODE_CHANNEL; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward( + handle, platform::CudnnDataType::kOne(), desc_, x->data(), + platform::CudnnDataType::kZero(), desc_, out_data)); +#else auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( handle, CUDNN_SOFTMAX_ACCURATE, mode, platform::CudnnDataType::kOne(), desc_, x->data(), platform::CudnnDataType::kZero(), desc_, out_data)); +#endif } } }; @@ -496,19 +512,32 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; DataLayout layout = DataLayout::kNCHW; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); +#else cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); +#endif auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); + +#ifdef PADDLE_WITH_HIP + auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE + : MIOPEN_SOFTMAX_MODE_CHANNEL; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward( + handle, platform::CudnnDataType::kOne(), desc_, out->data(), + desc_, dout->data(), platform::CudnnDataType::kZero(), desc_, + dx_data)); +#else auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( handle, CUDNN_SOFTMAX_ACCURATE, mode, platform::CudnnDataType::kOne(), desc_, out->data(), desc_, dout->data(), platform::CudnnDataType::kZero(), desc_, dx_data)); +#endif } } }; @@ -518,6 +547,15 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, + ops::SoftmaxCUDNNKernel, + ops::SoftmaxCUDNNKernel); +REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, + ops::SoftmaxGradCUDNNKernel, + ops::SoftmaxGradCUDNNKernel); +#else REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel, @@ -526,3 +564,4 @@ REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, ops::SoftmaxGradCUDNNKernel, ops::SoftmaxGradCUDNNKernel, ops::SoftmaxGradCUDNNKernel); +#endif diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 64030486eb4..a21ef252c03 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -22,6 +22,10 @@ limitations under the License. */ #include "paddle/fluid/platform/cudnn_helper.h" #endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif + #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -66,7 +70,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } @@ -190,7 +194,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } diff --git a/paddle/fluid/operators/split_selected_rows_op.h b/paddle/fluid/operators/split_selected_rows_op.h index 8d88da24c63..281f9fb7e59 100644 --- a/paddle/fluid/operators/split_selected_rows_op.h +++ b/paddle/fluid/operators/split_selected_rows_op.h @@ -82,7 +82,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(), src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel); } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto stream = ctx.cuda_device_context().stream(); memory::Copy(platform::CUDAPlace(), dst + j * row_numel, platform::CUDAPlace(), diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index 48d6cf8b361..eb15fe016d9 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -98,7 +98,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, memory::Copy(cpu_place, dst + i * dst_after, cpu_place, src + i * src_after, sizeof(T) * size); } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto& gpu_place = BOOST_GET_CONST(platform::CUDAPlace, place); auto& cuda_ctx = reinterpret_cast(ctx); diff --git a/paddle/fluid/operators/strided_memcpy_test.cc b/paddle/fluid/operators/strided_memcpy_test.cc index 83480b44d5b..1ab036e8692 100644 --- a/paddle/fluid/operators/strided_memcpy_test.cc +++ b/paddle/fluid/operators/strided_memcpy_test.cc @@ -72,7 +72,7 @@ TEST(StridedMemcpy, CPUConcat) { } } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(StridedMemcpy, GPUCrop) { // clang-format off int src[] = { -- GitLab