未验证 提交 4c46eed0 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet executor] add comm init for dist model inf (#39012)

上级 8b77f870
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/fluid/distributed/fleet_executor/dist_model.h" #include "paddle/fluid/distributed/fleet_executor/dist_model.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -37,24 +39,173 @@ bool IsPersistable(const framework::VarDesc *var) { ...@@ -37,24 +39,173 @@ bool IsPersistable(const framework::VarDesc *var) {
bool DistModel::Init() { bool DistModel::Init() {
/* TODO(fleet exe dev): implement this funct */ /* TODO(fleet exe dev): implement this funct */
place_ = paddle::platform::CUDAPlace(config_.device_id); bool init_method = (!config_.model_dir.empty() || config_.program_desc);
if (!PrepareScope()) { PADDLE_ENFORCE_EQ(init_method, true,
return false; platform::errors::InvalidArgument(
"One of model dir or program desc must be provided to "
"dist model inference."));
if (config_.program_desc) {
PADDLE_ENFORCE_NOT_NULL(
config_.scope, platform::errors::InvalidArgument(
"Scope must be provided to dist model inference if "
"program desc has been provided."));
} }
if (!PrepareProgram()) { if (!PreparePlace()) {
return false; return false;
} }
if (!config_.program_desc) {
if (config_.scope) {
LOG(WARNING) << "The provided scope will be ignored if model dir has "
"also been provided.";
}
if (!PrepareScope()) {
return false;
}
if (!PrepareProgram()) {
return false;
}
} else {
program_.reset(config_.program_desc);
scope_.reset(config_.scope);
}
if (!CommInit()) { if (!CommInit()) {
return false; return false;
} }
return true; return true;
} }
bool DistModel::PreparePlace() {
if (config_.place == "GPU") {
place_ = paddle::platform::CUDAPlace(config_.device_id);
} else if (config_.place == "CPU") {
place_ = paddle::platform::CPUPlace();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place must be choosen from GPU or CPU, but got %s.", config_.place));
}
return true;
}
bool DistModel::CommInit() { bool DistModel::CommInit() {
// TODO(fleet executor): init the comm // NOTE (Yuang Liu): The peer endpoints will be obtained with the assumption
// that mp part is always on inner side and pp part is always on outer side.
// TODO(fleet exe dev): The peer endpoints could be configured by users.
PADDLE_ENFORCE_EQ(
config_.pp_degree * config_.mp_degree, config_.nranks,
platform::errors::InvalidArgument(
"The mp_degree multiplies pp_degree is not equal with nranks"));
std::unique_ptr<framework::ProgramDesc> comm_init_program(
new framework::ProgramDesc());
framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock(0);
if (config_.mp_degree > 1) {
PADDLE_ENFORCE_GE(
config_.mp_ring_id, 0,
platform::errors::InvalidArgument(
"mp ring id must be provided for inference under mp."));
VLOG(3) << "Init comm group for mp.";
std::vector<std::string> peer_endpoints;
for (int64_t
idx = (config_.local_rank / config_.mp_degree) * config_.mp_degree,
i = 0;
i < config_.mp_degree; ++idx, ++i) {
if (config_.trainer_endpoints[idx] == config_.current_endpoint) {
continue;
}
peer_endpoints.emplace_back(config_.trainer_endpoints[idx]);
}
// get nranks in a mp group and inner group rank for local rank
int64_t mp_group_nranks = config_.nranks / config_.pp_degree;
int64_t mp_group_rank = config_.local_rank % config_.mp_degree;
InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints,
comm_init_block, config_.mp_ring_id);
}
if (config_.pp_degree) {
// NOTE: the last pp stage doesn't need init pp comm
VLOG(3) << "Init comm group for pp.";
if (config_.local_rank - config_.mp_degree >= 0) {
PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true,
platform::errors::InvalidArgument(
"pp upstream ring id must be provided for "
"non-first pp stage if inference under pp."));
// not the first pp stage, has upstream
std::vector<std::string> upstream_peer_endpoints;
upstream_peer_endpoints.emplace_back(
config_.trainer_endpoints[config_.local_rank - config_.mp_degree]);
InsertCommOp("pp_upstream_comm_id", 2, 1, upstream_peer_endpoints,
comm_init_block, config_.pp_upstream_ring_id);
}
if (config_.local_rank + config_.mp_degree < config_.nranks) {
PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true,
platform::errors::InvalidArgument(
"pp downstream ring id must be provided for "
"non-last pp stage if inference under pp."));
// not the last pp stage, has downstream
std::vector<std::string> downstream_peer_endpoints;
downstream_peer_endpoints.emplace_back(
config_.trainer_endpoints[config_.local_rank + config_.mp_degree]);
InsertCommOp("pp_downstream_comm_id", 2, 0, downstream_peer_endpoints,
comm_init_block, config_.pp_downstream_ring_id);
}
}
framework::NaiveExecutor e(place_);
e.CreateVariables(*comm_init_program, 0, true, scope_.get());
e.Prepare(scope_.get(), *comm_init_program, 0, false);
e.Run();
VLOG(3) << "Comm init successful.";
return true; return true;
} }
void DistModel::InsertCommOp(std::string tmp_var_name, int nranks, int rank,
const std::vector<std::string> &peer_endpoints,
framework::BlockDesc *block, int ring_id) {
/*
* tmp_var_name: the var name for var comm_id
* nranks: number of total ranks
* rank: the rank of local rank in the comm group
* peer_endpoints: peer's endpoints
* block: the block where to insert the comm ops
* ring_id: the ring_id to be inited
*/
std::string &endpoint = config_.current_endpoint;
std::stringstream ss;
ss << "Init comm with tmp var: " << tmp_var_name
<< ". The ring id is: " << ring_id << ". The group has: " << nranks
<< " ranks. Current rank in the group is: " << rank
<< ". The endpoint is: " << endpoint << ". Peer endpoints are: ";
for (auto ep : peer_endpoints) {
ss << ep << ", ";
}
VLOG(3) << ss.str();
if (config_.place == "GPU") {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
new_var->SetPersistable(true);
framework::OpDesc *gen_nccl_id_op = block->AppendOp();
gen_nccl_id_op->SetType("c_gen_nccl_id");
gen_nccl_id_op->SetOutput("Out", {tmp_var_name});
gen_nccl_id_op->SetAttr("rank", rank);
gen_nccl_id_op->SetAttr("endpoint", config_.current_endpoint);
gen_nccl_id_op->SetAttr("other_endpoints", peer_endpoints);
gen_nccl_id_op->SetAttr("ring_id", ring_id);
gen_nccl_id_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
gen_nccl_id_op->CheckAttrs();
framework::OpDesc *comm_init_op = block->AppendOp();
comm_init_op->SetType("c_comm_init");
comm_init_op->SetInput("X", {tmp_var_name});
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else {
LOG(WARNING) << "DistModelInf doesn't init comm.";
// TODO(fleet exe dev): comm init for more devices
}
}
bool DistModel::PrepareScope() { bool DistModel::PrepareScope() {
scope_.reset(new framework::Scope()); scope_.reset(new framework::Scope());
return true; return true;
...@@ -119,6 +270,8 @@ bool DistModel::LoadParameters() { ...@@ -119,6 +270,8 @@ bool DistModel::LoadParameters() {
new_var->SetLoDLevel(var->GetLoDLevel()); new_var->SetLoDLevel(var->GetLoDLevel());
new_var->SetPersistable(true); new_var->SetPersistable(true);
params.push_back(new_var->Name()); params.push_back(new_var->Name());
// NOTE: if the params are stored in different files, 'load' op should be
// added here
} }
} }
......
...@@ -23,22 +23,30 @@ ...@@ -23,22 +23,30 @@
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class ProgramDesc; class ProgramDesc;
class Scope; class Scope;
class BlockDesc;
} }
namespace distributed { namespace distributed {
struct DistModelConfig { struct DistModelConfig {
std::string model_dir{}; std::string model_dir{};
framework::ProgramDesc* program_desc{nullptr};
framework::Scope* scope{nullptr};
std::string place{};
int64_t device_id{0};
std::vector<std::string> trainer_endpoints{}; std::vector<std::string> trainer_endpoints{};
std::string current_endpoint{}; std::string current_endpoint{};
int64_t nranks{1}; int64_t nranks{1};
int64_t local_rank{0}; int64_t local_rank{0};
int64_t device_id{0};
int64_t mp_degree{1}; int64_t mp_degree{1};
int64_t pp_degree{1}; int64_t pp_degree{1};
int64_t mp_ring_id{-1};
int64_t pp_upstream_ring_id{-1};
int64_t pp_downstream_ring_id{-1};
}; };
class DistModel { class DistModel {
...@@ -56,12 +64,16 @@ class DistModel { ...@@ -56,12 +64,16 @@ class DistModel {
bool PrepareProgram(); bool PrepareProgram();
bool LoadProgram(); bool LoadProgram();
bool LoadParameters(); bool LoadParameters();
bool PreparePlace();
bool CommInit(); bool CommInit();
void InsertCommOp(std::string tmp_var_name, int nranks, int rank,
const std::vector<std::string>& peer_endpoints,
framework::BlockDesc* block, int ring_id);
DistModelConfig config_; DistModelConfig config_;
FleetExecutorDesc executor_desc_; FleetExecutorDesc executor_desc_;
platform::Place place_;
std::shared_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
paddle::platform::Place place_;
std::shared_ptr<framework::ProgramDesc> program_; std::shared_ptr<framework::ProgramDesc> program_;
}; };
......
...@@ -58,13 +58,21 @@ void BindFleetExecutor(py::module* m) { ...@@ -58,13 +58,21 @@ void BindFleetExecutor(py::module* m) {
py::class_<DistModelConfig>(*m, "DistModelConfig") py::class_<DistModelConfig>(*m, "DistModelConfig")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("model_dir", &DistModelConfig::model_dir) .def_readwrite("model_dir", &DistModelConfig::model_dir)
.def_readwrite("program_desc", &DistModelConfig::program_desc)
.def_readwrite("scope", &DistModelConfig::scope)
.def_readwrite("place", &DistModelConfig::place)
.def_readwrite("device_id", &DistModelConfig::device_id)
.def_readwrite("trainer_endpoints", &DistModelConfig::trainer_endpoints) .def_readwrite("trainer_endpoints", &DistModelConfig::trainer_endpoints)
.def_readwrite("current_endpoint", &DistModelConfig::current_endpoint) .def_readwrite("current_endpoint", &DistModelConfig::current_endpoint)
.def_readwrite("nranks", &DistModelConfig::nranks) .def_readwrite("nranks", &DistModelConfig::nranks)
.def_readwrite("local_rank", &DistModelConfig::local_rank) .def_readwrite("local_rank", &DistModelConfig::local_rank)
.def_readwrite("device_id", &DistModelConfig::device_id)
.def_readwrite("mp_degree", &DistModelConfig::mp_degree) .def_readwrite("mp_degree", &DistModelConfig::mp_degree)
.def_readwrite("pp_degree", &DistModelConfig::pp_degree); .def_readwrite("pp_degree", &DistModelConfig::pp_degree)
.def_readwrite("mp_ring_id", &DistModelConfig::mp_ring_id)
.def_readwrite("pp_upstream_ring_id",
&DistModelConfig::pp_upstream_ring_id)
.def_readwrite("pp_downstream_ring_id",
&DistModelConfig::pp_downstream_ring_id);
py::class_<DistModel>(*m, "DistModel") py::class_<DistModel>(*m, "DistModel")
.def(py::init<const DistModelConfig&>()) .def(py::init<const DistModelConfig&>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册