From 6241913b8bf3f6259a38cad29ea8ba9cd598ff4a Mon Sep 17 00:00:00 2001 From: maxhuiy <1508399706@qq.com> Date: Wed, 23 Feb 2022 17:33:17 +0800 Subject: [PATCH] [MLU] add cncl parallel context and mlu resource pool (#39803) * [MLU] add cncl parallel context and mlu resource pool * [MLU] fix the cncl_context_test --- paddle/fluid/imperative/CMakeLists.txt | 3 + paddle/fluid/imperative/cncl_context.cc | 237 ++++++++++++++++++ paddle/fluid/imperative/cncl_context.h | 75 ++++++ paddle/fluid/imperative/tests/CMakeLists.txt | 3 + .../imperative/tests/cncl_context_test.cc | 141 +++++++++++ paddle/fluid/platform/CMakeLists.txt | 4 + .../fluid/platform/device/mlu/CMakeLists.txt | 1 + .../platform/device/mlu/mlu_resource_pool.cc | 99 ++++++++ .../platform/device/mlu/mlu_resource_pool.h | 64 +++++ paddle/fluid/pybind/CMakeLists.txt | 8 + paddle/fluid/pybind/imperative.cc | 13 + 11 files changed, 648 insertions(+) create mode 100644 paddle/fluid/imperative/cncl_context.cc create mode 100644 paddle/fluid/imperative/cncl_context.h create mode 100644 paddle/fluid/imperative/tests/cncl_context_test.cc create mode 100644 paddle/fluid/platform/device/mlu/mlu_resource_pool.cc create mode 100644 paddle/fluid/platform/device/mlu/mlu_resource_pool.h diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 90cf0e76e00..72f7e5af9a9 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -31,6 +31,9 @@ if(NOT WIN32) cc_library(hccl_context SRCS hccl_context.cc DEPS collective_helper device_context tensor var_type_traits) cc_library(reducer SRCS reducer.cc DEPS layer) endif() + if(WITH_CNCL) + cc_library(cncl_context SRCS cncl_context.cc DEPS collective_helper device_context tensor var_type_traits) + endif() if(WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL) cc_library(heter_ccl_context SRCS heter_ccl_context.cc DEPS collective_helper device_context tensor var_type_traits) endif() diff --git a/paddle/fluid/imperative/cncl_context.cc b/paddle/fluid/imperative/cncl_context.cc new file mode 100644 index 00000000000..779b748c2d2 --- /dev/null +++ b/paddle/fluid/imperative/cncl_context.cc @@ -0,0 +1,237 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(PADDLE_WITH_CNCL) +#include "paddle/fluid/imperative/cncl_context.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" +#include "paddle/fluid/platform/place.h" + +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/mlu/cncl_helper.h" +#include "paddle/fluid/platform/device/mlu/mlu_info.h" + +namespace paddle { +namespace framework { +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace imperative { + +static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, + const mluStream stream, const platform::CNCLComm *comm) { + const auto &place = src.place(); + PADDLE_ENFORCE_EQ( + platform::is_mlu_place(place), true, + platform::errors::Unimplemented( + "Imperative mode does not support multi-CPU training yet.")); + + const void *src_ptr = src.data(); + dst->Resize(src.dims()); + auto *dst_ptr = dst->mutable_data(src.place(), src.dtype()); + auto cncl_dtype = + platform::ToCNCLDataType(framework::TransToProtoVarType(src.dtype())); + PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(src_ptr, dst_ptr, src.numel(), + cncl_dtype, cnclSum, comm->comm(), + stream)); +} + +void CNCLParallelContext::BcastCNCLId( + std::vector &cncl_ids, // NOLINT + int root, int server_fd) { + if (strategy_.local_rank_ == root) { + std::vector other_trainers; + for (auto &ep : strategy_.trainer_endpoints_) { + if (ep != strategy_.current_endpoint_) { + other_trainers.push_back(ep); + } + } + platform::SendBroadCastCommID(other_trainers, &cncl_ids); + } else { + platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_, + &cncl_ids); + } +} + +void CNCLParallelContext::Init() { + int server_fd = -1; + + std::vector cncl_ids; + cncl_ids.resize(strategy_.nrings_); + + if (strategy_.local_rank_ == 0) { + // generate the unique cnclid on the root worker + for (size_t i = 0; i < cncl_ids.size(); ++i) { + PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCliqueId(&cncl_ids[i])); + } + } else { + server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_) + .socket(); + } + BcastCNCLId(cncl_ids, 0, server_fd); + + int mlu_id = place_.device; + for (int ring_id = 0; ring_id < strategy_.nrings_; ++ring_id) { + VLOG(0) << "init cncl context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " mlu id: " << mlu_id + << " ring id: " << ring_id; + // it will assign cncl_comm in MLUDeviceContext within ring_id + platform::CNCLCommContext::Instance().CreateComm( + &cncl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, mlu_id, + ring_id); + + compute_events_.emplace_back( + platform::MluEventResourcePool::Instance().New(place_.device)); + comm_events_.emplace_back( + platform::MluEventResourcePool::Instance().New(place_.device)); + } +} + +void CNCLParallelContext::InitWithRingID(int ring_id) { + int server_fd = -1; + std::vector cncl_ids; + cncl_ids.resize(1); + + if (strategy_.local_rank_ == 0) { + // generate the unique cnclid on the root worker + PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCliqueId(&cncl_ids[0])); + } else { + server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_) + .socket(); + } + BcastCNCLId(cncl_ids, 0, server_fd); + + int mlu_id = place_.device; + VLOG(0) << "init cncl context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " mlu id: " << mlu_id + << " ring id: " << ring_id; + // it will assign cncl_comm in MLUDeviceContext within ring_id + platform::CNCLCommContext::Instance().CreateComm( + &cncl_ids[0], strategy_.nranks_, strategy_.local_rank_, mlu_id, ring_id); + + compute_events_.emplace_back( + platform::MluEventResourcePool::Instance().New(place_.device)); + comm_events_.emplace_back( + platform::MluEventResourcePool::Instance().New(place_.device)); +} + +void CNCLParallelContext::AllReduceByStream(const framework::Variable &src, + framework::Variable *dst, + int ring_id, bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + platform::is_mlu_place(place_), true, + platform::errors::Unimplemented( + "Dynamic graph mode does not support multi-CPU training yet.")); + auto *dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place_)); + platform::CNCLComm *comm = + platform::CNCLCommContext::Instance().Get(ring_id, place_); + mluStream stream = (use_calc_stream ? dev_ctx->stream() : comm->stream()); + + if (src.IsType()) { + if (!dst->IsType()) { + dst->Clear(); + } + AllReduce(src.Get(), + dst->GetMutable(), stream, comm); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported variable type %s for imperative allreduce, only " + "LoDTensor is supported.", + platform::demangle(framework::ToTypeName(src.Type())))); + } +} + +void CNCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { + VLOG(3) << "/// DEBUG /// start inter broadcast with ring_id: " << ring_id; + framework::Tensor *src_tensor = src->GetMutable(); + const auto &place = src_tensor->place(); + platform::CNCLComm *comm = + platform::CNCLCommContext::Instance().Get(ring_id, place); + mluStream stream = comm->stream(); + + void *src_ptr = src_tensor->data(); + auto cncl_dtype = platform::ToCNCLDataType( + framework::TransToProtoVarType(src_tensor->dtype())); + PADDLE_ENFORCE_MLU_SUCCESS(cnclBcast(src_ptr, src_tensor->numel(), cncl_dtype, + 0, comm->comm(), stream)); +} + +paddle::platform::DeviceContext *CNCLParallelContext::GetDeviceContext( + int ring_id) { + return static_cast( + platform::CNCLCommContext::Instance() + .Get(ring_id, place_) + ->dev_context()); +} + +void CNCLParallelContext::WaitCompute(int ring_id) { + PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange( + "ring id must >= 0, but got %d", ring_id)); + PADDLE_ENFORCE_LT(ring_id, compute_events_.size(), + platform::errors::OutOfRange( + "ring id must < compute events size," + "but got ring id = %d, compute events size = %d", + ring_id, compute_events_.size())); + + auto compute_stream = static_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->stream(); + auto comm_stream = + platform::CNCLCommContext::Instance().Get(ring_id, place_)->stream(); + auto event = compute_events_[ring_id].get(); + + // compute_stream-->event-->comm_stream + PADDLE_ENFORCE_MLU_SUCCESS(cnrtPlaceNotifier(event, compute_stream)); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueWaitNotifier(event, comm_stream, 0)); +} + +void CNCLParallelContext::WaitComm(int ring_id) { + PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange( + "ring id must >= 0, but got %d", ring_id)); + PADDLE_ENFORCE_LT(ring_id, comm_events_.size(), + platform::errors::OutOfRange( + "ring id must < comm events size," + "but got ring id = %d, comm events size = %d", + ring_id, comm_events_.size())); + + auto compute_stream = static_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->stream(); + auto comm_stream = + platform::CNCLCommContext::Instance().Get(ring_id, place_)->stream(); + auto event = comm_events_[ring_id].get(); + + // comm_stream-->event-->compute_stream + PADDLE_ENFORCE_MLU_SUCCESS(cnrtPlaceNotifier(event, comm_stream)); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueWaitNotifier(event, compute_stream, 0)); +} + +void CNCLParallelContext::SynchronizeCompute() { + auto *compute_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place_)); + compute_dev_ctx->Wait(); +} + +} // namespace imperative +} // namespace paddle + +#endif diff --git a/paddle/fluid/imperative/cncl_context.h b/paddle/fluid/imperative/cncl_context.h new file mode 100644 index 00000000000..85f53319bfc --- /dev/null +++ b/paddle/fluid/imperative/cncl_context.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if defined(PADDLE_WITH_CNCL) +#include + +#include +#include +#include + +#include "paddle/fluid/imperative/parallel_context.h" +#include "paddle/fluid/platform/device/mlu/mlu_resource_pool.h" + +namespace paddle { +namespace framework { +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace imperative { + +class CNCLParallelContext : public ParallelContext { + public: + explicit CNCLParallelContext(const ParallelStrategy& strategy, + const platform::Place& place) + : ParallelContext(strategy, place) {} + + ~CNCLParallelContext() override = default; + + void BcastCNCLId(std::vector& cncl_ids, int root, // NOLINT + int server_fd); + + void Init() override; + + void InitWithRingID(int ring_id) override; + + void AllReduceByStream(const framework::Variable& src, + framework::Variable* dst, int ring_id, + bool use_calc_stream) override; + + void Broadcast(framework::Variable* src, int ring_id) override; + + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; + + void WaitCompute(int ring_id) override; + + void WaitComm(int ring_id) override; + + void SynchronizeCompute() override; + + private: + // used for comm wait compute, compute_stream-->event-->comm_stream[ring_id] + std::vector> compute_events_; + + // used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream + std::vector> comm_events_; +}; + +} // namespace imperative +} // namespace paddle +#endif diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index 774bb9653e2..a9c81cb8779 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -9,6 +9,9 @@ else() if (WITH_XPU_BKCL) cc_test(bkcl_context_test SRCS bkcl_context_test.cc DEPS bkcl_context) endif() + if (WITH_CNCL) + cc_test(cncl_context_test SRCS cncl_context_test.cc DEPS cncl_context) + endif() endif(WIN32) diff --git a/paddle/fluid/imperative/tests/cncl_context_test.cc b/paddle/fluid/imperative/tests/cncl_context_test.cc new file mode 100644 index 00000000000..1d5ee8e7fc8 --- /dev/null +++ b/paddle/fluid/imperative/tests/cncl_context_test.cc @@ -0,0 +1,141 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/imperative/cncl_context.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" + +#include "gtest/gtest.h" + +namespace imperative = paddle::imperative; +namespace platform = paddle::platform; +namespace framework = paddle::framework; + +// Node1: FLAGS_selected_mlus=0 PADDLE_TRAINER_ID=0 ./cncl_context_test +// Node2: FLAGS_selected_mlus=1 PADDLE_TRAINER_ID=1 ./cncl_context_test + +int nrings = 1; +imperative::ParallelStrategy GetStrategy(int local_rank) { + std::vector eps = {"127.0.0.1:9866", "localhost:9867"}; + imperative::ParallelStrategy strategy; + strategy.trainer_endpoints_ = eps; + strategy.current_endpoint_ = eps[local_rank]; + strategy.nranks_ = 2; + strategy.local_rank_ = local_rank; + strategy.nrings_ = nrings; + return strategy; +} + +#if defined(PADDLE_WITH_CNCL) +void Broadcast(int local_rank, int device_id) { + int data_size = 4; + float test_data = 7; + const auto& place = platform::MLUPlace(device_id); + platform::MLUDeviceContext ctx(place); + + imperative::CNCLParallelContext cpc(GetStrategy(local_rank), place); + + // init + cpc.Init(); + + framework::Variable* src_dev_var(new framework::Variable()); + auto* src_dev_tensor = src_dev_var->GetMutable(); + src_dev_tensor->mutable_data(phi::make_ddim({data_size}), place); + + // fill data for rank 0 only + std::vector src_vec; + if (local_rank == 0) { + for (int i = 0; i < data_size; ++i) { + src_vec.push_back(test_data); + } + framework::TensorFromVector(src_vec, ctx, src_dev_tensor); + } + ctx.Wait(); + + // call broadcast + cpc.Broadcast(src_dev_var, 0); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // check result + std::vector dst_vec; + framework::TensorToVector(*src_dev_tensor, ctx, &dst_vec); + ctx.Wait(); + + for (int i = 0; i < data_size; ++i) { + EXPECT_EQ(dst_vec[i], test_data); + } +} + +TEST(Broadcast, Run) { + if (platform::GetMLUDeviceCount() >= 2) { + int local_rank = atoi(getenv("PADDLE_TRAINER_ID")); + int device_id = atoi(getenv("FLAGS_selected_mlus")); + Broadcast(local_rank, device_id); + } +} + +void AllReduceByStream(int local_rank, int device_id) { + int data_size = 32; + const auto& place = platform::MLUPlace(device_id); + platform::MLUDeviceContext ctx(place); + + imperative::CNCLParallelContext cpc(GetStrategy(local_rank), place); + + // init + cpc.Init(); + + // input data + framework::Variable* src_dev_var(new framework::Variable()); + auto* src_dev_tensor = src_dev_var->GetMutable(); + src_dev_tensor->mutable_data(phi::make_ddim({data_size}), place); + + // fill input data + std::vector src_vec; + for (int i = 0; i < data_size; ++i) { + src_vec.push_back(1.0 + local_rank); + } + framework::TensorFromVector(src_vec, ctx, src_dev_tensor); + ctx.Wait(); + + // output data + framework::Variable* dst_dev_var(new framework::Variable()); + auto* dst_dev_tensor = dst_dev_var->GetMutable(); + dst_dev_tensor->mutable_data(phi::make_ddim({data_size}), place); + + // call allreduce + cpc.AllReduceByStream(*src_dev_var, dst_dev_var, 0, false); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // check result + std::vector dst_vec; + framework::TensorToVector(*dst_dev_tensor, ctx, &dst_vec); + ctx.Wait(); + + EXPECT_EQ(dst_vec.size(), src_vec.size()); + for (int i = 0; i < data_size; ++i) { + EXPECT_EQ(dst_vec[i], 3.0); + } +} + +TEST(AllReduceByStream, Run) { + if (platform::GetMLUDeviceCount() >= 2) { + int local_rank = atoi(getenv("PADDLE_TRAINER_ID")); + int device_id = atoi(getenv("FLAGS_selected_mlus")); + AllReduceByStream(local_rank, device_id); + } +} +#endif diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 478b71745e4..37709c953e1 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -148,6 +148,10 @@ if(WITH_ASCEND_CL) target_link_libraries(device_context npu_resource_pool) endif() +if(WITH_MLU) + target_link_libraries(device_context mlu_resource_pool) +endif() + if(WITH_CUSTOM_DEVICE) target_link_libraries(device_context custom_context) endif() diff --git a/paddle/fluid/platform/device/mlu/CMakeLists.txt b/paddle/fluid/platform/device/mlu/CMakeLists.txt index 724776bfad2..1f3a7670849 100644 --- a/paddle/fluid/platform/device/mlu/CMakeLists.txt +++ b/paddle/fluid/platform/device/mlu/CMakeLists.txt @@ -9,3 +9,4 @@ cc_library(mlu_stream SRCS mlu_stream.cc DEPS boost mlu_info stream_callback_man cc_library(mlu_device_context SRCS device_context.cc DEPS mlu_stream) cc_test(mlu_device_context_test SRCS device_context_test.cc DEPS mlu_device_context) cc_library(mlu_collective_helper SRCS mlu_collective_helper.cc DEPS mlu_stream mlu_info) +cc_library(mlu_resource_pool SRCS mlu_resource_pool.cc DEPS mlu_info) diff --git a/paddle/fluid/platform/device/mlu/mlu_resource_pool.cc b/paddle/fluid/platform/device/mlu/mlu_resource_pool.cc new file mode 100644 index 00000000000..fbe3eca1c4d --- /dev/null +++ b/paddle/fluid/platform/device/mlu/mlu_resource_pool.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(PADDLE_WITH_MLU) +#include "paddle/fluid/platform/device/mlu/mlu_resource_pool.h" + +namespace paddle { +namespace platform { + +MluStreamResourcePool::MluStreamResourcePool() { + int dev_cnt = platform::GetMLUDeviceCount(); + pool_.reserve(dev_cnt); + for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { + auto creator = [dev_idx] { + platform::SetMLUDeviceId(dev_idx); + mluStream stream; + cnrtQueueCreate(&stream); + return stream; + }; + + auto deleter = [dev_idx](mluStream stream) { + platform::SetMLUDeviceId(dev_idx); + cnrtQueueDestroy(stream); + }; + + pool_.emplace_back(ResourcePool::Create(creator, deleter)); + } +} + +MluStreamResourcePool& MluStreamResourcePool::Instance() { + static MluStreamResourcePool pool; + return pool; +} + +std::shared_ptr MluStreamResourcePool::New(int dev_idx) { + PADDLE_ENFORCE_GE( + dev_idx, 0, + platform::errors::InvalidArgument( + "The dev_idx should be not less than 0, but got %d.", dev_idx)); + PADDLE_ENFORCE_LT( + dev_idx, pool_.size(), + platform::errors::OutOfRange( + "The dev_idx should be less than device count %d, but got %d.", + pool_.size(), dev_idx)); + return pool_[dev_idx]->New(); +} + +MluEventResourcePool::MluEventResourcePool() { + int dev_cnt = platform::GetMLUDeviceCount(); + pool_.reserve(dev_cnt); + for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { + auto creator = [dev_idx] { + platform::SetMLUDeviceId(dev_idx); + mluEventHandle event; + cnrtNotifierCreate(&event); + return event; + }; + + auto deleter = [dev_idx](mluEventHandle event) { + platform::SetMLUDeviceId(dev_idx); + cnrtNotifierDestroy(event); + }; + + pool_.emplace_back(ResourcePool::Create(creator, deleter)); + } +} + +MluEventResourcePool& MluEventResourcePool::Instance() { + static MluEventResourcePool pool; + return pool; +} + +std::shared_ptr MluEventResourcePool::New(int dev_idx) { + PADDLE_ENFORCE_GE( + dev_idx, 0, + platform::errors::InvalidArgument( + "The dev_idx should be not less than 0, but got %d.", dev_idx)); + PADDLE_ENFORCE_LT( + dev_idx, pool_.size(), + platform::errors::OutOfRange( + "The dev_idx should be less than device count %d, but got %d.", + pool_.size(), dev_idx)); + return pool_[dev_idx]->New(); +} + +} // namespace platform +} // namespace paddle +#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_resource_pool.h b/paddle/fluid/platform/device/mlu/mlu_resource_pool.h new file mode 100644 index 00000000000..b0e2af7f024 --- /dev/null +++ b/paddle/fluid/platform/device/mlu/mlu_resource_pool.h @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if defined(PADDLE_WITH_MLU) +#include +#include +#include + +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#include "paddle/fluid/platform/resource_pool.h" + +namespace paddle { +namespace platform { + +using MluStreamObject = std::remove_pointer::type; +using MluEventObject = std::remove_pointer::type; + +class MluStreamResourcePool { + public: + std::shared_ptr New(int dev_idx); + + static MluStreamResourcePool &Instance(); + + private: + MluStreamResourcePool(); + + DISABLE_COPY_AND_ASSIGN(MluStreamResourcePool); + + private: + std::vector>> pool_; +}; + +class MluEventResourcePool { + public: + std::shared_ptr New(int dev_idx); + + static MluEventResourcePool &Instance(); + + private: + MluEventResourcePool(); + + DISABLE_COPY_AND_ASSIGN(MluEventResourcePool); + + private: + std::vector>> pool_; +}; + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 26c35167f40..01b21d02ea0 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -37,6 +37,10 @@ if (WITH_ASCEND_CL) set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context) endif() +if (WITH_CNCL) + set(PYBIND_DEPS ${PYBIND_DEPS} cncl_context) +endif() + if(NOT WIN32) set(PYBIND_DEPS ${PYBIND_DEPS} data_loader) set(PYBIND_DEPS ${PYBIND_DEPS} mmap_allocator) @@ -134,6 +138,10 @@ if(WITH_PYTHON) list(APPEND OP_FUNCTION_GENERETOR_DEPS hccl_context) endif(WITH_ASCEND_CL) + if(WITH_CNCL) + list(APPEND OP_FUNCTION_GENERETOR_DEPS cncl_context) + endif(WITH_CNCL) + add_executable(op_function_generator op_function_generator.cc) target_link_libraries(op_function_generator ${OP_FUNCTION_GENERETOR_DEPS}) add_executable(eager_op_function_generator eager_op_function_generator.cc) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 387addda9ed..8c5ed2d1183 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -36,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/bkcl_context.h" +#include "paddle/fluid/imperative/cncl_context.h" #include "paddle/fluid/imperative/data_loader.h" #include "paddle/fluid/imperative/gloo_context.h" #include "paddle/fluid/imperative/hccl_context.h" @@ -2559,6 +2560,18 @@ void BindImperative(py::module *m_ptr) { py::arg("ring_id")); #endif +#if defined(PADDLE_WITH_CNCL) + py::class_>( + m, "CNCLParallelContext") + .def(py::init()) + .def("init", [](imperative::CNCLParallelContext &self) { self.Init(); }) + .def("init_with_ring_id", + &imperative::CNCLParallelContext::InitWithRingID, + py::arg("ring_id")); +#endif + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) py::class_