From 3b9db17199dd772b7067d4e2f165bdd3180b4133 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Wed, 3 Mar 2021 15:46:20 +0800 Subject: [PATCH] [ROCM] update fluid operators for rocm (part7), test=develop (#31307) --- paddle/fluid/operators/CMakeLists.txt | 13 +++- paddle/fluid/operators/bmm_op.cu | 2 +- paddle/fluid/operators/cholesky_op.cu | 5 ++ paddle/fluid/operators/clip_op.h | 4 +- paddle/fluid/operators/coalesce_tensor_op.cc | 2 +- paddle/fluid/operators/correlation_op.cu | 5 ++ paddle/fluid/operators/cudnn_lstm_op.cu.cc | 73 ++++++++++++++++--- paddle/fluid/operators/cumsum_op.cu | 8 +- paddle/fluid/operators/data_norm_op.cu | 8 +- paddle/fluid/operators/diag_embed_op.h | 2 +- paddle/fluid/operators/dot_op.h | 4 +- paddle/fluid/operators/dropout_op.cu | 54 +++++++++++++- paddle/fluid/operators/dropout_op.h | 4 +- paddle/fluid/operators/fake_quantize_op.cu | 4 + .../fill_constant_batch_size_like_op.h | 2 +- paddle/fluid/operators/fill_constant_op.h | 4 +- paddle/fluid/operators/filter_by_instag_op.h | 2 +- paddle/fluid/operators/gelu_op.h | 6 +- .../get_tensor_from_selected_rows_op.cc | 2 +- 19 files changed, 167 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 598e417526f..467a5ff9063 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -73,9 +73,11 @@ register_operators(EXCLUDES py_func_op warpctc_op dgc_op lstm_op run_program_op op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) -if (WITH_GPU) +if (WITH_GPU OR WITH_ROCM) + if(WITH_ROCM) + op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) # warpctc_op needs cudnn 7 above - if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7) + elseif(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) @@ -108,7 +110,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_fun set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) -if (WITH_GPU) +if (WITH_GPU OR WITH_ROCM) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) @@ -139,9 +141,12 @@ cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_t cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) -nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator) if (WITH_GPU) + nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator) nv_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3) +elseif(WITH_ROCM) + hip_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator) + hip_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3) else() cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3) endif() diff --git a/paddle/fluid/operators/bmm_op.cu b/paddle/fluid/operators/bmm_op.cu index 961d74b7ad4..15a7506a8f5 100644 --- a/paddle/fluid/operators/bmm_op.cu +++ b/paddle/fluid/operators/bmm_op.cu @@ -11,7 +11,7 @@ #include "paddle/fluid/operators/bmm_op.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( bmm, ops::BmmKernel, diff --git a/paddle/fluid/operators/cholesky_op.cu b/paddle/fluid/operators/cholesky_op.cu index 530147609fe..44260573052 100644 --- a/paddle/fluid/operators/cholesky_op.cu +++ b/paddle/fluid/operators/cholesky_op.cu @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + #include #include #include @@ -164,3 +167,5 @@ REGISTER_OP_CUDA_KERNEL( cholesky_grad, ops::CholeskyGradKernel, ops::CholeskyGradKernel); + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index 097b5e4863d..93157ed9d47 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -25,7 +25,7 @@ namespace operators { using framework::Tensor; using platform::Transform; -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) template __global__ void ClipCudaKernel(const T* input, T* out, int num, UnaryOperation op) { @@ -105,7 +105,7 @@ class ClipKernel : public framework::OpKernel { const T* x_data = x->data(); int64_t numel = x->numel(); if (platform::is_gpu_place(context.GetPlace())) { -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) int threads = 256; int blocks = (numel + threads - 1) / threads; ClipCudaKernel><<< diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index ad255b18826..153fa529f96 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -289,7 +289,7 @@ REGISTER_OP_CPU_KERNEL( ops::CoalesceTensorOpKernel, ops::CoalesceTensorOpKernel); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) REGISTER_OP_CUDA_KERNEL( coalesce_tensor, ops::CoalesceTensorOpKernel #include #include "paddle/fluid/framework/op_registry.h" @@ -480,3 +483,5 @@ REGISTER_OP_CUDA_KERNEL(correlation, ops::CorrelationCUDAKernel, ops::CorrelationCUDAKernel); REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel, ops::CorrelationCUDAGradKernel); + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index e935a3c0aac..27f64b41948 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -14,9 +14,14 @@ limitations under the License. */ #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/cudnn_lstm_cache.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/utils.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/operators/cudnn_lstm_cache.h" +#endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/operators/miopen_lstm_cache.h" +#endif namespace paddle { namespace platform { @@ -54,7 +59,7 @@ int size_sum(const std::vector &weight_list) { } template -void weight_to_tensor(const platform::Place &place, cudaStream_t stream, +void weight_to_tensor(const platform::Place &place, gpuStream_t stream, const std::vector &weight_list, Tensor *weight) { auto weight_data = weight->data(); @@ -72,7 +77,7 @@ void weight_to_tensor(const platform::Place &place, cudaStream_t stream, } template -void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream, +void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream, std::vector *weight_grad, const std::vector &weight_input, const Tensor *weight) { @@ -92,23 +97,36 @@ void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream, } template +#ifdef PADDLE_WITH_HIP +void LSTMInferece(const bool &has_seq_length, const miopenHandle_t &handle, +#else void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle, +#endif const int &seq_length, ScopedRNNBase *rnn, const T *x_data, const T *init_h_data, const T *init_c_data, const T *w_data, T *out_data, T *last_h_data, T *last_c_data, framework::Tensor *workspace_data, const size_t &workspace_size) { if (!has_seq_length) { - // for inference - // This interface is used when the input/output is unpadded. +// for inference +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNForwardInference( + handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data, + rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data, + rnn->weight_desc(), w_data, rnn->y_descs(), out_data, + rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data, + workspace_data->data(), workspace_size)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data, rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data, rnn->weight_desc(), w_data, rnn->y_descs(), out_data, rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data, workspace_data->data(), workspace_size)); +#endif } else { -#if CUDNN_VERSION >= 7201 +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 // for inference // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx( @@ -256,8 +274,17 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { last_c_data, &workspace_data_, workspace_size); } else { if (!has_seq_length) { - // for train - // This interface is used when the input/output is unpadded. +// for train +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNForwardTraining( + handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.weight_desc(), w_data, rnn.y_descs(), out_data, + rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, + workspace_data_.data(), workspace_size, reserve_data, + reserve_size)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining( handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data, rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, @@ -265,8 +292,9 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, workspace_data_.data(), workspace_size, reserve_data, reserve_size)); +#endif } else { -#if CUDNN_VERSION >= 7201 +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 // for train // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS( @@ -403,7 +431,23 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { const uint8_t *reserve_data = reserve->data(); if (!has_seq_length) { - // This interface is used when the input/output is unpadded. +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNBackwardData( + handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data, + rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data, + rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data, + rnn.init_c_desc(), init_c_grad_data, workspace_data_.data(), + workspace_size, const_cast(reserve_data), reserve_size)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNBackwardWeights( + handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data(), + rnn.init_h_desc(), init_h->data(), rnn.y_descs(), out->data(), + rnn.weight_desc(), weight_grad_data, workspace_data_.data(), + workspace_size, const_cast(reserve_data), reserve_size)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData( handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data, rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data, @@ -418,8 +462,9 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { rnn.init_h_desc(), init_h->data(), rnn.y_descs(), out->data(), workspace_data_.data(), workspace_size, rnn.weight_desc(), weight_grad_data, const_cast(reserve_data), reserve_size)); +#endif } else { -#if CUDNN_VERSION >= 7201 +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 // for train // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx( @@ -452,7 +497,13 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel); +#else REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel, ops::CudnnLSTMGPUKernel); REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel, ops::CudnnLSTMGPUGradKernel); +#endif diff --git a/paddle/fluid/operators/cumsum_op.cu b/paddle/fluid/operators/cumsum_op.cu index f75eb7fd967..854be76f24e 100644 --- a/paddle/fluid/operators/cumsum_op.cu +++ b/paddle/fluid/operators/cumsum_op.cu @@ -16,7 +16,13 @@ limitations under the License. */ #include #include #include -#include "cub/cub.cuh" +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/operators/cum_op.h" #include "paddle/fluid/platform/gpu_launch_config.h" diff --git a/paddle/fluid/operators/data_norm_op.cu b/paddle/fluid/operators/data_norm_op.cu index 9e284b1dcda..1043faa56f0 100644 --- a/paddle/fluid/operators/data_norm_op.cu +++ b/paddle/fluid/operators/data_norm_op.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/operators/data_norm_op.h" #include "paddle/fluid/platform/cuda_primitives.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -174,7 +174,7 @@ class DataNormGradKernel d_batch_sum, d_batch_square_sum); if (need_sync_stats) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace()); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( reinterpret_cast(d_batch_size), @@ -188,7 +188,11 @@ class DataNormGradKernel reinterpret_cast(d_batch_square_sum), reinterpret_cast(d_batch_square_sum), C, platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream)); +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU, and need_sync_stats connot be " diff --git a/paddle/fluid/operators/diag_embed_op.h b/paddle/fluid/operators/diag_embed_op.h index 8c4c68fb1ff..aff7d7e48a8 100644 --- a/paddle/fluid/operators/diag_embed_op.h +++ b/paddle/fluid/operators/diag_embed_op.h @@ -100,7 +100,7 @@ class DiagEmbedKernel : public framework::OpKernel { strides.push_back(stride[dim1_] + stride[dim2_]); const auto dims = vectorize(input->dims()); -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) thrust::device_vector dims_vec(dims); const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data()); thrust::device_vector strides_vec(strides); diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index a197e2149ee..0b0b7f69b9d 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -45,7 +45,7 @@ struct DotGradFunction> { const Tensor* tensor_dout, Tensor* tensor_dx, Tensor* tensor_dy, const paddle::framework::ExecutionContext& ctx) { -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) if (1 == tensor_dout->dims().size()) { auto dout = framework::EigenVector::Flatten(*tensor_dout); @@ -249,7 +249,7 @@ class DotKernel : public framework::OpKernel { auto* tensor_out = ctx.Output("Out"); tensor_out->mutable_data(ctx.GetPlace()); -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) if (1 == tensor_out->dims().size()) { auto out = framework::EigenScalar::From(*tensor_out); auto x = framework::EigenVector::Flatten(*tensor_x); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index cf90b9eb52b..fbc145d3123 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -11,8 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + +#ifdef PADDLE_WITH_CUDA #include #include +#include "paddle/fluid/platform/dynload/curand.h" +#endif +#ifdef PADDLE_WITH_HIP +#include +#include +#include "paddle/fluid/platform/dynload/hiprand.h" +#endif #include #include #include @@ -21,7 +30,6 @@ limitations under the License. */ #include #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/dropout_op.h" -#include "paddle/fluid/platform/dynload/curand.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -32,15 +40,24 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed, const float dropout_prob, const T* src, MaskType* mask_data, T* dst, bool is_upscale_in_train, uint64_t increment) { - curandStatePhilox4_32_10_t state; int idx = blockDim.x * blockIdx.x + threadIdx.x; +#ifdef PADDLE_WITH_HIP + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, idx, increment, &state); +#else + curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); +#endif MaskType mask; T dest; for (; idx < n; idx += blockDim.x * gridDim.x) { T s = src[idx]; +#ifdef PADDLE_WITH_HIP + if (hiprand_uniform(&state) < dropout_prob) { +#else if (curand_uniform(&state) < dropout_prob) { +#endif mask = 0; dest = 0; } else { @@ -62,9 +79,15 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, const T* src, MaskType* mask_data, T* dst, bool is_upscale_in_train, uint64_t increment) { +#ifdef PADDLE_WITH_HIP + int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, idx, increment, &state); +#else int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); +#endif MaskType mask; T dest; @@ -75,7 +98,11 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, T src_vec[VecSize]; LoadT* value = reinterpret_cast(&src_vec); *value = *reinterpret_cast(&src[i]); +#ifdef PADDLE_WITH_HIP + float4 rand = hiprand_uniform4(&state); +#else float4 rand = curand_uniform4(&state); +#endif T dest_vec[VecSize]; MaskType mask_vec[VecSize]; @@ -131,10 +158,17 @@ class GPUDropoutKernel : public framework::OpKernel { auto* x_data = x->data(); auto* y_data = y->mutable_data(context.GetPlace()); if (dropout_prob == 1.0f) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( mask_data, 0, x_numel * sizeof(*mask_data), stream)); +#endif return; } @@ -180,6 +214,20 @@ class GPUDropoutKernel : public framework::OpKernel { increment = offset; } +#ifdef __HIPCC__ + if (vec_size == 4 && size % 4 == 0) { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(VectorizedRandomGenerator), + config.block_per_grid, config.thread_per_block, 0, stream, size, + seed_data, dropout_prob, x_data, mask_data, y_data, + upscale_in_train, increment); + } else { + hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator), + config.block_per_grid, config.thread_per_block, 0, + stream, size, seed_data, dropout_prob, x_data, + mask_data, y_data, upscale_in_train, increment); + } +#else if (vec_size == 4 && size % 4 == 0) { VectorizedRandomGenerator< T, uint8_t, @@ -192,7 +240,7 @@ class GPUDropoutKernel : public framework::OpKernel { size, seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train, increment); } - +#endif } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index d77193e4851..69c420e2c93 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -42,7 +42,7 @@ inline int VectorizedSize(const T* pointer) { return 1; } -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) template __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, const T factor, const int64_t size, @@ -186,7 +186,7 @@ class DropoutGradKernel : public framework::OpKernel { int vec_size = VectorizedSize(grad_y->data()); if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 && size % 4 == 0) { -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) auto factor = static_cast(1.0f / (1.0f - dropout_prob)); auto stream = context.cuda_device_context().stream(); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D( diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 26dcf8bf39c..92127f9aebd 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -162,7 +162,11 @@ struct FindChannelAbsMaxFunctor { int grid = cout; int max_threads = 1024; +#ifdef PADDLE_WITH_HIP + hipMemset(out_abs_max, 0, sizeof(T) * cout); +#else cudaMemset(out_abs_max, 0, sizeof(T) * cout); +#endif for (int i = 0; i < cin / max_threads; i++) { int block = max_threads; diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.h b/paddle/fluid/operators/fill_constant_batch_size_like_op.h index e8a35d22277..432a9968ab0 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.h +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.h @@ -65,7 +65,7 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel { functor(reinterpret_cast(dev_ctx), out, static_cast(value)); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (!cpu_place) { math::SetConstant functor; out->mutable_data(ctx.GetPlace(), data_type); diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h index 5d1f1fa781d..4608f167548 100644 --- a/paddle/fluid/operators/fill_constant_op.h +++ b/paddle/fluid/operators/fill_constant_op.h @@ -121,7 +121,7 @@ class FillConstantKernel : public framework::OpKernel { functor(reinterpret_cast(dev_ctx), tensor, static_cast(value)); } else if (actual_place == 1) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) tensor->mutable_data(ctx.GetPlace(), data_type); math::SetConstant functor; functor(reinterpret_cast(dev_ctx), @@ -131,7 +131,7 @@ class FillConstantKernel : public framework::OpKernel { "PaddlePaddle should compile with GPU.")); #endif } else if (actual_place == 2) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) tensor->mutable_data(platform::CUDAPinnedPlace(), data_type); math::SetConstant functor; functor(reinterpret_cast(dev_ctx), diff --git a/paddle/fluid/operators/filter_by_instag_op.h b/paddle/fluid/operators/filter_by_instag_op.h index 9234f9be474..77bc9e466e8 100644 --- a/paddle/fluid/operators/filter_by_instag_op.h +++ b/paddle/fluid/operators/filter_by_instag_op.h @@ -31,7 +31,7 @@ namespace operators { using Tensor = framework::Tensor; using SelectedRows = framework::SelectedRows; using LoDTensor = framework::LoDTensor; -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template using Vector = framework::Vector; #else diff --git a/paddle/fluid/operators/gelu_op.h b/paddle/fluid/operators/gelu_op.h index 936da8dee85..0446d7d284b 100644 --- a/paddle/fluid/operators/gelu_op.h +++ b/paddle/fluid/operators/gelu_op.h @@ -54,7 +54,8 @@ struct GeluFunctor { } } else { #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) auto x_data = x.data(); auto out_data = out.data(); int n = std::min(x.size(), out.size()); @@ -121,7 +122,8 @@ struct GeluGradFunctor { } } else { #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) auto x_data = x.data(); auto dx_data = dx.data(); auto dout_data = dout.data(); diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index 89a5d81a227..8ce7df7eec1 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -107,7 +107,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float, ops::GetTensorFromSelectedRowsKernel, int64_t, ops::GetTensorFromSelectedRowsKernel); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) REGISTER_OP_CUDA_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float, ops::GetTensorFromSelectedRowsKernel, double, ops::GetTensorFromSelectedRowsKernel, int, -- GitLab