From f89da4ab4532461903221bc37f97e916fdefcb3d Mon Sep 17 00:00:00 2001 From: Qi Li Date: Thu, 28 Jan 2021 20:32:14 +0800 Subject: [PATCH] [ROCM] update fluid platform for rocm35 (part1), test=develop (#30639) * [ROCM] update fluid platform for rocm35 (part1), test=develop * address review comments, test=develop --- paddle/fluid/platform/bfloat16.h | 10 + paddle/fluid/platform/complex128.h | 43 ++-- paddle/fluid/platform/complex64.h | 43 ++-- .../details/cuda_transform_iterator_cast.h | 2 +- paddle/fluid/platform/dynload/CMakeLists.txt | 12 +- .../fluid/platform/dynload/dynamic_loader.cc | 27 ++- paddle/fluid/platform/dynload/miopen.h | 10 +- paddle/fluid/platform/dynload/rccl.cc | 8 + paddle/fluid/platform/dynload/rccl.h | 12 + paddle/fluid/platform/dynload/rocblas.h | 84 ++++--- paddle/fluid/platform/dynload/rocm_driver.h | 1 + paddle/fluid/platform/enforce.h | 218 +++++++++++++++++- paddle/fluid/platform/enforce_test.cc | 31 ++- paddle/fluid/platform/float16.h | 161 ++++++++----- paddle/fluid/platform/stream/CMakeLists.txt | 2 +- paddle/fluid/platform/stream/cuda_stream.cc | 30 +++ paddle/fluid/platform/stream/cuda_stream.h | 29 ++- paddle/fluid/platform/type_defs.h | 37 +++ tools/dockerfile/Dockerfile.rocm | 18 +- 19 files changed, 626 insertions(+), 152 deletions(-) create mode 100644 paddle/fluid/platform/type_defs.h diff --git a/paddle/fluid/platform/bfloat16.h b/paddle/fluid/platform/bfloat16.h index 4460139219..f373e5ddb6 100644 --- a/paddle/fluid/platform/bfloat16.h +++ b/paddle/fluid/platform/bfloat16.h @@ -47,7 +47,17 @@ struct PADDLE_ALIGN(2) bfloat16 { ~bfloat16() = default; HOSTDEVICE inline explicit bfloat16(float val) { +#ifdef PADDLE_WITH_HIP + uint32_t res = 0; + uint32_t* tempRes; + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + tempRes = reinterpret_cast(&val); + res = *tempRes; + x = res >> 16; +#else std::memcpy(&x, reinterpret_cast(&val) + 2, 2); +#endif } template diff --git a/paddle/fluid/platform/complex128.h b/paddle/fluid/platform/complex128.h index 58753527c0..c50ff2f810 100644 --- a/paddle/fluid/platform/complex128.h +++ b/paddle/fluid/platform/complex128.h @@ -28,6 +28,11 @@ #include #endif // PADDLE_WITH_CUDA +#ifdef PADDLE_WITH_HIP +#include +#include // NOLINT +#endif + #include #include "paddle/fluid/platform/hostdevice.h" @@ -54,7 +59,7 @@ struct PADDLE_ALIGN(16) complex128 { ~complex128() = default; HOSTDEVICE complex128(double real, double imag) : real(real), imag(imag) {} -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) HOSTDEVICE inline explicit complex128(const thrust::complex& c) { real = c.real(); @@ -65,9 +70,15 @@ struct PADDLE_ALIGN(16) complex128 { return thrust::complex(real, imag); } +#ifdef PADDLE_WITH_HIP + HOSTDEVICE inline explicit operator hipDoubleComplex() const { + return make_hipDoubleComplex(real, imag); + } +#else HOSTDEVICE inline explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real, imag); } +#endif #endif HOSTDEVICE complex128(const float& val) @@ -202,7 +213,7 @@ struct PADDLE_ALIGN(16) complex128 { HOSTDEVICE inline complex128 operator+(const complex128& a, const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(thrust::complex(a.real, a.imag) + thrust::complex(b.real, b.imag)); #else @@ -212,7 +223,7 @@ HOSTDEVICE inline complex128 operator+(const complex128& a, HOSTDEVICE inline complex128 operator-(const complex128& a, const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(thrust::complex(a.real, a.imag) - thrust::complex(b.real, b.imag)); #else @@ -222,7 +233,7 @@ HOSTDEVICE inline complex128 operator-(const complex128& a, HOSTDEVICE inline complex128 operator*(const complex128& a, const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(thrust::complex(a.real, a.imag) * thrust::complex(b.real, b.imag)); #else @@ -233,7 +244,7 @@ HOSTDEVICE inline complex128 operator*(const complex128& a, HOSTDEVICE inline complex128 operator/(const complex128& a, const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(thrust::complex(a.real, a.imag) / thrust::complex(b.real, b.imag)); #else @@ -244,7 +255,7 @@ HOSTDEVICE inline complex128 operator/(const complex128& a, } HOSTDEVICE inline complex128 operator-(const complex128& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(-thrust::complex(a.real, a.imag)); #else complex128 res; @@ -256,7 +267,7 @@ HOSTDEVICE inline complex128 operator-(const complex128& a) { HOSTDEVICE inline complex128& operator+=(complex128& a, // NOLINT const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) a = complex128(thrust::complex(a.real, a.imag) += thrust::complex(b.real, b.imag)); return a; @@ -269,7 +280,7 @@ HOSTDEVICE inline complex128& operator+=(complex128& a, // NOLINT HOSTDEVICE inline complex128& operator-=(complex128& a, // NOLINT const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) a = complex128(thrust::complex(a.real, a.imag) -= thrust::complex(b.real, b.imag)); return a; @@ -282,7 +293,7 @@ HOSTDEVICE inline complex128& operator-=(complex128& a, // NOLINT HOSTDEVICE inline complex128& operator*=(complex128& a, // NOLINT const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) a = complex128(thrust::complex(a.real, a.imag) *= thrust::complex(b.real, b.imag)); return a; @@ -295,7 +306,7 @@ HOSTDEVICE inline complex128& operator*=(complex128& a, // NOLINT HOSTDEVICE inline complex128& operator/=(complex128& a, // NOLINT const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) a = complex128(thrust::complex(a.real, a.imag) /= thrust::complex(b.real, b.imag)); return a; @@ -339,6 +350,7 @@ HOSTDEVICE inline bool operator>=(const complex128& a, const complex128& b) { HOSTDEVICE inline bool(isnan)(const complex128& a) { #if defined(__CUDA_ARCH__) + // __isnanf not supported on HIP platform return __isnan(a.real) || __isnan(a.imag); #else return std::isnan(a.real) || std::isnan(a.imag); @@ -347,6 +359,7 @@ HOSTDEVICE inline bool(isnan)(const complex128& a) { HOSTDEVICE inline bool(isinf)(const complex128& a) { #if defined(__CUDA_ARCH__) + // __isinf not supported on HIP platform return __isinf(a.real) || __isinf(a.imag); #else return std::isinf(a.real) || std::isinf(a.imag); @@ -358,7 +371,7 @@ HOSTDEVICE inline bool(isfinite)(const complex128& a) { } HOSTDEVICE inline double(abs)(const complex128& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return thrust::abs(thrust::complex(a.real, a.imag)); #else return std::abs(std::complex(a.real, a.imag)); @@ -366,7 +379,7 @@ HOSTDEVICE inline double(abs)(const complex128& a) { } HOSTDEVICE inline complex128(pow)(const complex128& a, const complex128& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(thrust::pow(thrust::complex(a.real, a.imag), thrust::complex(b.real, b.imag))); #else @@ -375,7 +388,7 @@ HOSTDEVICE inline complex128(pow)(const complex128& a, const complex128& b) { } HOSTDEVICE inline complex128(sqrt)(const complex128& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(thrust::sqrt(thrust::complex(a.real, a.imag))); #else return std::sqrt(std::complex(a)); @@ -383,7 +396,7 @@ HOSTDEVICE inline complex128(sqrt)(const complex128& a) { } HOSTDEVICE inline complex128(tanh)(const complex128& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(thrust::tanh(thrust::complex(a.real, a.imag))); #else return std::tanh(std::complex(a)); @@ -391,7 +404,7 @@ HOSTDEVICE inline complex128(tanh)(const complex128& a) { } HOSTDEVICE inline complex128(log)(const complex128& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex128(thrust::log(thrust::complex(a.real, a.imag))); #else return complex128(std::log(std::complex(a))); diff --git a/paddle/fluid/platform/complex64.h b/paddle/fluid/platform/complex64.h index 5f9b3c1118..b91fdbab28 100644 --- a/paddle/fluid/platform/complex64.h +++ b/paddle/fluid/platform/complex64.h @@ -27,6 +27,11 @@ #include #endif // PADDLE_WITH_CUDA +#ifdef PADDLE_WITH_HIP +#include +#include // NOLINT +#endif + #include #include "paddle/fluid/platform/complex128.h" @@ -54,7 +59,7 @@ struct PADDLE_ALIGN(8) complex64 { ~complex64() = default; HOSTDEVICE complex64(float real, float imag) : real(real), imag(imag) {} -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) HOSTDEVICE inline explicit complex64(const thrust::complex& c) { real = c.real(); @@ -65,9 +70,15 @@ struct PADDLE_ALIGN(8) complex64 { return thrust::complex(real, imag); } +#ifdef PADDLE_WITH_HIP + HOSTDEVICE inline explicit operator hipFloatComplex() const { + return make_hipFloatComplex(real, imag); + } +#else HOSTDEVICE inline explicit operator cuFloatComplex() const { return make_cuFloatComplex(real, imag); } +#endif #endif HOSTDEVICE complex64(const float& val) : real(val), imag(0) {} @@ -207,7 +218,7 @@ struct PADDLE_ALIGN(8) complex64 { }; HOSTDEVICE inline complex64 operator+(const complex64& a, const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::complex(a.real, a.imag) + thrust::complex(b.real, b.imag)); #else @@ -216,7 +227,7 @@ HOSTDEVICE inline complex64 operator+(const complex64& a, const complex64& b) { } HOSTDEVICE inline complex64 operator-(const complex64& a, const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::complex(a.real, a.imag) - thrust::complex(b.real, b.imag)); #else @@ -225,7 +236,7 @@ HOSTDEVICE inline complex64 operator-(const complex64& a, const complex64& b) { } HOSTDEVICE inline complex64 operator*(const complex64& a, const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::complex(a.real, a.imag) * thrust::complex(b.real, b.imag)); #else @@ -235,7 +246,7 @@ HOSTDEVICE inline complex64 operator*(const complex64& a, const complex64& b) { } HOSTDEVICE inline complex64 operator/(const complex64& a, const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::complex(a.real, a.imag) / thrust::complex(b.real, b.imag)); #else @@ -246,7 +257,7 @@ HOSTDEVICE inline complex64 operator/(const complex64& a, const complex64& b) { } HOSTDEVICE inline complex64 operator-(const complex64& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(-thrust::complex(a.real, a.imag)); #else complex64 res; @@ -258,7 +269,7 @@ HOSTDEVICE inline complex64 operator-(const complex64& a) { HOSTDEVICE inline complex64& operator+=(complex64& a, // NOLINT const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) a = complex64(thrust::complex(a.real, a.imag) += thrust::complex(b.real, b.imag)); return a; @@ -271,7 +282,7 @@ HOSTDEVICE inline complex64& operator+=(complex64& a, // NOLINT HOSTDEVICE inline complex64& operator-=(complex64& a, // NOLINT const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) a = complex64(thrust::complex(a.real, a.imag) -= thrust::complex(b.real, b.imag)); return a; @@ -284,7 +295,7 @@ HOSTDEVICE inline complex64& operator-=(complex64& a, // NOLINT HOSTDEVICE inline complex64& operator*=(complex64& a, // NOLINT const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) a = complex64(thrust::complex(a.real, a.imag) *= thrust::complex(b.real, b.imag)); return a; @@ -297,7 +308,7 @@ HOSTDEVICE inline complex64& operator*=(complex64& a, // NOLINT HOSTDEVICE inline complex64& operator/=(complex64& a, // NOLINT const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) a = complex64(thrust::complex(a.real, a.imag) /= thrust::complex(b.real, b.imag)); return a; @@ -341,6 +352,7 @@ HOSTDEVICE inline bool operator>=(const complex64& a, const complex64& b) { HOSTDEVICE inline bool(isnan)(const complex64& a) { #if defined(__CUDA_ARCH__) + // __isnanf not supported on HIP platform return __isnanf(a.real) || __isnanf(a.imag); #else return std::isnan(a.real) || std::isnan(a.imag); @@ -349,6 +361,7 @@ HOSTDEVICE inline bool(isnan)(const complex64& a) { HOSTDEVICE inline bool(isinf)(const complex64& a) { #if defined(__CUDA_ARCH__) + // __isinff not supported on HIP platform return __isinff(a.real) || __isinff(a.imag); #else return std::isinf(a.real) || std::isinf(a.imag); @@ -360,7 +373,7 @@ HOSTDEVICE inline bool(isfinite)(const complex64& a) { } HOSTDEVICE inline float(abs)(const complex64& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::abs(thrust::complex(a.real, a.imag))); #else return std::abs(std::complex(a.real, a.imag)); @@ -368,7 +381,7 @@ HOSTDEVICE inline float(abs)(const complex64& a) { } HOSTDEVICE inline complex64(pow)(const complex64& a, const complex64& b) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::pow(thrust::complex(a.real, a.imag), thrust::complex(b.real, b.imag))); #else @@ -377,7 +390,7 @@ HOSTDEVICE inline complex64(pow)(const complex64& a, const complex64& b) { } HOSTDEVICE inline complex64(sqrt)(const complex64& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::sqrt(thrust::complex(a.real, a.imag))); #else return std::sqrt(std::complex(a)); @@ -385,7 +398,7 @@ HOSTDEVICE inline complex64(sqrt)(const complex64& a) { } HOSTDEVICE inline complex64(tanh)(const complex64& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::tanh(thrust::complex(a.real, a.imag))); #else return std::tanh(std::complex(a)); @@ -393,7 +406,7 @@ HOSTDEVICE inline complex64(tanh)(const complex64& a) { } HOSTDEVICE inline complex64(log)(const complex64& a) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return complex64(thrust::log(thrust::complex(a.real, a.imag))); #else return std::log(std::complex(a)); diff --git a/paddle/fluid/platform/details/cuda_transform_iterator_cast.h b/paddle/fluid/platform/details/cuda_transform_iterator_cast.h index 06afc44c25..5101c78aee 100644 --- a/paddle/fluid/platform/details/cuda_transform_iterator_cast.h +++ b/paddle/fluid/platform/details/cuda_transform_iterator_cast.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#ifndef __NVCC__ +#if !defined(__NVCC__) && !defined(__HIPCC__) #error device_ptr_cast must be include by .cu file #endif diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 725b7fcf9d..e65a38cd32 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -1,9 +1,9 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc) -#hip -if (WITH_ROCM_PLATFORM) - list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc) + +if (WITH_ROCM) + list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc) endif() # There is no macOS version of NCCL. @@ -13,7 +13,7 @@ if (NOT APPLE AND NOT WIN32) if (WITH_NCCL) list(APPEND CUDA_SRCS nccl.cc) endif() - if (WITH_ROCM_PLATFORM) + if (WITH_ROCM) list(APPEND HIP_SRCS hiprtc.cc rocm_driver.cc) if (WITH_RCCL) list(APPEND HIP_SRCS rccl.cc) @@ -29,9 +29,9 @@ configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h) if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) endif(CUPTI_FOUND) -if(WITH_ROCM_PLATFORM) +if(WITH_ROCM) hip_library(dynload_cuda SRCS ${HIP_SRCS} DEPS dynamic_loader) - hip_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) + cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) else() nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index e713054468..45616e8bf5 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -55,7 +55,7 @@ DEFINE_string(miopen_dir, "", DEFINE_string(rocm_dir, "", "Specify path for loading rocm library, such as librocblas, " - "libcurand, libcusolver. For instance, /opt/rocm/lib. " + "libmiopen, libhipsparse. For instance, /opt/rocm/lib. " "If default, dlopen will search rocm from LD_LIBRARY_PATH"); DEFINE_string(rccl_dir, "", @@ -264,7 +264,7 @@ void* GetCublasDsoHandle() { #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#elif PADDLE_WITH_HIP +#elif defined(PADDLE_WITH_HIP) return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocblas.so"); #else return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so"); @@ -292,7 +292,7 @@ void* GetCUDNNDsoHandle() { "CUDNN version."); return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, win_cudnn_lib, true, {cuda_lib_path}, win_warn_meg); -#elif PADDLE_WITH_HIP +#elif defined(PADDLE_WITH_HIP) return GetDsoHandleFromSearchPath(FLAGS_miopen_dir, "libMIOpen.so", false); #else return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", false, @@ -316,7 +316,7 @@ void* GetCurandDsoHandle() { #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_curand_lib, true, {cuda_lib_path}); -#elif PADDLE_WITH_HIP +#elif defined(PADDLE_WITH_HIP) return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhiprand.so"); #else return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so"); @@ -337,8 +337,8 @@ void* GetCusolverDsoHandle() { void* GetNVRTCDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib", false); -#elif PADDLE_WITH_HIP - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhiprtc.so"); +#elif defined(PADDLE_WITH_HIP) + return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhiprtc.so", false); #else return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.so", false); #endif @@ -347,8 +347,8 @@ void* GetNVRTCDsoHandle() { void* GetCUDADsoHandle() { #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib", false); -#elif PADDLE_WITH_HIP - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhip_hcc.so"); +#elif defined(PADDLE_WITH_HIP) + return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhip_hcc.so", false); #else return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.so", false); #endif @@ -369,15 +369,24 @@ void* GetWarpCTCDsoHandle() { } void* GetNCCLDsoHandle() { +#ifdef PADDLE_WITH_HIP + std::string warning_msg( + "You may need to install 'rccl' from ROCM official website: " + "https://rocmdocs.amd.com/en/latest/Installation_Guide/" + "Installation-Guide.html before install PaddlePaddle."); +#else std::string warning_msg( "You may need to install 'nccl2' from NVIDIA official website: " "https://developer.nvidia.com/nccl/nccl-download" "before install PaddlePaddle."); +#endif + #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.dylib", true, {}, warning_msg); #elif defined(PADDLE_WITH_HIP) && defined(PADDLE_WITH_RCCL) - return GetDsoHandleFromSearchPath(FLAGS_rccl_dir, "librccl.so", true); + return GetDsoHandleFromSearchPath(FLAGS_rccl_dir, "librccl.so", true, {}, + warning_msg); #else return GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.so", true, {}, warning_msg); diff --git a/paddle/fluid/platform/dynload/miopen.h b/paddle/fluid/platform/dynload/miopen.h index 2de6429805..57fec91ffb 100644 --- a/paddle/fluid/platform/dynload/miopen.h +++ b/paddle/fluid/platform/dynload/miopen.h @@ -44,6 +44,8 @@ inline const char* miopenGetErrorString(miopenStatus_t status) { return "MIOPEN_STATUS_INTERNAL_ERROR"; case miopenStatusNotImplemented: return "MIOPEN_STATUS_NOT_IMPLEMENTED"; + case miopenStatusUnsupportedOp: + return "MIOPEN_STATUS_UNSUPPORTED_OP"; case miopenStatusUnknownError: default: return "MIOPEN_STATUS_UNKNOWN_ERROR"; @@ -70,6 +72,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); * include all needed miopen functions in HPPL **/ #define MIOPEN_DNN_ROUTINE_EACH(__macro) \ + __macro(miopenGetVersion); \ __macro(miopenSet4dTensorDescriptor); \ __macro(miopenSetTensorDescriptor); \ __macro(miopenInitConvolutionNdDescriptor); \ @@ -80,6 +83,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(miopenGetTensorDescriptor); \ __macro(miopenCreateTensorDescriptor); \ __macro(miopenDestroyTensorDescriptor); \ + __macro(miopenGetTensorDescriptorSize); \ __macro(miopenSet2dPoolingDescriptor); \ __macro(miopenGet2dPoolingDescriptor); \ __macro(miopenGetPoolingNdForwardOutputDim); \ @@ -109,9 +113,12 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(miopenSoftmaxBackward); \ __macro(miopenSoftmaxForward); \ __macro(miopenCreateDropoutDescriptor); \ + __macro(miopenDestroyDropoutDescriptor); \ + __macro(miopenRestoreDropoutDescriptor); \ __macro(miopenDropoutGetStatesSize); \ __macro(miopenSetDropoutDescriptor); \ __macro(miopenCreateRNNDescriptor); \ + __macro(miopenDestroyRNNDescriptor); \ __macro(miopenSetRNNDescriptor); \ __macro(miopenGetRNNParamsSize); \ __macro(miopenGetRNNWorkspaceSize); \ @@ -120,8 +127,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(miopenRNNBackwardData); \ __macro(miopenRNNBackwardWeights); \ __macro(miopenRNNForwardInference); \ - __macro(miopenDestroyDropoutDescriptor); \ - __macro(miopenDestroyRNNDescriptor); + __macro(miopenGetTensorNumBytes); MIOPEN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MIOPEN_WRAP) diff --git a/paddle/fluid/platform/dynload/rccl.cc b/paddle/fluid/platform/dynload/rccl.cc index a3043ead83..e19c22ba6d 100644 --- a/paddle/fluid/platform/dynload/rccl.cc +++ b/paddle/fluid/platform/dynload/rccl.cc @@ -25,6 +25,14 @@ void *rccl_dso_handle; RCCL_RAND_ROUTINE_EACH(DEFINE_WRAP); +#if NCCL_VERSION_CODE >= 2212 +RCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP) +#endif + +#if NCCL_VERSION_CODE >= 2703 +RCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP) +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/rccl.h b/paddle/fluid/platform/dynload/rccl.h index 1d61e330c2..ac9ab657d5 100644 --- a/paddle/fluid/platform/dynload/rccl.h +++ b/paddle/fluid/platform/dynload/rccl.h @@ -59,6 +59,18 @@ extern void* rccl_dso_handle; RCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) +#if NCCL_VERSION_CODE >= 2212 +#define RCCL_RAND_ROUTINE_EACH_AFTER_2212(__macro) __macro(ncclBroadcast); +RCCL_RAND_ROUTINE_EACH_AFTER_2212(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) +#endif + +#if NCCL_VERSION_CODE >= 2703 +#define RCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \ + __macro(ncclSend); \ + __macro(ncclRecv); +RCCL_RAND_ROUTINE_EACH_AFTER_2703(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/rocblas.h b/paddle/fluid/platform/dynload/rocblas.h index f78ed00ac6..45614f2209 100644 --- a/paddle/fluid/platform/dynload/rocblas.h +++ b/paddle/fluid/platform/dynload/rocblas.h @@ -36,12 +36,11 @@ extern void *rocblas_dso_handle; * * note: default dynamic linked libs */ -#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ +#define DECLARE_DYNAMIC_LOAD_ROCBLAS_WRAP(__name) \ struct DynLoad__##__name { \ template \ - inline auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using rocblas_func = \ - decltype(::__name(std::declval()...)) (*)(Args...); \ + rocblas_status operator()(Args... args) { \ + using rocblas_func = decltype(&::__name); \ std::call_once(rocblas_dso_flag, []() { \ rocblas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \ }); \ @@ -51,56 +50,65 @@ extern void *rocblas_dso_handle; }; \ extern DynLoad__##__name __name -#define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(rocblas_saxpy); \ - __macro(rocblas_daxpy); \ - __macro(rocblas_sscal); \ - __macro(rocblas_dscal); \ - __macro(rocblas_scopy); \ - __macro(rocblas_dcopy); \ - __macro(rocblas_sgemv); \ - __macro(rocblas_dgemv); \ - __macro(rocblas_sgemm); \ - __macro(rocblas_dgemm); \ - __macro(rocblas_hgemm); \ - __macro(rocblas_dgeam); \ - /*rocblas_gemm_ex function not support at rocm3.5*/ \ - /*__macro(rocblas_gemm_ex); */ \ - __macro(rocblas_sgemm_batched); \ - __macro(rocblas_dgemm_batched); \ - __macro(rocblas_cgemm_batched); \ - __macro(rocblas_zgemm_batched); \ - __macro(rocblas_create_handle); \ - __macro(rocblas_destroy_handle); \ - __macro(rocblas_add_stream); \ - __macro(rocblas_set_stream); \ - __macro(rocblas_get_stream); \ - __macro(rocblas_set_pointer_mode); \ +#define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \ + __macro(rocblas_caxpy); \ + __macro(rocblas_saxpy); \ + __macro(rocblas_daxpy); \ + __macro(rocblas_zaxpy); \ + __macro(rocblas_sscal); \ + __macro(rocblas_dscal); \ + __macro(rocblas_scopy); \ + __macro(rocblas_dcopy); \ + __macro(rocblas_cgemv); \ + __macro(rocblas_sgemv); \ + __macro(rocblas_zgemv); \ + __macro(rocblas_dgemv); \ + __macro(rocblas_cgemm); \ + __macro(rocblas_sgemm); \ + __macro(rocblas_dgemm); \ + __macro(rocblas_hgemm); \ + __macro(rocblas_zgemm); \ + __macro(rocblas_sgeam); \ + __macro(rocblas_strsm); \ + __macro(rocblas_dtrsm); \ + __macro(rocblas_dgeam); \ + __macro(rocblas_sgemm_batched); \ + __macro(rocblas_dgemm_batched); \ + __macro(rocblas_cgemm_batched); \ + __macro(rocblas_zgemm_batched); \ + __macro(rocblas_create_handle); \ + __macro(rocblas_destroy_handle); \ + __macro(rocblas_set_stream); \ + __macro(rocblas_get_stream); \ + __macro(rocblas_set_pointer_mode); \ __macro(rocblas_get_pointer_mode); -ROCBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) +ROCBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_ROCBLAS_WRAP) +// APIs available after CUDA 8.0 #define ROCBLAS_BLAS_ROUTINE_EACH_R2(__macro) \ + __macro(rocblas_gemm_ex); \ __macro(rocblas_sgemm_strided_batched); \ __macro(rocblas_dgemm_strided_batched); \ __macro(rocblas_cgemm_strided_batched); \ __macro(rocblas_zgemm_strided_batched); \ __macro(rocblas_hgemm_strided_batched); -ROCBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) +ROCBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_ROCBLAS_WRAP) -#define ROCBLAS_BLAS_ROUTINE_EACH_R3(__macro) - -ROCBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) +// HIP not supported in ROCM3.5 +// #define ROCBLAS_BLAS_ROUTINE_EACH_R3(__macro) +// __macro(cublasSetMathMode); +// __macro(cublasGetMathMode); +// ROCBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_ROCBLAS_WRAP) #define ROCBLAS_BLAS_ROUTINE_EACH_R4(__macro) \ __macro(rocblas_gemm_batched_ex); \ -// rocm not support now(rocm3.5) -// __macro(rocblas_gemm_strided_batched_ex); + __macro(rocblas_gemm_strided_batched_ex); -ROCBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) +ROCBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_ROCBLAS_WRAP) -#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP +#undef DECLARE_DYNAMIC_LOAD_ROCBLAS_WRAP } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/rocm_driver.h b/paddle/fluid/platform/dynload/rocm_driver.h index dc9c18e732..7633e84c85 100644 --- a/paddle/fluid/platform/dynload/rocm_driver.h +++ b/paddle/fluid/platform/dynload/rocm_driver.h @@ -55,6 +55,7 @@ extern bool HasCUDADriver(); __macro(hipModuleLaunchKernel); \ __macro(hipLaunchKernel); \ __macro(hipGetDevice); \ + __macro(hipGetDeviceCount); \ __macro(hipDevicePrimaryCtxGetState) ROCM_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_ROCM_WRAP); diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 0b8a361abb..d873ac619f 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -34,10 +34,18 @@ limitations under the License. */ #include #include #include - #include "paddle/fluid/platform/cuda_error.pb.h" #endif // PADDLE_WITH_CUDA +#ifdef PADDLE_WITH_HIP +#include +#include +#include +#include +#include // NOLINT +#include "paddle/fluid/platform/cuda_error.pb.h" // NOLINT +#endif + #include #include #include @@ -72,9 +80,23 @@ limitations under the License. */ #endif // __APPLE__ #endif // PADDLE_WITH_CUDA +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/dynload/hiprand.h" +#include "paddle/fluid/platform/dynload/miopen.h" +#include "paddle/fluid/platform/dynload/rocblas.h" +#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) +#include // NOLINT +#include "paddle/fluid/platform/dynload/rccl.h" +#endif // __APPLE__ +#endif // PADDLE_WITH_HIP + // Note: these headers for simplify demangle type string #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/imperative/type_defs.h" +// Note: this header for simplify HIP and CUDA type string +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/fluid/platform/type_defs.h" +#endif namespace paddle { namespace platform { @@ -82,7 +104,7 @@ class ErrorSummary; } // namespace platform } // namespace paddle -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) DECLARE_int64(gpu_allocator_retry_time); #endif DECLARE_int32(call_stack_level); @@ -406,6 +428,15 @@ struct EnforceNotMet : public std::exception { asm("trap;"); \ } \ } while (0) +#elif defined(__HIPCC__) +#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...) \ + do { \ + if (!(_IS_NOT_ERROR)) { \ + printf("Error: %s:%d Assertion `%s` failed. " __FORMAT "\n", __FILE__, \ + __LINE__, #_IS_NOT_ERROR, ##__VA_ARGS__); \ + abort(); \ + } \ + } while (0) #else #define PADDLE_ENFORCE(COND, ...) \ do { \ @@ -996,5 +1027,188 @@ inline void retry_sleep(unsigned milliseconds) { #undef DEFINE_CUDA_STATUS_TYPE #endif // PADDLE_WITH_CUDA +/** HIP PADDLE ENFORCE FUNCTIONS AND MACROS **/ +#ifdef PADDLE_WITH_HIP + +/***** HIP ERROR *****/ +inline bool is_error(hipError_t e) { return e != hipSuccess; } + +inline std::string build_rocm_error_msg(hipError_t e) { +#if defined(PADDLE_WITH_HIP) + int32_t cuda_version = 100; +#else + int32_t cuda_version = -1; +#endif + std::ostringstream sout; + sout << " Hip error(" << e << "), " << hipGetErrorString(e) << "."; + return sout.str(); +} + +/** HIPRAND ERROR **/ +inline bool is_error(hiprandStatus_t stat) { + return stat != HIPRAND_STATUS_SUCCESS; +} + +inline const char* hiprandGetErrorString(hiprandStatus_t stat) { + switch (stat) { + case HIPRAND_STATUS_SUCCESS: + return "HIPRAND_STATUS_SUCCESS"; + case HIPRAND_STATUS_VERSION_MISMATCH: + return "HIPRAND_STATUS_VERSION_MISMATCH"; + case HIPRAND_STATUS_NOT_INITIALIZED: + return "HIPRAND_STATUS_NOT_INITIALIZED"; + case HIPRAND_STATUS_ALLOCATION_FAILED: + return "HIPRAND_STATUS_ALLOCATION_FAILED"; + case HIPRAND_STATUS_TYPE_ERROR: + return "HIPRAND_STATUS_TYPE_ERROR"; + case HIPRAND_STATUS_OUT_OF_RANGE: + return "HIPRAND_STATUS_OUT_OF_RANGE"; + case HIPRAND_STATUS_LENGTH_NOT_MULTIPLE: + return "HIPRAND_STATUS_LENGTH_NOT_MULTIPLE"; + case HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case HIPRAND_STATUS_LAUNCH_FAILURE: + return "HIPRAND_STATUS_LAUNCH_FAILURE"; + case HIPRAND_STATUS_PREEXISTING_FAILURE: + return "HIPRAND_STATUS_PREEXISTING_FAILURE"; + case HIPRAND_STATUS_INITIALIZATION_FAILED: + return "HIPRAND_STATUS_INITIALIZATION_FAILED"; + case HIPRAND_STATUS_ARCH_MISMATCH: + return "HIPRAND_STATUS_ARCH_MISMATCH"; + case HIPRAND_STATUS_INTERNAL_ERROR: + return "HIPRAND_STATUS_INTERNAL_ERROR"; + case HIPRAND_STATUS_NOT_IMPLEMENTED: + return "HIPRAND_STATUS_NOT_IMPLEMENTED"; + default: + return "Unknown hiprand status"; + } +} + +inline std::string build_rocm_error_msg(hiprandStatus_t stat) { + std::string msg(" Hiprand error, "); + return msg + hiprandGetErrorString(stat) + " "; +} + +/***** MIOPEN ERROR *****/ +inline bool is_error(miopenStatus_t stat) { + return stat != miopenStatusSuccess; +} + +inline std::string build_rocm_error_msg(miopenStatus_t stat) { + std::string msg(" Miopen error, "); + return msg + platform::dynload::miopenGetErrorString(stat) + " "; +} + +/***** ROCBLAS ERROR *****/ +inline bool is_error(rocblas_status stat) { + return stat != rocblas_status_success; +} + +inline const char* rocblasGetErrorString(rocblas_status stat) { + switch (stat) { + case rocblas_status_invalid_handle: + return "rocblas_status_invalid_handle"; + case rocblas_status_memory_error: + return "rocblas_status_memory_error"; + case rocblas_status_invalid_value: + return "rocblas_status_invalid_value"; + case rocblas_status_not_implemented: + return "rocblas_status_not_implemented"; + case rocblas_status_invalid_pointer: + return "rocblas_status_invalid_pointer"; + case rocblas_status_invalid_size: + return "rocblas_status_invalid_size"; + case rocblas_status_internal_error: + return "rocblas_status_internal_error"; + default: + return "Unknown cublas status"; + } +} + +inline std::string build_rocm_error_msg(rocblas_status stat) { + std::string msg(" Rocblas error, "); + return msg + rocblasGetErrorString(stat) + " "; +} + +/****** RCCL ERROR ******/ +#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) +inline bool is_error(ncclResult_t nccl_result) { + return nccl_result != ncclSuccess; +} + +inline std::string build_rocm_error_msg(ncclResult_t nccl_result) { + std::string msg(" Rccl error, "); + return msg + platform::dynload::ncclGetErrorString(nccl_result) + " "; +} +#endif // not(__APPLE__) and PADDLE_WITH_NCCL + +namespace details { + +template +struct CudaStatusType {}; + +#define DEFINE_CUDA_STATUS_TYPE(type, success_value) \ + template <> \ + struct CudaStatusType { \ + using Type = type; \ + static constexpr Type kSuccess = success_value; \ + } + +DEFINE_CUDA_STATUS_TYPE(hipError_t, hipSuccess); +DEFINE_CUDA_STATUS_TYPE(hiprandStatus_t, HIPRAND_STATUS_SUCCESS); +DEFINE_CUDA_STATUS_TYPE(miopenStatus_t, miopenStatusSuccess); +DEFINE_CUDA_STATUS_TYPE(rocblas_status, rocblas_status_success); + +#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) +DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); +#endif + +} // namespace details + +#define PADDLE_ENFORCE_CUDA_SUCCESS(COND) \ + do { \ + auto __cond__ = (COND); \ + using __CUDA_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::CudaStatusType< \ + __CUDA_STATUS_TYPE__>::kSuccess; \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + auto __summary__ = ::paddle::platform::errors::External( \ + ::paddle::platform::build_rocm_error_msg(__cond__)); \ + __THROW_ERROR_INTERNAL__(__summary__); \ + } \ + } while (0) + +inline void retry_sleep(unsigned millisecond) { +#ifdef _WIN32 + Sleep(millisecond); +#else + sleep(millisecond); +#endif +} + +#define PADDLE_RETRY_CUDA_SUCCESS(COND) \ + do { \ + auto __cond__ = (COND); \ + int retry_count = 1; \ + using __CUDA_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::CudaStatusType< \ + __CUDA_STATUS_TYPE__>::kSuccess; \ + while (UNLIKELY(__cond__ != __success_type__) && retry_count < 5) { \ + retry_sleep(FLAGS_gpu_allocator_retry_time); \ + __cond__ = (COND); \ + ++retry_count; \ + } \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + auto __summary__ = ::paddle::platform::errors::External( \ + ::paddle::platform::build_rocm_error_msg(__cond__)); \ + __THROW_ERROR_INTERNAL__(__summary__); \ + } \ + } while (0) + +#undef DEFINE_CUDA_STATUS_TYPE +#endif // PADDLE_WITH_HIP + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index f086c3f823..549b0d50d9 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -295,7 +295,7 @@ TEST(EOF_EXCEPTION, THROW_EOF) { EXPECT_TRUE(caught_eof); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template bool CheckCudaStatusSuccess(T value, const std::string& msg = "success") { PADDLE_ENFORCE_CUDA_SUCCESS(value); @@ -312,7 +312,35 @@ bool CheckCudaStatusFailure(T value, const std::string& msg) { return ex_msg.find(msg) != std::string::npos; } } +#ifdef PADDLE_WITH_HIP +TEST(enforce, hip_success) { + EXPECT_TRUE(CheckCudaStatusSuccess(hipSuccess)); + EXPECT_TRUE(CheckCudaStatusFailure(hipErrorInvalidValue, "Hip error")); + EXPECT_TRUE(CheckCudaStatusFailure(hipErrorOutOfMemory, "Hip error")); + EXPECT_TRUE(CheckCudaStatusSuccess(HIPRAND_STATUS_SUCCESS)); + EXPECT_TRUE( + CheckCudaStatusFailure(HIPRAND_STATUS_VERSION_MISMATCH, "Hiprand error")); + EXPECT_TRUE( + CheckCudaStatusFailure(HIPRAND_STATUS_NOT_INITIALIZED, "Hiprand error")); + + EXPECT_TRUE(CheckCudaStatusSuccess(miopenStatusSuccess)); + EXPECT_TRUE( + CheckCudaStatusFailure(miopenStatusNotInitialized, "Miopen error")); + EXPECT_TRUE(CheckCudaStatusFailure(miopenStatusAllocFailed, "Miopen error")); + + EXPECT_TRUE(CheckCudaStatusSuccess(rocblas_status_success)); + EXPECT_TRUE( + CheckCudaStatusFailure(rocblas_status_invalid_handle, "Rocblas error")); + EXPECT_TRUE( + CheckCudaStatusFailure(rocblas_status_invalid_value, "Rocblas error")); +#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) + EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); + EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "Rccl error")); + EXPECT_TRUE(CheckCudaStatusFailure(ncclSystemError, "Rccl error")); +#endif +} +#else TEST(enforce, cuda_success) { EXPECT_TRUE(CheckCudaStatusSuccess(cudaSuccess)); EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorInvalidValue, "Cuda error")); @@ -341,6 +369,7 @@ TEST(enforce, cuda_success) { #endif } #endif +#endif struct CannotToStringType { explicit CannotToStringType(int num) : num_(num) {} diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index d4b308e6bc..f57da65179 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -20,8 +20,8 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include #endif // PADDLE_WITH_CUDA + #ifdef PADDLE_WITH_HIP -#define CUDA_VERSION 10000 #include #endif @@ -41,6 +41,7 @@ limitations under the License. */ #define PADDLE_CUDA_FP16 #include #endif + #ifdef __HIPCC__ #define PADDLE_CUDA_FP16 #include @@ -90,7 +91,7 @@ struct PADDLE_ALIGN(2) float16 { #ifdef PADDLE_CUDA_FP16 HOSTDEVICE inline explicit float16(const half& h) { #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) -#if CUDA_VERSION >= 9000 +#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 9000 x = reinterpret_cast<__half_raw*>(const_cast(&h))->x; #else x = h.x; @@ -110,9 +111,8 @@ struct PADDLE_ALIGN(2) float16 { #endif HOSTDEVICE inline explicit float16(float val) { -#if ((defined(PADDLE_CUDA_FP16)) && \ - ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \ - (defined(__HIP_DEVICE_COMPILE__)))) +#if defined(PADDLE_CUDA_FP16) && \ + (defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)) half tmp = __float2half(val); x = *reinterpret_cast(&tmp); @@ -154,7 +154,7 @@ struct PADDLE_ALIGN(2) float16 { // Assignment operators #ifdef PADDLE_CUDA_FP16 HOSTDEVICE inline float16& operator=(const half& rhs) { -#if CUDA_VERSION >= 9000 +#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 9000 x = reinterpret_cast<__half_raw*>(const_cast(&rhs))->x; #else x = rhs.x; @@ -233,7 +233,7 @@ struct PADDLE_ALIGN(2) float16 { // Conversion opertors #ifdef PADDLE_CUDA_FP16 HOSTDEVICE inline explicit operator half() const { -#if CUDA_VERSION >= 9000 +#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 9000 __half_raw h; h.x = x; return half(h); @@ -258,9 +258,8 @@ struct PADDLE_ALIGN(2) float16 { #endif HOSTDEVICE inline explicit operator float() const { -#if (defined(PADDLE_CUDA_FP16) && \ - ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \ - (defined(__HIP_DEVICE_COMPILE__)))) +#if defined(PADDLE_CUDA_FP16) && \ + (defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)) half tmp = *reinterpret_cast(this); return __half2float(tmp); @@ -370,8 +369,7 @@ struct PADDLE_ALIGN(2) float16 { // xuan[TODO] change for rocm #if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000 DEVICE inline half operator+(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hadd(a, b); #else float res = static_cast(float16(a)) + static_cast(float16(b)); @@ -380,8 +378,7 @@ DEVICE inline half operator+(const half& a, const half& b) { } DEVICE inline half operator-(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hsub(a, b); #else float res = static_cast(float16(a)) - static_cast(float16(b)); @@ -390,8 +387,7 @@ DEVICE inline half operator-(const half& a, const half& b) { } DEVICE inline half operator*(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hmul(a, b); #else float res = static_cast(float16(a)) * static_cast(float16(b)); @@ -400,8 +396,7 @@ DEVICE inline half operator*(const half& a, const half& b) { } DEVICE inline half operator/(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) float num = __half2float(a); float denom = __half2float(b); return __float2half(num / denom); @@ -412,8 +407,7 @@ DEVICE inline half operator/(const half& a, const half& b) { } DEVICE inline half operator-(const half& a) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hneg(a); #else float res = -static_cast(float16(a)); @@ -421,6 +415,7 @@ DEVICE inline half operator-(const half& a) { #endif } +#ifndef PADDLE_WITH_HIP // not defined __HIP_NO_HALF_OPERATORS__ DEVICE inline half& operator+=(half& a, const half& b) { // NOLINT a = a + b; return a; @@ -440,10 +435,10 @@ DEVICE inline half& operator/=(half& a, const half& b) { // NOLINT a = a / b; return a; } +#endif DEVICE inline bool operator==(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __heq(a, b); #else return static_cast(float16(a)) == static_cast(float16(b)); @@ -451,8 +446,7 @@ DEVICE inline bool operator==(const half& a, const half& b) { } DEVICE inline bool operator!=(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hne(a, b); #else return static_cast(float16(a)) != static_cast(float16(b)); @@ -460,8 +454,7 @@ DEVICE inline bool operator!=(const half& a, const half& b) { } DEVICE inline bool operator<(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hlt(a, b); #else return static_cast(float16(a)) < static_cast(float16(b)); @@ -469,8 +462,7 @@ DEVICE inline bool operator<(const half& a, const half& b) { } DEVICE inline bool operator<=(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hle(a, b); #else return static_cast(float16(a)) <= static_cast(float16(b)); @@ -478,8 +470,7 @@ DEVICE inline bool operator<=(const half& a, const half& b) { } DEVICE inline bool operator>(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hgt(a, b); #else return static_cast(float16(a)) > static_cast(float16(b)); @@ -487,8 +478,7 @@ DEVICE inline bool operator>(const half& a, const half& b) { } DEVICE inline bool operator>=(const half& a, const half& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hge(a, b); #else return static_cast(float16(a)) >= static_cast(float16(b)); @@ -499,36 +489,66 @@ DEVICE inline bool operator>=(const half& a, const half& b) { // Arithmetic operators for float16 on GPU #if defined(PADDLE_CUDA_FP16) + +// HIPCC has compile error if call __device__ function __hadd in __host__ +// __device__ function +#if defined(__HIPCC__) +DEVICE inline float16 operator+(const float16& a, const float16& b) { + return float16(__hadd(half(a), half(b))); +} +HOST inline float16 operator+(const float16& a, const float16& b) { + return float16(static_cast(a) + static_cast(b)); +} +#else HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hadd(half(a), half(b))); #else return float16(static_cast(a) + static_cast(b)); #endif } +#endif +// HIPCC has compile error if call __device__ function __hsub in __host__ +// __device__ function +#if defined(__HIPCC__) +DEVICE inline float16 operator-(const float16& a, const float16& b) { + return float16(__hsub(half(a), half(b))); +} +HOST inline float16 operator-(const float16& a, const float16& b) { + return float16(static_cast(a) - static_cast(b)); +} +#else HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hsub(half(a), half(b))); #else return float16(static_cast(a) - static_cast(b)); #endif } +#endif +// HIPCC has compile error if call __device__ function __hmul in __host__ +// __device__ function +#if defined(__HIPCC__) +DEVICE inline float16 operator*(const float16& a, const float16& b) { + return float16(__hmul(half(a), half(b))); +} +HOST inline float16 operator*(const float16& a, const float16& b) { + return float16(static_cast(a) * static_cast(b)); +} +#else HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hmul(half(a), half(b))); #else return float16(static_cast(a) * static_cast(b)); #endif } +#endif HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) // TODO(kexinzhao): check which cuda version starts to support __hdiv float num = __half2float(half(a)); float denom = __half2float(half(b)); @@ -538,9 +558,20 @@ HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { #endif } +// HIPCC has compile error if call __device__ function __hneg in __host__ +// __device__ function +#if defined(__HIPCC__) +DEVICE inline float16 operator-(const float16& a) { + return float16(__hneg(half(a))); +} +HOST inline float16 operator-(const float16& a) { + float16 res; + res.x = a.x ^ 0x8000; + return res; +} +#else HOSTDEVICE inline float16 operator-(const float16& a) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hneg(half(a))); #else float16 res; @@ -548,6 +579,7 @@ HOSTDEVICE inline float16 operator-(const float16& a) { return res; #endif } +#endif HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { // NOLINT a = a + b; @@ -569,18 +601,27 @@ HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { // NOLINT return a; } +// HIPCC has compile error if call __device__ function __heq in __host__ +// __device__ function +#if defined(__HIPCC__) +DEVICE inline bool operator==(const float16& a, const float16& b) { + return __heq(half(a), half(b)); +} +HOST inline bool operator==(const float16& a, const float16& b) { + return static_cast(a) == static_cast(b); +} +#else // CUDA HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __heq(half(a), half(b)); #else return static_cast(a) == static_cast(b); #endif } +#endif HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hne(half(a), half(b)); #else return static_cast(a) != static_cast(b); @@ -588,8 +629,7 @@ HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { } HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hlt(half(a), half(b)); #else return static_cast(a) < static_cast(b); @@ -597,8 +637,7 @@ HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { } HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hle(half(a), half(b)); #else return static_cast(a) <= static_cast(b); @@ -606,8 +645,7 @@ HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { } HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hgt(half(a), half(b)); #else return static_cast(a) > static_cast(b); @@ -615,8 +653,7 @@ HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { } HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { -#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__))) +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) return __hge(half(a), half(b)); #else return static_cast(a) >= static_cast(b); @@ -881,15 +918,20 @@ HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) { return res; } +// HIPCC has compile error if call __device__ function __hisnan in __host__ +// __device__ function +#if defined(PADDLE_CUDA_FP16) && defined(__HIPCC__) +DEVICE inline bool(isnan)(const float16& a) { return __hisnan(half(a)); } +HOST inline bool(isnan)(const float16& a) { return (a.x & 0x7fff) > 0x7c00; } +#else HOSTDEVICE inline bool(isnan)(const float16& a) { -#if (defined(PADDLE_CUDA_FP16) && \ - ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__)))) +#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hisnan(half(a)); #else return (a.x & 0x7fff) > 0x7c00; #endif } +#endif HOSTDEVICE inline bool(isinf)(const float16& a) { return (a.x & 0x7fff) == 0x7c00; @@ -900,9 +942,8 @@ HOSTDEVICE inline bool(isfinite)(const float16& a) { } HOSTDEVICE inline float16(abs)(const float16& a) { -#if (defined(PADDLE_CUDA_FP16) && \ - ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - (defined(__HIP_DEVICE_COMPILE__)))) +#if defined(PADDLE_CUDA_FP16) && \ + (defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)) return float16(::fabs(static_cast(a))); #else return float16(std::abs(static_cast(a))); diff --git a/paddle/fluid/platform/stream/CMakeLists.txt b/paddle/fluid/platform/stream/CMakeLists.txt index 78a7313bde..c0595eb415 100644 --- a/paddle/fluid/platform/stream/CMakeLists.txt +++ b/paddle/fluid/platform/stream/CMakeLists.txt @@ -1,3 +1,3 @@ -IF(WITH_GPU) +IF(WITH_GPU OR WITH_ROCM) cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce boost) ENDIF() diff --git a/paddle/fluid/platform/stream/cuda_stream.cc b/paddle/fluid/platform/stream/cuda_stream.cc index 4543f367ba..fc51a08c2a 100644 --- a/paddle/fluid/platform/stream/cuda_stream.cc +++ b/paddle/fluid/platform/stream/cuda_stream.cc @@ -20,7 +20,11 @@ namespace paddle { namespace platform { namespace stream { +#ifdef PADDLE_WITH_HIP +constexpr unsigned int kDefaultFlag = hipStreamDefault; +#else constexpr unsigned int kDefaultFlag = cudaStreamDefault; +#endif bool CUDAStream::Init(const Place& place, const Priority& priority) { PADDLE_ENFORCE_EQ(is_gpu_place(place), true, @@ -29,11 +33,21 @@ bool CUDAStream::Init(const Place& place, const Priority& priority) { place_ = place; CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device); if (priority == Priority::kHigh) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + hipStreamCreateWithPriority(&stream_, kDefaultFlag, -1)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamCreateWithPriority(&stream_, kDefaultFlag, -1)); +#endif } else if (priority == Priority::kNormal) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + hipStreamCreateWithPriority(&stream_, kDefaultFlag, 0)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamCreateWithPriority(&stream_, kDefaultFlag, 0)); +#endif } callback_manager_.reset(new StreamCallbackManager(stream_)); VLOG(3) << "CUDAStream Init stream: " << stream_ @@ -46,12 +60,27 @@ void CUDAStream::Destroy() { Wait(); WaitCallback(); if (stream_) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamDestroy(stream_)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_)); +#endif } stream_ = nullptr; } void CUDAStream::Wait() const { +#ifdef PADDLE_WITH_HIP + hipError_t e_sync = hipSuccess; +#if !defined(_WIN32) + e_sync = hipStreamSynchronize(stream_); +#else + while (e_sync = hipStreamQuery(stream_)) { + if (e_sync == hipErrorNotReady) continue; + break; + } +#endif +#else cudaError_t e_sync = cudaSuccess; #if !defined(_WIN32) e_sync = cudaStreamSynchronize(stream_); @@ -61,6 +90,7 @@ void CUDAStream::Wait() const { break; } #endif +#endif // PADDLE_WITH_HIP PADDLE_ENFORCE_CUDA_SUCCESS(e_sync); } diff --git a/paddle/fluid/platform/stream/cuda_stream.h b/paddle/fluid/platform/stream/cuda_stream.h index c65d107cf4..d937549251 100644 --- a/paddle/fluid/platform/stream/cuda_stream.h +++ b/paddle/fluid/platform/stream/cuda_stream.h @@ -26,7 +26,7 @@ namespace paddle { namespace platform { namespace stream { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) enum class Priority : uint8_t { kNull = 0x0, @@ -51,28 +51,55 @@ class CUDAStream final { } template +#ifdef PADDLE_WITH_HIP + void RecordEvent(hipEvent_t ev, Callback callback) const { + callback(); + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(ev, stream_)); + } +#else void RecordEvent(cudaEvent_t ev, Callback callback) const { callback(); PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_)); } +#endif +#ifdef PADDLE_WITH_HIP + void RecordEvent(hipEvent_t ev) const { + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(ev, stream_)); + } +#else void RecordEvent(cudaEvent_t ev) const { PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_)); } +#endif +#ifdef PADDLE_WITH_HIP + void WaitEvent(hipEvent_t ev) const { + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(stream_, ev, 0)); + } +#else void WaitEvent(cudaEvent_t ev) const { PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(stream_, ev, 0)); } +#endif void Wait() const; void WaitCallback() const { callback_manager_->Wait(); } +#ifdef PADDLE_WITH_HIP + const hipStream_t& raw_stream() const { return stream_; } +#else const cudaStream_t& raw_stream() const { return stream_; } +#endif void Destroy(); private: Place place_; +#ifdef PADDLE_WITH_HIP + hipStream_t stream_{nullptr}; +#else cudaStream_t stream_{nullptr}; +#endif Priority priority_{Priority::kNormal}; std::unique_ptr callback_manager_; diff --git a/paddle/fluid/platform/type_defs.h b/paddle/fluid/platform/type_defs.h new file mode 100644 index 0000000000..31784a0426 --- /dev/null +++ b/paddle/fluid/platform/type_defs.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_HIP +#include +#else +#include +#endif + +namespace paddle { + +#ifdef PADDLE_WITH_HIP +#define gpuSuccess hipSuccess +using gpuStream_t = hipStream_t; +using gpuError_t = hipError_t; +using gpuEvent_t = hipEvent_t; +#else +#define gpuSuccess cudaSuccess +using gpuStream_t = cudaStream_t; +using gpuError_t = cudaError_t; +using gpuEvent_t = cudaEvent_t; +#endif + +} // namespace paddle diff --git a/tools/dockerfile/Dockerfile.rocm b/tools/dockerfile/Dockerfile.rocm index fad20fbaea..2f624b2d97 100644 --- a/tools/dockerfile/Dockerfile.rocm +++ b/tools/dockerfile/Dockerfile.rocm @@ -30,7 +30,7 @@ ENV LC_ALL en_US.UTF-8 ENV LANG en_US.UTF-8 ENV LANGUAGE en_US.UTF-8 -RUN yum install -y epel-release deltarpm sudo openssh-server openssl-devel gettext-devel sqlite-devel \ +RUN yum install -y epel-release deltarpm sudo openssh-server gettext-devel sqlite-devel \ zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz graphviz wget curl-devel \ make bzip2 git patch unzip bison yasm diffutils automake which file kernel-headers kernel-devel @@ -65,6 +65,15 @@ RUN echo "[ROCm]" > /etc/yum.repos.d/rocm.repo && \ RUN yum install -y rocm-dev rocm-utils rocfft miopen-hip rocblas hipsparse rocrand rccl hipcub rocthrust rocprofiler-dev roctracer-dev # fix rocthrust RUN sed -i '21 a #include ' /opt/rocm/include/thrust/system/hip/detail/error.inl +# export ROCM env +ENV ROCM_PATH=/opt/rocm +ENV HIP_PATH=/opt/rocm/hip +ENV HIP_CLANG_PATH=/opt/rocm/llvm/bin +ENV PATH=/opt/rocm/bin:$PATH +ENV PATH=/opt/rocm/hcc/bin:$PATH +ENV PATH=/opt/rocm/hip/bin:$PATH +ENV PATH=/opt/rocm/opencl/bin:$PATH +ENV PATH=/opt/rocm/llvm/bin:$PATH # git 2.17.1 RUN cd /opt && wget -q https://paddle-ci.gz.bcebos.com/git-2.17.1.tar.gz && \ @@ -117,6 +126,13 @@ RUN sed -i "s/^#PermitRootLogin/PermitRootLogin/" /etc/ssh/sshd_config && \ sed -i "s/^#PubkeyAuthentication/PubkeyAuthentication/" /etc/ssh/sshd_config && \ sed -i "s/^#RSAAuthentication/RSAAuthentication/" /etc/ssh/sshd_config +# patchelf +RUN yum install -y patchelf && \ + yum clean all && \ + rm -rf /var/cache/yum && \ + rm -rf /var/lib/yum/yumdb && \ + rm -rf /var/lib/yum/history + # swig 2.0.12 RUN wget -O /opt/swig-2.0.12.tar.gz https://sourceforge.net/projects/swig/files/swig/swig-2.0.12/swig-2.0.12.tar.gz/download && \ cd /opt && tar xzf swig-2.0.12.tar.gz && cd /opt/swig-2.0.12 && ./configure && make && make install && \ -- GitLab