未验证 提交 57da105c 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix cupti, rccl on rocm (#54807)

* [ROCM] fix cupti, hipcub

* update

* update
上级 6cfe9bfd
...@@ -2,9 +2,15 @@ if(NOT WITH_GPU AND NOT WITH_ROCM) ...@@ -2,9 +2,15 @@ if(NOT WITH_GPU AND NOT WITH_ROCM)
return() return()
endif() endif()
set(CUPTI_ROOT if(WITH_ROCM)
"/usr" set(CUPTI_ROOT
CACHE PATH "CUPTI ROOT") "${ROCM_PATH}/CUPTI"
CACHE PATH "CUPTI ROOT")
else()
set(CUPTI_ROOT
"/usr"
CACHE PATH "CUPTI ROOT")
endif()
find_path( find_path(
CUPTI_INCLUDE_DIR cupti.h CUPTI_INCLUDE_DIR cupti.h
PATHS ${CUPTI_ROOT} PATHS ${CUPTI_ROOT}
......
...@@ -106,7 +106,11 @@ list(APPEND HIP_CXX_FLAGS -Wno-duplicate-decl-specifier) ...@@ -106,7 +106,11 @@ list(APPEND HIP_CXX_FLAGS -Wno-duplicate-decl-specifier)
list(APPEND HIP_CXX_FLAGS -Wno-implicit-int-float-conversion) list(APPEND HIP_CXX_FLAGS -Wno-implicit-int-float-conversion)
list(APPEND HIP_CXX_FLAGS -Wno-pass-failed) list(APPEND HIP_CXX_FLAGS -Wno-pass-failed)
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP) list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
list(APPEND HIP_CXX_FLAGS -std=c++14) if(WITH_CINN)
list(APPEND HIP_CXX_FLAGS -std=c++14)
else()
list(APPEND HIP_CXX_FLAGS -std=c++17)
endif()
if(CMAKE_BUILD_TYPE MATCHES Debug) if(CMAKE_BUILD_TYPE MATCHES Debug)
list(APPEND HIP_CXX_FLAGS -g2) list(APPEND HIP_CXX_FLAGS -g2)
......
...@@ -28,9 +28,17 @@ RCCL_RAND_ROUTINE_EACH(DEFINE_WRAP); ...@@ -28,9 +28,17 @@ RCCL_RAND_ROUTINE_EACH(DEFINE_WRAP);
RCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP) RCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP)
#endif #endif
#if NCCL_VERSION_CODE >= 2304
RCCL_RAND_ROUTINE_EACH_AFTER_2304(DEFINE_WRAP)
#endif
#if NCCL_VERSION_CODE >= 2703 #if NCCL_VERSION_CODE >= 2703
RCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP) RCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP)
#endif #endif
#if NCCL_VERSION_CODE >= 21100
RCCL_RAND_ROUTINE_EACH_AFTER_21100(DEFINE_WRAP)
#endif
} // namespace dynload } // namespace dynload
} // namespace phi } // namespace phi
...@@ -64,6 +64,11 @@ RCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) ...@@ -64,6 +64,11 @@ RCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
RCCL_RAND_ROUTINE_EACH_AFTER_2212(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) RCCL_RAND_ROUTINE_EACH_AFTER_2212(DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif #endif
#if NCCL_VERSION_CODE >= 2304
#define RCCL_RAND_ROUTINE_EACH_AFTER_2304(__macro) __macro(ncclGetVersion);
RCCL_RAND_ROUTINE_EACH_AFTER_2304(DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif
#if NCCL_VERSION_CODE >= 2703 #if NCCL_VERSION_CODE >= 2703
#define RCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \ #define RCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \
__macro(ncclSend); \ __macro(ncclSend); \
...@@ -71,5 +76,11 @@ RCCL_RAND_ROUTINE_EACH_AFTER_2212(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) ...@@ -71,5 +76,11 @@ RCCL_RAND_ROUTINE_EACH_AFTER_2212(DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
RCCL_RAND_ROUTINE_EACH_AFTER_2703(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) RCCL_RAND_ROUTINE_EACH_AFTER_2703(DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif #endif
#if NCCL_VERSION_CODE >= 21100
#define RCCL_RAND_ROUTINE_EACH_AFTER_21100(__macro) \
__macro(ncclRedOpCreatePreMulSum); \
__macro(ncclRedOpDestroy);
RCCL_RAND_ROUTINE_EACH_AFTER_21100(DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif
} // namespace dynload } // namespace dynload
} // namespace phi } // namespace phi
...@@ -999,12 +999,10 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA, ...@@ -999,12 +999,10 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
int ldc) const { int ldc) const {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
rocblas_operation cuTransA = (transA == CblasNoTrans) rocblas_operation cuTransA =
? rocblas_operation_none transA ? rocblas_operation_none : rocblas_operation_transpose;
: rocblas_operation_transpose; rocblas_operation cuTransB =
rocblas_operation cuTransB = (transB == CblasNoTrans) transB ? rocblas_operation_none : rocblas_operation_transpose;
? rocblas_operation_none
: rocblas_operation_transpose;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), context_.GetComputeCapability(),
80, 80,
......
...@@ -54,6 +54,15 @@ struct radix_key_codec_base<phi::dtype::float16> ...@@ -54,6 +54,15 @@ struct radix_key_codec_base<phi::dtype::float16>
template <> template <>
struct radix_key_codec_base<phi::dtype::bfloat16> struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {}; : radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
#if ROCM_VERSION_MAJOR >= 5 && ROCM_VERSION_MINOR >= 4
template <>
struct float_bit_mask<phi::dtype::float16> : float_bit_mask<rocprim::half> {};
template <>
struct float_bit_mask<phi::dtype::bfloat16>
: float_bit_mask<rocprim::bfloat16> {};
#endif
} // namespace detail } // namespace detail
} // namespace rocprim } // namespace rocprim
namespace cub = hipcub; namespace cub = hipcub;
......
...@@ -40,6 +40,19 @@ namespace detail { ...@@ -40,6 +40,19 @@ namespace detail {
template <> template <>
struct radix_key_codec_base<phi::dtype::float16> struct radix_key_codec_base<phi::dtype::float16>
: radix_key_codec_integral<phi::dtype::float16, uint16_t> {}; : radix_key_codec_integral<phi::dtype::float16, uint16_t> {};
template <>
struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
#if ROCM_VERSION_MAJOR >= 5 && ROCM_VERSION_MINOR >= 4
template <>
struct float_bit_mask<phi::dtype::float16> : float_bit_mask<rocprim::half> {};
template <>
struct float_bit_mask<phi::dtype::bfloat16>
: float_bit_mask<rocprim::bfloat16> {};
#endif
} // namespace detail } // namespace detail
} // namespace rocprim } // namespace rocprim
#else #else
......
...@@ -7,8 +7,7 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") ...@@ -7,8 +7,7 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(CUSTOM_ENVS set(CUSTOM_ENVS
PADDLE_SOURCE_DIR=${PADDLE_SOURCE_DIR} PADDLE_SOURCE_DIR=${PADDLE_SOURCE_DIR}
PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}
CUSTOM_DEVICE_ROOT=${CMAKE_BINARY_DIR}/python/paddle/fluid/tests/custom_kernel CUSTOM_DEVICE_ROOT=${CMAKE_BINARY_DIR}/test)
)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test(${TEST_OP} SRCS ${TEST_OP}.py ENVS ${CUSTOM_ENVS}) py_test(${TEST_OP} SRCS ${TEST_OP}.py ENVS ${CUSTOM_ENVS})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册