diff --git a/paddle/phi/backends/c_comm_lib.h b/paddle/phi/backends/c_comm_lib.h index e67530add58da126b198f773fd4f831f7b9bb418..e52e0fb586295fc8270ee2557623e3e5306906c0 100644 --- a/paddle/phi/backends/c_comm_lib.h +++ b/paddle/phi/backends/c_comm_lib.h @@ -56,5 +56,24 @@ inline CCLDataType ToCCLDataType(phi::DataType type) { } } +inline phi::DataType ToPhiDataType(CCLDataType type) { + if (type == CCLDataType::CCL_DATA_TYPE_FP64) { + return phi::DataType::FLOAT64; + } else if (type == CCLDataType::CCL_DATA_TYPE_FP32) { + return phi::DataType::FLOAT32; + } else if (type == CCLDataType::CCL_DATA_TYPE_FP16) { + return phi::DataType::FLOAT16; + } else if (type == CCLDataType::CCL_DATA_TYPE_INT64) { + return phi::DataType::INT64; + } else if (type == CCLDataType::CCL_DATA_TYPE_INT32) { + return phi::DataType::INT32; + } else if (type == CCLDataType::CCL_DATA_TYPE_INT8) { + return phi::DataType::INT8; + } else { + PADDLE_THROW( + phi::errors::Unimplemented("This datatype in CCL is not supported.")); + } +} + } // namespace ccl } // namespace phi diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 46535ee84777a3d88b28f2efaa92b9d16c9a9897..ed56b5d4ad6a8086542f9ab45bea16f086263369 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -802,6 +802,77 @@ class CustomDevice : public DeviceInterface { reinterpret_cast(stream.raw_stream()))); } + void CCLAllToAll(const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + if (pimpl_->xccl_all_to_all) { + std::vector c_send_dtype, c_recv_dtype; + for (size_t i = 0; i < nranks; ++i) { + c_send_dtype.push_back(ToXCCLDataType(send_dtype[i])); + c_recv_dtype.push_back(ToXCCLDataType(recv_dtype[i])); + } + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_to_all( + send_buf, + send_count, + c_send_dtype.data(), + recv_buf, + recv_count, + c_recv_dtype.data(), + rank, + nranks, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } else if (pimpl_->xccl_send && pimpl_->xccl_recv) { + // NOTE(wangran16): fallback to send and recv, while avoiding some devices + // not supporting asynchronous send and recv. + for (size_t i = 0; i < rank; ++i) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_recv(recv_buf[i], + recv_count[i], + ToXCCLDataType(recv_dtype[i]), + i, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + for (size_t i = 0; i < nranks; ++i) { + if (i != rank) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_send( + const_cast(send_buf[i]), + send_count[i], + ToXCCLDataType(send_dtype[i]), + i, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + } + MemoryCopyD2D(rank, + recv_buf[rank], + send_buf[rank], + send_count[rank] * + phi::SizeOf(phi::ccl::ToPhiDataType(send_dtype[rank])), + &stream); + for (size_t i = rank + 1; i < nranks; ++i) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_recv(recv_buf[i], + recv_count[i], + ToXCCLDataType(recv_dtype[i]), + i, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + } else { + PADDLE_THROW(phi::errors::Unavailable( + "CCLAllToAll is not supported on %s.", Type())); + } + } + void BlasAXPBY(size_t dev_id, const stream::Stream& stream, phi::DataType dtype, diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index 2a1bf8b07437c1c4305d739ff59b439b813c5fbc..2b7d0411fedcabfe4276025f6760af5595a34352 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -356,6 +356,19 @@ void DeviceInterface::CCLRecv(void* recvbuf, INTERFACE_UNIMPLEMENT; } +void DeviceInterface::CCLAllToAll(const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + // blas void DeviceInterface::BlasAXPBY(size_t dev_id, const stream::Stream& stream, diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index f51d5302140f8efc71700f4f99b073ca5a4c76d2..855e77890348ae048f6c5255119f5f8754fb2a8a 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -228,6 +228,16 @@ class DeviceInterface { // Driver / Runtime const ccl::CCLComm& ccl_comm, const stream::Stream& stream); + virtual void CCLAllToAll(const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream); // blas virtual void BlasAXPBY(size_t dev_id, const stream::Stream& stream, diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index 4563160e335a1f0fdf576143ab3c982889ff4758..bd3f5f687f29b130994c5d4e153bbb3a49121873 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -638,6 +638,17 @@ struct C_DeviceInterface { C_CCLComm comm, C_Stream stream); + C_Status (*xccl_all_to_all)(const void** send_buf, + const size_t* send_count, + const C_DataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const C_DataType* recv_dtype, + size_t rank, + size_t nranks, + C_CCLComm comm, + C_Stream stream); + void* reserved_ccl_api[8]; ////////////////// diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index c95616150d3a6392e65d392ad7427902ec1c06a0..4d01b2aec4dcacc0a2f587a5bb4ced40c10f924e 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -596,6 +596,30 @@ void DeviceManager::CCLRecv(const std::string& device_type, dev_impl->CCLRecv(recvbuf, num, data_type, src_rank, ccl_comm, stream); } +void DeviceManager::CCLAllToAll(const std::string& device_type, + const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLAllToAll(send_buf, + send_count, + send_dtype, + recv_buf, + recv_count, + recv_dtype, + rank, + nranks, + comm, + stream); +} + // profiler void DeviceManager::ProfilerInitialize(const std::string& dev_type, phi::TraceEventCollector* collector, diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 420c28720376032b5c9998839fa2547ec046c504..52496f54647e33523654d9bdd5ac7a6d11bd308a 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -241,6 +241,17 @@ class DeviceManager { const ccl::CCLComm& ccl_comm, const stream::Stream& stream); + static void CCLAllToAll(const std::string& device_type, + const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream); // profiler static void ProfilerInitialize(const std::string& dev_type, phi::TraceEventCollector* collector,