未验证 提交 7d53a288 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet exe] Update comm init for dist model (#39603)

上级 e254e7c6
......@@ -53,7 +53,6 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
} else if (input_data.dtype == DistModelDataType::INT32) {
input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place);
} else {
// Q(fleet exe dev): for input/output, should we support fp16
LOG(ERROR) << "unsupported feed type " << input_data.dtype;
return false;
}
......@@ -113,14 +112,6 @@ std::string DistModelDTypeToString(DistModelDataType dtype) {
return "NOT SUPPORT DTYPE";
}
bool IsPPFirstStage(const DistModelConfig &config) {
return config.local_rank - config.mp_degree < 0;
}
bool IsPPLastStage(const DistModelConfig &config) {
return config.local_rank + config.mp_degree >= config.nranks;
}
class DistModelTimer {
public:
void tic() { tic_time = std::chrono::high_resolution_clock::now(); }
......@@ -197,65 +188,34 @@ bool DistModel::PreparePlace() {
}
bool DistModel::CommInit() {
// NOTE (Yuang Liu): The peer endpoints will be obtained with the assumption
// that mp part is always on inner side and pp part is always on outer side.
// TODO(fleet exe dev): The peer endpoints could be configured by users.
PADDLE_ENFORCE_EQ(
config_.pp_degree * config_.mp_degree, config_.nranks,
platform::errors::InvalidArgument(
"The mp_degree multiplies pp_degree is not equal with nranks"));
std::unique_ptr<framework::ProgramDesc> comm_init_program(
new framework::ProgramDesc());
framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock(0);
if (config_.mp_degree > 1) {
PADDLE_ENFORCE_GE(
config_.mp_ring_id, 0,
platform::errors::InvalidArgument(
"mp ring id must be provided for inference under mp."));
VLOG(3) << "Init comm group for mp.";
std::vector<int64_t> &ring_ids =
config_.rank_to_ring_ids_[config_.local_rank];
int64_t order = 0;
std::string var_name_base = "comm_init_";
for (int64_t ring_id : ring_ids) {
VLOG(3) << "Init comm for ring id: " << ring_id;
int64_t ranks_in_group = config_.ring_id_to_ranks_[ring_id].size();
int64_t rank_in_group = 0;
std::vector<int64_t> &ranks = config_.ring_id_to_ranks_[ring_id];
for (int64_t rank : ranks) {
if (config_.local_rank == rank) {
break;
}
rank_in_group += 1;
}
std::vector<std::string> peer_endpoints;
for (int64_t
idx = (config_.local_rank / config_.mp_degree) * config_.mp_degree,
i = 0;
i < config_.mp_degree; ++idx, ++i) {
if (config_.trainer_endpoints[idx] == config_.current_endpoint) {
for (int64_t rank : ranks) {
if (config_.local_rank == rank) {
continue;
}
peer_endpoints.emplace_back(config_.trainer_endpoints[idx]);
}
// get nranks in a mp group and inner group rank for local rank
int64_t mp_group_nranks = config_.nranks / config_.pp_degree;
int64_t mp_group_rank = config_.local_rank % config_.mp_degree;
InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints,
comm_init_block, config_.mp_ring_id);
}
if (config_.pp_degree > 1) {
VLOG(3) << "Init comm group for pp.";
if (!IsPPFirstStage(config_)) {
PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true,
platform::errors::InvalidArgument(
"pp upstream ring id must be provided for "
"non-first pp stage if inference under pp."));
// not the first pp stage, has upstream
std::vector<std::string> upstream_peer_endpoints;
upstream_peer_endpoints.emplace_back(
config_.trainer_endpoints[config_.local_rank - config_.mp_degree]);
InsertCommOp("pp_upstream_comm_id", 2, 1, upstream_peer_endpoints,
comm_init_block, config_.pp_upstream_ring_id);
}
if (!IsPPLastStage(config_)) {
PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true,
platform::errors::InvalidArgument(
"pp downstream ring id must be provided for "
"non-last pp stage if inference under pp."));
// not the last pp stage, has downstream
std::vector<std::string> downstream_peer_endpoints;
downstream_peer_endpoints.emplace_back(
config_.trainer_endpoints[config_.local_rank + config_.mp_degree]);
InsertCommOp("pp_downstream_comm_id", 2, 0, downstream_peer_endpoints,
comm_init_block, config_.pp_downstream_ring_id);
peer_endpoints.emplace_back(config_.trainer_endpoints[rank]);
}
InsertCommOp(var_name_base + std::to_string(order), ranks_in_group,
rank_in_group, peer_endpoints, comm_init_block, ring_id);
order += 1;
}
framework::NaiveExecutor e(place_);
e.CreateVariables(*comm_init_program, 0, true, scope_.get());
......@@ -409,12 +369,7 @@ bool DistModel::LoadParameters() {
bool DistModel::PrepareFleetExe() {
task_node_.reset(new TaskNode(program_.get(), config_.local_rank));
if (config_.local_rank - config_.mp_degree >= 0) {
task_node_->AddUpstreamTask(config_.local_rank - config_.mp_degree);
}
if (config_.local_rank + config_.mp_degree < config_.nranks) {
task_node_->AddDownstreamTask(config_.local_rank + config_.mp_degree);
}
// With auto cut, there is no concept of pp, no need to add dependency.
task_node_->SetType("Compute");
task_node_->Init();
executor_desc_ = FleetExecutorDesc();
......@@ -473,40 +428,13 @@ bool DistModel::PrepareFeedAndFetch() {
}
}
if (config_.pp_degree == 1) {
if (feeds_.size() == 0) {
LOG(ERROR) << "No feed ops in the inf program, please check the program.";
return false;
}
if (fetches_.size() == 0) {
LOG(ERROR) << "No fetch op in the inf program, please check the program.";
return false;
}
} else {
if (IsPPFirstStage(config_)) {
if (feeds_.size() == 0) {
LOG(ERROR) << "Feed ops are needed for the first pp stage.";
return false;
}
} else {
if (feeds_.size() > 0) {
LOG(WARNING) << "Feed op is found in the non-first stage of pp.";
} else {
LOG(INFO) << "No feed ops in non-first pp stage.";
}
}
if (IsPPLastStage(config_)) {
if (fetches_.size() == 0) {
LOG(WARNING) << "No fetch op was found in the last pp stage. Make sure "
"the result has been sent to frist pp stage.";
}
} else {
if (fetches_.size() > 0) {
LOG(WARNING) << "Fetch op is found in the non-last stage of pp.";
} else {
LOG(INFO) << "No fetch op in non-last pp stage.";
}
}
if (feeds_.size() == 0) {
LOG(ERROR) << "No feed ops in the inf program, please check the program.";
return false;
}
if (fetches_.size() == 0) {
LOG(ERROR) << "No fetch op in the inf program, please check the program.";
return false;
}
return true;
}
......@@ -606,7 +534,6 @@ bool DistModel::FetchResult(const framework::LoDTensor &fetch,
bool DistModel::Run(const std::vector<DistModelTensor> &input_data,
std::vector<DistModelTensor> *output_data) {
// TODO(fleet exe dev): support pipeline inf mode
VLOG(3) << "DistModel run for once.";
DistModelTimer timer;
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
......@@ -47,12 +48,9 @@ struct DistModelConfig {
std::string current_endpoint{};
int64_t nranks{1};
int64_t local_rank{0};
int64_t mp_degree{1};
int64_t pp_degree{1};
int64_t mp_ring_id{-1};
int64_t pp_upstream_ring_id{-1};
int64_t pp_downstream_ring_id{-1};
bool enable_timer{false};
std::map<int64_t, std::vector<int64_t>> ring_id_to_ranks_{};
std::map<int64_t, std::vector<int64_t>> rank_to_ring_ids_{};
};
class DistModel {
......
......@@ -151,14 +151,9 @@ void BindFleetExecutor(py::module* m) {
.def_readwrite("current_endpoint", &DistModelConfig::current_endpoint)
.def_readwrite("nranks", &DistModelConfig::nranks)
.def_readwrite("local_rank", &DistModelConfig::local_rank)
.def_readwrite("mp_degree", &DistModelConfig::mp_degree)
.def_readwrite("pp_degree", &DistModelConfig::pp_degree)
.def_readwrite("mp_ring_id", &DistModelConfig::mp_ring_id)
.def_readwrite("enable_timer", &DistModelConfig::enable_timer)
.def_readwrite("pp_upstream_ring_id",
&DistModelConfig::pp_upstream_ring_id)
.def_readwrite("pp_downstream_ring_id",
&DistModelConfig::pp_downstream_ring_id);
.def_readwrite("ring_id_to_ranks", &DistModelConfig::ring_id_to_ranks_)
.def_readwrite("rank_to_ring_ids", &DistModelConfig::rank_to_ring_ids_)
.def_readwrite("enable_timer", &DistModelConfig::enable_timer);
py::class_<DistModel>(*m, "DistModel")
.def(py::init<const DistModelConfig&>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册