From 307ad60db23e49ae96ccd82dfb7cbb8d947d4604 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Mon, 22 Aug 2022 10:21:18 +0800 Subject: [PATCH] [CustomDevice] fix custom ccl (#45276) --- paddle/fluid/distributed/collective/Common.cc | 9 ++ paddle/fluid/distributed/collective/Common.h | 3 + .../collective/ProcessGroupCustom.cc | 111 ++++++++++++++++++ .../collective/ProcessGroupCustom.h | 10 ++ paddle/phi/backends/custom/custom_device.cc | 2 + .../phi/backends/custom/custom_device_test.cc | 1 + paddle/phi/backends/custom/fake_cpu_device.h | 1 + paddle/phi/backends/device_base.cc | 1 + paddle/phi/backends/device_base.h | 1 + paddle/phi/backends/device_ext.h | 1 + paddle/phi/backends/device_manager.cc | 3 +- paddle/phi/backends/device_manager.h | 1 + 12 files changed, 143 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/distributed/collective/Common.cc b/paddle/fluid/distributed/collective/Common.cc index 4eba22da8b0..e9572f28d32 100644 --- a/paddle/fluid/distributed/collective/Common.cc +++ b/paddle/fluid/distributed/collective/Common.cc @@ -47,5 +47,14 @@ bool CheckTensorsInCudaPlace(const std::vector& tensors) { }); } +bool CheckTensorsInCustomPlace(const std::vector& tensors, + const std::string& dev_type) { + return std::all_of( + tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) { + return platform::places_are_same_class( + t.place(), paddle::platform::CustomPlace(dev_type)); + }); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/Common.h b/paddle/fluid/distributed/collective/Common.h index 4c6c42bd86d..8d5db886989 100644 --- a/paddle/fluid/distributed/collective/Common.h +++ b/paddle/fluid/distributed/collective/Common.h @@ -28,5 +28,8 @@ std::string GetKeyFromPlaces(const std::vector& places); bool CheckTensorsInCudaPlace(const std::vector& tensors); +bool CheckTensorsInCustomPlace(const std::vector& tensors, + const std::string& dev_type); + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc index 73a0b631eef..ad9356b368e 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc @@ -207,10 +207,111 @@ std::shared_ptr ProcessGroupCustom::Collective( return task; } +std::shared_ptr ProcessGroupCustom::AllGather( + std::vector& in_tensors, + std::vector& out_tensors) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + phi::ccl::CCLComm comm, + const phi::stream::Stream& stream) { + return phi::DeviceManager::CCLAllGather( + device_type_, + input.data(), + output.data(), + input.numel(), + phi::ccl::ToCCLDataType(input.dtype()), + comm, + stream); + }, + CommType::ALLGATHER); +} + +void* XcclGetPointerByOffset(void* raw_pointer, + size_t offset, + experimental::DataType type) { + if (type == experimental::DataType::FLOAT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::FLOAT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::INT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::INT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::FLOAT16) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "This datatype in xccl is not supported.")); + } + return nullptr; +} + +std::shared_ptr ProcessGroupCustom::AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + phi::ccl::CCLComm comm, + const phi::stream::Stream& stream) { + return phi::DeviceManager::CCLAllGather( + device_type_, + XcclGetPointerByOffset(input.data(), offset, input.dtype()), + output.data(), + length, + phi::ccl::ToCCLDataType(input.dtype()), + comm, + stream); + }, + CommType::ALLGATHER); +} + std::shared_ptr ProcessGroupCustom::AllReduce( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT const AllreduceOptions& opts) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); return Collective( in_tensors, out_tensors, @@ -235,6 +336,16 @@ std::shared_ptr ProcessGroupCustom::Broadcast( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT const BroadcastOptions& opts) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); return Collective( in_tensors, out_tensors, diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.h b/paddle/fluid/distributed/collective/ProcessGroupCustom.h index fca6f127c38..ccce66603af 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.h +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.h @@ -73,6 +73,16 @@ class ProcessGroupCustom : public ProcessGroup { return "XCCL_" + device_type_; } + std::shared_ptr AllGather( + std::vector& in_tensors, + std::vector& out_tensors) override; + + std::shared_ptr AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length) override; + std::shared_ptr AllReduce( std::vector& in_tensors, std::vector& out_tensors, diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 2567857bca1..75f5433a640 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -705,6 +705,7 @@ class CustomDevice : public DeviceInterface { size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& comm, const stream::Stream& stream) override { CHECK_PTR(pimpl_->xccl_reduce); @@ -714,6 +715,7 @@ class CustomDevice : public DeviceInterface { num, ToXCCLDataType(data_type), ToXCCLReduceOp(reduce_op), + root_id, reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); } diff --git a/paddle/phi/backends/custom/custom_device_test.cc b/paddle/phi/backends/custom/custom_device_test.cc index 425d7bde617..2458241c3c8 100644 --- a/paddle/phi/backends/custom/custom_device_test.cc +++ b/paddle/phi/backends/custom/custom_device_test.cc @@ -203,6 +203,7 @@ void TestCustomCCL(const paddle::platform::Place& place) { 0, phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, phi::ccl::CCLReduceOp::SUM, + 0, comm, stream); phi::DeviceManager::CCLAllGather(dev_type, diff --git a/paddle/phi/backends/custom/fake_cpu_device.h b/paddle/phi/backends/custom/fake_cpu_device.h index a4eaa834a60..1fcbeab89ab 100644 --- a/paddle/phi/backends/custom/fake_cpu_device.h +++ b/paddle/phi/backends/custom/fake_cpu_device.h @@ -170,6 +170,7 @@ C_Status XcclReduce(void *send_buf, size_t count, C_DataType data_type, C_CCLReduceOp op, + size_t root_id, C_CCLComm comm, C_Stream stream) { return C_SUCCESS; diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index 41871f69c77..fca6a32e4f8 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -309,6 +309,7 @@ void DeviceInterface::CCLReduce(void* in_data, size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { INTERFACE_UNIMPLEMENT; diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index b823a4a9832..e5bdc6c8126 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -195,6 +195,7 @@ class DeviceInterface { // Driver / Runtime size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); virtual void CCLAllGather(void* in_data, diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index 5bb5def9c2b..ca254f8235a 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -593,6 +593,7 @@ struct C_DeviceInterface { size_t count, C_DataType data_type, C_CCLReduceOp op, + size_t root, C_CCLComm comm, C_Stream stream); diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index dbdbce13d4f..224bd0a1ff1 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -536,11 +536,12 @@ void DeviceManager::CCLReduce(const std::string& device_type, size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLReduce( - in_data, out_data, num, data_type, reduce_op, ccl_comm, stream); + in_data, out_data, num, data_type, reduce_op, root_id, ccl_comm, stream); } void DeviceManager::CCLAllGather(const std::string& device_type, diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 6d621b6a432..fc8529e5813 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -206,6 +206,7 @@ class DeviceManager { size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); static void CCLAllGather(const std::string& device_type, -- GitLab