未验证 提交 c497b43f 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add inference MP support, PART1 (#53702)

上级 60cf9b50
......@@ -76,6 +76,13 @@ bool CondInterceptor::GetCondResult() {
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_custom_place(cond_tensor.place())) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_cpu_place(cond_tensor.place())) {
res = cond_tensor.data<bool>()[0];
......
......@@ -102,10 +102,27 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Paddle wasn't compiled with XPU, but place is XPU."));
#endif
} else if (platform::is_custom_place(place)) {
VLOG(3) << "Loading data for CustomDevice: " << place;
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = dynamic_cast<const phi::CustomContext *>(pool.Get(place));
auto custom_place = place;
memory::Copy(custom_place,
static_cast<void *>(input_tensor_ptr),
platform::CPUPlace(),
input_data.data.data(),
input_data.data.length(),
dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Paddle wasn't compiled with custom_device, but place is "
"CustomPlace."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"DistModel only supports CPU and GPU and XPU."));
"DistModel only supports CPU and GPU and XPU and CustomDevice."));
}
framework::LoD dst_lod;
......@@ -204,6 +221,9 @@ bool DistModel::PreparePlace() {
place_ = paddle::platform::CPUPlace();
} else if (config_.place == "XPU") {
place_ = paddle::platform::XPUPlace(config_.device_id);
} else if (config_.place == "CUSTOM_DEVICE") {
place_ =
paddle::platform::CustomPlace(config_.device_type, config_.device_id);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place must be choosen from GPU or CPU or XPU, but got %s.",
......@@ -324,6 +344,29 @@ void DistModel::InsertCommOp(std::string tmp_var_name,
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else if (config_.place == "CUSTOM_DEVICE") {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
new_var->SetPersistable(true);
framework::OpDesc *gen_bkcl_id_op = block->AppendOp();
gen_bkcl_id_op->SetType("c_gen_xccl_id");
gen_bkcl_id_op->SetOutput("Out", {tmp_var_name});
gen_bkcl_id_op->SetAttr("rank", rank);
gen_bkcl_id_op->SetAttr("endpoint", config_.current_endpoint);
gen_bkcl_id_op->SetAttr("other_endpoints", peer_endpoints);
gen_bkcl_id_op->SetAttr("ring_id", ring_id);
gen_bkcl_id_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
gen_bkcl_id_op->CheckAttrs();
framework::OpDesc *comm_init_op = block->AppendOp();
comm_init_op->SetType("c_comm_init");
comm_init_op->SetInput("X", {tmp_var_name});
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else {
LOG(WARNING) << "DistModelInf doesn't init comm.";
// TODO(fleet exe dev): comm init for more devices
......
......@@ -44,6 +44,7 @@ struct DistModelConfig {
framework::Scope* scope{nullptr};
std::string place{};
int64_t device_id{0};
std::string device_type{};
std::vector<std::string> trainer_endpoints{};
std::string current_endpoint{};
int64_t nranks{1};
......
......@@ -52,7 +52,7 @@ void MessageBus::Init(
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
// NOTE: To make the brpc is compatible with collective,
// need release the handler holding the ip address.
if (addr_ != "") {
......
......@@ -43,6 +43,7 @@ static std::unordered_set<std::string> kMultiDeviceOps{
"c_comm_init_multitrainer",
"c_gen_nccl_id",
"c_gen_bkcl_id",
"c_gen_xccl_id",
"c_sync_comm_stream",
"send",
"recv",
......
......@@ -43,6 +43,10 @@
#include "xpu/bkcl.h"
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/phi/backends/c_comm_lib.h"
#endif
namespace phi {
class DenseTensor;
class SelectedRows;
......@@ -195,6 +199,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId,
platform::BKCLCommunicator,
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::ccl::CCLRootId,
#endif
std::vector<std::unique_ptr<operators::CUDAGraphWithInOuts>>,
int,
......
......@@ -139,9 +139,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToCustomDevice(
repetitive_params = graph.Get<std::vector<std::string>>(
framework::ir::kRepetitiveParamAttr);
LOG(INFO) << "Sync params from CPU to CustomDevice"
<< argument->custom_device_type() << "/"
<< argument->custom_device_id();
LOG(INFO) << "Sync params from CPU to " << argument->custom_device_type()
<< ":" << argument->custom_device_id();
platform::Place place = platform::CustomPlace(argument->custom_device_type(),
argument->custom_device_id());
......
......@@ -475,7 +475,8 @@ void AnalysisPredictor::InitPlace() {
#endif
} else if (config_.use_custom_device()) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
place_ = paddle::platform::CustomPlace(config_.custom_device_type());
place_ = paddle::platform::CustomPlace(config_.custom_device_type(),
config_.custom_device_id());
#else
PADDLE_THROW(platform::errors::Unavailable(
"You tried to use CustomDevice forward propagation, but Paddle was not "
......@@ -564,6 +565,14 @@ void *AnalysisPredictor::GetExecStream() const {
->stream();
}
}
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
if (place_.GetType() == phi::AllocationType::CUSTOM) {
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
return reinterpret_cast<const phi::CustomContext *>(pool.Get(place_))
->stream();
}
#endif
// TODO(inference): Support other backends.
return nullptr;
......@@ -679,12 +688,16 @@ static void DisablePrepareDataOpt(
}
bool AnalysisPredictor::PrepareExecutor() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
if (config_.dist_config().use_dist_model()) {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
VLOG(3) << "use_dist_model is enabled, will init FleetExecutor.";
return PrepareFleetExecutor();
}
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use FleetExecutor since it's not compiled with PSCORE,"
"Please recompile or reinstall Paddle with PSCORE support."));
#endif
}
DisablePrepareDataOpt(inference_program_, 0, false);
executor_->Prepare(
......@@ -875,6 +888,30 @@ void AnalysisPredictor::InsertCommOp(
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else if (config_.use_custom_device()) {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
new_var->SetPersistable(true);
framework::OpDesc *gen_bkcl_id_op = block->AppendOp();
gen_bkcl_id_op->SetType("c_gen_xccl_id");
gen_bkcl_id_op->SetOutput("Out", {tmp_var_name});
gen_bkcl_id_op->SetAttr("rank", rank);
gen_bkcl_id_op->SetAttr("endpoint",
config_.dist_config().current_endpoint());
gen_bkcl_id_op->SetAttr("other_endpoints", peer_endpoints);
gen_bkcl_id_op->SetAttr("ring_id", ring_id);
gen_bkcl_id_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
gen_bkcl_id_op->CheckAttrs();
framework::OpDesc *comm_init_op = block->AppendOp();
comm_init_op->SetType("c_comm_init");
comm_init_op->SetInput("X", {tmp_var_name});
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else {
LOG(WARNING) << "DistModelInf doesn't init comm.";
// TODO(fleet exe dev): comm init for more devices
......
......@@ -819,7 +819,7 @@ PHI_DEFINE_EXPORTED_bool(use_fast_math,
* Note: Get host by name time.
*/
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_XPU) || \
defined(PADDLE_WITH_HIP)
defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_CUSTOM_DEVICE)
PHI_DEFINE_EXPORTED_int32(get_host_by_name_time,
120,
"The maximum time for get host by name time");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册