未验证 提交 62840afa 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Add DCU backend support for custom ops (#34050)

* Add DCU backend support for custom ops

* Added checks for DeviceCopy and renamed some macros
上级 9e18114f
......@@ -17,6 +17,6 @@ limitations under the License. */
namespace paddle {
// TODO(yangjiabin): Add other place support in next PR
enum class PlaceType { kUNK = -1, kCPU, kGPU };
enum class PlaceType { kUNK = -1, kCPU, kGPU, kHIP };
} // namespace paddle
......@@ -116,9 +116,11 @@ class PD_DLL_DECL Tensor {
/// \brief Check Tensor is initialized
bool is_initialized() const;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA)
/// \bref Get current stream of Tensor
cudaStream_t stream() const;
#elif defined(PADDLE_WITH_HIP)
hipStream_t stream() const;
#endif
private:
......
......@@ -53,7 +53,7 @@ struct CastDataType {
auto *context = static_cast<const platform::CPUDeviceContext *>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
#ifdef __NVCC__
#if defined(__NVCC__) || defined(__HIPCC__)
} else if (platform::is_gpu_place(in_.place())) {
platform::Transform<platform::CUDADeviceContext> trans;
auto *context = static_cast<const platform::CUDADeviceContext *>(ctx_);
......@@ -67,10 +67,11 @@ struct CastDataType {
}
}
};
template <typename T>
void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
int64_t ele_size) {
#ifdef PADDLE_WITH_CUDA
void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
int64_t ele_size) {
#if defined(PADDLE_WITH_CUDA)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = paddle::platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
......@@ -90,6 +91,30 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
"Only GPU related Copy can reach this func."));
}
cudaStreamSynchronize(dev_ctx->stream());
#elif defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = paddle::platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
if ((src_plc == PlaceType::kHIP) && (dst_plc == PlaceType::kCPU)) {
memory::Copy(platform::CPUPlace(), static_cast<void *>(dst), gpu_place, src,
ele_size, dev_ctx->stream());
} else if ((src_plc == PlaceType::kHIP) && (dst_plc == PlaceType::kHIP)) {
memory::Copy(gpu_place, static_cast<void *>(dst), gpu_place, src, ele_size,
dev_ctx->stream());
} else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kHIP)) {
memory::Copy(gpu_place, static_cast<void *>(dst), platform::CPUPlace(), src,
ele_size, dev_ctx->stream());
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Only GPU related Copy can reach this func."));
}
hipStreamSynchronize(dev_ctx->stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
"This function can only be used if compiled with"
"either -DWITH_ROCM=ON or -DWITH_GPU=ON"));
#endif
}
......@@ -137,11 +162,16 @@ T *Tensor::mutable_data() {
case static_cast<int>(PlaceType::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace());
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA)
case static_cast<int>(PlaceType::kGPU): {
int device_num = platform::GetCurrentDeviceId();
return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
}
#elif defined(PADDLE_WITH_HIP)
case static_cast<int>(PlaceType::kHIP): {
int device_num = platform::GetCurrentDeviceId();
return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
}
#endif
default:
PADDLE_THROW(platform::errors::Unavailable(
......@@ -202,17 +232,23 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>();
bool supported_gpu_transform = false;
#if defined(PADDLE_WITH_CUDA)
supported_gpu_transform =
(src_place == PlaceType::kGPU && target_place == PlaceType::kCPU) ||
(src_place == PlaceType::kCPU && target_place == PlaceType::kGPU) ||
(src_place == PlaceType::kGPU && target_place == PlaceType::kGPU);
#elif defined(PADDLE_WITH_HIP)
supported_gpu_transform =
(src_place == PlaceType::kHIP && target_place == PlaceType::kCPU) ||
(src_place == PlaceType::kCPU && target_place == PlaceType::kHIP) ||
(src_place == PlaceType::kHIP && target_place == PlaceType::kHIP);
#endif
if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) {
std::memcpy(static_cast<void *>(p_target_data), p_src_data, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kCPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kCPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if (supported_gpu_transform) {
DeviceCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Not supported place transform of place: %d to place: %d",
......@@ -304,13 +340,18 @@ const PlaceType &Tensor::place() const {
GET_CASTED_TENSOR;
if (platform::is_cpu_place(tensor->place())) {
place_ = PlaceType::kCPU;
#if defined(PADDLE_WITH_CUDA)
} else if (platform::is_gpu_place(tensor->place())) {
place_ = PlaceType::kGPU;
#elif defined(PADDLE_WITH_HIP)
} else if (platform::is_gpu_place(tensor->place())) {
place_ = PlaceType::kHIP;
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Current Tensor hold unsupported Place Type, Please Init it"
"using Tensor::mutable_data<T>(PaddlePlace) which T is"
"either Place::kCPU or Place::kGPU"));
"using Tensor::mutable_data<T>(PaddlePlace) with T among:"
"Place::kCPU or Place::kGPU or Place::kHIP"));
}
return place_;
}
......@@ -392,16 +433,21 @@ bool Tensor::is_initialized() const {
}
}
#ifdef PADDLE_WITH_CUDA
cudaStream_t Tensor::stream() const {
if (!stream_.IsStreamSet()) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Stream is not Set, only input tensor will have "
"stream which is set by framework "));
} else {
return reinterpret_cast<cudaStream_t>(stream_.GetStream());
#define DEFINE_STREAM(_stream_t_) \
_stream_t_ Tensor::stream() const { \
if (!stream_.IsStreamSet()) { \
PADDLE_THROW(platform::errors::PreconditionNotMet( \
"Stream is not Set, only input tensor will have " \
"stream which is set by framework ")); \
} else { \
return reinterpret_cast<_stream_t_>(stream_.GetStream()); \
} \
}
}
#if defined(PADDLE_WITH_CUDA)
DEFINE_STREAM(cudaStream_t)
#elif defined(PADDLE_WITH_HIP)
DEFINE_STREAM(hipStream_t)
#endif
namespace framework {
......
......@@ -403,10 +403,20 @@ configure_file(commit.h.in commit.h)
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../extension/include)
cc_library(custom_tensor SRCS ../extension/src/ext_tensor.cc DEPS lod_tensor memory enforce)
if(WITH_ROCM)
hip_library(custom_tensor SRCS ../extension/src/ext_tensor.cc DEPS lod_tensor memory enforce)
else()
cc_library(custom_tensor SRCS ../extension/src/ext_tensor.cc DEPS lod_tensor memory enforce)
endif()
cc_library(op_meta_info SRCS ../extension/src/ext_op_meta_info.cc DEPS custom_tensor)
cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info)
cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog)
if(WITH_ROCM)
hip_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog)
else()
cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog)
endif()
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator)
......
......@@ -45,7 +45,19 @@ void TestCopyTensor() {
auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kGPU);
CHECK((paddle::PlaceType::kGPU == t1_gpu_cp_cp.place()));
auto t1_gpu_cp_cp_cpu =
t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kCPU);
t1_gpu_cp_cp.template copy_to<T>(paddle::PlaceType::kCPU);
CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place()));
for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], T(5));
}
#elif defined(PADDLE_WITH_HIP)
VLOG(2) << "Do HIP copy test";
auto t1_gpu_cp = t1_cpu_cp.template copy_to<T>(paddle::PlaceType::kHIP);
CHECK((paddle::PlaceType::kHIP == t1_gpu_cp.place()));
auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kHIP);
CHECK((paddle::PlaceType::kHIP == t1_gpu_cp_cp.place()));
auto t1_gpu_cp_cp_cpu =
t1_gpu_cp_cp.template copy_to<T>(paddle::PlaceType::kCPU);
CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place()));
for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], T(5));
......@@ -60,6 +72,11 @@ void TestAPIPlace() {
t1.reshape(tensor_shape);
t1.mutable_data<float>();
CHECK((paddle::PlaceType::kGPU == t1.place()));
#elif defined(PADDLE_WITH_HIP)
auto t1 = paddle::Tensor(paddle::PlaceType::kHIP);
t1.reshape(tensor_shape);
t1.mutable_data<float>();
CHECK((paddle::PlaceType::kHIP == t1.place()));
#endif
auto t2 = paddle::Tensor(paddle::PlaceType::kCPU);
t2.reshape(tensor_shape);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册