未验证 提交 6241913b 编写于 作者: mhhhh1's avatar mhhhh1 提交者: GitHub

[MLU] add cncl parallel context and mlu resource pool (#39803)

* [MLU] add cncl parallel context and mlu resource pool

* [MLU] fix the cncl_context_test
上级 b9675acc
......@@ -31,6 +31,9 @@ if(NOT WIN32)
cc_library(hccl_context SRCS hccl_context.cc DEPS collective_helper device_context tensor var_type_traits)
cc_library(reducer SRCS reducer.cc DEPS layer)
endif()
if(WITH_CNCL)
cc_library(cncl_context SRCS cncl_context.cc DEPS collective_helper device_context tensor var_type_traits)
endif()
if(WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL)
cc_library(heter_ccl_context SRCS heter_ccl_context.cc DEPS collective_helper device_context tensor var_type_traits)
endif()
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/imperative/cncl_context.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/mlu/cncl_helper.h"
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace imperative {
static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
const mluStream stream, const platform::CNCLComm *comm) {
const auto &place = src.place();
PADDLE_ENFORCE_EQ(
platform::is_mlu_place(place), true,
platform::errors::Unimplemented(
"Imperative mode does not support multi-CPU training yet."));
const void *src_ptr = src.data();
dst->Resize(src.dims());
auto *dst_ptr = dst->mutable_data(src.place(), src.dtype());
auto cncl_dtype =
platform::ToCNCLDataType(framework::TransToProtoVarType(src.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(src_ptr, dst_ptr, src.numel(),
cncl_dtype, cnclSum, comm->comm(),
stream));
}
void CNCLParallelContext::BcastCNCLId(
std::vector<cnclCliqueId> &cncl_ids, // NOLINT
int root, int server_fd) {
if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers;
for (auto &ep : strategy_.trainer_endpoints_) {
if (ep != strategy_.current_endpoint_) {
other_trainers.push_back(ep);
}
}
platform::SendBroadCastCommID(other_trainers, &cncl_ids);
} else {
platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_,
&cncl_ids);
}
}
void CNCLParallelContext::Init() {
int server_fd = -1;
std::vector<cnclCliqueId> cncl_ids;
cncl_ids.resize(strategy_.nrings_);
if (strategy_.local_rank_ == 0) {
// generate the unique cnclid on the root worker
for (size_t i = 0; i < cncl_ids.size(); ++i) {
PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCliqueId(&cncl_ids[i]));
}
} else {
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastCNCLId(cncl_ids, 0, server_fd);
int mlu_id = place_.device;
for (int ring_id = 0; ring_id < strategy_.nrings_; ++ring_id) {
VLOG(0) << "init cncl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " mlu id: " << mlu_id
<< " ring id: " << ring_id;
// it will assign cncl_comm in MLUDeviceContext within ring_id
platform::CNCLCommContext::Instance().CreateComm(
&cncl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, mlu_id,
ring_id);
compute_events_.emplace_back(
platform::MluEventResourcePool::Instance().New(place_.device));
comm_events_.emplace_back(
platform::MluEventResourcePool::Instance().New(place_.device));
}
}
void CNCLParallelContext::InitWithRingID(int ring_id) {
int server_fd = -1;
std::vector<cnclCliqueId> cncl_ids;
cncl_ids.resize(1);
if (strategy_.local_rank_ == 0) {
// generate the unique cnclid on the root worker
PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCliqueId(&cncl_ids[0]));
} else {
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastCNCLId(cncl_ids, 0, server_fd);
int mlu_id = place_.device;
VLOG(0) << "init cncl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " mlu id: " << mlu_id
<< " ring id: " << ring_id;
// it will assign cncl_comm in MLUDeviceContext within ring_id
platform::CNCLCommContext::Instance().CreateComm(
&cncl_ids[0], strategy_.nranks_, strategy_.local_rank_, mlu_id, ring_id);
compute_events_.emplace_back(
platform::MluEventResourcePool::Instance().New(place_.device));
comm_events_.emplace_back(
platform::MluEventResourcePool::Instance().New(place_.device));
}
void CNCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst,
int ring_id, bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
platform::is_mlu_place(place_), true,
platform::errors::Unimplemented(
"Dynamic graph mode does not support multi-CPU training yet."));
auto *dev_ctx = static_cast<platform::MLUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
platform::CNCLComm *comm =
platform::CNCLCommContext::Instance().Get(ring_id, place_);
mluStream stream = (use_calc_stream ? dev_ctx->stream() : comm->stream());
if (src.IsType<framework::LoDTensor>()) {
if (!dst->IsType<framework::LoDTensor>()) {
dst->Clear();
}
AllReduce(src.Get<framework::LoDTensor>(),
dst->GetMutable<framework::LoDTensor>(), stream, comm);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported variable type %s for imperative allreduce, only "
"LoDTensor is supported.",
platform::demangle(framework::ToTypeName(src.Type()))));
}
}
void CNCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
VLOG(3) << "/// DEBUG /// start inter broadcast with ring_id: " << ring_id;
framework::Tensor *src_tensor = src->GetMutable<framework::LoDTensor>();
const auto &place = src_tensor->place();
platform::CNCLComm *comm =
platform::CNCLCommContext::Instance().Get(ring_id, place);
mluStream stream = comm->stream();
void *src_ptr = src_tensor->data();
auto cncl_dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(src_tensor->dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclBcast(src_ptr, src_tensor->numel(), cncl_dtype,
0, comm->comm(), stream));
}
paddle::platform::DeviceContext *CNCLParallelContext::GetDeviceContext(
int ring_id) {
return static_cast<platform::DeviceContext *>(
platform::CNCLCommContext::Instance()
.Get(ring_id, place_)
->dev_context());
}
void CNCLParallelContext::WaitCompute(int ring_id) {
PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange(
"ring id must >= 0, but got %d", ring_id));
PADDLE_ENFORCE_LT(ring_id, compute_events_.size(),
platform::errors::OutOfRange(
"ring id must < compute events size,"
"but got ring id = %d, compute events size = %d",
ring_id, compute_events_.size()));
auto compute_stream = static_cast<platform::MLUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto comm_stream =
platform::CNCLCommContext::Instance().Get(ring_id, place_)->stream();
auto event = compute_events_[ring_id].get();
// compute_stream-->event-->comm_stream
PADDLE_ENFORCE_MLU_SUCCESS(cnrtPlaceNotifier(event, compute_stream));
PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueWaitNotifier(event, comm_stream, 0));
}
void CNCLParallelContext::WaitComm(int ring_id) {
PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange(
"ring id must >= 0, but got %d", ring_id));
PADDLE_ENFORCE_LT(ring_id, comm_events_.size(),
platform::errors::OutOfRange(
"ring id must < comm events size,"
"but got ring id = %d, comm events size = %d",
ring_id, comm_events_.size()));
auto compute_stream = static_cast<platform::MLUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto comm_stream =
platform::CNCLCommContext::Instance().Get(ring_id, place_)->stream();
auto event = comm_events_[ring_id].get();
// comm_stream-->event-->compute_stream
PADDLE_ENFORCE_MLU_SUCCESS(cnrtPlaceNotifier(event, comm_stream));
PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueWaitNotifier(event, compute_stream, 0));
}
void CNCLParallelContext::SynchronizeCompute() {
auto *compute_dev_ctx = static_cast<platform::MLUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
}
} // namespace imperative
} // namespace paddle
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CNCL)
#include <cncl.h>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/platform/device/mlu/mlu_resource_pool.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace imperative {
class CNCLParallelContext : public ParallelContext {
public:
explicit CNCLParallelContext(const ParallelStrategy& strategy,
const platform::Place& place)
: ParallelContext(strategy, place) {}
~CNCLParallelContext() override = default;
void BcastCNCLId(std::vector<cnclCliqueId>& cncl_ids, int root, // NOLINT
int server_fd);
void Init() override;
void InitWithRingID(int ring_id) override;
void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;
void Broadcast(framework::Variable* src, int ring_id) override;
paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override;
void WaitCompute(int ring_id) override;
void WaitComm(int ring_id) override;
void SynchronizeCompute() override;
private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<std::shared_ptr<platform::MluEventObject>> compute_events_;
// used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream
std::vector<std::shared_ptr<platform::MluEventObject>> comm_events_;
};
} // namespace imperative
} // namespace paddle
#endif
......@@ -9,6 +9,9 @@ else()
if (WITH_XPU_BKCL)
cc_test(bkcl_context_test SRCS bkcl_context_test.cc DEPS bkcl_context)
endif()
if (WITH_CNCL)
cc_test(cncl_context_test SRCS cncl_context_test.cc DEPS cncl_context)
endif()
endif(WIN32)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thread> // NOLINT
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/cncl_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "gtest/gtest.h"
namespace imperative = paddle::imperative;
namespace platform = paddle::platform;
namespace framework = paddle::framework;
// Node1: FLAGS_selected_mlus=0 PADDLE_TRAINER_ID=0 ./cncl_context_test
// Node2: FLAGS_selected_mlus=1 PADDLE_TRAINER_ID=1 ./cncl_context_test
int nrings = 1;
imperative::ParallelStrategy GetStrategy(int local_rank) {
std::vector<std::string> eps = {"127.0.0.1:9866", "localhost:9867"};
imperative::ParallelStrategy strategy;
strategy.trainer_endpoints_ = eps;
strategy.current_endpoint_ = eps[local_rank];
strategy.nranks_ = 2;
strategy.local_rank_ = local_rank;
strategy.nrings_ = nrings;
return strategy;
}
#if defined(PADDLE_WITH_CNCL)
void Broadcast(int local_rank, int device_id) {
int data_size = 4;
float test_data = 7;
const auto& place = platform::MLUPlace(device_id);
platform::MLUDeviceContext ctx(place);
imperative::CNCLParallelContext cpc(GetStrategy(local_rank), place);
// init
cpc.Init();
framework::Variable* src_dev_var(new framework::Variable());
auto* src_dev_tensor = src_dev_var->GetMutable<framework::LoDTensor>();
src_dev_tensor->mutable_data<float>(phi::make_ddim({data_size}), place);
// fill data for rank 0 only
std::vector<float> src_vec;
if (local_rank == 0) {
for (int i = 0; i < data_size; ++i) {
src_vec.push_back(test_data);
}
framework::TensorFromVector(src_vec, ctx, src_dev_tensor);
}
ctx.Wait();
// call broadcast
cpc.Broadcast(src_dev_var, 0);
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
// check result
std::vector<float> dst_vec;
framework::TensorToVector(*src_dev_tensor, ctx, &dst_vec);
ctx.Wait();
for (int i = 0; i < data_size; ++i) {
EXPECT_EQ(dst_vec[i], test_data);
}
}
TEST(Broadcast, Run) {
if (platform::GetMLUDeviceCount() >= 2) {
int local_rank = atoi(getenv("PADDLE_TRAINER_ID"));
int device_id = atoi(getenv("FLAGS_selected_mlus"));
Broadcast(local_rank, device_id);
}
}
void AllReduceByStream(int local_rank, int device_id) {
int data_size = 32;
const auto& place = platform::MLUPlace(device_id);
platform::MLUDeviceContext ctx(place);
imperative::CNCLParallelContext cpc(GetStrategy(local_rank), place);
// init
cpc.Init();
// input data
framework::Variable* src_dev_var(new framework::Variable());
auto* src_dev_tensor = src_dev_var->GetMutable<framework::LoDTensor>();
src_dev_tensor->mutable_data<float>(phi::make_ddim({data_size}), place);
// fill input data
std::vector<float> src_vec;
for (int i = 0; i < data_size; ++i) {
src_vec.push_back(1.0 + local_rank);
}
framework::TensorFromVector(src_vec, ctx, src_dev_tensor);
ctx.Wait();
// output data
framework::Variable* dst_dev_var(new framework::Variable());
auto* dst_dev_tensor = dst_dev_var->GetMutable<framework::LoDTensor>();
dst_dev_tensor->mutable_data<float>(phi::make_ddim({data_size}), place);
// call allreduce
cpc.AllReduceByStream(*src_dev_var, dst_dev_var, 0, false);
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
// check result
std::vector<float> dst_vec;
framework::TensorToVector(*dst_dev_tensor, ctx, &dst_vec);
ctx.Wait();
EXPECT_EQ(dst_vec.size(), src_vec.size());
for (int i = 0; i < data_size; ++i) {
EXPECT_EQ(dst_vec[i], 3.0);
}
}
TEST(AllReduceByStream, Run) {
if (platform::GetMLUDeviceCount() >= 2) {
int local_rank = atoi(getenv("PADDLE_TRAINER_ID"));
int device_id = atoi(getenv("FLAGS_selected_mlus"));
AllReduceByStream(local_rank, device_id);
}
}
#endif
......@@ -148,6 +148,10 @@ if(WITH_ASCEND_CL)
target_link_libraries(device_context npu_resource_pool)
endif()
if(WITH_MLU)
target_link_libraries(device_context mlu_resource_pool)
endif()
if(WITH_CUSTOM_DEVICE)
target_link_libraries(device_context custom_context)
endif()
......
......@@ -9,3 +9,4 @@ cc_library(mlu_stream SRCS mlu_stream.cc DEPS boost mlu_info stream_callback_man
cc_library(mlu_device_context SRCS device_context.cc DEPS mlu_stream)
cc_test(mlu_device_context_test SRCS device_context_test.cc DEPS mlu_device_context)
cc_library(mlu_collective_helper SRCS mlu_collective_helper.cc DEPS mlu_stream mlu_info)
cc_library(mlu_resource_pool SRCS mlu_resource_pool.cc DEPS mlu_info)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_MLU)
#include "paddle/fluid/platform/device/mlu/mlu_resource_pool.h"
namespace paddle {
namespace platform {
MluStreamResourcePool::MluStreamResourcePool() {
int dev_cnt = platform::GetMLUDeviceCount();
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [dev_idx] {
platform::SetMLUDeviceId(dev_idx);
mluStream stream;
cnrtQueueCreate(&stream);
return stream;
};
auto deleter = [dev_idx](mluStream stream) {
platform::SetMLUDeviceId(dev_idx);
cnrtQueueDestroy(stream);
};
pool_.emplace_back(ResourcePool<MluStreamObject>::Create(creator, deleter));
}
}
MluStreamResourcePool& MluStreamResourcePool::Instance() {
static MluStreamResourcePool pool;
return pool;
}
std::shared_ptr<MluStreamObject> MluStreamResourcePool::New(int dev_idx) {
PADDLE_ENFORCE_GE(
dev_idx, 0,
platform::errors::InvalidArgument(
"The dev_idx should be not less than 0, but got %d.", dev_idx));
PADDLE_ENFORCE_LT(
dev_idx, pool_.size(),
platform::errors::OutOfRange(
"The dev_idx should be less than device count %d, but got %d.",
pool_.size(), dev_idx));
return pool_[dev_idx]->New();
}
MluEventResourcePool::MluEventResourcePool() {
int dev_cnt = platform::GetMLUDeviceCount();
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [dev_idx] {
platform::SetMLUDeviceId(dev_idx);
mluEventHandle event;
cnrtNotifierCreate(&event);
return event;
};
auto deleter = [dev_idx](mluEventHandle event) {
platform::SetMLUDeviceId(dev_idx);
cnrtNotifierDestroy(event);
};
pool_.emplace_back(ResourcePool<MluEventObject>::Create(creator, deleter));
}
}
MluEventResourcePool& MluEventResourcePool::Instance() {
static MluEventResourcePool pool;
return pool;
}
std::shared_ptr<MluEventObject> MluEventResourcePool::New(int dev_idx) {
PADDLE_ENFORCE_GE(
dev_idx, 0,
platform::errors::InvalidArgument(
"The dev_idx should be not less than 0, but got %d.", dev_idx));
PADDLE_ENFORCE_LT(
dev_idx, pool_.size(),
platform::errors::OutOfRange(
"The dev_idx should be less than device count %d, but got %d.",
pool_.size(), dev_idx));
return pool_[dev_idx]->New();
}
} // namespace platform
} // namespace paddle
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_MLU)
#include <memory>
#include <type_traits>
#include <vector>
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#include "paddle/fluid/platform/resource_pool.h"
namespace paddle {
namespace platform {
using MluStreamObject = std::remove_pointer<mluStream>::type;
using MluEventObject = std::remove_pointer<mluEventHandle>::type;
class MluStreamResourcePool {
public:
std::shared_ptr<MluStreamObject> New(int dev_idx);
static MluStreamResourcePool &Instance();
private:
MluStreamResourcePool();
DISABLE_COPY_AND_ASSIGN(MluStreamResourcePool);
private:
std::vector<std::shared_ptr<ResourcePool<MluStreamObject>>> pool_;
};
class MluEventResourcePool {
public:
std::shared_ptr<MluEventObject> New(int dev_idx);
static MluEventResourcePool &Instance();
private:
MluEventResourcePool();
DISABLE_COPY_AND_ASSIGN(MluEventResourcePool);
private:
std::vector<std::shared_ptr<ResourcePool<MluEventObject>>> pool_;
};
} // namespace platform
} // namespace paddle
#endif
......@@ -37,6 +37,10 @@ if (WITH_ASCEND_CL)
set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context)
endif()
if (WITH_CNCL)
set(PYBIND_DEPS ${PYBIND_DEPS} cncl_context)
endif()
if(NOT WIN32)
set(PYBIND_DEPS ${PYBIND_DEPS} data_loader)
set(PYBIND_DEPS ${PYBIND_DEPS} mmap_allocator)
......@@ -134,6 +138,10 @@ if(WITH_PYTHON)
list(APPEND OP_FUNCTION_GENERETOR_DEPS hccl_context)
endif(WITH_ASCEND_CL)
if(WITH_CNCL)
list(APPEND OP_FUNCTION_GENERETOR_DEPS cncl_context)
endif(WITH_CNCL)
add_executable(op_function_generator op_function_generator.cc)
target_link_libraries(op_function_generator ${OP_FUNCTION_GENERETOR_DEPS})
add_executable(eager_op_function_generator eager_op_function_generator.cc)
......
......@@ -36,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/bkcl_context.h"
#include "paddle/fluid/imperative/cncl_context.h"
#include "paddle/fluid/imperative/data_loader.h"
#include "paddle/fluid/imperative/gloo_context.h"
#include "paddle/fluid/imperative/hccl_context.h"
......@@ -2559,6 +2560,18 @@ void BindImperative(py::module *m_ptr) {
py::arg("ring_id"));
#endif
#if defined(PADDLE_WITH_CNCL)
py::class_<imperative::CNCLParallelContext, imperative::ParallelContext,
std::shared_ptr<imperative::CNCLParallelContext>>(
m, "CNCLParallelContext")
.def(py::init<const imperative::ParallelStrategy &,
const platform::MLUPlace &>())
.def("init", [](imperative::CNCLParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::CNCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
py::class_<imperative::HeterParallelContext, imperative::ParallelContext,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册