未验证 提交 307ad60d 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix custom ccl (#45276)

上级 bba13e21
......@@ -47,5 +47,14 @@ bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
});
}
bool CheckTensorsInCustomPlace(const std::vector<phi::DenseTensor>& tensors,
const std::string& dev_type) {
return std::all_of(
tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) {
return platform::places_are_same_class(
t.place(), paddle::platform::CustomPlace(dev_type));
});
}
} // namespace distributed
} // namespace paddle
......@@ -28,5 +28,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors);
bool CheckTensorsInCustomPlace(const std::vector<phi::DenseTensor>& tensors,
const std::string& dev_type);
} // namespace distributed
} // namespace paddle
......@@ -207,10 +207,111 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
void* XcclGetPointerByOffset(void* raw_pointer,
size_t offset,
experimental::DataType type) {
if (type == experimental::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in xccl is not supported."));
}
return nullptr;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
XcclGetPointerByOffset(input.data(), offset, input.dtype()),
output.data(),
length,
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
......@@ -235,6 +336,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
......
......@@ -73,6 +73,16 @@ class ProcessGroupCustom : public ProcessGroup {
return "XCCL_" + device_type_;
}
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -705,6 +705,7 @@ class CustomDevice : public DeviceInterface {
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_reduce);
......@@ -714,6 +715,7 @@ class CustomDevice : public DeviceInterface {
num,
ToXCCLDataType(data_type),
ToXCCLReduceOp(reduce_op),
root_id,
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
......
......@@ -203,6 +203,7 @@ void TestCustomCCL(const paddle::platform::Place& place) {
0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
phi::ccl::CCLReduceOp::SUM,
0,
comm,
stream);
phi::DeviceManager::CCLAllGather(dev_type,
......
......@@ -170,6 +170,7 @@ C_Status XcclReduce(void *send_buf,
size_t count,
C_DataType data_type,
C_CCLReduceOp op,
size_t root_id,
C_CCLComm comm,
C_Stream stream) {
return C_SUCCESS;
......
......@@ -309,6 +309,7 @@ void DeviceInterface::CCLReduce(void* in_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
......
......@@ -195,6 +195,7 @@ class DeviceInterface { // Driver / Runtime
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
virtual void CCLAllGather(void* in_data,
......
......@@ -593,6 +593,7 @@ struct C_DeviceInterface {
size_t count,
C_DataType data_type,
C_CCLReduceOp op,
size_t root,
C_CCLComm comm,
C_Stream stream);
......
......@@ -536,11 +536,12 @@ void DeviceManager::CCLReduce(const std::string& device_type,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLReduce(
in_data, out_data, num, data_type, reduce_op, ccl_comm, stream);
in_data, out_data, num, data_type, reduce_op, root_id, ccl_comm, stream);
}
void DeviceManager::CCLAllGather(const std::string& device_type,
......
......@@ -206,6 +206,7 @@ class DeviceManager {
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
size_t root_id,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
static void CCLAllGather(const std::string& device_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册