未验证 提交 20db8602 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add MOE support, PART1 (#54572)

上级 3261b106
...@@ -56,5 +56,24 @@ inline CCLDataType ToCCLDataType(phi::DataType type) { ...@@ -56,5 +56,24 @@ inline CCLDataType ToCCLDataType(phi::DataType type) {
} }
} }
inline phi::DataType ToPhiDataType(CCLDataType type) {
if (type == CCLDataType::CCL_DATA_TYPE_FP64) {
return phi::DataType::FLOAT64;
} else if (type == CCLDataType::CCL_DATA_TYPE_FP32) {
return phi::DataType::FLOAT32;
} else if (type == CCLDataType::CCL_DATA_TYPE_FP16) {
return phi::DataType::FLOAT16;
} else if (type == CCLDataType::CCL_DATA_TYPE_INT64) {
return phi::DataType::INT64;
} else if (type == CCLDataType::CCL_DATA_TYPE_INT32) {
return phi::DataType::INT32;
} else if (type == CCLDataType::CCL_DATA_TYPE_INT8) {
return phi::DataType::INT8;
} else {
PADDLE_THROW(
phi::errors::Unimplemented("This datatype in CCL is not supported."));
}
}
} // namespace ccl } // namespace ccl
} // namespace phi } // namespace phi
...@@ -802,6 +802,77 @@ class CustomDevice : public DeviceInterface { ...@@ -802,6 +802,77 @@ class CustomDevice : public DeviceInterface {
reinterpret_cast<C_Stream>(stream.raw_stream()))); reinterpret_cast<C_Stream>(stream.raw_stream())));
} }
void CCLAllToAll(const void** send_buf,
const size_t* send_count,
const ccl::CCLDataType* send_dtype,
void** recv_buf,
const size_t* recv_count,
const ccl::CCLDataType* recv_dtype,
size_t rank,
size_t nranks,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
if (pimpl_->xccl_all_to_all) {
std::vector<C_DataType> c_send_dtype, c_recv_dtype;
for (size_t i = 0; i < nranks; ++i) {
c_send_dtype.push_back(ToXCCLDataType(send_dtype[i]));
c_recv_dtype.push_back(ToXCCLDataType(recv_dtype[i]));
}
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_to_all(
send_buf,
send_count,
c_send_dtype.data(),
recv_buf,
recv_count,
c_recv_dtype.data(),
rank,
nranks,
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
} else if (pimpl_->xccl_send && pimpl_->xccl_recv) {
// NOTE(wangran16): fallback to send and recv, while avoiding some devices
// not supporting asynchronous send and recv.
for (size_t i = 0; i < rank; ++i) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->xccl_recv(recv_buf[i],
recv_count[i],
ToXCCLDataType(recv_dtype[i]),
i,
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
for (size_t i = 0; i < nranks; ++i) {
if (i != rank) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_send(
const_cast<void*>(send_buf[i]),
send_count[i],
ToXCCLDataType(send_dtype[i]),
i,
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
}
MemoryCopyD2D(rank,
recv_buf[rank],
send_buf[rank],
send_count[rank] *
phi::SizeOf(phi::ccl::ToPhiDataType(send_dtype[rank])),
&stream);
for (size_t i = rank + 1; i < nranks; ++i) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->xccl_recv(recv_buf[i],
recv_count[i],
ToXCCLDataType(recv_dtype[i]),
i,
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
} else {
PADDLE_THROW(phi::errors::Unavailable(
"CCLAllToAll is not supported on %s.", Type()));
}
}
void BlasAXPBY(size_t dev_id, void BlasAXPBY(size_t dev_id,
const stream::Stream& stream, const stream::Stream& stream,
phi::DataType dtype, phi::DataType dtype,
......
...@@ -356,6 +356,19 @@ void DeviceInterface::CCLRecv(void* recvbuf, ...@@ -356,6 +356,19 @@ void DeviceInterface::CCLRecv(void* recvbuf,
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
} }
void DeviceInterface::CCLAllToAll(const void** send_buf,
const size_t* send_count,
const ccl::CCLDataType* send_dtype,
void** recv_buf,
const size_t* recv_count,
const ccl::CCLDataType* recv_dtype,
size_t rank,
size_t nranks,
const ccl::CCLComm& comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
}
// blas // blas
void DeviceInterface::BlasAXPBY(size_t dev_id, void DeviceInterface::BlasAXPBY(size_t dev_id,
const stream::Stream& stream, const stream::Stream& stream,
......
...@@ -228,6 +228,16 @@ class DeviceInterface { // Driver / Runtime ...@@ -228,6 +228,16 @@ class DeviceInterface { // Driver / Runtime
const ccl::CCLComm& ccl_comm, const ccl::CCLComm& ccl_comm,
const stream::Stream& stream); const stream::Stream& stream);
virtual void CCLAllToAll(const void** send_buf,
const size_t* send_count,
const ccl::CCLDataType* send_dtype,
void** recv_buf,
const size_t* recv_count,
const ccl::CCLDataType* recv_dtype,
size_t rank,
size_t nranks,
const ccl::CCLComm& comm,
const stream::Stream& stream);
// blas // blas
virtual void BlasAXPBY(size_t dev_id, virtual void BlasAXPBY(size_t dev_id,
const stream::Stream& stream, const stream::Stream& stream,
......
...@@ -638,6 +638,17 @@ struct C_DeviceInterface { ...@@ -638,6 +638,17 @@ struct C_DeviceInterface {
C_CCLComm comm, C_CCLComm comm,
C_Stream stream); C_Stream stream);
C_Status (*xccl_all_to_all)(const void** send_buf,
const size_t* send_count,
const C_DataType* send_dtype,
void** recv_buf,
const size_t* recv_count,
const C_DataType* recv_dtype,
size_t rank,
size_t nranks,
C_CCLComm comm,
C_Stream stream);
void* reserved_ccl_api[8]; void* reserved_ccl_api[8];
////////////////// //////////////////
......
...@@ -596,6 +596,30 @@ void DeviceManager::CCLRecv(const std::string& device_type, ...@@ -596,6 +596,30 @@ void DeviceManager::CCLRecv(const std::string& device_type,
dev_impl->CCLRecv(recvbuf, num, data_type, src_rank, ccl_comm, stream); dev_impl->CCLRecv(recvbuf, num, data_type, src_rank, ccl_comm, stream);
} }
void DeviceManager::CCLAllToAll(const std::string& device_type,
const void** send_buf,
const size_t* send_count,
const ccl::CCLDataType* send_dtype,
void** recv_buf,
const size_t* recv_count,
const ccl::CCLDataType* recv_dtype,
size_t rank,
size_t nranks,
const ccl::CCLComm& comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLAllToAll(send_buf,
send_count,
send_dtype,
recv_buf,
recv_count,
recv_dtype,
rank,
nranks,
comm,
stream);
}
// profiler // profiler
void DeviceManager::ProfilerInitialize(const std::string& dev_type, void DeviceManager::ProfilerInitialize(const std::string& dev_type,
phi::TraceEventCollector* collector, phi::TraceEventCollector* collector,
......
...@@ -241,6 +241,17 @@ class DeviceManager { ...@@ -241,6 +241,17 @@ class DeviceManager {
const ccl::CCLComm& ccl_comm, const ccl::CCLComm& ccl_comm,
const stream::Stream& stream); const stream::Stream& stream);
static void CCLAllToAll(const std::string& device_type,
const void** send_buf,
const size_t* send_count,
const ccl::CCLDataType* send_dtype,
void** recv_buf,
const size_t* recv_count,
const ccl::CCLDataType* recv_dtype,
size_t rank,
size_t nranks,
const ccl::CCLComm& comm,
const stream::Stream& stream);
// profiler // profiler
static void ProfilerInitialize(const std::string& dev_type, static void ProfilerInitialize(const std::string& dev_type,
phi::TraceEventCollector* collector, phi::TraceEventCollector* collector,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册