未验证 提交 d88e77a7 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add custom ccl 1/2 (#44294)

* [CustomDevice] add custom ccl api

* add ut
上级 c446ab7b
......@@ -180,6 +180,15 @@ FUNCTION_SET_DEVICE_TEMPLATE = \
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
}}
if (paddle::platform::is_custom_place(place)) {{
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::DeviceManager::SetDevice(place);
VLOG(1) <<"CurrentDeviceId: " << phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from " << (int)place.device;
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace."));
#endif
}}
"""
......@@ -200,6 +209,7 @@ PYTHON_C_WRAPPER_TEMPLATE = \
#include <Python.h>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/include/strings_api.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
......
......@@ -764,7 +764,7 @@ class BuddyAllocatorList {
private:
explicit BuddyAllocatorList(const std::string &device_type)
: device_type_(device_type) {
auto devices = phi::DeviceManager::GetDeviceList(device_type);
auto devices = phi::DeviceManager::GetSelectedDeviceList(device_type);
for (auto dev_id : devices) {
init_flags_[dev_id].reset(new std::once_flag());
}
......
......@@ -264,11 +264,11 @@ void InitDevices(const std::vector<int> devices) {
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (auto &dev_type : device_types) {
auto device_count = phi::DeviceManager::GetDeviceCount(dev_type);
auto device_list = phi::DeviceManager::GetSelectedDeviceList(dev_type);
LOG(INFO) << "CustomDevice: " << dev_type
<< ", visible devices count: " << device_count;
for (size_t i = 0; i < device_count; i++) {
places.push_back(platform::CustomPlace(dev_type, i));
<< ", visible devices count: " << device_list.size();
for (auto &dev_id : device_list) {
places.push_back(platform::CustomPlace(dev_type, dev_id));
}
}
} else {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/c_comm_lib.h"
namespace phi {
// Even this source file does not contains any code, it is better to keep this
// source file for cmake dependency.
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/macros.h"
namespace phi {
namespace ccl {
using CCLComm = void*;
using CCLRootId = std::vector<uint8_t>;
enum CCLReduceOp { SUM = 0, AVG, MAX, MIN, PRODUCT };
enum CCLDataType {
CCL_DATA_TYPE_FP64 = 0,
CCL_DATA_TYPE_FP32,
CCL_DATA_TYPE_FP16,
CCL_DATA_TYPE_INT64,
CCL_DATA_TYPE_INT32,
CCL_DATA_TYPE_INT16,
CCL_DATA_TYPE_INT8
};
inline CCLDataType ToCCLDataType(paddle::experimental::DataType type) {
if (type == paddle::experimental::DataType::FLOAT64) {
return CCL_DATA_TYPE_FP64;
} else if (type == paddle::experimental::DataType::FLOAT32) {
return CCL_DATA_TYPE_FP32;
} else if (type == paddle::experimental::DataType::FLOAT16) {
return CCL_DATA_TYPE_FP16;
} else if (type == paddle::experimental::DataType::INT64) {
return CCL_DATA_TYPE_INT64;
} else if (type == paddle::experimental::DataType::INT32) {
return CCL_DATA_TYPE_INT32;
} else if (type == paddle::experimental::DataType::INT8) {
return CCL_DATA_TYPE_INT8;
} else {
PADDLE_THROW(
phi::errors::Unimplemented("This datatype in CCL is not supported."));
}
}
} // namespace ccl
} // namespace phi
......@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/device_guard.h"
namespace phi {
......@@ -33,12 +34,13 @@ void CallbackManager::AddCallback(std::function<void()> callback) const {
(*callback_func)();
});
});
phi::DeviceGuard guard(stream_->GetPlace());
phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
->AddCallback(stream_, func);
}
void CallbackManager::Wait() const {
phi::DeviceGuard guard(stream_->GetPlace());
phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
->SynchronizeStream(stream_);
......
......@@ -27,6 +27,14 @@ static bool operator==(const C_Device_st& d1, const C_Device_st& d2) {
namespace phi {
#define INTERFACE_UNIMPLEMENT \
PADDLE_THROW(phi::errors::Unimplemented( \
"%s is not implemented on %s device.", __func__, Type()));
#define CHECK_PTR(x) \
if (x == nullptr) { \
INTERFACE_UNIMPLEMENT; \
}
class CustomDevice : public DeviceInterface {
public:
CustomDevice(const std::string& type,
......@@ -561,6 +569,208 @@ class CustomDevice : public DeviceInterface {
return version;
}
C_DataType ToXCCLDataType(ccl::CCLDataType data_type) {
#define return_result(in, ret) \
case ccl::CCLDataType::in: \
return C_DataType::ret
switch (data_type) {
return_result(CCL_DATA_TYPE_FP64, FLOAT64);
return_result(CCL_DATA_TYPE_FP32, FLOAT32);
return_result(CCL_DATA_TYPE_FP16, FLOAT16);
return_result(CCL_DATA_TYPE_INT64, INT64);
return_result(CCL_DATA_TYPE_INT32, INT32);
return_result(CCL_DATA_TYPE_INT16, INT16);
return_result(CCL_DATA_TYPE_INT8, INT8);
default: {
PADDLE_THROW(phi::errors::Unavailable(
"DataType is not supported on %s.", Type()));
return C_DataType::UNDEFINED;
}
}
#undef return_result
}
C_CCLReduceOp ToXCCLReduceOp(ccl::CCLReduceOp reduce_op) {
#define return_result(in, ret) \
case ccl::CCLReduceOp::in: \
return C_CCLReduceOp::ret
switch (reduce_op) {
return_result(SUM, SUM);
return_result(AVG, AVG);
return_result(MAX, MAX);
return_result(MIN, MIN);
return_result(PRODUCT, PRODUCT);
default: {
PADDLE_THROW(phi::errors::Unavailable(
"ReduceOp is not supported on %s.", Type()));
}
}
#undef return_result
}
void CCLGetUniqueId(ccl::CCLRootId* unique_id) override {
CHECK_PTR(pimpl_->xccl_get_unique_id_size);
CHECK_PTR(pimpl_->xccl_get_unique_id);
C_CCLRootId root_id;
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->xccl_get_unique_id_size(&(root_id.sz)));
root_id.data = new uint8_t[root_id.sz];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_get_unique_id(&root_id));
uint8_t* ptr = reinterpret_cast<uint8_t*>(root_id.data);
*unique_id = std::vector<uint8_t>(ptr, ptr + root_id.sz);
delete[] ptr;
}
void CCLCommInitRank(size_t nranks,
ccl::CCLRootId* unique_id,
size_t rank,
ccl::CCLComm* comm) override {
CHECK_PTR(pimpl_->xccl_comm_init_rank);
C_CCLRootId root_id;
root_id.sz = unique_id->size();
root_id.data = unique_id->data();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_comm_init_rank(
nranks, &root_id, rank, reinterpret_cast<C_CCLComm*>(comm)));
}
void CCLDestroyComm(ccl::CCLComm comm) override {
CHECK_PTR(pimpl_->xccl_destroy_comm);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->xccl_destroy_comm(reinterpret_cast<C_CCLComm>(comm)));
}
void CCLAllReduce(void* send_buf,
void* recv_buf,
size_t count,
ccl::CCLDataType data_type,
ccl::CCLReduceOp op,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_all_reduce);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_reduce(
send_buf,
recv_buf,
count,
ToXCCLDataType(data_type),
ToXCCLReduceOp(op),
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
void CCLBroadcast(void* buf,
size_t count,
ccl::CCLDataType data_type,
size_t root,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_broadcast);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_broadcast(
buf,
count,
ToXCCLDataType(data_type),
root,
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
void CCLReduce(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_reduce);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->xccl_reduce(in_data,
out_data,
num,
ToXCCLDataType(data_type),
ToXCCLReduceOp(reduce_op),
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
void CCLAllGather(void* send_buf,
void* recv_buf,
size_t count,
ccl::CCLDataType data_type,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_all_gather);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_gather(
send_buf,
recv_buf,
count,
ToXCCLDataType(data_type),
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
void CCLReduceScatter(void* send_buf,
void* recv_buf,
size_t count,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_reduce_scatter);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_reduce_scatter(
send_buf,
recv_buf,
count,
ToXCCLDataType(data_type),
ToXCCLReduceOp(reduce_op),
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
void CCLGroupStart() override {
CHECK_PTR(pimpl_->xccl_group_start);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start());
}
void CCLGroupEnd() override {
CHECK_PTR(pimpl_->xccl_group_end);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end());
}
void CCLSend(void* send_buf,
size_t count,
ccl::CCLDataType data_type,
size_t dest_rank,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_send);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->xccl_send(send_buf,
count,
ToXCCLDataType(data_type),
dest_rank,
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
void CCLRecv(void* recv_buf,
size_t count,
ccl::CCLDataType data_type,
size_t src_rank,
const ccl::CCLComm& comm,
const stream::Stream& stream) override {
CHECK_PTR(pimpl_->xccl_recv);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->xccl_recv(recv_buf,
count,
ToXCCLDataType(data_type),
src_rank,
reinterpret_cast<C_CCLComm>(comm),
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
private:
inline int PlaceToIdNoCheck(const Place& place) {
int dev_id = place.GetDeviceId();
......@@ -584,7 +794,7 @@ class CustomDevice : public DeviceInterface {
};
bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
#define CHECK_PTR(ptr, required) \
#define CHECK_INTERFACE(ptr, required) \
if (params->interface->ptr == nullptr && required) { \
LOG(WARNING) << "CustomRuntime [type: " << params->device_type \
<< "] pointer: " << #ptr << " is not set."; \
......@@ -604,58 +814,71 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
return false;
}
CHECK_PTR(initialize, false);
CHECK_PTR(finalize, false)
CHECK_PTR(init_device, false);
CHECK_PTR(set_device, true);
CHECK_PTR(get_device, true);
CHECK_PTR(deinit_device, false);
CHECK_PTR(create_stream, true);
CHECK_PTR(destroy_stream, true);
CHECK_PTR(query_stream, false);
CHECK_PTR(stream_add_callback, false);
CHECK_PTR(create_event, true);
CHECK_PTR(record_event, true);
CHECK_PTR(destroy_event, true);
CHECK_PTR(query_event, false);
CHECK_PTR(synchronize_device, false);
CHECK_PTR(synchronize_stream, true);
CHECK_PTR(synchronize_event, true);
CHECK_PTR(stream_wait_event, true);
CHECK_PTR(device_memory_allocate, true);
CHECK_PTR(device_memory_deallocate, true);
CHECK_PTR(host_memory_allocate, false);
CHECK_PTR(host_memory_deallocate, false);
CHECK_PTR(unified_memory_allocate, false);
CHECK_PTR(unified_memory_deallocate, false);
CHECK_PTR(memory_copy_h2d, true);
CHECK_PTR(memory_copy_d2h, true);
CHECK_PTR(memory_copy_d2d, true);
CHECK_PTR(memory_copy_p2p, false);
CHECK_PTR(async_memory_copy_h2d, false);
CHECK_PTR(async_memory_copy_d2h, false);
CHECK_PTR(async_memory_copy_d2d, false);
CHECK_PTR(async_memory_copy_p2p, false);
CHECK_PTR(get_device_count, true);
CHECK_PTR(get_device_list, true);
CHECK_PTR(device_memory_stats, true);
CHECK_PTR(device_min_chunk_size, true);
CHECK_PTR(device_max_chunk_size, false);
CHECK_PTR(device_max_alloc_size, false);
CHECK_PTR(device_extra_padding_size, false);
CHECK_PTR(get_compute_capability, false);
CHECK_PTR(get_runtime_version, false);
CHECK_PTR(get_driver_version, false);
CHECK_INTERFACE(initialize, false);
CHECK_INTERFACE(finalize, false)
CHECK_INTERFACE(init_device, false);
CHECK_INTERFACE(set_device, true);
CHECK_INTERFACE(get_device, true);
CHECK_INTERFACE(deinit_device, false);
CHECK_INTERFACE(create_stream, true);
CHECK_INTERFACE(destroy_stream, true);
CHECK_INTERFACE(query_stream, false);
CHECK_INTERFACE(stream_add_callback, false);
CHECK_INTERFACE(create_event, true);
CHECK_INTERFACE(record_event, true);
CHECK_INTERFACE(destroy_event, true);
CHECK_INTERFACE(query_event, false);
CHECK_INTERFACE(synchronize_device, false);
CHECK_INTERFACE(synchronize_stream, true);
CHECK_INTERFACE(synchronize_event, true);
CHECK_INTERFACE(stream_wait_event, true);
CHECK_INTERFACE(device_memory_allocate, true);
CHECK_INTERFACE(device_memory_deallocate, true);
CHECK_INTERFACE(host_memory_allocate, false);
CHECK_INTERFACE(host_memory_deallocate, false);
CHECK_INTERFACE(unified_memory_allocate, false);
CHECK_INTERFACE(unified_memory_deallocate, false);
CHECK_INTERFACE(memory_copy_h2d, true);
CHECK_INTERFACE(memory_copy_d2h, true);
CHECK_INTERFACE(memory_copy_d2d, true);
CHECK_INTERFACE(memory_copy_p2p, false);
CHECK_INTERFACE(async_memory_copy_h2d, false);
CHECK_INTERFACE(async_memory_copy_d2h, false);
CHECK_INTERFACE(async_memory_copy_d2d, false);
CHECK_INTERFACE(async_memory_copy_p2p, false);
CHECK_INTERFACE(get_device_count, true);
CHECK_INTERFACE(get_device_list, true);
CHECK_INTERFACE(device_memory_stats, true);
CHECK_INTERFACE(device_min_chunk_size, true);
CHECK_INTERFACE(device_max_chunk_size, false);
CHECK_INTERFACE(device_max_alloc_size, false);
CHECK_INTERFACE(device_extra_padding_size, false);
CHECK_INTERFACE(get_compute_capability, false);
CHECK_INTERFACE(get_runtime_version, false);
CHECK_INTERFACE(get_driver_version, false);
CHECK_INTERFACE(xccl_get_unique_id, false);
CHECK_INTERFACE(xccl_get_unique_id_size, false);
CHECK_INTERFACE(xccl_comm_init_rank, false);
CHECK_INTERFACE(xccl_destroy_comm, false);
CHECK_INTERFACE(xccl_all_reduce, false);
CHECK_INTERFACE(xccl_broadcast, false);
CHECK_INTERFACE(xccl_reduce, false);
CHECK_INTERFACE(xccl_all_gather, false);
CHECK_INTERFACE(xccl_reduce_scatter, false);
CHECK_INTERFACE(xccl_group_start, false);
CHECK_INTERFACE(xccl_group_end, false);
CHECK_INTERFACE(xccl_send, false);
CHECK_INTERFACE(xccl_recv, false);
return true;
#undef CHECK_PTR
#undef CHECK_INTERFACE
}
typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params);
......@@ -712,4 +935,6 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) {
LOG(INFO) << "Successed in loading custom runtime in lib: " << dso_lib_path;
}
#undef INTERFACE_UNIMPLEMENT
} // namespace phi
......@@ -107,6 +107,7 @@ void TestTensorShareDataWith(const paddle::platform::Place& place) {
}
void TestTensorUtils(const paddle::platform::Place& place) {
std::cout << "TestTensorUtils on " << place << std::endl;
if (paddle::platform::is_custom_place(place) == false) {
return;
}
......@@ -166,6 +167,76 @@ void TestTensorUtils(const paddle::platform::Place& place) {
#endif
}
void TestCustomCCL(const paddle::platform::Place& place) {
std::cout << "TestCustomCCL on " << place << std::endl;
if (paddle::platform::is_custom_place(place) == false) {
return;
}
std::string dev_type = place.GetDeviceType();
phi::ccl::CCLComm comm;
phi::stream::Stream stream(place, nullptr);
phi::ccl::CCLRootId root_id;
phi::DeviceManager::CCLDestroyComm(dev_type, nullptr);
phi::DeviceManager::CCLGetUniqueId(dev_type, &root_id);
phi::DeviceManager::CCLCommInitRank(dev_type, 0, &root_id, 0, nullptr);
phi::DeviceManager::CCLBroadcast(dev_type,
nullptr,
0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
0,
comm,
stream);
phi::DeviceManager::CCLAllReduce(dev_type,
nullptr,
nullptr,
0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
phi::ccl::CCLReduceOp::SUM,
comm,
stream);
phi::DeviceManager::CCLReduce(dev_type,
nullptr,
nullptr,
0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
phi::ccl::CCLReduceOp::SUM,
comm,
stream);
phi::DeviceManager::CCLAllGather(dev_type,
nullptr,
nullptr,
0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
comm,
stream);
phi::DeviceManager::CCLReduceScatter(
dev_type,
nullptr,
nullptr,
0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
phi::ccl::CCLReduceOp::SUM,
comm,
stream);
phi::DeviceManager::CCLGroupStart(dev_type);
phi::DeviceManager::CCLGroupEnd(dev_type);
phi::DeviceManager::CCLSend(dev_type,
nullptr,
0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
0,
comm,
stream);
phi::DeviceManager::CCLRecv(dev_type,
nullptr,
0,
phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32,
0,
comm,
stream);
}
TEST(CustomDevice, Tensor) {
InitDevice();
auto dev_types = phi::DeviceManager::GetAllDeviceTypes();
......@@ -179,6 +250,7 @@ TEST(CustomDevice, Tensor) {
TestTensorMutableData(place);
TestTensorShareDataWith(place);
TestTensorUtils(place);
TestCustomCCL(place);
}
}
......
......@@ -136,6 +136,80 @@ C_Status DeviceMaxAllocSize(const C_Device device, size_t *size) {
return C_SUCCESS;
}
C_Status XcclGetUniqueIdSize(size_t *size) {
*size = sizeof(size_t);
return C_SUCCESS;
}
C_Status XcclGetUniqueId(C_CCLRootId *unique_id) { return C_SUCCESS; }
C_Status XcclCommInitRank(size_t ranks,
C_CCLRootId *unique_id,
size_t rank,
C_CCLComm *comm) {
return C_SUCCESS;
}
C_Status XcclDestroyComm(C_CCLComm comm) { return C_SUCCESS; }
C_Status XcclAllReduce(void *send_buf,
void *recv_buf,
size_t count,
C_DataType data_type,
C_CCLReduceOp op,
C_CCLComm comm,
C_Stream stream) {
return C_SUCCESS;
}
C_Status XcclBroadcast(void *buf,
size_t count,
C_DataType data_type,
size_t root,
C_CCLComm comm,
C_Stream stream) {
return C_SUCCESS;
}
C_Status XcclReduce(void *send_buf,
void *recv_buf,
size_t count,
C_DataType data_type,
C_CCLReduceOp op,
C_CCLComm comm,
C_Stream stream) {
return C_SUCCESS;
}
C_Status XcclAllGather(void *send_buf,
void *recv_buf,
size_t count,
C_DataType data_type,
C_CCLComm comm,
C_Stream stream) {
return C_SUCCESS;
}
C_Status XcclReduceScatter(void *send_buf,
void *recv_buf,
size_t count,
C_DataType data_type,
C_CCLReduceOp op,
C_CCLComm comm,
C_Stream stream) {
return C_SUCCESS;
}
C_Status XcclGroupStart() { return C_SUCCESS; }
C_Status XcclGroupEnd() { return C_SUCCESS; }
C_Status XcclSend(void *send_buf,
size_t count,
C_DataType data_type,
size_t dest_rank,
C_CCLComm comm,
C_Stream stream) {
return C_SUCCESS;
}
C_Status XcclRecv(void *recv_buf,
size_t count,
C_DataType data_type,
size_t src_rank,
C_CCLComm comm,
C_Stream stream) {
return C_SUCCESS;
}
#define DEVICE_TYPE "FakeCPU"
#define SUB_DEVICE_TYPE "V100"
......@@ -190,4 +264,18 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
params->interface->device_max_chunk_size = DeviceMaxChunkSize;
params->interface->device_min_chunk_size = DeviceMinChunkSize;
params->interface->device_max_alloc_size = DeviceMaxAllocSize;
params->interface->xccl_get_unique_id_size = XcclGetUniqueIdSize;
params->interface->xccl_get_unique_id = XcclGetUniqueId;
params->interface->xccl_all_reduce = XcclAllReduce;
params->interface->xccl_all_gather = XcclAllGather;
params->interface->xccl_broadcast = XcclBroadcast;
params->interface->xccl_comm_init_rank = XcclCommInitRank;
params->interface->xccl_destroy_comm = XcclDestroyComm;
params->interface->xccl_group_end = XcclGroupEnd;
params->interface->xccl_group_start = XcclGroupStart;
params->interface->xccl_reduce = XcclReduce;
params->interface->xccl_reduce_scatter = XcclReduceScatter;
params->interface->xccl_send = XcclSend;
params->interface->xccl_recv = XcclRecv;
}
......@@ -270,4 +270,91 @@ size_t DeviceInterface::GetExtraPaddingSize(size_t dev_id) {
return 0;
}
void DeviceInterface::CCLDestroyComm(ccl::CCLComm ccl_comm) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLCommInitRank(size_t num_ranks,
ccl::CCLRootId* root_id,
size_t rank_id,
ccl::CCLComm* ccl_comm) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLGetUniqueId(ccl::CCLRootId* root_id) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLBroadcast(void* data,
size_t num,
ccl::CCLDataType data_type,
size_t root,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLAllReduce(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLReduce(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLAllGather(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLReduceScatter(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLGroupStart() { INTERFACE_UNIMPLEMENT; }
void DeviceInterface::CCLGroupEnd() { INTERFACE_UNIMPLEMENT; }
void DeviceInterface::CCLSend(void* sendbuf,
size_t num,
ccl::CCLDataType data_type,
size_t dst_rank,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::CCLRecv(void* recvbuf,
size_t num,
ccl::CCLDataType data_type,
size_t src_rank,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
INTERFACE_UNIMPLEMENT;
}
#undef INTERFACE_UNIMPLEMENT
} // namespace phi
......@@ -16,6 +16,7 @@
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include <vector>
#include "paddle/phi/backends/c_comm_lib.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
......@@ -165,6 +166,65 @@ class DeviceInterface { // Driver / Runtime
virtual size_t GetExtraPaddingSize(size_t dev_id);
// CCL
virtual void CCLDestroyComm(ccl::CCLComm ccl_comm);
virtual void CCLCommInitRank(size_t num_ranks,
ccl::CCLRootId* root_id,
size_t rank_id,
ccl::CCLComm* ccl_comm);
virtual void CCLGetUniqueId(ccl::CCLRootId* root_id);
virtual void CCLBroadcast(void* data,
size_t num,
ccl::CCLDataType data_type,
size_t root,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
virtual void CCLAllReduce(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
virtual void CCLReduce(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
virtual void CCLAllGather(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
virtual void CCLReduceScatter(void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
virtual void CCLGroupStart();
virtual void CCLGroupEnd();
virtual void CCLSend(void* sendbuf,
size_t num,
ccl::CCLDataType data_type,
size_t dst_rank,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
virtual void CCLRecv(void* recvbuf,
size_t num,
ccl::CCLDataType data_type,
size_t src_rank,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
private:
const std::string type_;
const uint8_t priority_;
......
......@@ -74,6 +74,15 @@ typedef void (*C_Callback)(C_Device device,
void* user_data,
C_Status* status);
typedef struct {
size_t sz;
void* data;
} C_CCLRootId;
typedef struct C_CCLComm_st* C_CCLComm;
typedef enum { SUM = 0, AVG, MAX, MIN, PRODUCT } C_CCLReduceOp;
struct C_DeviceInterface {
// Core fill it and plugin must to check it
size_t size;
......@@ -526,6 +535,102 @@ struct C_DeviceInterface {
void* reserved_info_api[8];
//////////////
// ccl api //
//////////////
/**
* @brief Get size of unique id
*
* @param[size_t*] size
*/
C_Status (*xccl_get_unique_id_size)(size_t* size);
/**
* @brief Get unique id
*
* @param[C_CCLRootId*] unique_id
*/
C_Status (*xccl_get_unique_id)(C_CCLRootId* unique_id);
/**
* @brief Initialize communicator
*
* @param[size_t] ranks
* @param[C_CCLRootId*] unique_id
* @param[size_t] rank
* @param[C_CCLComm*] comm
*/
C_Status (*xccl_comm_init_rank)(size_t ranks,
C_CCLRootId* unique_id,
size_t rank,
C_CCLComm* comm);
/**
* @brief Destroy communicator
*
* @param[C_CCLComm] comm
*/
C_Status (*xccl_destroy_comm)(C_CCLComm comm);
C_Status (*xccl_all_reduce)(void* send_buf,
void* recv_buf,
size_t count,
C_DataType data_type,
C_CCLReduceOp op,
C_CCLComm comm,
C_Stream stream);
C_Status (*xccl_broadcast)(void* buf,
size_t count,
C_DataType data_type,
size_t root,
C_CCLComm comm,
C_Stream stream);
C_Status (*xccl_reduce)(void* send_buf,
void* recv_buf,
size_t count,
C_DataType data_type,
C_CCLReduceOp op,
C_CCLComm comm,
C_Stream stream);
C_Status (*xccl_all_gather)(void* send_buf,
void* recv_buf,
size_t count,
C_DataType data_type,
C_CCLComm comm,
C_Stream stream);
C_Status (*xccl_reduce_scatter)(void* send_buf,
void* recv_buf,
size_t count,
C_DataType data_type,
C_CCLReduceOp op,
C_CCLComm comm,
C_Stream stream);
C_Status (*xccl_group_start)();
C_Status (*xccl_group_end)();
C_Status (*xccl_send)(void* send_buf,
size_t count,
C_DataType data_type,
size_t dest_rank,
C_CCLComm comm,
C_Stream stream);
C_Status (*xccl_recv)(void* recv_buf,
size_t count,
C_DataType data_type,
size_t src_rank,
C_CCLComm comm,
C_Stream stream);
void* reserved_ccl_api[8];
///////////////
// other api //
///////////////
......
......@@ -25,6 +25,7 @@
#include <regex>
#include "glog/logging.h"
#include "paddle/utils/string/split.h"
namespace phi {
......@@ -390,6 +391,139 @@ std::vector<size_t> DeviceManager::GetDeviceList(
return dev_impl->GetDeviceList();
}
std::vector<size_t> DeviceManager::GetSelectedDeviceList(
const std::string& device_type) {
std::vector<size_t> devices;
std::string FLAGS = "FLAGS_selected_" + device_type + "s";
auto FLAGS_selected_devices = getenv(FLAGS.c_str());
if (FLAGS_selected_devices) {
auto devices_str = paddle::string::Split(FLAGS_selected_devices, ',');
for (auto id : devices_str) {
devices.push_back(atoi(id.c_str()));
}
} else {
int count = DeviceManager::GetDeviceCount(device_type);
for (int i = 0; i < count; ++i) {
devices.push_back(i);
}
}
return devices;
}
void DeviceManager::CCLDestroyComm(const std::string& device_type,
ccl::CCLComm ccl_comm) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLDestroyComm(ccl_comm);
}
void DeviceManager::CCLCommInitRank(const std::string& device_type,
size_t num_ranks,
ccl::CCLRootId* root_id,
size_t rank_id,
ccl::CCLComm* ccl_comm) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLCommInitRank(num_ranks, root_id, rank_id, ccl_comm);
}
void DeviceManager::CCLGetUniqueId(const std::string& device_type,
ccl::CCLRootId* root_id) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLGetUniqueId(root_id);
}
void DeviceManager::CCLBroadcast(const std::string& device_type,
void* data,
size_t num,
ccl::CCLDataType data_type,
size_t root_id,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLBroadcast(data, num, data_type, root_id, ccl_comm, stream);
}
void DeviceManager::CCLAllReduce(const std::string& device_type,
void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLAllReduce(
in_data, out_data, num, data_type, reduce_op, ccl_comm, stream);
}
void DeviceManager::CCLReduce(const std::string& device_type,
void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLReduce(
in_data, out_data, num, data_type, reduce_op, ccl_comm, stream);
}
void DeviceManager::CCLAllGather(const std::string& device_type,
void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLAllGather(in_data, out_data, num, data_type, ccl_comm, stream);
}
void DeviceManager::CCLReduceScatter(const std::string& device_type,
void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLReduceScatter(
in_data, out_data, num, data_type, op, ccl_comm, stream);
}
void DeviceManager::CCLGroupStart(const std::string& device_type) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLGroupStart();
}
void DeviceManager::CCLGroupEnd(const std::string& device_type) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLGroupEnd();
}
void DeviceManager::CCLSend(const std::string& device_type,
void* sendbuf,
size_t num,
ccl::CCLDataType data_type,
size_t dst_rank,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLSend(sendbuf, num, data_type, dst_rank, ccl_comm, stream);
}
void DeviceManager::CCLRecv(const std::string& device_type,
void* recvbuf,
size_t num,
ccl::CCLDataType data_type,
size_t src_rank,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->CCLRecv(recvbuf, num, data_type, src_rank, ccl_comm, stream);
}
DeviceManager& DeviceManager::Instance() {
static DeviceManager platform_manager;
return platform_manager;
......
......@@ -17,6 +17,7 @@
#include <unordered_map>
#include "paddle/phi/backends/c_comm_lib.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_ext.h"
#include "paddle/phi/backends/dynload/port.h"
......@@ -159,6 +160,74 @@ class DeviceManager {
static std::vector<size_t> GetDeviceList(const std::string& device_type);
static std::vector<size_t> GetSelectedDeviceList(
const std::string& device_type);
// CCL
static void CCLDestroyComm(const std::string& device_type,
ccl::CCLComm ccl_comm);
static void CCLCommInitRank(const std::string& device_type,
size_t num_ranks,
ccl::CCLRootId* root_id,
size_t rank_id,
ccl::CCLComm* ccl_comm);
static void CCLGetUniqueId(const std::string& device_type,
ccl::CCLRootId* root_id);
static void CCLBroadcast(const std::string& device_type,
void* data,
size_t num,
ccl::CCLDataType data_type,
size_t root,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
static void CCLAllReduce(const std::string& device_type,
void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
static void CCLReduce(const std::string& device_type,
void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp reduce_op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
static void CCLAllGather(const std::string& device_type,
void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
static void CCLReduceScatter(const std::string& device_type,
void* in_data,
void* out_data,
size_t num,
ccl::CCLDataType data_type,
ccl::CCLReduceOp op,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
static void CCLGroupStart(const std::string& device_type);
static void CCLGroupEnd(const std::string& device_type);
static void CCLSend(const std::string& device_type,
void* sendbuf,
size_t num,
ccl::CCLDataType data_type,
size_t dst_rank,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
static void CCLRecv(const std::string& device_type,
void* recvbuf,
size_t num,
ccl::CCLDataType data_type,
size_t src_rank,
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
static void Clear();
private:
......
......@@ -230,7 +230,10 @@ def _convert_to_place(device):
device_id = int(selected_mlus[0])
place = core.MLUPlace(device_id)
elif device in core.get_all_custom_device_type():
place = core.CustomPlace(device, 0)
selected_devices = os.getenv("FLAGS_selected_{}s".format(device),
"0").split(",")
device_id = int(selected_devices[0])
place = core.CustomPlace(device, device_id)
else:
avaliable_gpu_device = re.match(r'gpu:\d+', lower_device)
avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册