From 65bcaeb004945440d2595df1b118a54b5b0809bc Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 2 Mar 2021 17:51:31 +0800 Subject: [PATCH] [ROCM] update fluid operators for rocm (part5), test=develop (#31258) * [ROCM] update fluid operators for rocm (part5), test=develop * address review comments, test=develop * fix typo, test=develop --- cmake/hip.cmake | 1 + .../cudf/concurrent_unordered_map.cuh.h | 6 +- .../fluid/operators/array_to_lod_tensor_op.cc | 2 +- paddle/fluid/operators/assign_op.cc | 2 +- .../operators/math/bert_encoder_functor.cu | 37 ++++++-- .../operators/math/bert_encoder_functor.h | 15 ++- paddle/fluid/operators/math/depthwise_conv.cu | 15 ++- .../math/detail/activation_functions.h | 95 +++++++++++++++++++ paddle/fluid/operators/math/fc.cu | 2 +- paddle/fluid/operators/math/gru_compute.cc | 8 +- paddle/fluid/operators/math/im2col_test.cc | 2 +- paddle/fluid/operators/math/math_cuda_utils.h | 18 +++- paddle/fluid/operators/math/math_function.cc | 2 +- paddle/fluid/operators/math/prelu.cu | 6 +- paddle/fluid/operators/math/prelu.h | 18 ++-- paddle/fluid/operators/math/sample_prob.cu | 14 +++ paddle/fluid/operators/math/sample_prob.h | 2 +- .../math/selected_rows_functor_test.cu.cc | 6 ++ paddle/fluid/operators/math/vol2col_test.cc | 2 +- 19 files changed, 214 insertions(+), 39 deletions(-) diff --git a/cmake/hip.cmake b/cmake/hip.cmake index 523540c9794..4c492d7cc48 100644 --- a/cmake/hip.cmake +++ b/cmake/hip.cmake @@ -45,6 +45,7 @@ set(THRUST_DEVICE_SYSTEM THRUST_DEVICE_SYSTEM_HIP) # define HIP_CXX_FLAGS list(APPEND HIP_CXX_FLAGS -fPIC) list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_HCC__=1) +# Note(qili93): HIP has compile conflicts of float16.h as platform::float16 overload std::is_floating_point and std::is_integer list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1) list(APPEND HIP_CXX_FLAGS -Wno-macro-redefined) list(APPEND HIP_CXX_FLAGS -Wno-inconsistent-missing-override) diff --git a/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h b/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h index d14abd218c2..c5647f2cdcf 100644 --- a/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h @@ -737,7 +737,7 @@ x.second ); } int assign_async(const concurrent_unordered_map& other, - gpuStream_t stream = 0) { + cudaStream_t stream = 0) { m_collisions = other.m_collisions; if (other.m_hashtbl_size <= m_hashtbl_capacity) { m_hashtbl_size = other.m_hashtbl_size; @@ -754,7 +754,7 @@ x.second ); return 0; } - void clear_async(gpuStream_t stream = 0) { + void clear_async(cudaStream_t stream = 0) { constexpr int block_size = 128; init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0, stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key, @@ -771,7 +771,7 @@ x.second ); } } - int prefetch(const int dev_id, gpuStream_t stream = 0) { + int prefetch(const int dev_id, cudaStream_t stream = 0) { cudaPointerAttributes hashtbl_values_ptr_attributes; cudaError_t status = cudaPointerGetAttributes( &hashtbl_values_ptr_attributes, m_hashtbl_values); diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index 30ac662c567..1680ad528ab 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -51,7 +51,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor { if (std::is_same::value) { Apply(static_cast(pool.Get(place))); } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) Apply(static_cast(pool.Get(place))); #else PADDLE_THROW( diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index e5bceae1c95..add533bafcb 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -164,7 +164,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, ops::AssignKernel, plat::float16, ops::AssignKernel); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, ops::AssignKernel, int, ops::AssignKernel, int64_t, ops::AssignKernel, bool, diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 2373042815c..bd7f71cd131 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" @@ -145,6 +144,8 @@ __global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids, LayerNorm(thread_data, hidden, out_offset, bias, scale, output, eps); } +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#ifndef __HIPCC__ // @{ Half kernel: EmbEltwiseLayernormKernel template <> __global__ void EmbEltwiseLayernormKernel( int hidden, const int64_t *ids, const float *scale, const float *bias, @@ -188,12 +189,13 @@ __global__ void EmbEltwiseLayernormKernel( eps); #endif } +#endif // @} End Half kernel: EmbEltwiseLayernormKernel template void EmbEltwiseLayerNormFunctor::operator()( int batch, int seq_len, int hidden, const int64_t *ids, const float *scale, const float *bias, const int64_t *embs, T *output, float eps, int input_num, - cudaStream_t stream) { + gpuStream_t stream) { const unsigned tpb = 256; const dim3 grid(seq_len, batch, 1); const dim3 block(tpb, 1, 1); @@ -205,7 +207,8 @@ void EmbEltwiseLayerNormFunctor::operator()( template class EmbEltwiseLayerNormFunctor; // device function 'operator()' is not supportted until cuda 10.0 -#if CUDA_VERSION >= 10000 +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 template class EmbEltwiseLayerNormFunctor; #endif @@ -230,6 +233,8 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_, qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val); } +// HIP defined __HIP_NO_HALF_CONVERSIONS__ +#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd template <> __global__ void SoftmaxKernelWithEltadd( half *qk_buf_, const half *bias_qk_, const int batch_size, @@ -251,6 +256,7 @@ __global__ void SoftmaxKernelWithEltadd( qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val); #endif } +#endif // @} End Half kernel: SoftmaxKernelWithEltadd template __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, const T *bias_qk_, @@ -282,7 +288,9 @@ __global__ void SoftmaxKernelWithEltadd2( half2 *qk_buf_, const half2 *bias_qk_, const int batch_size, const int head_num, const int seq_len, const unsigned mask) { // operator "+" of half only suppotted after cuda version 10.0 -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000 +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#if defined(PADDLE_WITH_CUDA) || \ + (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000) int qk_offset = blockIdx.x * seq_len; int idx = threadIdx.x; assert(blockDim.x % 32 == 0); @@ -398,7 +406,8 @@ void MultiHeadGPUComputeFunctor::operator()( template class MultiHeadGPUComputeFunctor; // device function 'operator()' is not supportted until cuda 10.0 -#if CUDA_VERSION >= 10000 +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#if defined(PADDLE_WITH_CUDA) || CUDA_VERSION >= 10000 template class MultiHeadGPUComputeFunctor; #endif @@ -422,6 +431,8 @@ __global__ void SkipLayerNormSmallKernel(int num, int hidden, const T *input1, eps); } +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormSmallKernel template <> __global__ void SkipLayerNormSmallKernel( int num, int hidden, const half *input1, const half *input2, half *output, @@ -484,6 +495,7 @@ __global__ void SkipLayerNormSmallKernel( eps); #endif } +#endif // @} End Half kernel: SkipLayerNormSmallKernel template __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1, @@ -505,6 +517,8 @@ __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1, LayerNorm(thread_data, hidden, offset, bias, scale, output, eps); } +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormKernel template <> __global__ void SkipLayerNormKernel(int num, int hidden, const half *input1, @@ -527,6 +541,7 @@ __global__ void SkipLayerNormKernel(int num, int hidden, LayerNorm(thread_data, hidden, offset, bias, scale, output, eps); #endif } +#endif // @} End Half kernel: SkipLayerNormKernel template __global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1, @@ -549,6 +564,8 @@ __global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1, LayerNorm2(thread_data, hidden, offset, bias, scale, output, eps); } +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormKernel2 template <> __global__ void SkipLayerNormKernel2( int num, int hidden, const half2 *input1, const half2 *input2, @@ -572,13 +589,13 @@ __global__ void SkipLayerNormKernel2( eps); #endif } +#endif // @} End Half kernel: SkipLayerNormKernel2 template void SkipLayerNormFunctor::operator()(const int num, const int hidden, const T *input1, const T *input2, const float *scale, const float *bias, - T *output, T eps, - cudaStream_t stream) { + T *output, T eps, gpuStream_t stream) { int block = num / hidden; if (hidden <= 32) { const int threads = 32; @@ -603,6 +620,8 @@ void SkipLayerNormFunctor::operator()(const int num, const int hidden, reinterpret_cast(output), reinterpret_cast(scale), reinterpret_cast(bias), eps); +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#ifndef __HIPCC__ } else if (std::is_same::value) { SkipLayerNormKernel2<__half, __half2, threads><<>>( @@ -611,6 +630,7 @@ void SkipLayerNormFunctor::operator()(const int num, const int hidden, reinterpret_cast<__half2 *>(output), reinterpret_cast(scale), reinterpret_cast(bias), eps); +#endif } else { assert(false); // should not be here @@ -625,7 +645,8 @@ void SkipLayerNormFunctor::operator()(const int num, const int hidden, template class SkipLayerNormFunctor; // device function 'operator()' is not supportted until cuda 10.0 -#if CUDA_VERSION >= 10000 +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#if defined(PADDLE_WITH_CUDA) || CUDA_VERSION >= 10000 template class SkipLayerNormFunctor; #endif diff --git a/paddle/fluid/operators/math/bert_encoder_functor.h b/paddle/fluid/operators/math/bert_encoder_functor.h index fdbddd96a57..683606ec733 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.h +++ b/paddle/fluid/operators/math/bert_encoder_functor.h @@ -13,9 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#ifdef PADDLE_WITH_CUDA #include #include #include // NOLINT +#endif +#ifdef PADDLE_WITH_HIP +#include +#include +namespace cub = hipcub; +#endif + #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" @@ -36,7 +45,7 @@ struct CUDATypeTraits { typedef float TYPE; }; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // This functor involves a fusion calculation in Ernie or Bert. // The fusion mode is as follows: // @@ -55,7 +64,7 @@ class EmbEltwiseLayerNormFunctor { public: void operator()(int batch, int seq_len, int hidden, const int64_t *ids, const float *scale, const float *bias, const int64_t *embs, - T *output, float eps, int input_num, cudaStream_t stream); + T *output, float eps, int input_num, gpuStream_t stream); }; // This functor involves a fusion calculation in Ernie or Bert. @@ -97,7 +106,7 @@ class SkipLayerNormFunctor { public: void operator()(const int num, const int hidden, const T *input1, const T *input2, const float *scale, const float *bias, - T *output, T eps, cudaStream_t stream); + T *output, T eps, gpuStream_t stream); }; #endif diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index 882b914f94f..7439a959d38 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -14,7 +14,13 @@ limitations under the License. */ #include #include -#include "cub/cub.cuh" +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -27,7 +33,14 @@ template __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { typedef cub::WarpReduce WarpReduce; typename WarpReduce::TempStorage temp_storage; + +#ifdef __HIPCC__ + int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); + value = WarpReduce(temp_storage).Sum(value, block_size); +#else value = WarpReduce(temp_storage).Sum(value); +#endif + if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); } diff --git a/paddle/fluid/operators/math/detail/activation_functions.h b/paddle/fluid/operators/math/detail/activation_functions.h index 883ddec8fa1..38bd1a3dadb 100644 --- a/paddle/fluid/operators/math/detail/activation_functions.h +++ b/paddle/fluid/operators/math/detail/activation_functions.h @@ -130,6 +130,8 @@ struct Active { typedef T (*ActGrad)(T, T); }; +#ifdef PADDLE_WITH_CUDA + static DEVICE Active::Act kActFloat[] = { &forward::Sigmoid, &forward::SigmoidV2, &forward::Relu, &forward::Tanh, @@ -171,6 +173,99 @@ inline DEVICE double activation(double a, double b, int index) { } } // namespace backward +#else // PADDLE_WITH_CUDA + +// Note(qili93): The above implementing not work in HIP +// It will throw compile error when calling detail::forward::lstm() +// Which used ActivationType in lstm_kernel.h, compile error is: +// lstm_gpu_kernel.h:33:17: error: unsupported indirect call to function +// + +// To-do(qili93): fix this after HIP issue fixed: +// https://github.com/ROCm-Developer-Tools/HIP/issues/2186 + +namespace forward { +inline DEVICE float activation(float a, int index) { + switch (index) { + case 0: + return Sigmoid(a); + case 1: + return SigmoidV2(a); + case 2: + return Relu(a); + case 3: + return Tanh(a); + case 4: + return TanhV2(a); + case 5: + return Identity(a); + default: + return 0.0f; + } +} + +inline DEVICE double activation(double a, int index) { + switch (index) { + case 0: + return Sigmoid(a); + case 1: + return SigmoidV2(a); + case 2: + return Relu(a); + case 3: + return Tanh(a); + case 4: + return TanhV2(a); + case 5: + return Identity(a); + default: + return 0.0f; + } +} +} // namespace forward + +namespace backward { +inline DEVICE float activation(float a, float b, int index) { + switch (index) { + case 0: + return Sigmoid(a, b); + case 1: + return Sigmoid(a, b); + case 2: + return Relu(a, b); + case 3: + return Tanh(a, b); + case 4: + return Tanh(a, b); + case 5: + return Identity(a, b); + default: + return 0.0f; + } +} + +inline DEVICE double activation(double a, double b, int index) { + switch (index) { + case 0: + return Sigmoid(a, b); + case 1: + return Sigmoid(a, b); + case 2: + return Relu(a, b); + case 3: + return Tanh(a, b); + case 4: + return Tanh(a, b); + case 5: + return Identity(a, b); + default: + return 0.0f; + } +} +} // namespace backward + +#endif // PADDLE_WITH_CUDA + #ifdef __AVX__ namespace forward { namespace avx { diff --git a/paddle/fluid/operators/math/fc.cu b/paddle/fluid/operators/math/fc.cu index 1de3fa44faf..69f62d1d53d 100644 --- a/paddle/fluid/operators/math/fc.cu +++ b/paddle/fluid/operators/math/fc.cu @@ -61,7 +61,7 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) { for (int i = threadIdx.x; i < N; i += BlockDim) { T temp; -#if __CUDA_ARCH__ >= 350 +#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350 temp = __ldg(data + offset + i) + __ldg(bias + i); #else temp = data[offset + i] + bias[i]; diff --git a/paddle/fluid/operators/math/gru_compute.cc b/paddle/fluid/operators/math/gru_compute.cc index 6468296546c..b7a3974ae33 100644 --- a/paddle/fluid/operators/math/gru_compute.cc +++ b/paddle/fluid/operators/math/gru_compute.cc @@ -32,7 +32,7 @@ struct GRUUnitFunctor { const detail::ActivationType active_node, const detail::ActivationType active_gate, bool origin_mode) { -#ifndef __NVCC__ +#if !defined(__NVCC__) && !defined(__HIPCC___) auto blas = math::GetBlas(context); if (value.prev_out_value) { blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1, @@ -66,7 +66,7 @@ struct GRUUnitGradFunctor { const detail::ActivationType active_node, const detail::ActivationType active_gate, bool origin_mode) { -#ifndef __NVCC__ +#if !defined(__NVCC__) && !defined(__HIPCC___) detail::backward_state_grad(detail::backward::gru_stateGrad(), value, grad, frame_size, batch_size, active_node, origin_mode); @@ -108,7 +108,7 @@ struct GRUUnitFunctorV2 { GRUMetaValue value, int frame_size, int batch_size, const detail::ActivationType active_node, const detail::ActivationType active_gate) { -#ifndef __NVCC__ +#if !defined(__NVCC__) && !defined(__HIPCC___) auto blas = math::GetBlas(context); if (value.prev_out_value) { blas.GEMM(CblasNoTrans, CblasTrans, batch_size, frame_size, frame_size, 1, @@ -142,7 +142,7 @@ struct GRUUnitGradFunctorV2 { int frame_size, int batch_size, const detail::ActivationType active_node, const detail::ActivationType active_gate) { -#ifndef __NVCC__ +#if !defined(__NVCC__) && !defined(__HIPCC___) // calculate grad_update_gate, grad_frame_state, // grad_reset_output, grad_reset_gate detail::cpu_gru_backward(context, detail::backward::gru(), value, grad, diff --git a/paddle/fluid/operators/math/im2col_test.cc b/paddle/fluid/operators/math/im2col_test.cc index e65bda44b3b..0122e6cdeb4 100644 --- a/paddle/fluid/operators/math/im2col_test.cc +++ b/paddle/fluid/operators/math/im2col_test.cc @@ -162,7 +162,7 @@ void testIm2col() { TEST(math, im2col) { testIm2col(); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) testIm2col(); #endif diff --git a/paddle/fluid/operators/math/math_cuda_utils.h b/paddle/fluid/operators/math/math_cuda_utils.h index 65961f33aa4..b9afd2d39d0 100644 --- a/paddle/fluid/operators/math/math_cuda_utils.h +++ b/paddle/fluid/operators/math/math_cuda_utils.h @@ -13,7 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#ifdef PADDLE_WITH_CUDA #include +#endif +#ifdef PADDLE_WITH_HIP +#include +#endif + #include namespace paddle { @@ -96,7 +103,7 @@ __device__ __forceinline__ float exp_func(float a) { template <> __device__ __forceinline__ half exp_func(half a) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +#if defined(__HIPCC__) || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) return hexp(a); #else return FromFloat(expf(ToFloat(a))); @@ -137,6 +144,7 @@ struct KeyValuePair { operator+(const KeyValuePair &a) const { const half2 a2 = __halves2half2(key, value); const half2 b2 = __halves2half2(a.key, a.value); +#ifdef PADDLE_WITH_CUDA #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) const half2 res = __hadd2(a2, b2); #else @@ -149,6 +157,10 @@ struct KeyValuePair { const half2 res = __floats2half2_rn(r1, r2); #endif return KeyValuePair(res.x, res.y); +#else // PADDLE_WITH_HIP + const half2 res = __hadd2(a2, b2); + return KeyValuePair(__low2half(res), __high2half(res)); +#endif } }; @@ -159,7 +171,7 @@ struct KeyValuePair { template __inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) -#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 +#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) val += __shfl_xor_sync(lane_mask, val, mask, warpSize); #else val += __shfl_xor(val, mask, warpSize); @@ -191,7 +203,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { template __inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) -#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 +#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); #else val = max(val, __shfl_xor(val, mask, warpSize)); diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 5afda787339..a61b50faa75 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -180,7 +180,7 @@ struct TensorSetConstantWithPlace : public boost::static_visitor { void set_constant(const platform::DeviceContext& context, framework::Tensor* tensor, float value) { TensorSetConstantWithPlace func(context, tensor, value); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) tensor->place().apply_visitor(func); #else func(platform::CPUPlace()); diff --git a/paddle/fluid/operators/math/prelu.cu b/paddle/fluid/operators/math/prelu.cu index 323c3ad3064..42c4c799c57 100644 --- a/paddle/fluid/operators/math/prelu.cu +++ b/paddle/fluid/operators/math/prelu.cu @@ -61,7 +61,7 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, template void PreluChannelWiseDirectCUDAFunctor::operator()( - cudaStream_t stream, const T *input, const T *alpha, T *output, + gpuStream_t stream, const T *input, const T *alpha, T *output, size_t batch_size, size_t channel, size_t numel) { PReluChannelWiseKernel<<>>(input, alpha, output, channel, @@ -69,7 +69,7 @@ void PreluChannelWiseDirectCUDAFunctor::operator()( } template -void PreluElementWiseDirectCUDAFunctor::operator()(cudaStream_t stream, +void PreluElementWiseDirectCUDAFunctor::operator()(gpuStream_t stream, const T *input, const T *alpha, T *output, size_t batch_size, @@ -80,7 +80,7 @@ void PreluElementWiseDirectCUDAFunctor::operator()(cudaStream_t stream, } template -void PreluScalarDirectCUDAFunctor::operator()(cudaStream_t stream, +void PreluScalarDirectCUDAFunctor::operator()(gpuStream_t stream, const T *input, const T *alpha, T *output, size_t numel) { PReluScalarKernel<<>>( diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h index 93c7035d449..efa493a06c4 100644 --- a/paddle/fluid/operators/math/prelu.h +++ b/paddle/fluid/operators/math/prelu.h @@ -16,32 +16,36 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/math_function.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#else #include "paddle/fluid/platform/cudnn_helper.h" +#endif namespace paddle { namespace operators { namespace math { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class PreluChannelWiseDirectCUDAFunctor { public: - void operator()(cudaStream_t stream, const T *input, const T *alpha, - T *output, size_t batch_size, size_t channel, size_t numel); + void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output, + size_t batch_size, size_t channel, size_t numel); }; template class PreluElementWiseDirectCUDAFunctor { public: - void operator()(cudaStream_t stream, const T *input, const T *alpha, - T *output, size_t batch_size, size_t numel); + void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output, + size_t batch_size, size_t numel); }; template class PreluScalarDirectCUDAFunctor { public: - void operator()(cudaStream_t stream, const T *input, const T *alpha, - T *output, size_t numel); + void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output, + size_t numel); }; #endif diff --git a/paddle/fluid/operators/math/sample_prob.cu b/paddle/fluid/operators/math/sample_prob.cu index 6aabfb06945..446acc033eb 100644 --- a/paddle/fluid/operators/math/sample_prob.cu +++ b/paddle/fluid/operators/math/sample_prob.cu @@ -142,16 +142,30 @@ void GPUSampleWithProb::operator()( int num_tries = UniqSampler(sampler, num_samples, s_data); VLOG(1) << "num_tries: " << num_tries; + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpy(samples_data + num_true, s_data, + sizeof(int64_t) * num_samples, + hipMemcpyHostToDevice)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(samples_data + num_true, s_data, sizeof(int64_t) * num_samples, cudaMemcpyHostToDevice)); +#endif int threads = 512; const size_t size = batch_size * num_sampled_classes; int grid = (batch_size * num_sampled_classes + threads - 1) / threads; +#ifdef PADDLE_WITH_HIP + hipLaunchKernelGGL(HIP_KERNEL_NAME(SamplingCondidate), dim3(grid), + dim3(threads), 0, context.stream(), size, num_tries, range, + log_range, num_true, num_samples, label_data, samples_data, + probabilities_data); +#else SamplingCondidate<<>>( size, num_tries, range, log_range, num_true, num_samples, label_data, samples_data, probabilities_data); +#endif } template class GPUSampleWithProb; diff --git a/paddle/fluid/operators/math/sample_prob.h b/paddle/fluid/operators/math/sample_prob.h index 3653ccb693c..8968ba546ad 100644 --- a/paddle/fluid/operators/math/sample_prob.h +++ b/paddle/fluid/operators/math/sample_prob.h @@ -110,7 +110,7 @@ class SampleWithProb { } }; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class GPUSampleWithProb { public: diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cu.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cu.cc index 5cb1cc5dc03..ebcd97b32c4 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cu.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cu.cc @@ -37,9 +37,15 @@ TEST(selected_rows_functor, gpu_add) { {static_cast(rows1.size()), row_numel}), gpu_place); functor(ctx, in1_value, 1.0); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_EQ(hipDeviceSynchronize(), 0, + paddle::platform::errors::PreconditionNotMet( + "The all synchronization on the cuda is error!")); +#else PADDLE_ENFORCE_EQ(cudaDeviceSynchronize(), 0, paddle::platform::errors::PreconditionNotMet( "The all synchronization on the cuda is error!")); +#endif std::vector rows2{0, 5, 7, 9}; std::unique_ptr selected_rows2{ diff --git a/paddle/fluid/operators/math/vol2col_test.cc b/paddle/fluid/operators/math/vol2col_test.cc index 6ed5a0943eb..cc3b838cbcf 100644 --- a/paddle/fluid/operators/math/vol2col_test.cc +++ b/paddle/fluid/operators/math/vol2col_test.cc @@ -120,7 +120,7 @@ void testVol2col() { TEST(math, vol2col) { testVol2col(); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) testVol2col(); #endif // PADDLE_WITH_CUDA -- GitLab