From d88e77a7b4acfecb58d8eefcfc994a56f94b301a Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 14 Jul 2022 14:39:37 +0800 Subject: [PATCH] [CustomDevice] add custom ccl 1/2 (#44294) * [CustomDevice] add custom ccl api * add ut --- .../final_state_generator/python_c_gen.py | 10 + .../allocation/naive_best_fit_allocator.cc | 2 +- paddle/fluid/platform/init.cc | 8 +- paddle/phi/backends/c_comm_lib.cc | 20 ++ paddle/phi/backends/c_comm_lib.h | 60 ++++ paddle/phi/backends/callback_manager.cc | 4 +- paddle/phi/backends/custom/custom_device.cc | 329 +++++++++++++++--- .../phi/backends/custom/custom_device_test.cc | 72 ++++ paddle/phi/backends/custom/fake_cpu_device.h | 88 +++++ paddle/phi/backends/device_base.cc | 87 +++++ paddle/phi/backends/device_base.h | 60 ++++ paddle/phi/backends/device_ext.h | 105 ++++++ paddle/phi/backends/device_manager.cc | 134 +++++++ paddle/phi/backends/device_manager.h | 69 ++++ python/paddle/device/__init__.py | 5 +- 15 files changed, 994 insertions(+), 59 deletions(-) create mode 100644 paddle/phi/backends/c_comm_lib.cc create mode 100644 paddle/phi/backends/c_comm_lib.h diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index c6ac5a12f5..9d5706f65b 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -180,6 +180,15 @@ FUNCTION_SET_DEVICE_TEMPLATE = \ #else PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU if use CUDAPlace.")); +#endif + }} + if (paddle::platform::is_custom_place(place)) {{ +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + phi::DeviceManager::SetDevice(place); + VLOG(1) <<"CurrentDeviceId: " << phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from " << (int)place.device; +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace.")); #endif }} """ @@ -200,6 +209,7 @@ PYTHON_C_WRAPPER_TEMPLATE = \ #include #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/api/include/strings_api.h" +#include "paddle/phi/backends/device_manager.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/platform/profiler/event_tracing.h" diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index 4553c80e74..d696b8bffd 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -764,7 +764,7 @@ class BuddyAllocatorList { private: explicit BuddyAllocatorList(const std::string &device_type) : device_type_(device_type) { - auto devices = phi::DeviceManager::GetDeviceList(device_type); + auto devices = phi::DeviceManager::GetSelectedDeviceList(device_type); for (auto dev_id : devices) { init_flags_[dev_id].reset(new std::once_flag()); } diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index b6f6deb80d..6e28c775a3 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -264,11 +264,11 @@ void InitDevices(const std::vector devices) { auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); for (auto &dev_type : device_types) { - auto device_count = phi::DeviceManager::GetDeviceCount(dev_type); + auto device_list = phi::DeviceManager::GetSelectedDeviceList(dev_type); LOG(INFO) << "CustomDevice: " << dev_type - << ", visible devices count: " << device_count; - for (size_t i = 0; i < device_count; i++) { - places.push_back(platform::CustomPlace(dev_type, i)); + << ", visible devices count: " << device_list.size(); + for (auto &dev_id : device_list) { + places.push_back(platform::CustomPlace(dev_type, dev_id)); } } } else { diff --git a/paddle/phi/backends/c_comm_lib.cc b/paddle/phi/backends/c_comm_lib.cc new file mode 100644 index 0000000000..7f86ac6eff --- /dev/null +++ b/paddle/phi/backends/c_comm_lib.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/c_comm_lib.h" + +namespace phi { +// Even this source file does not contains any code, it is better to keep this +// source file for cmake dependency. +} // namespace phi diff --git a/paddle/phi/backends/c_comm_lib.h b/paddle/phi/backends/c_comm_lib.h new file mode 100644 index 0000000000..f2987996cb --- /dev/null +++ b/paddle/phi/backends/c_comm_lib.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/macros.h" + +namespace phi { +namespace ccl { +using CCLComm = void*; +using CCLRootId = std::vector; + +enum CCLReduceOp { SUM = 0, AVG, MAX, MIN, PRODUCT }; +enum CCLDataType { + CCL_DATA_TYPE_FP64 = 0, + CCL_DATA_TYPE_FP32, + CCL_DATA_TYPE_FP16, + CCL_DATA_TYPE_INT64, + CCL_DATA_TYPE_INT32, + CCL_DATA_TYPE_INT16, + CCL_DATA_TYPE_INT8 +}; + +inline CCLDataType ToCCLDataType(paddle::experimental::DataType type) { + if (type == paddle::experimental::DataType::FLOAT64) { + return CCL_DATA_TYPE_FP64; + } else if (type == paddle::experimental::DataType::FLOAT32) { + return CCL_DATA_TYPE_FP32; + } else if (type == paddle::experimental::DataType::FLOAT16) { + return CCL_DATA_TYPE_FP16; + } else if (type == paddle::experimental::DataType::INT64) { + return CCL_DATA_TYPE_INT64; + } else if (type == paddle::experimental::DataType::INT32) { + return CCL_DATA_TYPE_INT32; + } else if (type == paddle::experimental::DataType::INT8) { + return CCL_DATA_TYPE_INT8; + } else { + PADDLE_THROW( + phi::errors::Unimplemented("This datatype in CCL is not supported.")); + } +} + +} // namespace ccl +} // namespace phi diff --git a/paddle/phi/backends/callback_manager.cc b/paddle/phi/backends/callback_manager.cc index 295f70fc65..7ce5988038 100644 --- a/paddle/phi/backends/callback_manager.cc +++ b/paddle/phi/backends/callback_manager.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/backends/device_guard.h" namespace phi { @@ -33,12 +34,13 @@ void CallbackManager::AddCallback(std::function callback) const { (*callback_func)(); }); }); - + phi::DeviceGuard guard(stream_->GetPlace()); phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace()) ->AddCallback(stream_, func); } void CallbackManager::Wait() const { + phi::DeviceGuard guard(stream_->GetPlace()); phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace()) ->SynchronizeStream(stream_); diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 541acd9eca..1a92868dd0 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -27,6 +27,14 @@ static bool operator==(const C_Device_st& d1, const C_Device_st& d2) { namespace phi { +#define INTERFACE_UNIMPLEMENT \ + PADDLE_THROW(phi::errors::Unimplemented( \ + "%s is not implemented on %s device.", __func__, Type())); +#define CHECK_PTR(x) \ + if (x == nullptr) { \ + INTERFACE_UNIMPLEMENT; \ + } + class CustomDevice : public DeviceInterface { public: CustomDevice(const std::string& type, @@ -561,6 +569,208 @@ class CustomDevice : public DeviceInterface { return version; } + C_DataType ToXCCLDataType(ccl::CCLDataType data_type) { +#define return_result(in, ret) \ + case ccl::CCLDataType::in: \ + return C_DataType::ret + switch (data_type) { + return_result(CCL_DATA_TYPE_FP64, FLOAT64); + return_result(CCL_DATA_TYPE_FP32, FLOAT32); + return_result(CCL_DATA_TYPE_FP16, FLOAT16); + return_result(CCL_DATA_TYPE_INT64, INT64); + return_result(CCL_DATA_TYPE_INT32, INT32); + return_result(CCL_DATA_TYPE_INT16, INT16); + return_result(CCL_DATA_TYPE_INT8, INT8); + default: { + PADDLE_THROW(phi::errors::Unavailable( + "DataType is not supported on %s.", Type())); + return C_DataType::UNDEFINED; + } + } +#undef return_result + } + + C_CCLReduceOp ToXCCLReduceOp(ccl::CCLReduceOp reduce_op) { +#define return_result(in, ret) \ + case ccl::CCLReduceOp::in: \ + return C_CCLReduceOp::ret + switch (reduce_op) { + return_result(SUM, SUM); + return_result(AVG, AVG); + return_result(MAX, MAX); + return_result(MIN, MIN); + return_result(PRODUCT, PRODUCT); + default: { + PADDLE_THROW(phi::errors::Unavailable( + "ReduceOp is not supported on %s.", Type())); + } + } +#undef return_result + } + + void CCLGetUniqueId(ccl::CCLRootId* unique_id) override { + CHECK_PTR(pimpl_->xccl_get_unique_id_size); + CHECK_PTR(pimpl_->xccl_get_unique_id); + + C_CCLRootId root_id; + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_get_unique_id_size(&(root_id.sz))); + root_id.data = new uint8_t[root_id.sz]; + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_get_unique_id(&root_id)); + + uint8_t* ptr = reinterpret_cast(root_id.data); + *unique_id = std::vector(ptr, ptr + root_id.sz); + delete[] ptr; + } + + void CCLCommInitRank(size_t nranks, + ccl::CCLRootId* unique_id, + size_t rank, + ccl::CCLComm* comm) override { + CHECK_PTR(pimpl_->xccl_comm_init_rank); + + C_CCLRootId root_id; + root_id.sz = unique_id->size(); + root_id.data = unique_id->data(); + + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_comm_init_rank( + nranks, &root_id, rank, reinterpret_cast(comm))); + } + + void CCLDestroyComm(ccl::CCLComm comm) override { + CHECK_PTR(pimpl_->xccl_destroy_comm); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_destroy_comm(reinterpret_cast(comm))); + } + + void CCLAllReduce(void* send_buf, + void* recv_buf, + size_t count, + ccl::CCLDataType data_type, + ccl::CCLReduceOp op, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + CHECK_PTR(pimpl_->xccl_all_reduce); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_reduce( + send_buf, + recv_buf, + count, + ToXCCLDataType(data_type), + ToXCCLReduceOp(op), + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + + void CCLBroadcast(void* buf, + size_t count, + ccl::CCLDataType data_type, + size_t root, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + CHECK_PTR(pimpl_->xccl_broadcast); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_broadcast( + buf, + count, + ToXCCLDataType(data_type), + root, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + + void CCLReduce(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + CHECK_PTR(pimpl_->xccl_reduce); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_reduce(in_data, + out_data, + num, + ToXCCLDataType(data_type), + ToXCCLReduceOp(reduce_op), + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + + void CCLAllGather(void* send_buf, + void* recv_buf, + size_t count, + ccl::CCLDataType data_type, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + CHECK_PTR(pimpl_->xccl_all_gather); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_gather( + send_buf, + recv_buf, + count, + ToXCCLDataType(data_type), + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + + void CCLReduceScatter(void* send_buf, + void* recv_buf, + size_t count, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + CHECK_PTR(pimpl_->xccl_reduce_scatter); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_reduce_scatter( + send_buf, + recv_buf, + count, + ToXCCLDataType(data_type), + ToXCCLReduceOp(reduce_op), + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + + void CCLGroupStart() override { + CHECK_PTR(pimpl_->xccl_group_start); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start()); + } + + void CCLGroupEnd() override { + CHECK_PTR(pimpl_->xccl_group_end); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end()); + } + + void CCLSend(void* send_buf, + size_t count, + ccl::CCLDataType data_type, + size_t dest_rank, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + CHECK_PTR(pimpl_->xccl_send); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_send(send_buf, + count, + ToXCCLDataType(data_type), + dest_rank, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + + void CCLRecv(void* recv_buf, + size_t count, + ccl::CCLDataType data_type, + size_t src_rank, + const ccl::CCLComm& comm, + const stream::Stream& stream) override { + CHECK_PTR(pimpl_->xccl_recv); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_recv(recv_buf, + count, + ToXCCLDataType(data_type), + src_rank, + reinterpret_cast(comm), + reinterpret_cast(stream.raw_stream()))); + } + private: inline int PlaceToIdNoCheck(const Place& place) { int dev_id = place.GetDeviceId(); @@ -584,7 +794,7 @@ class CustomDevice : public DeviceInterface { }; bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { -#define CHECK_PTR(ptr, required) \ +#define CHECK_INTERFACE(ptr, required) \ if (params->interface->ptr == nullptr && required) { \ LOG(WARNING) << "CustomRuntime [type: " << params->device_type \ << "] pointer: " << #ptr << " is not set."; \ @@ -604,58 +814,71 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { return false; } - CHECK_PTR(initialize, false); - CHECK_PTR(finalize, false) - - CHECK_PTR(init_device, false); - CHECK_PTR(set_device, true); - CHECK_PTR(get_device, true); - CHECK_PTR(deinit_device, false); - - CHECK_PTR(create_stream, true); - CHECK_PTR(destroy_stream, true); - CHECK_PTR(query_stream, false); - CHECK_PTR(stream_add_callback, false); - - CHECK_PTR(create_event, true); - CHECK_PTR(record_event, true); - CHECK_PTR(destroy_event, true); - CHECK_PTR(query_event, false); - - CHECK_PTR(synchronize_device, false); - CHECK_PTR(synchronize_stream, true); - CHECK_PTR(synchronize_event, true); - CHECK_PTR(stream_wait_event, true); - - CHECK_PTR(device_memory_allocate, true); - CHECK_PTR(device_memory_deallocate, true); - CHECK_PTR(host_memory_allocate, false); - CHECK_PTR(host_memory_deallocate, false); - CHECK_PTR(unified_memory_allocate, false); - CHECK_PTR(unified_memory_deallocate, false); - CHECK_PTR(memory_copy_h2d, true); - CHECK_PTR(memory_copy_d2h, true); - CHECK_PTR(memory_copy_d2d, true); - CHECK_PTR(memory_copy_p2p, false); - CHECK_PTR(async_memory_copy_h2d, false); - CHECK_PTR(async_memory_copy_d2h, false); - CHECK_PTR(async_memory_copy_d2d, false); - CHECK_PTR(async_memory_copy_p2p, false); - - CHECK_PTR(get_device_count, true); - CHECK_PTR(get_device_list, true); - CHECK_PTR(device_memory_stats, true); - - CHECK_PTR(device_min_chunk_size, true); - CHECK_PTR(device_max_chunk_size, false); - CHECK_PTR(device_max_alloc_size, false); - CHECK_PTR(device_extra_padding_size, false); - CHECK_PTR(get_compute_capability, false); - CHECK_PTR(get_runtime_version, false); - CHECK_PTR(get_driver_version, false); - + CHECK_INTERFACE(initialize, false); + CHECK_INTERFACE(finalize, false) + + CHECK_INTERFACE(init_device, false); + CHECK_INTERFACE(set_device, true); + CHECK_INTERFACE(get_device, true); + CHECK_INTERFACE(deinit_device, false); + + CHECK_INTERFACE(create_stream, true); + CHECK_INTERFACE(destroy_stream, true); + CHECK_INTERFACE(query_stream, false); + CHECK_INTERFACE(stream_add_callback, false); + + CHECK_INTERFACE(create_event, true); + CHECK_INTERFACE(record_event, true); + CHECK_INTERFACE(destroy_event, true); + CHECK_INTERFACE(query_event, false); + + CHECK_INTERFACE(synchronize_device, false); + CHECK_INTERFACE(synchronize_stream, true); + CHECK_INTERFACE(synchronize_event, true); + CHECK_INTERFACE(stream_wait_event, true); + + CHECK_INTERFACE(device_memory_allocate, true); + CHECK_INTERFACE(device_memory_deallocate, true); + CHECK_INTERFACE(host_memory_allocate, false); + CHECK_INTERFACE(host_memory_deallocate, false); + CHECK_INTERFACE(unified_memory_allocate, false); + CHECK_INTERFACE(unified_memory_deallocate, false); + CHECK_INTERFACE(memory_copy_h2d, true); + CHECK_INTERFACE(memory_copy_d2h, true); + CHECK_INTERFACE(memory_copy_d2d, true); + CHECK_INTERFACE(memory_copy_p2p, false); + CHECK_INTERFACE(async_memory_copy_h2d, false); + CHECK_INTERFACE(async_memory_copy_d2h, false); + CHECK_INTERFACE(async_memory_copy_d2d, false); + CHECK_INTERFACE(async_memory_copy_p2p, false); + + CHECK_INTERFACE(get_device_count, true); + CHECK_INTERFACE(get_device_list, true); + CHECK_INTERFACE(device_memory_stats, true); + + CHECK_INTERFACE(device_min_chunk_size, true); + CHECK_INTERFACE(device_max_chunk_size, false); + CHECK_INTERFACE(device_max_alloc_size, false); + CHECK_INTERFACE(device_extra_padding_size, false); + CHECK_INTERFACE(get_compute_capability, false); + CHECK_INTERFACE(get_runtime_version, false); + CHECK_INTERFACE(get_driver_version, false); + + CHECK_INTERFACE(xccl_get_unique_id, false); + CHECK_INTERFACE(xccl_get_unique_id_size, false); + CHECK_INTERFACE(xccl_comm_init_rank, false); + CHECK_INTERFACE(xccl_destroy_comm, false); + CHECK_INTERFACE(xccl_all_reduce, false); + CHECK_INTERFACE(xccl_broadcast, false); + CHECK_INTERFACE(xccl_reduce, false); + CHECK_INTERFACE(xccl_all_gather, false); + CHECK_INTERFACE(xccl_reduce_scatter, false); + CHECK_INTERFACE(xccl_group_start, false); + CHECK_INTERFACE(xccl_group_end, false); + CHECK_INTERFACE(xccl_send, false); + CHECK_INTERFACE(xccl_recv, false); return true; -#undef CHECK_PTR +#undef CHECK_INTERFACE } typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params); @@ -712,4 +935,6 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) { LOG(INFO) << "Successed in loading custom runtime in lib: " << dso_lib_path; } +#undef INTERFACE_UNIMPLEMENT + } // namespace phi diff --git a/paddle/phi/backends/custom/custom_device_test.cc b/paddle/phi/backends/custom/custom_device_test.cc index 51fa74b4dc..930750e864 100644 --- a/paddle/phi/backends/custom/custom_device_test.cc +++ b/paddle/phi/backends/custom/custom_device_test.cc @@ -107,6 +107,7 @@ void TestTensorShareDataWith(const paddle::platform::Place& place) { } void TestTensorUtils(const paddle::platform::Place& place) { + std::cout << "TestTensorUtils on " << place << std::endl; if (paddle::platform::is_custom_place(place) == false) { return; } @@ -166,6 +167,76 @@ void TestTensorUtils(const paddle::platform::Place& place) { #endif } +void TestCustomCCL(const paddle::platform::Place& place) { + std::cout << "TestCustomCCL on " << place << std::endl; + if (paddle::platform::is_custom_place(place) == false) { + return; + } + std::string dev_type = place.GetDeviceType(); + phi::ccl::CCLComm comm; + phi::stream::Stream stream(place, nullptr); + phi::ccl::CCLRootId root_id; + + phi::DeviceManager::CCLDestroyComm(dev_type, nullptr); + phi::DeviceManager::CCLGetUniqueId(dev_type, &root_id); + phi::DeviceManager::CCLCommInitRank(dev_type, 0, &root_id, 0, nullptr); + phi::DeviceManager::CCLBroadcast(dev_type, + nullptr, + 0, + phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, + 0, + comm, + stream); + phi::DeviceManager::CCLAllReduce(dev_type, + nullptr, + nullptr, + 0, + phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, + phi::ccl::CCLReduceOp::SUM, + comm, + stream); + phi::DeviceManager::CCLReduce(dev_type, + nullptr, + nullptr, + 0, + phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, + phi::ccl::CCLReduceOp::SUM, + comm, + stream); + phi::DeviceManager::CCLAllGather(dev_type, + nullptr, + nullptr, + 0, + phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, + comm, + stream); + phi::DeviceManager::CCLReduceScatter( + dev_type, + nullptr, + nullptr, + 0, + phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, + phi::ccl::CCLReduceOp::SUM, + comm, + stream); + phi::DeviceManager::CCLGroupStart(dev_type); + phi::DeviceManager::CCLGroupEnd(dev_type); + phi::DeviceManager::CCLSend(dev_type, + nullptr, + 0, + phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, + 0, + comm, + stream); + phi::DeviceManager::CCLRecv(dev_type, + nullptr, + 0, + phi::ccl::CCLDataType::CCL_DATA_TYPE_FP32, + 0, + comm, + stream); +} + TEST(CustomDevice, Tensor) { InitDevice(); auto dev_types = phi::DeviceManager::GetAllDeviceTypes(); @@ -179,6 +250,7 @@ TEST(CustomDevice, Tensor) { TestTensorMutableData(place); TestTensorShareDataWith(place); TestTensorUtils(place); + TestCustomCCL(place); } } diff --git a/paddle/phi/backends/custom/fake_cpu_device.h b/paddle/phi/backends/custom/fake_cpu_device.h index 22c344a0e0..41c7acc446 100644 --- a/paddle/phi/backends/custom/fake_cpu_device.h +++ b/paddle/phi/backends/custom/fake_cpu_device.h @@ -136,6 +136,80 @@ C_Status DeviceMaxAllocSize(const C_Device device, size_t *size) { return C_SUCCESS; } +C_Status XcclGetUniqueIdSize(size_t *size) { + *size = sizeof(size_t); + return C_SUCCESS; +} +C_Status XcclGetUniqueId(C_CCLRootId *unique_id) { return C_SUCCESS; } +C_Status XcclCommInitRank(size_t ranks, + C_CCLRootId *unique_id, + size_t rank, + C_CCLComm *comm) { + return C_SUCCESS; +} +C_Status XcclDestroyComm(C_CCLComm comm) { return C_SUCCESS; } +C_Status XcclAllReduce(void *send_buf, + void *recv_buf, + size_t count, + C_DataType data_type, + C_CCLReduceOp op, + C_CCLComm comm, + C_Stream stream) { + return C_SUCCESS; +} +C_Status XcclBroadcast(void *buf, + size_t count, + C_DataType data_type, + size_t root, + C_CCLComm comm, + C_Stream stream) { + return C_SUCCESS; +} +C_Status XcclReduce(void *send_buf, + void *recv_buf, + size_t count, + C_DataType data_type, + C_CCLReduceOp op, + C_CCLComm comm, + C_Stream stream) { + return C_SUCCESS; +} +C_Status XcclAllGather(void *send_buf, + void *recv_buf, + size_t count, + C_DataType data_type, + C_CCLComm comm, + C_Stream stream) { + return C_SUCCESS; +} +C_Status XcclReduceScatter(void *send_buf, + void *recv_buf, + size_t count, + C_DataType data_type, + C_CCLReduceOp op, + C_CCLComm comm, + C_Stream stream) { + return C_SUCCESS; +} +C_Status XcclGroupStart() { return C_SUCCESS; } +C_Status XcclGroupEnd() { return C_SUCCESS; } +C_Status XcclSend(void *send_buf, + size_t count, + C_DataType data_type, + size_t dest_rank, + C_CCLComm comm, + C_Stream stream) { + return C_SUCCESS; +} +C_Status XcclRecv(void *recv_buf, + size_t count, + C_DataType data_type, + size_t src_rank, + C_CCLComm comm, + C_Stream stream) { + return C_SUCCESS; +} + #define DEVICE_TYPE "FakeCPU" #define SUB_DEVICE_TYPE "V100" @@ -190,4 +264,18 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) { params->interface->device_max_chunk_size = DeviceMaxChunkSize; params->interface->device_min_chunk_size = DeviceMinChunkSize; params->interface->device_max_alloc_size = DeviceMaxAllocSize; + + params->interface->xccl_get_unique_id_size = XcclGetUniqueIdSize; + params->interface->xccl_get_unique_id = XcclGetUniqueId; + params->interface->xccl_all_reduce = XcclAllReduce; + params->interface->xccl_all_gather = XcclAllGather; + params->interface->xccl_broadcast = XcclBroadcast; + params->interface->xccl_comm_init_rank = XcclCommInitRank; + params->interface->xccl_destroy_comm = XcclDestroyComm; + params->interface->xccl_group_end = XcclGroupEnd; + params->interface->xccl_group_start = XcclGroupStart; + params->interface->xccl_reduce = XcclReduce; + params->interface->xccl_reduce_scatter = XcclReduceScatter; + params->interface->xccl_send = XcclSend; + params->interface->xccl_recv = XcclRecv; } diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index e57653702c..4b82f4a340 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -270,4 +270,91 @@ size_t DeviceInterface::GetExtraPaddingSize(size_t dev_id) { return 0; } +void DeviceInterface::CCLDestroyComm(ccl::CCLComm ccl_comm) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLCommInitRank(size_t num_ranks, + ccl::CCLRootId* root_id, + size_t rank_id, + ccl::CCLComm* ccl_comm) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLGetUniqueId(ccl::CCLRootId* root_id) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLBroadcast(void* data, + size_t num, + ccl::CCLDataType data_type, + size_t root, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLAllReduce(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLReduce(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLAllGather(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLReduceScatter(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLGroupStart() { INTERFACE_UNIMPLEMENT; } + +void DeviceInterface::CCLGroupEnd() { INTERFACE_UNIMPLEMENT; } + +void DeviceInterface::CCLSend(void* sendbuf, + size_t num, + ccl::CCLDataType data_type, + size_t dst_rank, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + +void DeviceInterface::CCLRecv(void* recvbuf, + size_t num, + ccl::CCLDataType data_type, + size_t src_rank, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + INTERFACE_UNIMPLEMENT; +} + +#undef INTERFACE_UNIMPLEMENT + } // namespace phi diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index 8cc6e49806..84249261d1 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -16,6 +16,7 @@ #ifdef PADDLE_WITH_CUSTOM_DEVICE #include +#include "paddle/phi/backends/c_comm_lib.h" #include "paddle/phi/backends/event.h" #include "paddle/phi/backends/stream.h" @@ -165,6 +166,65 @@ class DeviceInterface { // Driver / Runtime virtual size_t GetExtraPaddingSize(size_t dev_id); + // CCL + virtual void CCLDestroyComm(ccl::CCLComm ccl_comm); + + virtual void CCLCommInitRank(size_t num_ranks, + ccl::CCLRootId* root_id, + size_t rank_id, + ccl::CCLComm* ccl_comm); + + virtual void CCLGetUniqueId(ccl::CCLRootId* root_id); + + virtual void CCLBroadcast(void* data, + size_t num, + ccl::CCLDataType data_type, + size_t root, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + + virtual void CCLAllReduce(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + virtual void CCLReduce(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + virtual void CCLAllGather(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + virtual void CCLReduceScatter(void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + virtual void CCLGroupStart(); + virtual void CCLGroupEnd(); + virtual void CCLSend(void* sendbuf, + size_t num, + ccl::CCLDataType data_type, + size_t dst_rank, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + virtual void CCLRecv(void* recvbuf, + size_t num, + ccl::CCLDataType data_type, + size_t src_rank, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + private: const std::string type_; const uint8_t priority_; diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index 77c9ee6185..a4dc9176e1 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -74,6 +74,15 @@ typedef void (*C_Callback)(C_Device device, void* user_data, C_Status* status); +typedef struct { + size_t sz; + void* data; +} C_CCLRootId; + +typedef struct C_CCLComm_st* C_CCLComm; + +typedef enum { SUM = 0, AVG, MAX, MIN, PRODUCT } C_CCLReduceOp; + struct C_DeviceInterface { // Core fill it and plugin must to check it size_t size; @@ -526,6 +535,102 @@ struct C_DeviceInterface { void* reserved_info_api[8]; + ////////////// + // ccl api // + ////////////// + + /** + * @brief Get size of unique id + * + * @param[size_t*] size + */ + C_Status (*xccl_get_unique_id_size)(size_t* size); + + /** + * @brief Get unique id + * + * @param[C_CCLRootId*] unique_id + */ + C_Status (*xccl_get_unique_id)(C_CCLRootId* unique_id); + + /** + * @brief Initialize communicator + * + * @param[size_t] ranks + * @param[C_CCLRootId*] unique_id + * @param[size_t] rank + * @param[C_CCLComm*] comm + */ + C_Status (*xccl_comm_init_rank)(size_t ranks, + C_CCLRootId* unique_id, + size_t rank, + C_CCLComm* comm); + + /** + * @brief Destroy communicator + * + * @param[C_CCLComm] comm + */ + C_Status (*xccl_destroy_comm)(C_CCLComm comm); + + C_Status (*xccl_all_reduce)(void* send_buf, + void* recv_buf, + size_t count, + C_DataType data_type, + C_CCLReduceOp op, + C_CCLComm comm, + C_Stream stream); + + C_Status (*xccl_broadcast)(void* buf, + size_t count, + C_DataType data_type, + size_t root, + C_CCLComm comm, + C_Stream stream); + + C_Status (*xccl_reduce)(void* send_buf, + void* recv_buf, + size_t count, + C_DataType data_type, + C_CCLReduceOp op, + C_CCLComm comm, + C_Stream stream); + + C_Status (*xccl_all_gather)(void* send_buf, + void* recv_buf, + size_t count, + C_DataType data_type, + C_CCLComm comm, + C_Stream stream); + + C_Status (*xccl_reduce_scatter)(void* send_buf, + void* recv_buf, + size_t count, + C_DataType data_type, + C_CCLReduceOp op, + C_CCLComm comm, + C_Stream stream); + + C_Status (*xccl_group_start)(); + + C_Status (*xccl_group_end)(); + + C_Status (*xccl_send)(void* send_buf, + size_t count, + C_DataType data_type, + size_t dest_rank, + C_CCLComm comm, + C_Stream stream); + + C_Status (*xccl_recv)(void* recv_buf, + size_t count, + C_DataType data_type, + size_t src_rank, + C_CCLComm comm, + C_Stream stream); + + void* reserved_ccl_api[8]; + /////////////// // other api // /////////////// diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 5b1022794a..405a87f749 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -25,6 +25,7 @@ #include #include "glog/logging.h" +#include "paddle/utils/string/split.h" namespace phi { @@ -390,6 +391,139 @@ std::vector DeviceManager::GetDeviceList( return dev_impl->GetDeviceList(); } +std::vector DeviceManager::GetSelectedDeviceList( + const std::string& device_type) { + std::vector devices; + std::string FLAGS = "FLAGS_selected_" + device_type + "s"; + auto FLAGS_selected_devices = getenv(FLAGS.c_str()); + if (FLAGS_selected_devices) { + auto devices_str = paddle::string::Split(FLAGS_selected_devices, ','); + for (auto id : devices_str) { + devices.push_back(atoi(id.c_str())); + } + } else { + int count = DeviceManager::GetDeviceCount(device_type); + for (int i = 0; i < count; ++i) { + devices.push_back(i); + } + } + return devices; +} + +void DeviceManager::CCLDestroyComm(const std::string& device_type, + ccl::CCLComm ccl_comm) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLDestroyComm(ccl_comm); +} + +void DeviceManager::CCLCommInitRank(const std::string& device_type, + size_t num_ranks, + ccl::CCLRootId* root_id, + size_t rank_id, + ccl::CCLComm* ccl_comm) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLCommInitRank(num_ranks, root_id, rank_id, ccl_comm); +} + +void DeviceManager::CCLGetUniqueId(const std::string& device_type, + ccl::CCLRootId* root_id) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLGetUniqueId(root_id); +} + +void DeviceManager::CCLBroadcast(const std::string& device_type, + void* data, + size_t num, + ccl::CCLDataType data_type, + size_t root_id, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLBroadcast(data, num, data_type, root_id, ccl_comm, stream); +} + +void DeviceManager::CCLAllReduce(const std::string& device_type, + void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLAllReduce( + in_data, out_data, num, data_type, reduce_op, ccl_comm, stream); +} + +void DeviceManager::CCLReduce(const std::string& device_type, + void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLReduce( + in_data, out_data, num, data_type, reduce_op, ccl_comm, stream); +} + +void DeviceManager::CCLAllGather(const std::string& device_type, + void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLAllGather(in_data, out_data, num, data_type, ccl_comm, stream); +} + +void DeviceManager::CCLReduceScatter(const std::string& device_type, + void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLReduceScatter( + in_data, out_data, num, data_type, op, ccl_comm, stream); +} + +void DeviceManager::CCLGroupStart(const std::string& device_type) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLGroupStart(); +} + +void DeviceManager::CCLGroupEnd(const std::string& device_type) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLGroupEnd(); +} + +void DeviceManager::CCLSend(const std::string& device_type, + void* sendbuf, + size_t num, + ccl::CCLDataType data_type, + size_t dst_rank, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLSend(sendbuf, num, data_type, dst_rank, ccl_comm, stream); +} + +void DeviceManager::CCLRecv(const std::string& device_type, + void* recvbuf, + size_t num, + ccl::CCLDataType data_type, + size_t src_rank, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + dev_impl->CCLRecv(recvbuf, num, data_type, src_rank, ccl_comm, stream); +} + DeviceManager& DeviceManager::Instance() { static DeviceManager platform_manager; return platform_manager; diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 56d99ba43b..4ad7643c33 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -17,6 +17,7 @@ #include +#include "paddle/phi/backends/c_comm_lib.h" #include "paddle/phi/backends/device_base.h" #include "paddle/phi/backends/device_ext.h" #include "paddle/phi/backends/dynload/port.h" @@ -159,6 +160,74 @@ class DeviceManager { static std::vector GetDeviceList(const std::string& device_type); + static std::vector GetSelectedDeviceList( + const std::string& device_type); + + // CCL + static void CCLDestroyComm(const std::string& device_type, + ccl::CCLComm ccl_comm); + static void CCLCommInitRank(const std::string& device_type, + size_t num_ranks, + ccl::CCLRootId* root_id, + size_t rank_id, + ccl::CCLComm* ccl_comm); + static void CCLGetUniqueId(const std::string& device_type, + ccl::CCLRootId* root_id); + static void CCLBroadcast(const std::string& device_type, + void* data, + size_t num, + ccl::CCLDataType data_type, + size_t root, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + static void CCLAllReduce(const std::string& device_type, + void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + static void CCLReduce(const std::string& device_type, + void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp reduce_op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + static void CCLAllGather(const std::string& device_type, + void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + static void CCLReduceScatter(const std::string& device_type, + void* in_data, + void* out_data, + size_t num, + ccl::CCLDataType data_type, + ccl::CCLReduceOp op, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + static void CCLGroupStart(const std::string& device_type); + static void CCLGroupEnd(const std::string& device_type); + static void CCLSend(const std::string& device_type, + void* sendbuf, + size_t num, + ccl::CCLDataType data_type, + size_t dst_rank, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + static void CCLRecv(const std::string& device_type, + void* recvbuf, + size_t num, + ccl::CCLDataType data_type, + size_t src_rank, + const ccl::CCLComm& ccl_comm, + const stream::Stream& stream); + static void Clear(); private: diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index 4fcf9c5d21..aa959150ce 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -230,7 +230,10 @@ def _convert_to_place(device): device_id = int(selected_mlus[0]) place = core.MLUPlace(device_id) elif device in core.get_all_custom_device_type(): - place = core.CustomPlace(device, 0) + selected_devices = os.getenv("FLAGS_selected_{}s".format(device), + "0").split(",") + device_id = int(selected_devices[0]) + place = core.CustomPlace(device, device_id) else: avaliable_gpu_device = re.match(r'gpu:\d+', lower_device) avaliable_xpu_device = re.match(r'xpu:\d+', lower_device) -- GitLab