未验证 提交 65bcaeb0 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] update fluid operators for rocm (part5), test=develop (#31258)

* [ROCM] update fluid operators for rocm (part5), test=develop

* address review comments, test=develop

* fix typo, test=develop
上级 2111d912
...@@ -45,6 +45,7 @@ set(THRUST_DEVICE_SYSTEM THRUST_DEVICE_SYSTEM_HIP) ...@@ -45,6 +45,7 @@ set(THRUST_DEVICE_SYSTEM THRUST_DEVICE_SYSTEM_HIP)
# define HIP_CXX_FLAGS # define HIP_CXX_FLAGS
list(APPEND HIP_CXX_FLAGS -fPIC) list(APPEND HIP_CXX_FLAGS -fPIC)
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_HCC__=1) list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_HCC__=1)
# Note(qili93): HIP has compile conflicts of float16.h as platform::float16 overload std::is_floating_point and std::is_integer
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1) list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1)
list(APPEND HIP_CXX_FLAGS -Wno-macro-redefined) list(APPEND HIP_CXX_FLAGS -Wno-macro-redefined)
list(APPEND HIP_CXX_FLAGS -Wno-inconsistent-missing-override) list(APPEND HIP_CXX_FLAGS -Wno-inconsistent-missing-override)
......
...@@ -737,7 +737,7 @@ x.second ); ...@@ -737,7 +737,7 @@ x.second );
} }
int assign_async(const concurrent_unordered_map& other, int assign_async(const concurrent_unordered_map& other,
gpuStream_t stream = 0) { cudaStream_t stream = 0) {
m_collisions = other.m_collisions; m_collisions = other.m_collisions;
if (other.m_hashtbl_size <= m_hashtbl_capacity) { if (other.m_hashtbl_size <= m_hashtbl_capacity) {
m_hashtbl_size = other.m_hashtbl_size; m_hashtbl_size = other.m_hashtbl_size;
...@@ -754,7 +754,7 @@ x.second ); ...@@ -754,7 +754,7 @@ x.second );
return 0; return 0;
} }
void clear_async(gpuStream_t stream = 0) { void clear_async(cudaStream_t stream = 0) {
constexpr int block_size = 128; constexpr int block_size = 128;
init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0, init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0,
stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key, stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key,
...@@ -771,7 +771,7 @@ x.second ); ...@@ -771,7 +771,7 @@ x.second );
} }
} }
int prefetch(const int dev_id, gpuStream_t stream = 0) { int prefetch(const int dev_id, cudaStream_t stream = 0) {
cudaPointerAttributes hashtbl_values_ptr_attributes; cudaPointerAttributes hashtbl_values_ptr_attributes;
cudaError_t status = cudaPointerGetAttributes( cudaError_t status = cudaPointerGetAttributes(
&hashtbl_values_ptr_attributes, m_hashtbl_values); &hashtbl_values_ptr_attributes, m_hashtbl_values);
......
...@@ -51,7 +51,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> { ...@@ -51,7 +51,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
if (std::is_same<Place, platform::CPUPlace>::value) { if (std::is_same<Place, platform::CPUPlace>::value) {
Apply(static_cast<platform::CPUDeviceContext *>(pool.Get(place))); Apply(static_cast<platform::CPUDeviceContext *>(pool.Get(place)));
} else { } else {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Apply(static_cast<platform::CUDADeviceContext *>(pool.Get(place))); Apply(static_cast<platform::CUDADeviceContext *>(pool.Get(place)));
#else #else
PADDLE_THROW( PADDLE_THROW(
......
...@@ -164,7 +164,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, ...@@ -164,7 +164,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, plat::float16, ops::AssignKernel, plat::float16,
ops::AssignKernel); ops::AssignKernel);
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel, ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, bool, int64_t, ops::AssignKernel, bool,
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cuda_runtime.h>
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -145,6 +144,8 @@ __global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids, ...@@ -145,6 +144,8 @@ __global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids,
LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps); LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
} }
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__ // @{ Half kernel: EmbEltwiseLayernormKernel
template <> template <>
__global__ void EmbEltwiseLayernormKernel<half, 256>( __global__ void EmbEltwiseLayernormKernel<half, 256>(
int hidden, const int64_t *ids, const float *scale, const float *bias, int hidden, const int64_t *ids, const float *scale, const float *bias,
...@@ -188,12 +189,13 @@ __global__ void EmbEltwiseLayernormKernel<half, 256>( ...@@ -188,12 +189,13 @@ __global__ void EmbEltwiseLayernormKernel<half, 256>(
eps); eps);
#endif #endif
} }
#endif // @} End Half kernel: EmbEltwiseLayernormKernel
template <typename T> template <typename T>
void EmbEltwiseLayerNormFunctor<T>::operator()( void EmbEltwiseLayerNormFunctor<T>::operator()(
int batch, int seq_len, int hidden, const int64_t *ids, const float *scale, int batch, int seq_len, int hidden, const int64_t *ids, const float *scale,
const float *bias, const int64_t *embs, T *output, float eps, int input_num, const float *bias, const int64_t *embs, T *output, float eps, int input_num,
cudaStream_t stream) { gpuStream_t stream) {
const unsigned tpb = 256; const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1); const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1); const dim3 block(tpb, 1, 1);
...@@ -205,7 +207,8 @@ void EmbEltwiseLayerNormFunctor<T>::operator()( ...@@ -205,7 +207,8 @@ void EmbEltwiseLayerNormFunctor<T>::operator()(
template class EmbEltwiseLayerNormFunctor<float>; template class EmbEltwiseLayerNormFunctor<float>;
// device function 'operator()' is not supportted until cuda 10.0 // device function 'operator()' is not supportted until cuda 10.0
#if CUDA_VERSION >= 10000 // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
template class EmbEltwiseLayerNormFunctor<half>; template class EmbEltwiseLayerNormFunctor<half>;
#endif #endif
...@@ -230,6 +233,8 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_, ...@@ -230,6 +233,8 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_,
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val); qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
} }
// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
template <> template <>
__global__ void SoftmaxKernelWithEltadd<half>( __global__ void SoftmaxKernelWithEltadd<half>(
half *qk_buf_, const half *bias_qk_, const int batch_size, half *qk_buf_, const half *bias_qk_, const int batch_size,
...@@ -251,6 +256,7 @@ __global__ void SoftmaxKernelWithEltadd<half>( ...@@ -251,6 +256,7 @@ __global__ void SoftmaxKernelWithEltadd<half>(
qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val); qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val);
#endif #endif
} }
#endif // @} End Half kernel: SoftmaxKernelWithEltadd
template <typename T> template <typename T>
__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, const T *bias_qk_, __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, const T *bias_qk_,
...@@ -282,7 +288,9 @@ __global__ void SoftmaxKernelWithEltadd2<half2>( ...@@ -282,7 +288,9 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(
half2 *qk_buf_, const half2 *bias_qk_, const int batch_size, half2 *qk_buf_, const half2 *bias_qk_, const int batch_size,
const int head_num, const int seq_len, const unsigned mask) { const int head_num, const int seq_len, const unsigned mask) {
// operator "+" of half only suppotted after cuda version 10.0 // operator "+" of half only suppotted after cuda version 10.0
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000 // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) || \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
int idx = threadIdx.x; int idx = threadIdx.x;
assert(blockDim.x % 32 == 0); assert(blockDim.x % 32 == 0);
...@@ -398,7 +406,8 @@ void MultiHeadGPUComputeFunctor<T>::operator()( ...@@ -398,7 +406,8 @@ void MultiHeadGPUComputeFunctor<T>::operator()(
template class MultiHeadGPUComputeFunctor<float>; template class MultiHeadGPUComputeFunctor<float>;
// device function 'operator()' is not supportted until cuda 10.0 // device function 'operator()' is not supportted until cuda 10.0
#if CUDA_VERSION >= 10000 // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) || CUDA_VERSION >= 10000
template class MultiHeadGPUComputeFunctor<half>; template class MultiHeadGPUComputeFunctor<half>;
#endif #endif
...@@ -422,6 +431,8 @@ __global__ void SkipLayerNormSmallKernel(int num, int hidden, const T *input1, ...@@ -422,6 +431,8 @@ __global__ void SkipLayerNormSmallKernel(int num, int hidden, const T *input1,
eps); eps);
} }
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormSmallKernel
template <> template <>
__global__ void SkipLayerNormSmallKernel<half, 32>( __global__ void SkipLayerNormSmallKernel<half, 32>(
int num, int hidden, const half *input1, const half *input2, half *output, int num, int hidden, const half *input1, const half *input2, half *output,
...@@ -484,6 +495,7 @@ __global__ void SkipLayerNormSmallKernel<half, 384>( ...@@ -484,6 +495,7 @@ __global__ void SkipLayerNormSmallKernel<half, 384>(
eps); eps);
#endif #endif
} }
#endif // @} End Half kernel: SkipLayerNormSmallKernel
template <typename T, unsigned TPB> template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(int num, int hidden, const T *input1, __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1,
...@@ -505,6 +517,8 @@ __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1, ...@@ -505,6 +517,8 @@ __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1,
LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps); LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
} }
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormKernel
template <> template <>
__global__ void SkipLayerNormKernel<half, 256>(int num, int hidden, __global__ void SkipLayerNormKernel<half, 256>(int num, int hidden,
const half *input1, const half *input1,
...@@ -527,6 +541,7 @@ __global__ void SkipLayerNormKernel<half, 256>(int num, int hidden, ...@@ -527,6 +541,7 @@ __global__ void SkipLayerNormKernel<half, 256>(int num, int hidden,
LayerNorm<half, 256>(thread_data, hidden, offset, bias, scale, output, eps); LayerNorm<half, 256>(thread_data, hidden, offset, bias, scale, output, eps);
#endif #endif
} }
#endif // @} End Half kernel: SkipLayerNormKernel
template <typename T, typename T2, unsigned TPB> template <typename T, typename T2, unsigned TPB>
__global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1, __global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1,
...@@ -549,6 +564,8 @@ __global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1, ...@@ -549,6 +564,8 @@ __global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1,
LayerNorm2<T, T2, TPB>(thread_data, hidden, offset, bias, scale, output, eps); LayerNorm2<T, T2, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
} }
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormKernel2
template <> template <>
__global__ void SkipLayerNormKernel2<half, half2, 256>( __global__ void SkipLayerNormKernel2<half, half2, 256>(
int num, int hidden, const half2 *input1, const half2 *input2, int num, int hidden, const half2 *input1, const half2 *input2,
...@@ -572,13 +589,13 @@ __global__ void SkipLayerNormKernel2<half, half2, 256>( ...@@ -572,13 +589,13 @@ __global__ void SkipLayerNormKernel2<half, half2, 256>(
eps); eps);
#endif #endif
} }
#endif // @} End Half kernel: SkipLayerNormKernel2
template <typename T> template <typename T>
void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden, void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
const T *input1, const T *input2, const T *input1, const T *input2,
const float *scale, const float *bias, const float *scale, const float *bias,
T *output, T eps, T *output, T eps, gpuStream_t stream) {
cudaStream_t stream) {
int block = num / hidden; int block = num / hidden;
if (hidden <= 32) { if (hidden <= 32) {
const int threads = 32; const int threads = 32;
...@@ -603,6 +620,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden, ...@@ -603,6 +620,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
reinterpret_cast<float2 *>(output), reinterpret_cast<float2 *>(output),
reinterpret_cast<const float2 *>(scale), reinterpret_cast<const float2 *>(scale),
reinterpret_cast<const float2 *>(bias), eps); reinterpret_cast<const float2 *>(bias), eps);
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__
} else if (std::is_same<T, __half>::value) { } else if (std::is_same<T, __half>::value) {
SkipLayerNormKernel2<__half, __half2, SkipLayerNormKernel2<__half, __half2,
threads><<<block, threads, 0, stream>>>( threads><<<block, threads, 0, stream>>>(
...@@ -611,6 +630,7 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden, ...@@ -611,6 +630,7 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
reinterpret_cast<__half2 *>(output), reinterpret_cast<__half2 *>(output),
reinterpret_cast<const float2 *>(scale), reinterpret_cast<const float2 *>(scale),
reinterpret_cast<const float2 *>(bias), eps); reinterpret_cast<const float2 *>(bias), eps);
#endif
} else { } else {
assert(false); assert(false);
// should not be here // should not be here
...@@ -625,7 +645,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden, ...@@ -625,7 +645,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
template class SkipLayerNormFunctor<float>; template class SkipLayerNormFunctor<float>;
// device function 'operator()' is not supportted until cuda 10.0 // device function 'operator()' is not supportted until cuda 10.0
#if CUDA_VERSION >= 10000 // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) || CUDA_VERSION >= 10000
template class SkipLayerNormFunctor<half>; template class SkipLayerNormFunctor<half>;
#endif #endif
......
...@@ -13,9 +13,18 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cub/cub.cuh> // NOLINT #include <cub/cub.cuh> // NOLINT
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -36,7 +45,7 @@ struct CUDATypeTraits<float> { ...@@ -36,7 +45,7 @@ struct CUDATypeTraits<float> {
typedef float TYPE; typedef float TYPE;
}; };
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// This functor involves a fusion calculation in Ernie or Bert. // This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows: // The fusion mode is as follows:
// //
...@@ -55,7 +64,7 @@ class EmbEltwiseLayerNormFunctor { ...@@ -55,7 +64,7 @@ class EmbEltwiseLayerNormFunctor {
public: public:
void operator()(int batch, int seq_len, int hidden, const int64_t *ids, void operator()(int batch, int seq_len, int hidden, const int64_t *ids,
const float *scale, const float *bias, const int64_t *embs, const float *scale, const float *bias, const int64_t *embs,
T *output, float eps, int input_num, cudaStream_t stream); T *output, float eps, int input_num, gpuStream_t stream);
}; };
// This functor involves a fusion calculation in Ernie or Bert. // This functor involves a fusion calculation in Ernie or Bert.
...@@ -97,7 +106,7 @@ class SkipLayerNormFunctor { ...@@ -97,7 +106,7 @@ class SkipLayerNormFunctor {
public: public:
void operator()(const int num, const int hidden, const T *input1, void operator()(const int num, const int hidden, const T *input1,
const T *input2, const float *scale, const float *bias, const T *input2, const float *scale, const float *bias,
T *output, T eps, cudaStream_t stream); T *output, T eps, gpuStream_t stream);
}; };
#endif #endif
......
...@@ -14,7 +14,13 @@ limitations under the License. */ ...@@ -14,7 +14,13 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "cub/cub.cuh" #ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -27,7 +33,14 @@ template <typename T> ...@@ -27,7 +33,14 @@ template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
typedef cub::WarpReduce<T> WarpReduce; typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage; typename WarpReduce::TempStorage temp_storage;
#ifdef __HIPCC__
int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
value = WarpReduce(temp_storage).Sum(value, block_size);
#else
value = WarpReduce(temp_storage).Sum(value); value = WarpReduce(temp_storage).Sum(value);
#endif
if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
} }
......
...@@ -130,6 +130,8 @@ struct Active { ...@@ -130,6 +130,8 @@ struct Active {
typedef T (*ActGrad)(T, T); typedef T (*ActGrad)(T, T);
}; };
#ifdef PADDLE_WITH_CUDA
static DEVICE Active<float>::Act kActFloat[] = { static DEVICE Active<float>::Act kActFloat[] = {
&forward::Sigmoid<float>, &forward::SigmoidV2<float>, &forward::Sigmoid<float>, &forward::SigmoidV2<float>,
&forward::Relu<float>, &forward::Tanh<float>, &forward::Relu<float>, &forward::Tanh<float>,
...@@ -171,6 +173,99 @@ inline DEVICE double activation(double a, double b, int index) { ...@@ -171,6 +173,99 @@ inline DEVICE double activation(double a, double b, int index) {
} }
} // namespace backward } // namespace backward
#else // PADDLE_WITH_CUDA
// Note(qili93): The above implementing not work in HIP
// It will throw compile error when calling detail::forward::lstm<T>()
// Which used ActivationType in lstm_kernel.h, compile error is:
// lstm_gpu_kernel.h:33:17: error: unsupported indirect call to function
// <unknown>
// To-do(qili93): fix this after HIP issue fixed:
// https://github.com/ROCm-Developer-Tools/HIP/issues/2186
namespace forward {
inline DEVICE float activation(float a, int index) {
switch (index) {
case 0:
return Sigmoid<float>(a);
case 1:
return SigmoidV2<float>(a);
case 2:
return Relu<float>(a);
case 3:
return Tanh<float>(a);
case 4:
return TanhV2<float>(a);
case 5:
return Identity<float>(a);
default:
return 0.0f;
}
}
inline DEVICE double activation(double a, int index) {
switch (index) {
case 0:
return Sigmoid<double>(a);
case 1:
return SigmoidV2<double>(a);
case 2:
return Relu<double>(a);
case 3:
return Tanh<double>(a);
case 4:
return TanhV2<double>(a);
case 5:
return Identity<double>(a);
default:
return 0.0f;
}
}
} // namespace forward
namespace backward {
inline DEVICE float activation(float a, float b, int index) {
switch (index) {
case 0:
return Sigmoid<float>(a, b);
case 1:
return Sigmoid<float>(a, b);
case 2:
return Relu<float>(a, b);
case 3:
return Tanh<float>(a, b);
case 4:
return Tanh<float>(a, b);
case 5:
return Identity<float>(a, b);
default:
return 0.0f;
}
}
inline DEVICE double activation(double a, double b, int index) {
switch (index) {
case 0:
return Sigmoid<double>(a, b);
case 1:
return Sigmoid<double>(a, b);
case 2:
return Relu<double>(a, b);
case 3:
return Tanh<double>(a, b);
case 4:
return Tanh<double>(a, b);
case 5:
return Identity<double>(a, b);
default:
return 0.0f;
}
}
} // namespace backward
#endif // PADDLE_WITH_CUDA
#ifdef __AVX__ #ifdef __AVX__
namespace forward { namespace forward {
namespace avx { namespace avx {
......
...@@ -61,7 +61,7 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) { ...@@ -61,7 +61,7 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
for (int i = threadIdx.x; i < N; i += BlockDim) { for (int i = threadIdx.x; i < N; i += BlockDim) {
T temp; T temp;
#if __CUDA_ARCH__ >= 350 #if defined(__HIPCC__) || __CUDA_ARCH__ >= 350
temp = __ldg(data + offset + i) + __ldg(bias + i); temp = __ldg(data + offset + i) + __ldg(bias + i);
#else #else
temp = data[offset + i] + bias[i]; temp = data[offset + i] + bias[i];
......
...@@ -32,7 +32,7 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -32,7 +32,7 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate, const detail::ActivationType active_gate,
bool origin_mode) { bool origin_mode) {
#ifndef __NVCC__ #if !defined(__NVCC__) && !defined(__HIPCC___)
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) { if (value.prev_out_value) {
blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1, blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1,
...@@ -66,7 +66,7 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -66,7 +66,7 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate, const detail::ActivationType active_gate,
bool origin_mode) { bool origin_mode) {
#ifndef __NVCC__ #if !defined(__NVCC__) && !defined(__HIPCC___)
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value, detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node, grad, frame_size, batch_size, active_node,
origin_mode); origin_mode);
...@@ -108,7 +108,7 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> { ...@@ -108,7 +108,7 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> {
GRUMetaValue<T> value, int frame_size, int batch_size, GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate) {
#ifndef __NVCC__ #if !defined(__NVCC__) && !defined(__HIPCC___)
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) { if (value.prev_out_value) {
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, frame_size, frame_size, 1, blas.GEMM(CblasNoTrans, CblasTrans, batch_size, frame_size, frame_size, 1,
...@@ -142,7 +142,7 @@ struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, T> { ...@@ -142,7 +142,7 @@ struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, T> {
int frame_size, int batch_size, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate) {
#ifndef __NVCC__ #if !defined(__NVCC__) && !defined(__HIPCC___)
// calculate grad_update_gate, grad_frame_state, // calculate grad_update_gate, grad_frame_state,
// grad_reset_output, grad_reset_gate // grad_reset_output, grad_reset_gate
detail::cpu_gru_backward(context, detail::backward::gru<T>(), value, grad, detail::cpu_gru_backward(context, detail::backward::gru<T>(), value, grad,
......
...@@ -162,7 +162,7 @@ void testIm2col() { ...@@ -162,7 +162,7 @@ void testIm2col() {
TEST(math, im2col) { TEST(math, im2col) {
testIm2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>(); testIm2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
testIm2col<paddle::platform::CUDADeviceContext, testIm2col<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>(); paddle::platform::CUDAPlace>();
#endif #endif
......
...@@ -13,7 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda_fp16.h> #include <cuda_fp16.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
#include <algorithm> #include <algorithm>
namespace paddle { namespace paddle {
...@@ -96,7 +103,7 @@ __device__ __forceinline__ float exp_func<float>(float a) { ...@@ -96,7 +103,7 @@ __device__ __forceinline__ float exp_func<float>(float a) {
template <> template <>
__device__ __forceinline__ half exp_func<half>(half a) { __device__ __forceinline__ half exp_func<half>(half a) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if defined(__HIPCC__) || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
return hexp(a); return hexp(a);
#else #else
return FromFloat<half>(expf(ToFloat<half>(a))); return FromFloat<half>(expf(ToFloat<half>(a)));
...@@ -137,6 +144,7 @@ struct KeyValuePair<half> { ...@@ -137,6 +144,7 @@ struct KeyValuePair<half> {
operator+(const KeyValuePair &a) const { operator+(const KeyValuePair &a) const {
const half2 a2 = __halves2half2(key, value); const half2 a2 = __halves2half2(key, value);
const half2 b2 = __halves2half2(a.key, a.value); const half2 b2 = __halves2half2(a.key, a.value);
#ifdef PADDLE_WITH_CUDA
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half2 res = __hadd2(a2, b2); const half2 res = __hadd2(a2, b2);
#else #else
...@@ -149,6 +157,10 @@ struct KeyValuePair<half> { ...@@ -149,6 +157,10 @@ struct KeyValuePair<half> {
const half2 res = __floats2half2_rn(r1, r2); const half2 res = __floats2half2_rn(r1, r2);
#endif #endif
return KeyValuePair(res.x, res.y); return KeyValuePair(res.x, res.y);
#else // PADDLE_WITH_HIP
const half2 res = __hadd2(a2, b2);
return KeyValuePair(__low2half(res), __high2half(res));
#endif
} }
}; };
...@@ -159,7 +171,7 @@ struct KeyValuePair<half> { ...@@ -159,7 +171,7 @@ struct KeyValuePair<half> {
template <typename T> template <typename T>
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) { __inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val += __shfl_xor_sync(lane_mask, val, mask, warpSize); val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else #else
val += __shfl_xor(val, mask, warpSize); val += __shfl_xor(val, mask, warpSize);
...@@ -191,7 +203,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { ...@@ -191,7 +203,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
template <typename T> template <typename T>
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) { __inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else #else
val = max(val, __shfl_xor(val, mask, warpSize)); val = max(val, __shfl_xor(val, mask, warpSize));
......
...@@ -180,7 +180,7 @@ struct TensorSetConstantWithPlace : public boost::static_visitor<void> { ...@@ -180,7 +180,7 @@ struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
void set_constant(const platform::DeviceContext& context, void set_constant(const platform::DeviceContext& context,
framework::Tensor* tensor, float value) { framework::Tensor* tensor, float value) {
TensorSetConstantWithPlace func(context, tensor, value); TensorSetConstantWithPlace func(context, tensor, value);
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->place().apply_visitor(func); tensor->place().apply_visitor(func);
#else #else
func(platform::CPUPlace()); func(platform::CPUPlace());
......
...@@ -61,7 +61,7 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, ...@@ -61,7 +61,7 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
template <typename T> template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()( void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
cudaStream_t stream, const T *input, const T *alpha, T *output, gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel) { size_t batch_size, size_t channel, size_t numel) {
PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, channel, stream>>>(input, alpha, output, channel,
...@@ -69,7 +69,7 @@ void PreluChannelWiseDirectCUDAFunctor<T>::operator()( ...@@ -69,7 +69,7 @@ void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
} }
template <typename T> template <typename T>
void PreluElementWiseDirectCUDAFunctor<T>::operator()(cudaStream_t stream, void PreluElementWiseDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input, const T *input,
const T *alpha, T *output, const T *alpha, T *output,
size_t batch_size, size_t batch_size,
...@@ -80,7 +80,7 @@ void PreluElementWiseDirectCUDAFunctor<T>::operator()(cudaStream_t stream, ...@@ -80,7 +80,7 @@ void PreluElementWiseDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
} }
template <typename T> template <typename T>
void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream, void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input, const T *alpha, const T *input, const T *alpha,
T *output, size_t numel) { T *output, size_t numel) {
PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>( PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
......
...@@ -16,32 +16,36 @@ limitations under the License. */ ...@@ -16,32 +16,36 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T> template <typename T>
class PreluChannelWiseDirectCUDAFunctor { class PreluChannelWiseDirectCUDAFunctor {
public: public:
void operator()(cudaStream_t stream, const T *input, const T *alpha, void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output,
T *output, size_t batch_size, size_t channel, size_t numel); size_t batch_size, size_t channel, size_t numel);
}; };
template <typename T> template <typename T>
class PreluElementWiseDirectCUDAFunctor { class PreluElementWiseDirectCUDAFunctor {
public: public:
void operator()(cudaStream_t stream, const T *input, const T *alpha, void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output,
T *output, size_t batch_size, size_t numel); size_t batch_size, size_t numel);
}; };
template <typename T> template <typename T>
class PreluScalarDirectCUDAFunctor { class PreluScalarDirectCUDAFunctor {
public: public:
void operator()(cudaStream_t stream, const T *input, const T *alpha, void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output,
T *output, size_t numel); size_t numel);
}; };
#endif #endif
......
...@@ -142,16 +142,30 @@ void GPUSampleWithProb<T>::operator()( ...@@ -142,16 +142,30 @@ void GPUSampleWithProb<T>::operator()(
int num_tries = UniqSampler<T>(sampler, num_samples, s_data); int num_tries = UniqSampler<T>(sampler, num_samples, s_data);
VLOG(1) << "num_tries: " << num_tries; VLOG(1) << "num_tries: " << num_tries;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpy(samples_data + num_true, s_data,
sizeof(int64_t) * num_samples,
hipMemcpyHostToDevice));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(samples_data + num_true, s_data, PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(samples_data + num_true, s_data,
sizeof(int64_t) * num_samples, sizeof(int64_t) * num_samples,
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
#endif
int threads = 512; int threads = 512;
const size_t size = batch_size * num_sampled_classes; const size_t size = batch_size * num_sampled_classes;
int grid = (batch_size * num_sampled_classes + threads - 1) / threads; int grid = (batch_size * num_sampled_classes + threads - 1) / threads;
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL(HIP_KERNEL_NAME(SamplingCondidate<T>), dim3(grid),
dim3(threads), 0, context.stream(), size, num_tries, range,
log_range, num_true, num_samples, label_data, samples_data,
probabilities_data);
#else
SamplingCondidate<T><<<grid, threads, 0, context.stream()>>>( SamplingCondidate<T><<<grid, threads, 0, context.stream()>>>(
size, num_tries, range, log_range, num_true, num_samples, label_data, size, num_tries, range, log_range, num_true, num_samples, label_data,
samples_data, probabilities_data); samples_data, probabilities_data);
#endif
} }
template class GPUSampleWithProb<float>; template class GPUSampleWithProb<float>;
......
...@@ -110,7 +110,7 @@ class SampleWithProb { ...@@ -110,7 +110,7 @@ class SampleWithProb {
} }
}; };
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T> template <typename T>
class GPUSampleWithProb { class GPUSampleWithProb {
public: public:
......
...@@ -37,9 +37,15 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -37,9 +37,15 @@ TEST(selected_rows_functor, gpu_add) {
{static_cast<int64_t>(rows1.size()), row_numel}), {static_cast<int64_t>(rows1.size()), row_numel}),
gpu_place); gpu_place);
functor(ctx, in1_value, 1.0); functor(ctx, in1_value, 1.0);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(hipDeviceSynchronize(), 0,
paddle::platform::errors::PreconditionNotMet(
"The all synchronization on the cuda is error!"));
#else
PADDLE_ENFORCE_EQ(cudaDeviceSynchronize(), 0, PADDLE_ENFORCE_EQ(cudaDeviceSynchronize(), 0,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"The all synchronization on the cuda is error!")); "The all synchronization on the cuda is error!"));
#endif
std::vector<int64_t> rows2{0, 5, 7, 9}; std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
......
...@@ -120,7 +120,7 @@ void testVol2col() { ...@@ -120,7 +120,7 @@ void testVol2col() {
TEST(math, vol2col) { TEST(math, vol2col) {
testVol2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>(); testVol2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
testVol2col<paddle::platform::CUDADeviceContext, testVol2col<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>(); paddle::platform::CUDAPlace>();
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册