未验证 提交 307ad60d 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix custom ccl (#45276)

上级 bba13e21
...@@ -47,5 +47,14 @@ bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) { ...@@ -47,5 +47,14 @@ bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
}); });
} }
bool CheckTensorsInCustomPlace(const std::vector<phi::DenseTensor>& tensors,
const std::string& dev_type) {
return std::all_of(
tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) {
return platform::places_are_same_class(
t.place(), paddle::platform::CustomPlace(dev_type));
});
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -28,5 +28,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places); ...@@ -28,5 +28,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors); bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors);
bool CheckTensorsInCustomPlace(const std::vector<phi::DenseTensor>& tensors,
const std::string& dev_type);
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -207,10 +207,111 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective( ...@@ -207,10 +207,111 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
return task; return task;
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
void* XcclGetPointerByOffset(void* raw_pointer,
size_t offset,
experimental::DataType type) {
if (type == experimental::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in xccl is not supported."));
}
return nullptr;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
XcclGetPointerByOffset(input.data(), offset, input.dtype()),
output.data(),
length,
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& opts) { const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
...@@ -235,6 +336,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( ...@@ -235,6 +336,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts) { const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
......
...@@ -73,6 +73,16 @@ class ProcessGroupCustom : public ProcessGroup { ...@@ -73,6 +73,16 @@ class ProcessGroupCustom : public ProcessGroup {
return "XCCL_" + device_type_; return "XCCL_" + device_type_;
} }
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
......
...@@ -705,6 +705,7 @@ class CustomDevice : public DeviceInterface { ...@@ -705,6 +705,7 @@ class CustomDevice : public DeviceInterface {
size_t num, size_t num,
ccl::CCLDataType data_type, ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op, ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& comm, const ccl::CCLComm& comm,
const stream::Stream& stream) override { const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_reduce); CHECK_PTR(pimpl_->xccl_reduce);
...@@ -714,6 +715,7 @@ class CustomDevice : public DeviceInterface { ...@@ -714,6 +715,7 @@ class CustomDevice : public DeviceInterface {
num, num,
ToXCCLDataType(data_type), ToXCCLDataType(data_type),
ToXCCLReduceOp(reduce_op), ToXCCLReduceOp(reduce_op),
root_id,
reinterpret_cast<C_CCLComm>(comm), reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream()))); reinterpret_cast<C_Stream>(stream.raw_stream())));
} }
......
...@@ -203,6 +203,7 @@ void TestCustomCCL(const paddle::platform::Place& place) { ...@@ -203,6 +203,7 @@ void TestCustomCCL(const paddle::platform::Place& place) {
0, 0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
phi::ccl::CCLReduceOp::SUM, phi::ccl::CCLReduceOp::SUM,
0,
comm, comm,
stream); stream);
phi::DeviceManager::CCLAllGather(dev_type, phi::DeviceManager::CCLAllGather(dev_type,
......
...@@ -170,6 +170,7 @@ C_Status XcclReduce(void *send_buf, ...@@ -170,6 +170,7 @@ C_Status XcclReduce(void *send_buf,
size_t count, size_t count,
C_DataType data_type, C_DataType data_type,
C_CCLReduceOp op, C_CCLReduceOp op,
size_t root_id,
C_CCLComm comm, C_CCLComm comm,
C_Stream stream) { C_Stream stream) {
return C_SUCCESS; return C_SUCCESS;
......
...@@ -309,6 +309,7 @@ void DeviceInterface::CCLReduce(void* in_data, ...@@ -309,6 +309,7 @@ void DeviceInterface::CCLReduce(void* in_data,
size_t num, size_t num,
ccl::CCLDataType data_type, ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op, ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& ccl_comm, const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) { const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
......
...@@ -195,6 +195,7 @@ class DeviceInterface { // Driver / Runtime ...@@ -195,6 +195,7 @@ class DeviceInterface { // Driver / Runtime
size_t num, size_t num,
ccl::CCLDataType data_type, ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op, ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& ccl_comm, const ccl::CCLComm& ccl_comm,
const stream::Stream& stream); const stream::Stream& stream);
virtual void CCLAllGather(void* in_data, virtual void CCLAllGather(void* in_data,
......
...@@ -593,6 +593,7 @@ struct C_DeviceInterface { ...@@ -593,6 +593,7 @@ struct C_DeviceInterface {
size_t count, size_t count,
C_DataType data_type, C_DataType data_type,
C_CCLReduceOp op, C_CCLReduceOp op,
size_t root,
C_CCLComm comm, C_CCLComm comm,
C_Stream stream); C_Stream stream);
......
...@@ -536,11 +536,12 @@ void DeviceManager::CCLReduce(const std::string& device_type, ...@@ -536,11 +536,12 @@ void DeviceManager::CCLReduce(const std::string& device_type,
size_t num, size_t num,
ccl::CCLDataType data_type, ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op, ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& ccl_comm, const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) { const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type); auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLReduce( dev_impl->CCLReduce(
in_data, out_data, num, data_type, reduce_op, ccl_comm, stream); in_data, out_data, num, data_type, reduce_op, root_id, ccl_comm, stream);
} }
void DeviceManager::CCLAllGather(const std::string& device_type, void DeviceManager::CCLAllGather(const std::string& device_type,
......
...@@ -206,6 +206,7 @@ class DeviceManager { ...@@ -206,6 +206,7 @@ class DeviceManager {
size_t num, size_t num,
ccl::CCLDataType data_type, ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op, ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& ccl_comm, const ccl::CCLComm& ccl_comm,
const stream::Stream& stream); const stream::Stream& stream);
static void CCLAllGather(const std::string& device_type, static void CCLAllGather(const std::string& device_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册