From 59940cb383bb9db453e470495e05428063e80154 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 2 Mar 2021 19:40:40 +0800 Subject: [PATCH] [ROCM] update fluid operators for rocm (part8), test=develop (#31309) --- .../operators/grid_sampler_cudnn_op.cu.cc | 5 + paddle/fluid/operators/grid_sampler_op.cc | 7 +- paddle/fluid/operators/group_norm_op.cu | 23 +- paddle/fluid/operators/index_select_op.cu | 16 ++ paddle/fluid/operators/inplace_abn_op.cu | 9 + paddle/fluid/operators/instance_norm_op.cu | 118 +++++++- paddle/fluid/operators/layer_norm_op.cu | 41 ++- paddle/fluid/operators/layer_norm_op.h | 8 +- .../fluid/operators/lod_tensor_to_array_op.cc | 2 +- paddle/fluid/operators/matmul_op.cc | 20 +- paddle/fluid/operators/mean_op.cu | 6 + paddle/fluid/operators/merge_lod_tensor_op.cc | 2 +- paddle/fluid/operators/miopen_lstm_cache.h | 141 +++++++++ paddle/fluid/operators/miopen_rnn_cache.h | 267 ++++++++++++++++++ .../fluid/operators/modified_huber_loss_op.h | 4 +- paddle/fluid/operators/multinomial_op.cu | 21 +- paddle/fluid/operators/nll_loss_op.cu | 9 +- paddle/fluid/operators/norm_op.cu | 6 + paddle/fluid/operators/norm_utils.cu.h | 10 + 19 files changed, 687 insertions(+), 28 deletions(-) create mode 100644 paddle/fluid/operators/miopen_lstm_cache.h create mode 100644 paddle/fluid/operators/miopen_rnn_cache.h diff --git a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc index f0903bdfce9..d2002b487ca 100644 --- a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc +++ b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef PADDLE_WITH_HIP +// HIP not support cudnnSpatialTfGridGeneratorForward + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -140,3 +143,5 @@ REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(grid_sampler_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNGridSampleGradOpKernel, paddle::operators::CUDNNGridSampleGradOpKernel); + +#endif // PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index e357133be44..a75ea538f25 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -20,6 +20,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif namespace paddle { namespace operators { @@ -71,7 +74,7 @@ class GridSampleOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } @@ -191,7 +194,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index b7f79be45be..2a550486929 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -12,9 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +#endif + #include "paddle/fluid/operators/group_norm_op.h" #include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -39,10 +46,18 @@ enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 }; template __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { +#ifdef PADDLE_WITH_CUDA typedef cub::WarpReduce WarpReduce; +#else + typedef hipcub::WarpReduce WarpReduce; +#endif typename WarpReduce::TempStorage temp_storage; value = WarpReduce(temp_storage).Sum(value); +#ifdef PADDLE_WITH_CUDA if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); +#else + if (hipcub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); +#endif } template @@ -217,10 +232,10 @@ __global__ void GroupNormBackwardGetMeanAndVar( d_bias_data += dval; d_scale_data += val * dval; } - CudaAtomicAddWithWarp(&d_mean[bid * groups + gid], d_mean_data); - CudaAtomicAddWithWarp(&d_var[bid * groups + gid], d_var_data); - if (flags & kHasScale) CudaAtomicAddWithWarp(&d_scale[ccid], d_scale_data); - if (flags & kHasBias) CudaAtomicAddWithWarp(&d_bias[ccid], d_bias_data); + CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data); + CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data); + if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data); + if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data); } template diff --git a/paddle/fluid/operators/index_select_op.cu b/paddle/fluid/operators/index_select_op.cu index 752e8b277da..43761d97962 100644 --- a/paddle/fluid/operators/index_select_op.cu +++ b/paddle/fluid/operators/index_select_op.cu @@ -106,14 +106,22 @@ class IndexSelectCUDAKernel : public framework::OpKernel { (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data, numel, stride, size, delta); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } else { const int* index_data = index->data(); index_select_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( in_data, out_data, index_data, numel, stride, size, delta); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } } }; @@ -164,7 +172,11 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel { PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, index_data, index_nums, numel, stride, size, delta); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } else { const int* index_data = index->data(); index_select_grad_cuda_kernel<<< @@ -172,7 +184,11 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel { PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, index_data, index_nums, numel, stride, size, delta); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } } }; diff --git a/paddle/fluid/operators/inplace_abn_op.cu b/paddle/fluid/operators/inplace_abn_op.cu index 9e12a8291c0..be7a7bd7171 100644 --- a/paddle/fluid/operators/inplace_abn_op.cu +++ b/paddle/fluid/operators/inplace_abn_op.cu @@ -84,9 +84,18 @@ class InplaceABNGradKernel namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_CUDA_KERNEL(inplace_abn, + ops::InplaceABNKernel); +REGISTER_OP_CUDA_KERNEL( + inplace_abn_grad, + ops::InplaceABNGradKernel); +#else REGISTER_OP_CUDA_KERNEL(inplace_abn, ops::InplaceABNKernel, ops::InplaceABNKernel); REGISTER_OP_CUDA_KERNEL( inplace_abn_grad, ops::InplaceABNGradKernel, ops::InplaceABNGradKernel); +#endif diff --git a/paddle/fluid/operators/instance_norm_op.cu b/paddle/fluid/operators/instance_norm_op.cu index 51313835eba..affd0b7e1ed 100644 --- a/paddle/fluid/operators/instance_norm_op.cu +++ b/paddle/fluid/operators/instance_norm_op.cu @@ -16,11 +16,22 @@ limitations under the License. */ #include #include #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/operators/instance_norm_op.h" #include "paddle/fluid/operators/math/math_function.h" +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif namespace paddle { namespace operators { @@ -99,6 +110,15 @@ class InstanceNormKernel auto *y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t data_desc_; + miopenTensorDescriptor_t in_param_desc_; + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&in_param_desc_)); +#else cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t in_param_desc_; @@ -106,7 +126,7 @@ class InstanceNormKernel platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&in_param_desc_)); - +#endif if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " @@ -122,12 +142,22 @@ class InstanceNormKernel auto &dev_ctx = ctx.template device_context(); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, const_cast(dims.data()), + const_cast(strides.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDeriveBNTensorDescriptor( + in_param_desc_, data_desc_, miopenBNSpatial)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDeriveBNTensorDescriptor( in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL)); +#endif const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); @@ -171,6 +201,35 @@ class InstanceNormKernel functor(dev_ctx, saved_mean, static_cast>(0)); functor(dev_ctx, saved_variance, static_cast>(0)); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenBatchNormalizationForwardTraining( + handle, miopenBNSpatial, + const_cast( + static_cast(CudnnDataType::kOne())), + const_cast( + static_cast(CudnnDataType::kZero())), + data_desc_, static_cast(x_tmp.template data()), + data_desc_, + static_cast(y->template mutable_data(ctx.GetPlace())), + in_param_desc_, + const_cast(static_cast( + scale_tmp.template data>())), + const_cast(static_cast( + bias_tmp.template data>())), + 0, nullptr, nullptr, epsilon, + static_cast( + saved_mean->template mutable_data>( + ctx.GetPlace())), + static_cast( + saved_variance->template mutable_data>( + ctx.GetPlace())))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(in_param_desc_)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnBatchNormalizationForwardTraining( handle, CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), @@ -188,6 +247,7 @@ class InstanceNormKernel platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(in_param_desc_)); +#endif } }; @@ -332,6 +392,15 @@ class InstanceNormGradKernel return; } +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t data_desc_; + miopenTensorDescriptor_t in_param_desc_; + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&in_param_desc_)); +#else cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t in_param_desc_; @@ -339,6 +408,8 @@ class InstanceNormGradKernel platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&in_param_desc_)); +#endif + if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " @@ -346,12 +417,22 @@ class InstanceNormGradKernel } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, const_cast(dims.data()), + const_cast(strides.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDeriveBNTensorDescriptor( + in_param_desc_, data_desc_, miopenBNSpatial)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDeriveBNTensorDescriptor( in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL)); +#endif const auto *saved_mean = ctx.Input("SavedMean"); const auto *saved_var = ctx.Input("SavedVariance"); @@ -360,6 +441,21 @@ class InstanceNormGradKernel const auto *saved_var_data = saved_var->template data>(); if (d_scale && d_bias) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenBatchNormalizationBackward( + dev_ctx.cudnn_handle(), miopenBNSpatial, CudnnDataType::kOne(), + CudnnDataType::kZero(), CudnnDataType::kOne(), + CudnnDataType::kZero(), data_desc_, x_tmp.template data(), + data_desc_, d_y_tmp.template data(), data_desc_, + d_x->template mutable_data(ctx.GetPlace()), in_param_desc_, + scale_tmp.template data>(), + d_scale_tmp.template mutable_data>( + ctx.GetPlace()), + d_bias_tmp.template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean_data, saved_var_data)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnBatchNormalizationBackward( dev_ctx.cudnn_handle(), CUDNN_BATCHNORM_SPATIAL, @@ -373,6 +469,7 @@ class InstanceNormGradKernel d_bias_tmp.template mutable_data>( ctx.GetPlace()), epsilon, saved_mean_data, saved_var_data)); +#endif } else { if (d_x) { GradComputeDX<<>>( @@ -389,10 +486,17 @@ class InstanceNormGradKernel d_bias_tmp.data(), d_bias->data(), N, C); } +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(in_param_desc_)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(in_param_desc_)); +#endif } }; @@ -693,6 +797,17 @@ class InstanceNormDoubleGradKernel namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_CUDA_KERNEL( + instance_norm, ops::InstanceNormKernel); +REGISTER_OP_CUDA_KERNEL( + instance_norm_grad, + ops::InstanceNormGradKernel); +REGISTER_OP_CUDA_KERNEL(instance_norm_grad_grad, + ops::InstanceNormDoubleGradKernel< + paddle::platform::CUDADeviceContext, float>); +#else REGISTER_OP_CUDA_KERNEL( instance_norm, ops::InstanceNormKernel, ops::InstanceNormKernel); @@ -706,3 +821,4 @@ REGISTER_OP_CUDA_KERNEL( float>, ops::InstanceNormDoubleGradKernel); +#endif diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index 6883ba009c5..d0f7dca98af 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -12,14 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/operators/layer_norm_op.h" -#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif namespace paddle { namespace operators { @@ -348,7 +359,11 @@ __global__ void LayerNormBackwardComputeGradInput( // epsilon, const T* gamma, const U *__restrict__ mean, const U *__restrict__ var, const float epsilon, const U *gamma, T *grad_input) { +#ifdef __HIPCC__ + for (auto i1 = hipBlockIdx_y; i1 < n1; i1 += hipGridDim_y) { +#else for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { +#endif U sum_loss1 = U(0); U sum_loss2 = U(0); const U c_mean = mean[i1]; @@ -392,12 +407,19 @@ __global__ void LayerNormBackwardComputeGradInput( } // intra-warp reductions for (int mask = BDIMX / 2; mask > 0; mask /= 2) { +#ifdef PADDLE_WITH_HIP + sum_loss1 += __shfl_xor(sum_loss1, mask, + warpSize); // WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += __shfl_xor(sum_loss2, mask, + warpSize); // WARP_SHFL_XOR(sum_loss2, mask); +#else sum_loss1 += __shfl_xor_sync(0xffffffff, sum_loss1, mask, warpSize); // WARP_SHFL_XOR(sum_loss1, mask); sum_loss2 += __shfl_xor_sync(0xffffffff, sum_loss2, mask, warpSize); // WARP_SHFL_XOR(sum_loss2, mask); +#endif } // inter-warp reductions if (BDIMY > 1) { @@ -821,7 +843,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, } template -void LayerNormDirectCUDAFunctor::operator()(cudaStream_t stream, +void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, const T *input, std::vector input_shape, const T *bias, const T *scale, @@ -942,6 +964,18 @@ template class LayerNormDirectCUDAFunctor; namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_CUDA_KERNEL( + layer_norm, + ops::LayerNormKernel, + ops::LayerNormKernel); +REGISTER_OP_CUDA_KERNEL( + layer_norm_grad, + ops::LayerNormGradKernel, + ops::LayerNormGradKernel); +#else REGISTER_OP_CUDA_KERNEL( layer_norm, ops::LayerNormKernel, @@ -953,3 +987,4 @@ REGISTER_OP_CUDA_KERNEL( ops::LayerNormGradKernel, ops::LayerNormGradKernel); +#endif diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 931cd6d1794..c9ba37d0008 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -51,7 +51,7 @@ struct RowwiseMean2D { const framework::Tensor& input, framework::Tensor* vec); }; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class RowwiseMean2D { public: @@ -97,7 +97,7 @@ struct ColwiseSum2D { const framework::Tensor& input, framework::Tensor* vec); }; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class ColwiseSum2D { public: @@ -163,11 +163,11 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using DataLayout = framework::DataLayout; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class LayerNormDirectCUDAFunctor { public: - void operator()(cudaStream_t stream, const T* input, + void operator()(gpuStream_t stream, const T* input, std::vector input_shape, const T* bias, const T* scale, T* output, T* mean, T* variance, int begin_norm_axis, float eps); diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index cb857e5d906..e02972bd753 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -63,7 +63,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor { if (std::is_same::value) { Apply(static_cast(dev_ctx)); } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) Apply(static_cast(dev_ctx)); #else PADDLE_THROW( diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index e97565a6623..9b64e99c944 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -76,7 +76,8 @@ class MatMulKernel : public framework::OpKernel { auto scale = static_cast(context.Attr("alpha")); int head_number = 1; -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) head_number = context.Attr("head_number"); #endif @@ -89,7 +90,8 @@ class MatMulKernel : public framework::OpKernel { mat_dim_a.batch_size_ = 0; } } -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_); if (head_number > 1) { @@ -228,7 +230,8 @@ class MatMulGradKernel : public framework::OpKernel { auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); int head_number = 1; -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) head_number = context.Attr("head_number"); #endif @@ -362,7 +365,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel { auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); int head_number = 1; -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) head_number = context.Attr("head_number"); #endif @@ -562,7 +566,8 @@ class MatMulOp : public framework::OperatorWithKernel { DumpMatrixShape(mat_dim_y).c_str())); } int64_t dim_out_y = mat_dim_y.width_; -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) int head_number = context->Attrs().Get("head_number"); bool split_vertical_y = (mat_dim_x.width_ != mat_dim_y.height_); if (context->IsRuntime()) { @@ -750,7 +755,8 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { "used in MKL-DNN INT8") .SetDefault(false); -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) AddAttr("head_number", "The number of heads of the matrix") .SetDefault(1); #endif @@ -916,7 +922,7 @@ REGISTER_OP_CPU_KERNEL( ops::MatMulDoubleGradKernel, ops::MatMulDoubleGradKernel); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) REGISTER_OP_CUDA_KERNEL( matmul, ops::MatMulKernel, ops::MatMulKernel, diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 081c077ab73..430036bc67d 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -11,7 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/merge_lod_tensor_op.cc b/paddle/fluid/operators/merge_lod_tensor_op.cc index 584de34c5d3..5024148fe58 100644 --- a/paddle/fluid/operators/merge_lod_tensor_op.cc +++ b/paddle/fluid/operators/merge_lod_tensor_op.cc @@ -65,7 +65,7 @@ class MergeLoDTensorOp : public framework::OperatorBase { if (platform::is_cpu_place(mask.place())) { cpu_mask->ShareDataWith(mask); } else if (platform::is_gpu_place(mask.place())) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get()); #else diff --git a/paddle/fluid/operators/miopen_lstm_cache.h b/paddle/fluid/operators/miopen_lstm_cache.h new file mode 100644 index 00000000000..7c0faa86be0 --- /dev/null +++ b/paddle/fluid/operators/miopen_lstm_cache.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/miopen_helper.h" + +namespace paddle { +namespace operators { + +class ScopedRNNBase { + public: + ScopedRNNBase(int seq_length, int batch_size, int input_size, int hidden_size, + int num_layers, float dropout_prob, int seed, int weight_numel, + bool initialized, bool is_bidirec) + : seq_length_(seq_length), + batch_size_(batch_size), + input_size_(input_size), + hidden_size_(hidden_size), + num_layers_(num_layers), + dropout_prob_(dropout_prob), + seed_(seed), + weight_numel_(weight_numel), + initialized_(initialized), + is_bidirec_(is_bidirec) {} + + template + void Create(const miopenHandle_t& handle, const platform::Place& place, + const std::vector& sequence_length, size_t* workspace_size, + size_t* reserve_size, framework::Tensor* dropout_state) { + int numDirections = is_bidirec_ ? 2 : 1; + miopenDataType_t miopen_type = platform::CudnnDataType::type; + + // ------------------- miopen x, y descriptors --------------------- + std::vector dims_x = {batch_size_, input_size_, 1}; + std::vector strides_x = {input_size_, 1, 1}; + std::vector dims_y = {batch_size_, hidden_size_ * numDirections, 1}; + std::vector strides_y = {hidden_size_ * numDirections, 1, 1}; + for (int i = 0; i < seq_length_; ++i) { + x_descs_.emplace_back(x_desc_.descriptor(dims_x, strides_x)); + y_descs_.emplace_back(y_desc_.descriptor(dims_y, strides_y)); + } + + // ------------------- miopen hx, hy, cx, cy descriptors---------- + std::vector dims_hx = {num_layers_ * numDirections, batch_size_, + hidden_size_}; + std::vector strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1}; + init_h_desc_.descriptor(dims_hx, strides_hx); + init_c_desc_.descriptor(dims_hx, strides_hx); + last_h_desc_.descriptor(dims_hx, strides_hx); + last_c_desc_.descriptor(dims_hx, strides_hx); + + // ------------------- miopen dropout descriptors --------------------- + size_t state_size; + if (!initialized_) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDropoutGetStatesSize(handle, &state_size)); + dropout_state->mutable_data({static_cast(state_size)}, + place); + } + dropout_desc_.descriptor(handle, place, initialized_, dropout_prob_, + dropout_state, seed_, state_size); + + // ------------------- miopen rnn descriptors --------------------- + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetRNNDescriptor( + rnn_desc_.desc(), hidden_size_, num_layers_, miopenRNNlinear, + is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection, miopenLSTM, + miopenRNNNoBias, miopenRNNdefault, miopen_type)); + + // ------------------- miopen weights_size --------------------- + size_t weights_size_; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenGetRNNParamsSize( + handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, miopen_type)); + PADDLE_ENFORCE_EQ( + weights_size_, sizeof(T) * weight_numel_, + platform::errors::InvalidArgument( + "The miopen lstm and setting weight size should be same.")); + // ------------------- miopen weight descriptors --------------------- + platform::DataLayout layout = platform::DataLayout::kNCHW; + int dim_tmp = weights_size_ / sizeof(T); + std::vector dim_w = {dim_tmp, 1, 1}; + weight_desc_.descriptor(layout, dim_w); + // ------------------- miopen workspace, reserve size --------------------- + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenGetRNNWorkspaceSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), + workspace_size)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenGetRNNTrainingReserveSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), + reserve_size)); + } + miopenTensorDescriptor_t* x_descs() { return x_descs_.data(); } + miopenTensorDescriptor_t* y_descs() { return y_descs_.data(); } + miopenTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); } + miopenTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); } + miopenTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); } + miopenTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); } + miopenRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); } + miopenDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); } + miopenTensorDescriptor_t weight_desc() { return weight_desc_.desc(); } + + private: + int seq_length_; + int batch_size_; + int input_size_; + int hidden_size_; + int num_layers_; + float dropout_prob_; + int seed_; + int weight_numel_; + bool initialized_; + bool is_bidirec_; + std::vector x_descs_; + std::vector y_descs_; + + platform::ScopedTensorDescriptor x_desc_; + platform::ScopedTensorDescriptor y_desc_; + platform::ScopedTensorDescriptor init_h_desc_; + platform::ScopedTensorDescriptor init_c_desc_; + platform::ScopedTensorDescriptor last_h_desc_; + platform::ScopedTensorDescriptor last_c_desc_; + platform::ScopedDropoutDescriptor dropout_desc_; + platform::ScopedFilterDescriptor weight_desc_; + platform::ScopedRNNDescriptor rnn_desc_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/miopen_rnn_cache.h b/paddle/fluid/operators/miopen_rnn_cache.h new file mode 100644 index 00000000000..97d608331cc --- /dev/null +++ b/paddle/fluid/operators/miopen_rnn_cache.h @@ -0,0 +1,267 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/miopen_helper.h" + +namespace paddle { +namespace operators { + +struct CudnnRNNCache { + CudnnRNNCache() { + x_desc_ = NULL; + y_desc_ = NULL; + } + ~CudnnRNNCache() { release(); } + + miopenRNNDescriptor_t rnn_desc_; + miopenTensorDescriptor_t *x_desc_; + miopenTensorDescriptor_t *y_desc_; + + miopenTensorDescriptor_t hx_desc_; + miopenTensorDescriptor_t cx_desc_; + miopenTensorDescriptor_t hy_desc_; + miopenTensorDescriptor_t cy_desc_; + + miopenTensorDescriptor_t dhx_desc_; + miopenTensorDescriptor_t dcx_desc_; + miopenTensorDescriptor_t dhy_desc_; + miopenTensorDescriptor_t dcy_desc_; + + miopenTensorDescriptor_t output_x_desc_; + miopenTensorDescriptor_t output_y_desc_; + + miopenDropoutDescriptor_t dropout_desc_; + + size_t weights_size_; + miopenTensorDescriptor_t w_desc_; + miopenTensorDescriptor_t dw_desc_; + + size_t workspace_size_; + framework::Tensor workspace_data_; + + size_t seq_length_; + + float dropout_prob_; + bool is_bidirec_; + + int batch_size_; + int input_size_; + int hidden_size_; + int num_layers_; + int seed_; + + void init(miopenHandle_t handle, const platform::Place &place, size_t seq_len, + int batch_size, int input_size, int hidden_size, int num_layers, + float dropout_prob, bool is_bidirec, int seed, int weight_numel, + size_t *reserve_size_, framework::Tensor *dropout_state_, + bool initialized, miopenDataType_t miopen_type) { + seq_length_ = seq_len; + batch_size_ = batch_size; + input_size_ = input_size; + hidden_size_ = hidden_size; + num_layers_ = num_layers; + dropout_prob_ = dropout_prob; + is_bidirec_ = is_bidirec; + seed_ = seed; + + const auto numDirections = is_bidirec_ ? 2 : 1; + + PADDLE_ENFORCE_EQ(miopen_type, miopenFloat, + platform::errors::InvalidArgument( + "MIOPEN do not support double datatype.")); + auto miopen_size = sizeof(float); + + x_desc_ = new miopenTensorDescriptor_t[seq_length_]; + y_desc_ = new miopenTensorDescriptor_t[seq_length_]; + std::vector dims = {batch_size_, input_size_, 1}; + std::vector strides = {input_size_, 1, 1}; + + std::vector dims_y = {batch_size_, hidden_size_ * numDirections, 1}; + std::vector strides_y = {hidden_size_ * numDirections, 1, 1}; + + for (size_t i = 0; i < seq_length_; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&x_desc_[i])); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&y_desc_[i])); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + x_desc_[i], miopen_type, 3, const_cast(dims.data()), + const_cast(strides.data()))); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + y_desc_[i], miopen_type, 3, const_cast(dims_y.data()), + const_cast(strides_y.data()))); + } + + std::vector dims_hx = {num_layers_ * numDirections, batch_size_, + hidden_size_}; + std::vector strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1}; + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&hx_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&cx_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&hy_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&cy_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&dhx_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&dcx_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&dhy_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&dcy_desc_)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + hx_desc_, miopen_type, 3, const_cast(dims_hx.data()), + const_cast(strides_hx.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + cx_desc_, miopen_type, 3, const_cast(dims_hx.data()), + const_cast(strides_hx.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + hy_desc_, miopen_type, 3, const_cast(dims_hx.data()), + const_cast(strides_hx.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + cy_desc_, miopen_type, 3, const_cast(dims_hx.data()), + const_cast(strides_hx.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + dhx_desc_, miopen_type, 3, const_cast(dims_hx.data()), + const_cast(strides_hx.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + dcx_desc_, miopen_type, 3, const_cast(dims_hx.data()), + const_cast(strides_hx.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + dhy_desc_, miopen_type, 3, const_cast(dims_hx.data()), + const_cast(strides_hx.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + dcy_desc_, miopen_type, 3, const_cast(dims_hx.data()), + const_cast(strides_hx.data()))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateDropoutDescriptor(&dropout_desc_)); + + size_t state_size; + if (!initialized) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDropoutGetStatesSize(handle, &state_size)); + dropout_state_->Resize({static_cast(state_size)}); + uint8_t *dropout_state_data = + dropout_state_->mutable_data(place); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetDropoutDescriptor( + dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size, + seed_, false, false, MIOPEN_RNG_PSEUDO_XORWOW)); + } else { + uint8_t *dropout_state_data = dropout_state_->data(); + auto dropout_state_dims = dropout_state_->dims(); + state_size = dropout_state_dims[0]; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenRestoreDropoutDescriptor( + dropout_desc_, handle, dropout_prob_, dropout_state_data, + state_size, 0, false, false, MIOPEN_RNG_PSEUDO_XORWOW)); + } + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateRNNDescriptor(&rnn_desc_)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetRNNDescriptor( + rnn_desc_, hidden_size_, num_layers_, miopenRNNlinear, + is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection, miopenLSTM, + miopenRNNNoBias, miopenRNNdefault, miopen_type)); + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&w_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&dw_desc_)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenGetRNNParamsSize( + handle, rnn_desc_, x_desc_[0], &weights_size_, miopen_type)); + + PADDLE_ENFORCE_EQ( + weights_size_, miopen_size * weight_numel, + platform::errors::InvalidArgument( + "The miopen lstm and setting weight size should be same.")); + + int dim_w[3]; + dim_w[0] = weights_size_ / miopen_size; + dim_w[1] = 1; + dim_w[2] = 1; + + int dim_s[2]; + dim_s[1] = 1; + dim_s[0] = dim_w[1]; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + w_desc_, miopen_type, 3, dim_w, dim_s)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + dw_desc_, miopen_type, 3, dim_w, dim_s)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenGetRNNWorkspaceSize( + handle, rnn_desc_, seq_length_, x_desc_, &workspace_size_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenGetRNNTrainingReserveSize( + handle, rnn_desc_, seq_length_, x_desc_, reserve_size_)); + + workspace_data_.Resize({static_cast(workspace_size_)}); + workspace_data_.mutable_data(place); + } + + void release() { + for (size_t i = 0; i < seq_length_; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(x_desc_[i])); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(y_desc_[i])); + } + + delete[] x_desc_; + delete[] y_desc_; + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(hx_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(cx_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(hy_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(cy_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(dhx_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(dcx_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(dhy_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(dcy_desc_)); + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyDropoutDescriptor(dropout_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyRNNDescriptor(rnn_desc_)); + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(w_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(dw_desc_)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/modified_huber_loss_op.h b/paddle/fluid/operators/modified_huber_loss_op.h index 17621095c49..398676ba741 100644 --- a/paddle/fluid/operators/modified_huber_loss_op.h +++ b/paddle/fluid/operators/modified_huber_loss_op.h @@ -29,8 +29,8 @@ using EigenVector = framework::EigenVector; template struct CheckLabelValue { HOSTDEVICE T operator()(const T& val) const { - PADDLE_ENFORCE( - val == static_cast(0) || val == static_cast(1), + PADDLE_ENFORCE_EQ( + val == static_cast(0) || val == static_cast(1), true, platform::errors::InvalidArgument( "Input(label) value of modified_huber_loss_op expected to be 0 " "or 1, but got %ld. Please check label value.", diff --git a/paddle/fluid/operators/multinomial_op.cu b/paddle/fluid/operators/multinomial_op.cu index 92f7c992ed9..2d97111709a 100644 --- a/paddle/fluid/operators/multinomial_op.cu +++ b/paddle/fluid/operators/multinomial_op.cu @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef PADDLE_WITH_HIP +// To-do(qili93): fix this after issue resolved +// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202 + #include #include #include @@ -155,13 +159,24 @@ class MultinomialOpKernel T* cpu_in_data = new T[in_data_numel]; int64_t* cpu_out_data = new int64_t[out_data_numel]; +#ifdef PADDLE_WITH_HIP + hipMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T), + hipMemcpyDeviceToHost); +#else cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T), cudaMemcpyDeviceToHost); +#endif MultinomialFunctor(cpu_out_data, cpu_in_data, num_samples, replacement, num_categories, num_distributions); + +#ifdef PADDLE_WITH_HIP + hipMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t), + hipMemcpyHostToDevice); +#else cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t), cudaMemcpyHostToDevice); +#endif delete[] cpu_in_data; delete[] cpu_out_data; @@ -250,5 +265,7 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - multinomial, ops::MultinomialOpKernel, - ops::MultinomialOpKernel); + multinomial, ops::MultinomialOpKernel, + ops::MultinomialOpKernel); + +#endif diff --git a/paddle/fluid/operators/nll_loss_op.cu b/paddle/fluid/operators/nll_loss_op.cu index 531c175e03e..b6e7cd256e1 100644 --- a/paddle/fluid/operators/nll_loss_op.cu +++ b/paddle/fluid/operators/nll_loss_op.cu @@ -11,7 +11,6 @@ limitations under the License. */ #include #include #include -#include "cub/cub.cuh" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/nll_loss_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -361,7 +360,11 @@ class NLLLossCUDAKernel : public framework::OpKernel { auto total_weight_data = total_weight->mutable_data(ctx.GetPlace()); auto label_data = labels->data(); auto weight_data = weight ? weight->data() : nullptr; +#ifdef PADDLE_WITH_HIP + hipMemset(total_weight_data, 0, sizeof(T)); +#else cudaMemset(total_weight_data, 0, sizeof(T)); +#endif auto x_dims = x->dims(); auto batch_size = x_dims[0]; auto n_classes = x_dims[1]; @@ -429,7 +432,11 @@ class NLLLossGradCUDAKernel : public framework::OpKernel { auto total_weight_data = total_weight->data(); auto ignore_index = ctx.Attr("ignore_index"); auto reduction = ctx.Attr("reduction"); +#ifdef PADDLE_WITH_HIP + hipMemset(dx_data, 0, dx->numel() * sizeof(T)); +#else cudaMemset(dx_data, 0, dx->numel() * sizeof(T)); +#endif int64_t size_average = (int64_t)(reduction == "mean"); auto x_dims = x->dims(); diff --git a/paddle/fluid/operators/norm_op.cu b/paddle/fluid/operators/norm_op.cu index 67449aa4c67..6b5c70c9258 100644 --- a/paddle/fluid/operators/norm_op.cu +++ b/paddle/fluid/operators/norm_op.cu @@ -13,7 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/operators/norm_op.h" namespace paddle { diff --git a/paddle/fluid/operators/norm_utils.cu.h b/paddle/fluid/operators/norm_utils.cu.h index 02dcb4045f4..9fcc6292338 100644 --- a/paddle/fluid/operators/norm_utils.cu.h +++ b/paddle/fluid/operators/norm_utils.cu.h @@ -17,10 +17,20 @@ limitations under the License. */ #include #include #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/operators/math/math_function.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#else #include "paddle/fluid/platform/cudnn_helper.h" +#endif namespace paddle { namespace operators { -- GitLab