未验证 提交 c497b43f 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add inference MP support, PART1 (#53702)

上级 60cf9b50
...@@ -76,6 +76,13 @@ bool CondInterceptor::GetCondResult() { ...@@ -76,6 +76,13 @@ bool CondInterceptor::GetCondResult() {
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor); framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait(); platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0]; res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_custom_place(cond_tensor.place())) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif #endif
} else if (platform::is_cpu_place(cond_tensor.place())) { } else if (platform::is_cpu_place(cond_tensor.place())) {
res = cond_tensor.data<bool>()[0]; res = cond_tensor.data<bool>()[0];
......
...@@ -102,10 +102,27 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, ...@@ -102,10 +102,27 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
#else #else
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Paddle wasn't compiled with XPU, but place is XPU.")); "Paddle wasn't compiled with XPU, but place is XPU."));
#endif
} else if (platform::is_custom_place(place)) {
VLOG(3) << "Loading data for CustomDevice: " << place;
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = dynamic_cast<const phi::CustomContext *>(pool.Get(place));
auto custom_place = place;
memory::Copy(custom_place,
static_cast<void *>(input_tensor_ptr),
platform::CPUPlace(),
input_data.data.data(),
input_data.data.length(),
dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Paddle wasn't compiled with custom_device, but place is "
"CustomPlace."));
#endif #endif
} else { } else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"DistModel only supports CPU and GPU and XPU.")); "DistModel only supports CPU and GPU and XPU and CustomDevice."));
} }
framework::LoD dst_lod; framework::LoD dst_lod;
...@@ -204,6 +221,9 @@ bool DistModel::PreparePlace() { ...@@ -204,6 +221,9 @@ bool DistModel::PreparePlace() {
place_ = paddle::platform::CPUPlace(); place_ = paddle::platform::CPUPlace();
} else if (config_.place == "XPU") { } else if (config_.place == "XPU") {
place_ = paddle::platform::XPUPlace(config_.device_id); place_ = paddle::platform::XPUPlace(config_.device_id);
} else if (config_.place == "CUSTOM_DEVICE") {
place_ =
paddle::platform::CustomPlace(config_.device_type, config_.device_id);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Place must be choosen from GPU or CPU or XPU, but got %s.", "Place must be choosen from GPU or CPU or XPU, but got %s.",
...@@ -324,6 +344,29 @@ void DistModel::InsertCommOp(std::string tmp_var_name, ...@@ -324,6 +344,29 @@ void DistModel::InsertCommOp(std::string tmp_var_name,
comm_init_op->SetAttr("op_role", comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward)); static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs(); comm_init_op->CheckAttrs();
} else if (config_.place == "CUSTOM_DEVICE") {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
new_var->SetPersistable(true);
framework::OpDesc *gen_bkcl_id_op = block->AppendOp();
gen_bkcl_id_op->SetType("c_gen_xccl_id");
gen_bkcl_id_op->SetOutput("Out", {tmp_var_name});
gen_bkcl_id_op->SetAttr("rank", rank);
gen_bkcl_id_op->SetAttr("endpoint", config_.current_endpoint);
gen_bkcl_id_op->SetAttr("other_endpoints", peer_endpoints);
gen_bkcl_id_op->SetAttr("ring_id", ring_id);
gen_bkcl_id_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
gen_bkcl_id_op->CheckAttrs();
framework::OpDesc *comm_init_op = block->AppendOp();
comm_init_op->SetType("c_comm_init");
comm_init_op->SetInput("X", {tmp_var_name});
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else { } else {
LOG(WARNING) << "DistModelInf doesn't init comm."; LOG(WARNING) << "DistModelInf doesn't init comm.";
// TODO(fleet exe dev): comm init for more devices // TODO(fleet exe dev): comm init for more devices
......
...@@ -44,6 +44,7 @@ struct DistModelConfig { ...@@ -44,6 +44,7 @@ struct DistModelConfig {
framework::Scope* scope{nullptr}; framework::Scope* scope{nullptr};
std::string place{}; std::string place{};
int64_t device_id{0}; int64_t device_id{0};
std::string device_type{};
std::vector<std::string> trainer_endpoints{}; std::vector<std::string> trainer_endpoints{};
std::string current_endpoint{}; std::string current_endpoint{};
int64_t nranks{1}; int64_t nranks{1};
......
...@@ -52,7 +52,7 @@ void MessageBus::Init( ...@@ -52,7 +52,7 @@ void MessageBus::Init(
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
// NOTE: To make the brpc is compatible with collective, // NOTE: To make the brpc is compatible with collective,
// need release the handler holding the ip address. // need release the handler holding the ip address.
if (addr_ != "") { if (addr_ != "") {
......
...@@ -43,6 +43,7 @@ static std::unordered_set<std::string> kMultiDeviceOps{ ...@@ -43,6 +43,7 @@ static std::unordered_set<std::string> kMultiDeviceOps{
"c_comm_init_multitrainer", "c_comm_init_multitrainer",
"c_gen_nccl_id", "c_gen_nccl_id",
"c_gen_bkcl_id", "c_gen_bkcl_id",
"c_gen_xccl_id",
"c_sync_comm_stream", "c_sync_comm_stream",
"send", "send",
"recv", "recv",
......
...@@ -43,6 +43,10 @@ ...@@ -43,6 +43,10 @@
#include "xpu/bkcl.h" #include "xpu/bkcl.h"
#endif #endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/phi/backends/c_comm_lib.h"
#endif
namespace phi { namespace phi {
class DenseTensor; class DenseTensor;
class SelectedRows; class SelectedRows;
...@@ -195,6 +199,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -195,6 +199,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId, BKCLUniqueId,
platform::BKCLCommunicator, platform::BKCLCommunicator,
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::ccl::CCLRootId,
#endif #endif
std::vector<std::unique_ptr<operators::CUDAGraphWithInOuts>>, std::vector<std::unique_ptr<operators::CUDAGraphWithInOuts>>,
int, int,
......
...@@ -139,9 +139,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToCustomDevice( ...@@ -139,9 +139,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToCustomDevice(
repetitive_params = graph.Get<std::vector<std::string>>( repetitive_params = graph.Get<std::vector<std::string>>(
framework::ir::kRepetitiveParamAttr); framework::ir::kRepetitiveParamAttr);
LOG(INFO) << "Sync params from CPU to CustomDevice" LOG(INFO) << "Sync params from CPU to " << argument->custom_device_type()
<< argument->custom_device_type() << "/" << ":" << argument->custom_device_id();
<< argument->custom_device_id();
platform::Place place = platform::CustomPlace(argument->custom_device_type(), platform::Place place = platform::CustomPlace(argument->custom_device_type(),
argument->custom_device_id()); argument->custom_device_id());
......
...@@ -475,7 +475,8 @@ void AnalysisPredictor::InitPlace() { ...@@ -475,7 +475,8 @@ void AnalysisPredictor::InitPlace() {
#endif #endif
} else if (config_.use_custom_device()) { } else if (config_.use_custom_device()) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
place_ = paddle::platform::CustomPlace(config_.custom_device_type()); place_ = paddle::platform::CustomPlace(config_.custom_device_type(),
config_.custom_device_id());
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"You tried to use CustomDevice forward propagation, but Paddle was not " "You tried to use CustomDevice forward propagation, but Paddle was not "
...@@ -564,6 +565,14 @@ void *AnalysisPredictor::GetExecStream() const { ...@@ -564,6 +565,14 @@ void *AnalysisPredictor::GetExecStream() const {
->stream(); ->stream();
} }
} }
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
if (place_.GetType() == phi::AllocationType::CUSTOM) {
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
return reinterpret_cast<const phi::CustomContext *>(pool.Get(place_))
->stream();
}
#endif #endif
// TODO(inference): Support other backends. // TODO(inference): Support other backends.
return nullptr; return nullptr;
...@@ -679,12 +688,16 @@ static void DisablePrepareDataOpt( ...@@ -679,12 +688,16 @@ static void DisablePrepareDataOpt(
} }
bool AnalysisPredictor::PrepareExecutor() { bool AnalysisPredictor::PrepareExecutor() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
if (config_.dist_config().use_dist_model()) { if (config_.dist_config().use_dist_model()) {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
VLOG(3) << "use_dist_model is enabled, will init FleetExecutor."; VLOG(3) << "use_dist_model is enabled, will init FleetExecutor.";
return PrepareFleetExecutor(); return PrepareFleetExecutor();
} #else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use FleetExecutor since it's not compiled with PSCORE,"
"Please recompile or reinstall Paddle with PSCORE support."));
#endif #endif
}
DisablePrepareDataOpt(inference_program_, 0, false); DisablePrepareDataOpt(inference_program_, 0, false);
executor_->Prepare( executor_->Prepare(
...@@ -875,6 +888,30 @@ void AnalysisPredictor::InsertCommOp( ...@@ -875,6 +888,30 @@ void AnalysisPredictor::InsertCommOp(
comm_init_op->SetAttr("op_role", comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward)); static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs(); comm_init_op->CheckAttrs();
} else if (config_.use_custom_device()) {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
new_var->SetPersistable(true);
framework::OpDesc *gen_bkcl_id_op = block->AppendOp();
gen_bkcl_id_op->SetType("c_gen_xccl_id");
gen_bkcl_id_op->SetOutput("Out", {tmp_var_name});
gen_bkcl_id_op->SetAttr("rank", rank);
gen_bkcl_id_op->SetAttr("endpoint",
config_.dist_config().current_endpoint());
gen_bkcl_id_op->SetAttr("other_endpoints", peer_endpoints);
gen_bkcl_id_op->SetAttr("ring_id", ring_id);
gen_bkcl_id_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
gen_bkcl_id_op->CheckAttrs();
framework::OpDesc *comm_init_op = block->AppendOp();
comm_init_op->SetType("c_comm_init");
comm_init_op->SetInput("X", {tmp_var_name});
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else { } else {
LOG(WARNING) << "DistModelInf doesn't init comm."; LOG(WARNING) << "DistModelInf doesn't init comm.";
// TODO(fleet exe dev): comm init for more devices // TODO(fleet exe dev): comm init for more devices
......
...@@ -819,7 +819,7 @@ PHI_DEFINE_EXPORTED_bool(use_fast_math, ...@@ -819,7 +819,7 @@ PHI_DEFINE_EXPORTED_bool(use_fast_math,
* Note: Get host by name time. * Note: Get host by name time.
*/ */
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_XPU) || \ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_XPU) || \
defined(PADDLE_WITH_HIP) defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_CUSTOM_DEVICE)
PHI_DEFINE_EXPORTED_int32(get_host_by_name_time, PHI_DEFINE_EXPORTED_int32(get_host_by_name_time,
120, 120,
"The maximum time for get host by name time"); "The maximum time for get host by name time");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册