diff --git a/paddle/fluid/distributed/fleet_executor/dist_model.cc b/paddle/fluid/distributed/fleet_executor/dist_model.cc index 40b0a8b55e17a2eca26bb2c4d94221054724c530..941d470f87935f95abe5d599c9b7fa7a2730228b 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model.cc +++ b/paddle/fluid/distributed/fleet_executor/dist_model.cc @@ -53,7 +53,6 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, } else if (input_data.dtype == DistModelDataType::INT32) { input_tensor_ptr = input_tensor->mutable_data(dims, place); } else { - // Q(fleet exe dev): for input/output, should we support fp16 LOG(ERROR) << "unsupported feed type " << input_data.dtype; return false; } @@ -113,14 +112,6 @@ std::string DistModelDTypeToString(DistModelDataType dtype) { return "NOT SUPPORT DTYPE"; } -bool IsPPFirstStage(const DistModelConfig &config) { - return config.local_rank - config.mp_degree < 0; -} - -bool IsPPLastStage(const DistModelConfig &config) { - return config.local_rank + config.mp_degree >= config.nranks; -} - class DistModelTimer { public: void tic() { tic_time = std::chrono::high_resolution_clock::now(); } @@ -197,65 +188,34 @@ bool DistModel::PreparePlace() { } bool DistModel::CommInit() { - // NOTE (Yuang Liu): The peer endpoints will be obtained with the assumption - // that mp part is always on inner side and pp part is always on outer side. - // TODO(fleet exe dev): The peer endpoints could be configured by users. - PADDLE_ENFORCE_EQ( - config_.pp_degree * config_.mp_degree, config_.nranks, - platform::errors::InvalidArgument( - "The mp_degree multiplies pp_degree is not equal with nranks")); std::unique_ptr comm_init_program( new framework::ProgramDesc()); framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock(0); - if (config_.mp_degree > 1) { - PADDLE_ENFORCE_GE( - config_.mp_ring_id, 0, - platform::errors::InvalidArgument( - "mp ring id must be provided for inference under mp.")); - VLOG(3) << "Init comm group for mp."; + std::vector &ring_ids = + config_.rank_to_ring_ids_[config_.local_rank]; + int64_t order = 0; + std::string var_name_base = "comm_init_"; + for (int64_t ring_id : ring_ids) { + VLOG(3) << "Init comm for ring id: " << ring_id; + int64_t ranks_in_group = config_.ring_id_to_ranks_[ring_id].size(); + int64_t rank_in_group = 0; + std::vector &ranks = config_.ring_id_to_ranks_[ring_id]; + for (int64_t rank : ranks) { + if (config_.local_rank == rank) { + break; + } + rank_in_group += 1; + } std::vector peer_endpoints; - for (int64_t - idx = (config_.local_rank / config_.mp_degree) * config_.mp_degree, - i = 0; - i < config_.mp_degree; ++idx, ++i) { - if (config_.trainer_endpoints[idx] == config_.current_endpoint) { + for (int64_t rank : ranks) { + if (config_.local_rank == rank) { continue; } - peer_endpoints.emplace_back(config_.trainer_endpoints[idx]); - } - // get nranks in a mp group and inner group rank for local rank - int64_t mp_group_nranks = config_.nranks / config_.pp_degree; - int64_t mp_group_rank = config_.local_rank % config_.mp_degree; - InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints, - comm_init_block, config_.mp_ring_id); - } - if (config_.pp_degree > 1) { - VLOG(3) << "Init comm group for pp."; - if (!IsPPFirstStage(config_)) { - PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true, - platform::errors::InvalidArgument( - "pp upstream ring id must be provided for " - "non-first pp stage if inference under pp.")); - // not the first pp stage, has upstream - std::vector upstream_peer_endpoints; - upstream_peer_endpoints.emplace_back( - config_.trainer_endpoints[config_.local_rank - config_.mp_degree]); - InsertCommOp("pp_upstream_comm_id", 2, 1, upstream_peer_endpoints, - comm_init_block, config_.pp_upstream_ring_id); - } - - if (!IsPPLastStage(config_)) { - PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true, - platform::errors::InvalidArgument( - "pp downstream ring id must be provided for " - "non-last pp stage if inference under pp.")); - // not the last pp stage, has downstream - std::vector downstream_peer_endpoints; - downstream_peer_endpoints.emplace_back( - config_.trainer_endpoints[config_.local_rank + config_.mp_degree]); - InsertCommOp("pp_downstream_comm_id", 2, 0, downstream_peer_endpoints, - comm_init_block, config_.pp_downstream_ring_id); + peer_endpoints.emplace_back(config_.trainer_endpoints[rank]); } + InsertCommOp(var_name_base + std::to_string(order), ranks_in_group, + rank_in_group, peer_endpoints, comm_init_block, ring_id); + order += 1; } framework::NaiveExecutor e(place_); e.CreateVariables(*comm_init_program, 0, true, scope_.get()); @@ -409,12 +369,7 @@ bool DistModel::LoadParameters() { bool DistModel::PrepareFleetExe() { task_node_.reset(new TaskNode(program_.get(), config_.local_rank)); - if (config_.local_rank - config_.mp_degree >= 0) { - task_node_->AddUpstreamTask(config_.local_rank - config_.mp_degree); - } - if (config_.local_rank + config_.mp_degree < config_.nranks) { - task_node_->AddDownstreamTask(config_.local_rank + config_.mp_degree); - } + // With auto cut, there is no concept of pp, no need to add dependency. task_node_->SetType("Compute"); task_node_->Init(); executor_desc_ = FleetExecutorDesc(); @@ -473,40 +428,13 @@ bool DistModel::PrepareFeedAndFetch() { } } - if (config_.pp_degree == 1) { - if (feeds_.size() == 0) { - LOG(ERROR) << "No feed ops in the inf program, please check the program."; - return false; - } - if (fetches_.size() == 0) { - LOG(ERROR) << "No fetch op in the inf program, please check the program."; - return false; - } - } else { - if (IsPPFirstStage(config_)) { - if (feeds_.size() == 0) { - LOG(ERROR) << "Feed ops are needed for the first pp stage."; - return false; - } - } else { - if (feeds_.size() > 0) { - LOG(WARNING) << "Feed op is found in the non-first stage of pp."; - } else { - LOG(INFO) << "No feed ops in non-first pp stage."; - } - } - if (IsPPLastStage(config_)) { - if (fetches_.size() == 0) { - LOG(WARNING) << "No fetch op was found in the last pp stage. Make sure " - "the result has been sent to frist pp stage."; - } - } else { - if (fetches_.size() > 0) { - LOG(WARNING) << "Fetch op is found in the non-last stage of pp."; - } else { - LOG(INFO) << "No fetch op in non-last pp stage."; - } - } + if (feeds_.size() == 0) { + LOG(ERROR) << "No feed ops in the inf program, please check the program."; + return false; + } + if (fetches_.size() == 0) { + LOG(ERROR) << "No fetch op in the inf program, please check the program."; + return false; } return true; } @@ -606,7 +534,6 @@ bool DistModel::FetchResult(const framework::LoDTensor &fetch, bool DistModel::Run(const std::vector &input_data, std::vector *output_data) { - // TODO(fleet exe dev): support pipeline inf mode VLOG(3) << "DistModel run for once."; DistModelTimer timer; diff --git a/paddle/fluid/distributed/fleet_executor/dist_model.h b/paddle/fluid/distributed/fleet_executor/dist_model.h index c980178b67c5244e751a8e89b945f353110a7456..d0203c131357c749b7df20a345982d2ddd025783 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model.h +++ b/paddle/fluid/distributed/fleet_executor/dist_model.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include #include @@ -47,12 +48,9 @@ struct DistModelConfig { std::string current_endpoint{}; int64_t nranks{1}; int64_t local_rank{0}; - int64_t mp_degree{1}; - int64_t pp_degree{1}; - int64_t mp_ring_id{-1}; - int64_t pp_upstream_ring_id{-1}; - int64_t pp_downstream_ring_id{-1}; bool enable_timer{false}; + std::map> ring_id_to_ranks_{}; + std::map> rank_to_ring_ids_{}; }; class DistModel { diff --git a/paddle/fluid/pybind/bind_fleet_executor.cc b/paddle/fluid/pybind/bind_fleet_executor.cc index 0422a9cf8cc0ad984621fe09ee28bb7d624897d6..7bb7f03983eb9e8c88f46174a40664f1110682d1 100644 --- a/paddle/fluid/pybind/bind_fleet_executor.cc +++ b/paddle/fluid/pybind/bind_fleet_executor.cc @@ -151,14 +151,9 @@ void BindFleetExecutor(py::module* m) { .def_readwrite("current_endpoint", &DistModelConfig::current_endpoint) .def_readwrite("nranks", &DistModelConfig::nranks) .def_readwrite("local_rank", &DistModelConfig::local_rank) - .def_readwrite("mp_degree", &DistModelConfig::mp_degree) - .def_readwrite("pp_degree", &DistModelConfig::pp_degree) - .def_readwrite("mp_ring_id", &DistModelConfig::mp_ring_id) - .def_readwrite("enable_timer", &DistModelConfig::enable_timer) - .def_readwrite("pp_upstream_ring_id", - &DistModelConfig::pp_upstream_ring_id) - .def_readwrite("pp_downstream_ring_id", - &DistModelConfig::pp_downstream_ring_id); + .def_readwrite("ring_id_to_ranks", &DistModelConfig::ring_id_to_ranks_) + .def_readwrite("rank_to_ring_ids", &DistModelConfig::rank_to_ring_ids_) + .def_readwrite("enable_timer", &DistModelConfig::enable_timer); py::class_(*m, "DistModel") .def(py::init())