diff --git a/paddle/fluid/distributed/collective/Common.cc b/paddle/fluid/distributed/collective/Common.cc index 4eba22da8b030583fb6b6fd2b393c996d19c269c..e9572f28d32824c4e6d61b8afa0ee96d4e6d86a3 100644 --- a/paddle/fluid/distributed/collective/Common.cc +++ b/paddle/fluid/distributed/collective/Common.cc @@ -47,5 +47,14 @@ bool CheckTensorsInCudaPlace(const std::vector& tensors) { }); } +bool CheckTensorsInCustomPlace(const std::vector& tensors, + const std::string& dev_type) { + return std::all_of( + tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) { + return platform::places_are_same_class( + t.place(), paddle::platform::CustomPlace(dev_type)); + }); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/Common.h b/paddle/fluid/distributed/collective/Common.h index 4c6c42bd86d636ef8838c4e33b23e9ac8e1dec2a..8d5db886989fc2851ec8b84c2849c00823258749 100644 --- a/paddle/fluid/distributed/collective/Common.h +++ b/paddle/fluid/distributed/collective/Common.h @@ -28,5 +28,8 @@ std::string GetKeyFromPlaces(const std::vector& places); bool CheckTensorsInCudaPlace(const std::vector& tensors); +bool CheckTensorsInCustomPlace(const std::vector& tensors, + const std::string& dev_type); + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc index 73a0b631eefc353cf66bb92883dee966efd81025..ad9356b368ea264dafde34393c945cf522d63210 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc @@ -207,10 +207,111 @@ std::shared_ptr ProcessGroupCustom::Collective( return task; } +std::shared_ptr ProcessGroupCustom::AllGather( + std::vector& in_tensors, + std::vector& out_tensors) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + phi::ccl::CCLComm comm, + const phi::stream::Stream& stream) { + return phi::DeviceManager::CCLAllGather( + device_type_, + input.data(), + output.data(), + input.numel(), + phi::ccl::ToCCLDataType(input.dtype()), + comm, + stream); + }, + CommType::ALLGATHER); +} + +void* XcclGetPointerByOffset(void* raw_pointer, + size_t offset, + experimental::DataType type) { + if (type == experimental::DataType::FLOAT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::FLOAT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::INT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::INT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::FLOAT16) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "This datatype in xccl is not supported.")); + } + return nullptr; +} + +std::shared_ptr ProcessGroupCustom::AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + phi::ccl::CCLComm comm, + const phi::stream::Stream& stream) { + return phi::DeviceManager::CCLAllGather( + device_type_, + XcclGetPointerByOffset(input.data(), offset, input.dtype()), + output.data(), + length, + phi::ccl::ToCCLDataType(input.dtype()), + comm, + stream); + }, + CommType::ALLGATHER); +} + std::shared_ptr ProcessGroupCustom::AllReduce( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT const AllreduceOptions& opts) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); return Collective( in_tensors, out_tensors, @@ -235,6 +336,16 @@ std::shared_ptr ProcessGroupCustom::Broadcast( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT const BroadcastOptions& opts) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); return Collective( in_tensors, out_tensors, diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.h b/paddle/fluid/distributed/collective/ProcessGroupCustom.h index fca6f127c3806cd92e0736b2f6be87c5c3e5de4f..ccce66603afe69b96fdd11f3e575373284966cc9 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.h +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.h @@ -73,6 +73,16 @@ class ProcessGroupCustom : public ProcessGroup { return "XCCL_" + device_type_; } + std::shared_ptr AllGather( + std::vector& in_tensors, + std::vector& out_tensors) override; + + std::shared_ptr AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length) override; + std::shared_ptr AllReduce( std::vector& in_tensors, std::vector& out_tensors, diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 2567857bca1cad0596dcf693979c511bb55f4460..75f5433a64012f9462da70e2dfce6faa8724f360 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -705,6 +705,7 @@ class CustomDevice : public DeviceInterface { size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& comm, const stream::Stream& stream) override { CHECK_PTR(pimpl_->xccl_reduce); @@ -714,6 +715,7 @@ class CustomDevice : public DeviceInterface { num, ToXCCLDataType(data_type), ToXCCLReduceOp(reduce_op), + root_id, reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); } diff --git a/paddle/phi/backends/custom/custom_device_test.cc b/paddle/phi/backends/custom/custom_device_test.cc index 425d7bde6173ccab94da9685af4f63b66e78176c..2458241c3c85ddfba1a7f4f5aa69d033ccf622c0 100644 --- a/paddle/phi/backends/custom/custom_device_test.cc +++ b/paddle/phi/backends/custom/custom_device_test.cc @@ -203,6 +203,7 @@ void TestCustomCCL(const paddle::platform::Place& place) { 0, phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, phi::ccl::CCLReduceOp::SUM, + 0, comm, stream); phi::DeviceManager::CCLAllGather(dev_type, diff --git a/paddle/phi/backends/custom/fake_cpu_device.h b/paddle/phi/backends/custom/fake_cpu_device.h index a4eaa834a60f9843a14f56c2d58173a93f2e0a54..1fcbeab89ab30e1a8db64ecac933fe77fecd4f0f 100644 --- a/paddle/phi/backends/custom/fake_cpu_device.h +++ b/paddle/phi/backends/custom/fake_cpu_device.h @@ -170,6 +170,7 @@ C_Status XcclReduce(void *send_buf, size_t count, C_DataType data_type, C_CCLReduceOp op, + size_t root_id, C_CCLComm comm, C_Stream stream) { return C_SUCCESS; diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index 41871f69c7790458cad21e92fe0dc6209f7ebb61..fca6a32e4f888ecee3f33b350324a30af7ed2a04 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -309,6 +309,7 @@ void DeviceInterface::CCLReduce(void* in_data, size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { INTERFACE_UNIMPLEMENT; diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index b823a4a983207c4ea902ad1580eeb48eeeb068af..e5bdc6c81268dd7f9a9d1f29d97ca2c17e458839 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -195,6 +195,7 @@ class DeviceInterface { // Driver / Runtime size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); virtual void CCLAllGather(void* in_data, diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index 5bb5def9c2b19866910817df7c5f465a72b02970..ca254f8235a0c69f6ee554bf84f7b3e6c5e90f5d 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -593,6 +593,7 @@ struct C_DeviceInterface { size_t count, C_DataType data_type, C_CCLReduceOp op, + size_t root, C_CCLComm comm, C_Stream stream); diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index dbdbce13d4f40122027a0737479eb4fbf3630b54..224bd0a1ff1f3093afc86fdd029ccd9de595557d 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -536,11 +536,12 @@ void DeviceManager::CCLReduce(const std::string& device_type, size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLReduce( - in_data, out_data, num, data_type, reduce_op, ccl_comm, stream); + in_data, out_data, num, data_type, reduce_op, root_id, ccl_comm, stream); } void DeviceManager::CCLAllGather(const std::string& device_type, diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 6d621b6a43223919c74815a799b4d21177d66736..fc8529e5813f301dfb69b77cfc9d09b5e19e43c3 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -206,6 +206,7 @@ class DeviceManager { size_t num, ccl::CCLDataType data_type, ccl::CCLReduceOp reduce_op, + size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); static void CCLAllGather(const std::string& device_type,