未验证 提交 e04f8d4a 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add inference MP support, PART2 (#53701)

上级 cc9aedaf
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#endif #endif
...@@ -48,6 +48,29 @@ class CCommInitOp : public framework::OperatorBase { ...@@ -48,6 +48,29 @@ class CCommInitOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
phi::ccl::CCLRootId* comm_id = var->GetMutable<phi::ccl::CCLRootId>();
int nranks = Attr<int>("nranks");
int rid = Attr<int>("ring_id");
int device_id = place.device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
int rank_id = Attr<int>("rank");
platform::XCCLCommContext::Instance(place.GetDeviceType())
.CreateComm(comm_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with custom device."));
#endif
} else {
// TODO(wangxi): Put this in the unified header file // TODO(wangxi): Put this in the unified header file
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
using UniqueId = ncclUniqueId; using UniqueId = ncclUniqueId;
...@@ -86,6 +109,7 @@ class CCommInitOp : public framework::OperatorBase { ...@@ -86,6 +109,7 @@ class CCommInitOp : public framework::OperatorBase {
comm_id, nranks, rank_id, device_id, rid); comm_id, nranks, rank_id, device_id, rid);
#endif #endif
} }
}
}; };
class CCommInitOpMaker : public framework::OpProtoAndCheckerMaker { class CCommInitOpMaker : public framework::OpProtoAndCheckerMaker {
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/device_manager.h"
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
static void CopyXCCLIDToVar(const std::vector<phi::ccl::CCLRootId>& xccl_ids,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
for (size_t i = 0; i < xccl_ids.size(); ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto xccl_id = var->GetMutable<phi::ccl::CCLRootId>();
*xccl_id = xccl_ids[i];
}
}
class CGenXCCLIdOp : public framework::OperatorBase {
public:
CGenXCCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
int ring_id = Attr<int>("ring_id");
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
};
std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
std::vector<phi::ccl::CCLRootId> xccl_ids;
xccl_ids.resize(1);
if (rank == 0) {
for (size_t i = 0; i < xccl_ids.size(); ++i) {
phi::DeviceManager::CCLGetUniqueId(dev_place.GetDeviceType(),
&xccl_ids[i]);
}
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &xccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &xccl_ids, ring_id);
}
CopyXCCLIDToVar(xccl_ids, func, scope);
}
};
#else
class CGenXCCLIdOp : public framework::OperatorBase {
public:
CGenXCCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {}
};
#endif
class CGenXCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "Raw variable contains a XCCL UniqueId instaces.");
AddComment(R"DOC(
CGenXCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::string>("endpoint",
"(string), e.g. 127.0.0.1:6175 "
"current listen endpoint");
AddAttr<std::vector<std::string>>(
"other_endpoints",
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
"list of other trainer endpoints")
.SetDefault({});
AddAttr<int>("rank",
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_gen_xccl_id, ops::CGenXCCLIdOp, ops::CGenXCCLIdOpMaker);
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h" #include "paddle/fluid/operators/load_combine_op.h"
#include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h" #include "paddle/fluid/operators/save_combine_op.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/api/backward/backward_api.h" #include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/api.h"
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
...@@ -497,6 +498,131 @@ class CSyncCalcStreamCustomDeviceKernel : public framework::OpKernel<T> { ...@@ -497,6 +498,131 @@ class CSyncCalcStreamCustomDeviceKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T, phi::ccl::CCLReduceOp red_type>
class CAllReduceOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto in = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto dtype = phi::ccl::ToCCLDataType(in->dtype());
int64_t numel = in->numel();
const void* sendbuff = in->data<T>();
out->Resize(in->dims());
void* recvbuff = ctx.device_context().Alloc<T>(out);
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
paddle::distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*in);
out_tensor.push_back(*out);
paddle::distributed::AllreduceOptions opts;
switch (red_type) {
case phi::ccl::CCLReduceOp::SUM:
opts.reduce_op = paddle::distributed::ReduceOp::SUM;
break;
case phi::ccl::CCLReduceOp::MAX:
opts.reduce_op = paddle::distributed::ReduceOp::MAX;
break;
case phi::ccl::CCLReduceOp::MIN:
opts.reduce_op = paddle::distributed::ReduceOp::MIN;
break;
case phi::ccl::CCLReduceOp::PRODUCT:
opts.reduce_op = paddle::distributed::ReduceOp::PRODUCT;
break;
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Invalid reduce type: %d", red_type));
}
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
return;
}
auto comm =
paddle::platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
std::shared_ptr<phi::stream::Stream> stream;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<paddle::platform::CustomDeviceContext*>(dev_ctx)
->GetStream();
} else {
stream = comm->stream();
}
phi::DeviceManager::CCLAllReduce(place.GetDeviceType(),
const_cast<void*>(sendbuff),
recvbuff,
numel,
dtype,
red_type,
comm->comm(),
*stream);
}
};
template <typename T>
class CBroadcastOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
const auto& place = ctx.GetPlace();
ctx.device_context().Alloc<T>(out);
int root = ctx.Attr<int>("root");
int rid = ctx.Attr<int>("ring_id");
auto stream = static_cast<const phi::CustomContext&>(ctx.device_context())
.GetStream();
int numel = x->numel();
auto dtype = phi::ccl::ToCCLDataType(x->dtype());
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
if (root == comm->rank()) {
phi::DeviceManager::CCLBroadcast(place.GetDeviceType(),
const_cast<void*>(x->data()),
numel,
dtype,
root,
comm->comm(),
*stream);
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
framework::TensorCopy(
*static_cast<const phi::DenseTensor*>(x),
place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<phi::DenseTensor*>(out));
}
} else {
phi::DeviceManager::CCLBroadcast(place.GetDeviceType(),
out->data(),
numel,
dtype,
root,
comm->comm(),
*stream);
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received "
<< phi::product(out->dims());
}
out->set_lod(x->lod());
}
};
template <typename Context> template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx, void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x, const phi::ExtendedTensor& x,
...@@ -636,6 +762,107 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { ...@@ -636,6 +762,107 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
paddle::operators::CSyncCalcStreamCustomDeviceKernel< paddle::operators::CSyncCalcStreamCustomDeviceKernel<
paddle::platform::CustomDeviceContext, paddle::platform::CustomDeviceContext,
paddle::platform::float16>) {} paddle::platform::float16>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_allreduce_sum,
device_type,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
float,
phi::ccl::CCLReduceOp::SUM>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
double,
phi::ccl::CCLReduceOp::SUM>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
paddle::platform::float16,
phi::ccl::CCLReduceOp::SUM>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int32_t,
phi::ccl::CCLReduceOp::SUM>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int64_t,
phi::ccl::CCLReduceOp::SUM>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_allreduce_min,
device_type,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
float,
phi::ccl::CCLReduceOp::MIN>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
double,
phi::ccl::CCLReduceOp::MIN>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
paddle::platform::float16,
phi::ccl::CCLReduceOp::MIN>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int32_t,
phi::ccl::CCLReduceOp::MIN>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int64_t,
phi::ccl::CCLReduceOp::MIN>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_allreduce_max,
device_type,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
float,
phi::ccl::CCLReduceOp::MAX>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
double,
phi::ccl::CCLReduceOp::MAX>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
paddle::platform::float16,
phi::ccl::CCLReduceOp::MAX>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int32_t,
phi::ccl::CCLReduceOp::MAX>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int64_t,
phi::ccl::CCLReduceOp::MAX>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_allreduce_prod,
device_type,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
float,
phi::ccl::CCLReduceOp::PRODUCT>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
double,
phi::ccl::CCLReduceOp::PRODUCT>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
paddle::platform::float16,
phi::ccl::CCLReduceOp::PRODUCT>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int32_t,
phi::ccl::CCLReduceOp::PRODUCT>,
paddle::operators::CAllReduceOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int64_t,
phi::ccl::CCLReduceOp::PRODUCT>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_broadcast,
device_type,
paddle::operators::CBroadcastOpCustomDeviceKernel<int32_t>,
paddle::operators::CBroadcastOpCustomDeviceKernel<int64_t>,
paddle::operators::CBroadcastOpCustomDeviceKernel<float>,
paddle::operators::CBroadcastOpCustomDeviceKernel<double>,
paddle::operators::CBroadcastOpCustomDeviceKernel<
paddle::platform::float16>) {}
#endif #endif
} }
......
...@@ -91,6 +91,17 @@ if(WITH_XPU_BKCL) ...@@ -91,6 +91,17 @@ if(WITH_XPU_BKCL)
set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context) set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context)
endif() endif()
if(WITH_CUSTOM_DEVICE)
set(PYBIND_DEPS ${PYBIND_DEPS} xccl_context)
if(NOT
(WITH_NCCL
OR WITH_RCCL
OR WITH_XPU_BKCL))
set(PYBIND_DEPS ${PYBIND_DEPS} reducer)
set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context)
endif()
endif()
if(NOT WIN32) if(NOT WIN32)
set(PYBIND_DEPS ${PYBIND_DEPS} data_loader) set(PYBIND_DEPS ${PYBIND_DEPS} data_loader)
if(WITH_NCCL OR WITH_RCCL) if(WITH_NCCL OR WITH_RCCL)
......
...@@ -52,6 +52,7 @@ limitations under the License. */ ...@@ -52,6 +52,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/reducer.h" #include "paddle/fluid/imperative/reducer.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/xccl_context.h"
#include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/pybind/cuda_streams_py.h" #include "paddle/fluid/pybind/cuda_streams_py.h"
...@@ -2477,7 +2478,8 @@ void BindImperative(py::module *m_ptr) { ...@@ -2477,7 +2478,8 @@ void BindImperative(py::module *m_ptr) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) || \
defined(PADDLE_WITH_CUSTOM_DEVICE)
py::class_<imperative::ParallelContext, py::class_<imperative::ParallelContext,
std::shared_ptr<imperative::ParallelContext>>(m, std::shared_ptr<imperative::ParallelContext>>(m,
"ParallelContext"); "ParallelContext");
...@@ -2517,6 +2519,19 @@ void BindImperative(py::module *m_ptr) { ...@@ -2517,6 +2519,19 @@ void BindImperative(py::module *m_ptr) {
py::arg("ring_id")); py::arg("ring_id"));
#endif #endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
py::class_<imperative::XCCLParallelContext,
imperative::ParallelContext,
std::shared_ptr<imperative::XCCLParallelContext>>(
m, "XCCLParallelContext")
.def(py::init<const imperative::ParallelStrategy &,
const platform::CustomPlace &>())
.def("init", [](imperative::XCCLParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::XCCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
py::class_<imperative::BKCLParallelContext, py::class_<imperative::BKCLParallelContext,
imperative::ParallelContext, imperative::ParallelContext,
...@@ -2545,7 +2560,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -2545,7 +2560,7 @@ void BindImperative(py::module *m_ptr) {
#endif #endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
py::class_<imperative::HeterParallelContext, py::class_<imperative::HeterParallelContext,
imperative::ParallelContext, imperative::ParallelContext,
std::shared_ptr<imperative::HeterParallelContext>>( std::shared_ptr<imperative::HeterParallelContext>>(
......
...@@ -158,6 +158,7 @@ limitations under the License. */ ...@@ -158,6 +158,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/capi/capi.h" #include "paddle/phi/capi/capi.h"
#endif #endif
...@@ -990,6 +991,7 @@ PYBIND11_MODULE(libpaddle, m) { ...@@ -990,6 +991,7 @@ PYBIND11_MODULE(libpaddle, m) {
[]() { phi::KernelFactory::Instance().kernels().clear(); }); []() { phi::KernelFactory::Instance().kernels().clear(); });
m.def("clear_device_manager", []() { m.def("clear_device_manager", []() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::XCCLCommContext::Release();
phi::DeviceManager::Clear(); phi::DeviceManager::Clear();
#endif #endif
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册