未验证 提交 dd1d3789 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] add custom op support, test=develop (#36771)

* [ROCM] add custom op support, test=develop

* remove debug codes, test=develop
上级 5c569aef
...@@ -17,6 +17,6 @@ limitations under the License. */ ...@@ -17,6 +17,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
// TODO(yangjiabin): Add other place support in next PR // TODO(yangjiabin): Add other place support in next PR
enum class PlaceType { kUNK = -1, kCPU, kGPU, kHIP }; enum class PlaceType { kUNK = -1, kCPU, kGPU };
} // namespace paddle } // namespace paddle
...@@ -16,8 +16,15 @@ limitations under the License. */ ...@@ -16,8 +16,15 @@ limitations under the License. */
#include <memory> #include <memory>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h> #include <cuda_runtime.h>
using gpuStream_t = cudaStream_t;
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
using gpuStream_t = hipStream_t;
#endif #endif
#include "ext_dll_decl.h" // NOLINT #include "ext_dll_decl.h" // NOLINT
...@@ -126,11 +133,9 @@ class PD_DLL_DECL Tensor { ...@@ -126,11 +133,9 @@ class PD_DLL_DECL Tensor {
/// \brief Check Tensor is initialized /// \brief Check Tensor is initialized
bool is_initialized() const; bool is_initialized() const;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/// \bref Get current stream of Tensor /// \bref Get current stream of Tensor
cudaStream_t stream() const; gpuStream_t stream() const;
#elif defined(PADDLE_WITH_HIP)
hipStream_t stream() const;
#endif #endif
private: private:
......
...@@ -69,9 +69,9 @@ struct CastDataType { ...@@ -69,9 +69,9 @@ struct CastDataType {
}; };
template <typename T> template <typename T>
void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
int64_t ele_size) { int64_t ele_size) {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = paddle::platform::GetCurrentDeviceId(); int device_num = paddle::platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num); platform::CUDAPlace gpu_place(device_num);
...@@ -90,29 +90,11 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, ...@@ -90,29 +90,11 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Only GPU related Copy can reach this func.")); "Only GPU related Copy can reach this func."));
} }
#elif defined(PADDLE_WITH_HIP) #ifdef PADDLE_WITH_HIP
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); hipStreamSynchronize(dev_ctx->stream());
int device_num = paddle::platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
if ((src_plc == PlaceType::kHIP) && (dst_plc == PlaceType::kCPU)) {
memory::Copy(platform::CPUPlace(), static_cast<void *>(dst), gpu_place, src,
ele_size, dev_ctx->stream());
} else if ((src_plc == PlaceType::kHIP) && (dst_plc == PlaceType::kHIP)) {
memory::Copy(gpu_place, static_cast<void *>(dst), gpu_place, src, ele_size,
dev_ctx->stream());
} else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kHIP)) {
memory::Copy(gpu_place, static_cast<void *>(dst), platform::CPUPlace(), src,
ele_size, dev_ctx->stream());
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Only GPU related Copy can reach this func."));
}
#else #else
PADDLE_THROW(platform::errors::Unavailable( cudaStreamSynchronize(dev_ctx->stream());
"This function can only be used if compiled with" #endif
"either -DWITH_ROCM=ON or -DWITH_GPU=ON"));
#endif #endif
} }
...@@ -175,16 +157,11 @@ T *Tensor::mutable_data() { ...@@ -175,16 +157,11 @@ T *Tensor::mutable_data() {
case static_cast<int>(PlaceType::kCPU): { case static_cast<int>(PlaceType::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace()); return tensor->mutable_data<T>(platform::CPUPlace());
} }
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case static_cast<int>(PlaceType::kGPU): { case static_cast<int>(PlaceType::kGPU): {
int device_num = platform::GetCurrentDeviceId(); int device_num = platform::GetCurrentDeviceId();
return tensor->mutable_data<T>(platform::CUDAPlace(device_num)); return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
} }
#elif defined(PADDLE_WITH_HIP)
case static_cast<int>(PlaceType::kHIP): {
int device_num = platform::GetCurrentDeviceId();
return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
}
#endif #endif
default: default:
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
...@@ -245,23 +222,17 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const { ...@@ -245,23 +222,17 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
target.reshape(shape()); target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>(); auto *p_target_data = target.template mutable_data<T>();
bool supported_gpu_transform = false;
#if defined(PADDLE_WITH_CUDA)
supported_gpu_transform =
(src_place == PlaceType::kGPU && target_place == PlaceType::kCPU) ||
(src_place == PlaceType::kCPU && target_place == PlaceType::kGPU) ||
(src_place == PlaceType::kGPU && target_place == PlaceType::kGPU);
#elif defined(PADDLE_WITH_HIP)
supported_gpu_transform =
(src_place == PlaceType::kHIP && target_place == PlaceType::kCPU) ||
(src_place == PlaceType::kCPU && target_place == PlaceType::kHIP) ||
(src_place == PlaceType::kHIP && target_place == PlaceType::kHIP);
#endif
if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) { if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) {
std::memcpy(static_cast<void *>(p_target_data), p_src_data, ele_size); std::memcpy(static_cast<void *>(p_target_data), p_src_data, ele_size);
} else if (supported_gpu_transform) { } else if ((src_place == PlaceType::kGPU) &&
DeviceCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size); (target_place == PlaceType::kCPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kCPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else { } else {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Not supported place transform of place: %d to place: %d", "Not supported place transform of place: %d to place: %d",
...@@ -363,18 +334,15 @@ const PlaceType &Tensor::place() const { ...@@ -363,18 +334,15 @@ const PlaceType &Tensor::place() const {
GET_CASTED_TENSOR; GET_CASTED_TENSOR;
if (platform::is_cpu_place(tensor->place())) { if (platform::is_cpu_place(tensor->place())) {
place_ = PlaceType::kCPU; place_ = PlaceType::kCPU;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (platform::is_gpu_place(tensor->place())) { } else if (platform::is_gpu_place(tensor->place())) {
place_ = PlaceType::kGPU; place_ = PlaceType::kGPU;
#elif defined(PADDLE_WITH_HIP)
} else if (platform::is_gpu_place(tensor->place())) {
place_ = PlaceType::kHIP;
#endif #endif
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Current Tensor hold unsupported Place Type, Please Init it" "Current Tensor hold unsupported Place Type, Please Init it"
"using Tensor::mutable_data<T>(PaddlePlace) with T among:" "using Tensor::mutable_data<T>(PaddlePlace) with T among:"
"Place::kCPU or Place::kGPU or Place::kHIP")); "Place::kCPU or Place::kGPU"));
} }
return place_; return place_;
} }
...@@ -456,21 +424,16 @@ bool Tensor::is_initialized() const { ...@@ -456,21 +424,16 @@ bool Tensor::is_initialized() const {
} }
} }
#define DEFINE_STREAM(_stream_t_) \ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
_stream_t_ Tensor::stream() const { \ gpuStream_t Tensor::stream() const {
if (!stream_.IsStreamSet()) { \ if (!stream_.IsStreamSet()) {
PADDLE_THROW(platform::errors::PreconditionNotMet( \ PADDLE_THROW(platform::errors::PreconditionNotMet(
"Stream is not Set, only input tensor will have " \ "Stream is not Set, only input tensor will have "
"stream which is set by framework ")); \ "stream which is set by framework "));
} else { \ } else {
return reinterpret_cast<_stream_t_>(stream_.GetStream()); \ return reinterpret_cast<gpuStream_t>(stream_.GetStream());
} \
} }
}
#if defined(PADDLE_WITH_CUDA)
DEFINE_STREAM(cudaStream_t)
#elif defined(PADDLE_WITH_HIP)
DEFINE_STREAM(hipStream_t)
#endif #endif
namespace framework { namespace framework {
......
...@@ -503,7 +503,7 @@ void RegisterOperatorKernel(const std::string& name, ...@@ -503,7 +503,7 @@ void RegisterOperatorKernel(const std::string& name,
// but call api in gpu device, it will cause error. // but call api in gpu device, it will cause error.
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kCPU, inputs, outputs, attrs); PlaceType::kCPU, inputs, outputs, attrs);
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kGPU, inputs, outputs, attrs); PlaceType::kGPU, inputs, outputs, attrs);
#endif #endif
......
...@@ -38,7 +38,7 @@ void TestCopyTensor() { ...@@ -38,7 +38,7 @@ void TestCopyTensor() {
for (int64_t i = 0; i < t1.size(); i++) { for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_cpu_cp.template data<T>()[i], T(5)); CHECK_EQ(t1_cpu_cp.template data<T>()[i], T(5));
} }
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
VLOG(2) << "Do GPU copy test"; VLOG(2) << "Do GPU copy test";
auto t1_gpu_cp = t1_cpu_cp.template copy_to<T>(paddle::PlaceType::kGPU); auto t1_gpu_cp = t1_cpu_cp.template copy_to<T>(paddle::PlaceType::kGPU);
CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place())); CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place()));
...@@ -50,33 +50,16 @@ void TestCopyTensor() { ...@@ -50,33 +50,16 @@ void TestCopyTensor() {
for (int64_t i = 0; i < t1.size(); i++) { for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], T(5)); CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], T(5));
} }
#elif defined(PADDLE_WITH_HIP)
VLOG(2) << "Do HIP copy test";
auto t1_gpu_cp = t1_cpu_cp.template copy_to<T>(paddle::PlaceType::kHIP);
CHECK((paddle::PlaceType::kHIP == t1_gpu_cp.place()));
auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kHIP);
CHECK((paddle::PlaceType::kHIP == t1_gpu_cp_cp.place()));
auto t1_gpu_cp_cp_cpu =
t1_gpu_cp_cp.template copy_to<T>(paddle::PlaceType::kCPU);
CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place()));
for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], T(5));
}
#endif #endif
} }
void TestAPIPlace() { void TestAPIPlace() {
std::vector<int64_t> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU); auto t1 = paddle::Tensor(paddle::PlaceType::kGPU);
t1.reshape(tensor_shape); t1.reshape(tensor_shape);
t1.mutable_data<float>(); t1.mutable_data<float>();
CHECK((paddle::PlaceType::kGPU == t1.place())); CHECK((paddle::PlaceType::kGPU == t1.place()));
#elif defined(PADDLE_WITH_HIP)
auto t1 = paddle::Tensor(paddle::PlaceType::kHIP);
t1.reshape(tensor_shape);
t1.mutable_data<float>();
CHECK((paddle::PlaceType::kHIP == t1.place()));
#endif #endif
auto t2 = paddle::Tensor(paddle::PlaceType::kCPU); auto t2 = paddle::Tensor(paddle::PlaceType::kCPU);
t2.reshape(tensor_shape); t2.reshape(tensor_shape);
...@@ -97,7 +80,7 @@ void TestAPISlice() { ...@@ -97,7 +80,7 @@ void TestAPISlice() {
std::vector<int64_t> tensor_shape_sub1 = {3, 5}; std::vector<int64_t> tensor_shape_sub1 = {3, 5};
std::vector<int64_t> tensor_shape_origin2 = {5, 5, 5}; std::vector<int64_t> tensor_shape_origin2 = {5, 5, 5};
std::vector<int64_t> tensor_shape_sub2 = {1, 5, 5}; std::vector<int64_t> tensor_shape_sub2 = {1, 5, 5};
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin1); auto t1 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin1);
t1.mutable_data<float>(); t1.mutable_data<float>();
CHECK(t1.slice(0, 5).shape() == tensor_shape_origin1); CHECK(t1.slice(0, 5).shape() == tensor_shape_origin1);
...@@ -144,7 +127,7 @@ void TestCast(paddle::DataType data_type) { ...@@ -144,7 +127,7 @@ void TestCast(paddle::DataType data_type) {
t1.template mutable_data<T>(); t1.template mutable_data<T>();
auto t2 = t1.cast(data_type); auto t2 = t1.cast(data_type);
CHECK(t2.type() == data_type); CHECK(t2.type() == data_type);
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tg1 = paddle::Tensor(paddle::PlaceType::kGPU); auto tg1 = paddle::Tensor(paddle::PlaceType::kGPU);
tg1.reshape(tensor_shape); tg1.reshape(tensor_shape);
tg1.template mutable_data<T>(); tg1.template mutable_data<T>();
......
...@@ -18,11 +18,9 @@ limitations under the License. */ ...@@ -18,11 +18,9 @@ limitations under the License. */
#include "paddle/fluid/extension/include/ext_tensor.h" #include "paddle/fluid/extension/include/ext_tensor.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#endif
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -110,7 +108,7 @@ class CustomTensorUtils { ...@@ -110,7 +108,7 @@ class CustomTensorUtils {
if (pc == PlaceType::kCPU) { if (pc == PlaceType::kCPU) {
return platform::Place(platform::CPUPlace()); return platform::Place(platform::CPUPlace());
} else if (pc == PlaceType::kGPU) { } else if (pc == PlaceType::kGPU) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return platform::Place( return platform::Place(
platform::CUDAPlace(platform::GetCurrentDeviceId())); platform::CUDAPlace(platform::GetCurrentDeviceId()));
#endif #endif
...@@ -127,7 +125,7 @@ class CustomTensorUtils { ...@@ -127,7 +125,7 @@ class CustomTensorUtils {
if (platform::is_cpu_place(pc)) { if (platform::is_cpu_place(pc)) {
return PlaceType::kCPU; return PlaceType::kCPU;
} else if (platform::is_gpu_place(pc)) { } else if (platform::is_gpu_place(pc)) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return PlaceType::kGPU; return PlaceType::kGPU;
#endif #endif
} else { } else {
...@@ -142,7 +140,7 @@ class CustomTensorUtils { ...@@ -142,7 +140,7 @@ class CustomTensorUtils {
static void SetTensorCurrentStream(paddle::Tensor* src, static void SetTensorCurrentStream(paddle::Tensor* src,
const platform::Place& pc) { const platform::Place& pc) {
if (platform::is_gpu_place(pc)) { if (platform::is_gpu_place(pc)) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* dev_ctx = static_cast<platform::CUDADeviceContext*>( auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(pc)); platform::DeviceContextPool::Instance().Get(pc));
src->stream_.SetStream(reinterpret_cast<void*>(dev_ctx->stream())); src->stream_.SetStream(reinterpret_cast<void*>(dev_ctx->stream()));
......
...@@ -403,7 +403,7 @@ class BuildExtension(build_ext, object): ...@@ -403,7 +403,7 @@ class BuildExtension(build_ext, object):
cflags = copy.deepcopy(extra_postargs) cflags = copy.deepcopy(extra_postargs)
try: try:
original_compiler = self.compiler.compiler_so original_compiler = self.compiler.compiler_so
# nvcc compile CUDA source # nvcc or hipcc compile CUDA source
if is_cuda_file(src): if is_cuda_file(src):
if core.is_compiled_with_rocm(): if core.is_compiled_with_rocm():
assert ROCM_HOME is not None, "Not found ROCM runtime, \ assert ROCM_HOME is not None, "Not found ROCM runtime, \
...@@ -429,6 +429,13 @@ class BuildExtension(build_ext, object): ...@@ -429,6 +429,13 @@ class BuildExtension(build_ext, object):
elif isinstance(cflags, dict): elif isinstance(cflags, dict):
cflags = cflags['cxx'] cflags = cflags['cxx']
# Note(qili93): HIP require some additional flags for CMAKE_C_FLAGS
if core.is_compiled_with_rocm():
cflags.append('-D__HIP_PLATFORM_HCC__')
cflags.append('-D__HIP_NO_HALF_CONVERSIONS__=1')
cflags.append(
'-DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP')
# NOTE(Aurelius84): Since Paddle 2.0, we require gcc version > 5.x, # NOTE(Aurelius84): Since Paddle 2.0, we require gcc version > 5.x,
# so we add this flag to ensure the symbol names from user compiled # so we add this flag to ensure the symbol names from user compiled
# shared library have same ABI suffix with core_(no)avx.so. # shared library have same ABI suffix with core_(no)avx.so.
...@@ -436,7 +443,10 @@ class BuildExtension(build_ext, object): ...@@ -436,7 +443,10 @@ class BuildExtension(build_ext, object):
add_compile_flag(['-D_GLIBCXX_USE_CXX11_ABI=1'], cflags) add_compile_flag(['-D_GLIBCXX_USE_CXX11_ABI=1'], cflags)
# Append this macor only when jointly compiling .cc with .cu # Append this macor only when jointly compiling .cc with .cu
if not is_cuda_file(src) and self.contain_cuda_file: if not is_cuda_file(src) and self.contain_cuda_file:
cflags.append('-DPADDLE_WITH_CUDA') if core.is_compiled_with_rocm():
cflags.append('-DPADDLE_WITH_HIP')
else:
cflags.append('-DPADDLE_WITH_CUDA')
add_std_without_repeat( add_std_without_repeat(
cflags, self.compiler.compiler_type, use_std14=True) cflags, self.compiler.compiler_type, use_std14=True)
......
...@@ -56,7 +56,12 @@ CLANG_LINK_FLAGS = [ ...@@ -56,7 +56,12 @@ CLANG_LINK_FLAGS = [
MSVC_LINK_FLAGS = ['/MACHINE:X64'] MSVC_LINK_FLAGS = ['/MACHINE:X64']
COMMON_NVCC_FLAGS = ['-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU'] if core.is_compiled_with_rocm():
COMMON_HIPCC_FLAGS = [
'-DPADDLE_WITH_HIP', '-DEIGEN_USE_GPU', '-DEIGEN_USE_HIP'
]
else:
COMMON_NVCC_FLAGS = ['-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU']
GCC_MINI_VERSION = (5, 4, 0) GCC_MINI_VERSION = (5, 4, 0)
MSVC_MINI_VERSION = (19, 0, 24215) MSVC_MINI_VERSION = (19, 0, 24215)
...@@ -319,10 +324,14 @@ def prepare_unix_cudaflags(cflags): ...@@ -319,10 +324,14 @@ def prepare_unix_cudaflags(cflags):
""" """
Prepare all necessary compiled flags for nvcc compiling CUDA files. Prepare all necessary compiled flags for nvcc compiling CUDA files.
""" """
cflags = COMMON_NVCC_FLAGS + [ if core.is_compiled_with_rocm():
'-ccbin', 'cc', '-Xcompiler', '-fPIC', '--expt-relaxed-constexpr', cflags = COMMON_HIPCC_FLAGS + ['-Xcompiler', '-fPIC'
'-DNVCC' ] + cflags + get_rocm_arch_flags(cflags)
] + cflags + get_cuda_arch_flags(cflags) else:
cflags = COMMON_NVCC_FLAGS + [
'-ccbin', 'cc', '-Xcompiler', '-fPIC', '--expt-relaxed-constexpr',
'-DNVCC'
] + cflags + get_cuda_arch_flags(cflags)
return cflags return cflags
...@@ -358,6 +367,14 @@ def get_cuda_arch_flags(cflags): ...@@ -358,6 +367,14 @@ def get_cuda_arch_flags(cflags):
return [] return []
def get_rocm_arch_flags(cflags):
"""
For ROCm platform, amdgpu target should be added for HIPCC.
"""
cflags = cflags + ['-fno-gpu-rdc', '-amdgpu-target=gfx906']
return cflags
def _get_fluid_path(): def _get_fluid_path():
""" """
Return installed fluid dir path. Return installed fluid dir path.
...@@ -471,7 +488,10 @@ def normalize_extension_kwargs(kwargs, use_cuda=False): ...@@ -471,7 +488,10 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
add_compile_flag(extra_compile_args, ['-w']) # disable warning add_compile_flag(extra_compile_args, ['-w']) # disable warning
if use_cuda: if use_cuda:
extra_link_args.append('-lcudart') if core.is_compiled_with_rocm():
extra_link_args.append('-lamdhip64')
else:
extra_link_args.append('-lcudart')
kwargs['extra_link_args'] = extra_link_args kwargs['extra_link_args'] = extra_link_args
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册