From 20db8602f6ba66268e6dfd81b9282072a5af4f0d Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 15 Jun 2023 09:56:50 +0800 Subject: [PATCH] [CustomDevice] add MOE support, PART1 (#54572) --- paddle/phi/backends/c_comm_lib.h | 19 ++++++ paddle/phi/backends/custom/custom_device.cc | 71 +++++++++++++++++++++ paddle/phi/backends/device_base.cc | 13 ++++ paddle/phi/backends/device_base.h | 10 +++ paddle/phi/backends/device_ext.h | 11 ++++ paddle/phi/backends/device_manager.cc | 24 +++++++ paddle/phi/backends/device_manager.h | 11 ++++ 7 files changed, 159 insertions(+) diff --git a/paddle/phi/backends/c_comm_lib.h b/paddle/phi/backends/c_comm_lib.h index e67530add58..e52e0fb5862 100644 --- a/paddle/phi/backends/c_comm_lib.h +++ b/paddle/phi/backends/c_comm_lib.h @@ -56,5 +56,24 @@ inline CCLDataType ToCCLDataType(phi::DataType type) { } } +inline phi::DataType ToPhiDataType(CCLDataType type) { + if (type == CCLDataType::CCL_DATA_TYPE_FP64) { + return phi::DataType::FLOAT64; + } else if (type == CCLDataType::CCL_DATA_TYPE_FP32) { + return phi::DataType::FLOAT32; + } else if (type == CCLDataType::CCL_DATA_TYPE_FP16) { + return phi::DataType::FLOAT16; + } else if (type == CCLDataType::CCL_DATA_TYPE_INT64) { + return phi::DataType::INT64; + } else if (type == CCLDataType::CCL_DATA_TYPE_INT32) { + return phi::DataType::INT32; + } else if (type == CCLDataType::CCL_DATA_TYPE_INT8) { + return phi::DataType::INT8; + } else { + PADDLE_THROW( + phi::errors::Unimplemented("This datatype in CCL is not supported.")); + } +} + } // namespace ccl } // namespace phi diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 46535ee8477..ed56b5d4ad6 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -802,6 +802,77 @@ class CustomDevice : public DeviceInterface { reinterpret_cast(stream.raw_stream()))); } + void CCLAllToAll(const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + if (pimpl_->xccl_all_to_all) { + std::vector c_send_dtype, c_recv_dtype; + for (size_t i = 0; i < nranks; ++i) { + c_send_dtype.push_back(ToXCCLDataType(send_dtype[i])); + c_recv_dtype.push_back(ToXCCLDataType(recv_dtype[i])); + } + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_to_all( + send_buf, + send_count, + c_send_dtype.data(), + recv_buf, + recv_count, + c_recv_dtype.data(), + rank, + nranks, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } else if (pimpl_->xccl_send && pimpl_->xccl_recv) { + // NOTE(wangran16): fallback to send and recv, while avoiding some devices + // not supporting asynchronous send and recv. + for (size_t i = 0; i < rank; ++i) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_recv(recv_buf[i], + recv_count[i], + ToXCCLDataType(recv_dtype[i]), + i, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + for (size_t i = 0; i < nranks; ++i) { + if (i != rank) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_send( + const_cast(send_buf[i]), + send_count[i], + ToXCCLDataType(send_dtype[i]), + i, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + } + MemoryCopyD2D(rank, + recv_buf[rank], + send_buf[rank], + send_count[rank] * + phi::SizeOf(phi::ccl::ToPhiDataType(send_dtype[rank])), + &stream); + for (size_t i = rank + 1; i < nranks; ++i) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_recv(recv_buf[i], + recv_count[i], + ToXCCLDataType(recv_dtype[i]), + i, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + } else { + PADDLE_THROW(phi::errors::Unavailable( + "CCLAllToAll is not supported on %s.", Type())); + } + } + void BlasAXPBY(size_t dev_id, const stream::Stream& stream, phi::DataType dtype, diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index 2a1bf8b0743..2b7d0411fed 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -356,6 +356,19 @@ void DeviceInterface::CCLRecv(void* recvbuf, INTERFACE_UNIMPLEMENT; } +void DeviceInterface::CCLAllToAll(const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + // blas void DeviceInterface::BlasAXPBY(size_t dev_id, const stream::Stream& stream, diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index f51d5302140..855e7789034 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -228,6 +228,16 @@ class DeviceInterface { // Driver / Runtime const ccl::CCLComm& ccl_comm, const stream::Stream& stream); + virtual void CCLAllToAll(const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream); // blas virtual void BlasAXPBY(size_t dev_id, const stream::Stream& stream, diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index 4563160e335..bd3f5f687f2 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -638,6 +638,17 @@ struct C_DeviceInterface { C_CCLComm comm, C_Stream stream); + C_Status (*xccl_all_to_all)(const void** send_buf, + const size_t* send_count, + const C_DataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const C_DataType* recv_dtype, + size_t rank, + size_t nranks, + C_CCLComm comm, + C_Stream stream); + void* reserved_ccl_api[8]; ////////////////// diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index c95616150d3..4d01b2aec4d 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -596,6 +596,30 @@ void DeviceManager::CCLRecv(const std::string& device_type, dev_impl->CCLRecv(recvbuf, num, data_type, src_rank, ccl_comm, stream); } +void DeviceManager::CCLAllToAll(const std::string& device_type, + const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLAllToAll(send_buf, + send_count, + send_dtype, + recv_buf, + recv_count, + recv_dtype, + rank, + nranks, + comm, + stream); +} + // profiler void DeviceManager::ProfilerInitialize(const std::string& dev_type, phi::TraceEventCollector* collector, diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 420c2872037..52496f54647 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -241,6 +241,17 @@ class DeviceManager { const ccl::CCLComm& ccl_comm, const stream::Stream& stream); + static void CCLAllToAll(const std::string& device_type, + const void** send_buf, + const size_t* send_count, + const ccl::CCLDataType* send_dtype, + void** recv_buf, + const size_t* recv_count, + const ccl::CCLDataType* recv_dtype, + size_t rank, + size_t nranks, + const ccl::CCLComm& comm, + const stream::Stream& stream); // profiler static void ProfilerInitialize(const std::string& dev_type, phi::TraceEventCollector* collector, -- GitLab