From eabb2105fae03db056dd85e50bf4e959417f4c63 Mon Sep 17 00:00:00 2001 From: chengduo Date: Mon, 7 Jan 2019 02:11:01 -0600 Subject: [PATCH] Refactor MultiDevSSAGraphBuilder (#15090) * Refactor ParallelExecutor test=develop * extract Reduce and AllReduce mode from MultiDevSSAGraphBuilder test=develop * Refactor MultiDevSSAGraphBuilder test=developt * Remove enable_data_balance test=develop * code refine test=develop * remove data balance test=develop * refine ScaleLossGradOp test=develop * remove uncessary file test=develop * code refine test=develop * modify function name test=develop * follow comments test=develop * add is_distribution field test=develop * set is_distribution test=develop * fix DistSSAGraphBuilder test=develop --- .../fluid/framework/details/build_strategy.cc | 54 +- .../fluid/framework/details/build_strategy.h | 8 +- .../details/multi_devices_graph_check_pass.cc | 104 ++- .../details/multi_devices_graph_check_pass.h | 38 - .../details/multi_devices_graph_pass.cc | 864 ++++++++++-------- .../details/multi_devices_graph_pass.h | 144 ++- paddle/fluid/pybind/pybind.cc | 11 +- python/paddle/fluid/parallel_executor.py | 14 + .../tests/unittests/test_reader_reset.py | 2 - 9 files changed, 701 insertions(+), 538 deletions(-) delete mode 100644 paddle/fluid/framework/details/multi_devices_graph_check_pass.h diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 43c2eb717..a68b69e02 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/details/memory_reuse_types.h" -#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" +#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/sequential_execution_pass.h" @@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { if (strategy.memory_optimize_) { auto analysis_var_pass = AppendPass("analysis_var_pass"); } - // Convert graph to run on multi-devices. - auto multi_devices_pass = AppendPass("multi_devices_pass"); - multi_devices_pass->SetNotOwned("strategy", - &strategy_); + + AppendMultiDevPass(strategy); // Add a graph print pass to record a graph with device info. if (!strategy_.debug_graphviz_path_.empty()) { @@ -115,6 +113,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { } } + // Convert graph to run on multi-devices. + void AppendMultiDevPass(const BuildStrategy &strategy) { + ir::Pass *multi_devices_pass; + if (strategy_.is_distribution_) { + multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); + } else { + if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { + multi_devices_pass = + AppendPass("allreduce_mode_multi_devices_pass").get(); + } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { + multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); + } else { + PADDLE_THROW("Unknown reduce strategy."); + } + } + multi_devices_pass->SetNotOwned("strategy", + &strategy_); + } + private: BuildStrategy strategy_; }; @@ -131,6 +148,10 @@ std::shared_ptr BuildStrategy::CreatePassesFromStrategy( return pass_builder_; } +bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const { + return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0; +} + std::unique_ptr BuildStrategy::Apply( const ProgramDesc &main_program, const std::vector &places, const std::string &loss_var_name, const std::vector &local_scopes, @@ -145,22 +166,23 @@ std::unique_ptr BuildStrategy::Apply( std::unique_ptr graph(new ir::Graph(main_program)); for (std::shared_ptr &pass : pass_builder_->AllPasses()) { - if (pass->Type() == "multi_devices_pass") { - pass->Erase("places"); - pass->SetNotOwned>("places", &places); - pass->Erase("loss_var_name"); - pass->SetNotOwned("loss_var_name", &loss_var_name); - pass->Erase("local_scopes"); - pass->SetNotOwned>("local_scopes", + if (IsMultiDevPass(pass->Type())) { + pass->Erase(kPlaces); + pass->SetNotOwned>(kPlaces, &places); + pass->Erase(kLossVarName); + pass->SetNotOwned(kLossVarName, &loss_var_name); + pass->Erase(kLocalScopes); + pass->SetNotOwned>(kLocalScopes, &local_scopes); - pass->Erase("nranks"); - pass->Set("nranks", new size_t(nranks)); + pass->Erase(kNRanks); + pass->Set(kNRanks, new size_t(nranks)); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; pass->Erase("nccl_ctxs"); pass->SetNotOwned("nccl_ctxs", nctx); #endif + } else if (pass->Type() == "analysis_var_pass") { const std::vector *all_op_descs = new std::vector(main_program.Block(0).AllOps()); @@ -201,7 +223,9 @@ std::unique_ptr BuildStrategy::Apply( USE_PASS(fuse_elewise_add_act_pass); USE_PASS(graph_viz_pass); USE_PASS(multi_batch_merge_pass); -USE_PASS(multi_devices_pass); +USE_PASS(reduce_mode_multi_devices_pass); +USE_PASS(allreduce_mode_multi_devices_pass); +USE_PASS(dist_multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); USE_PASS(analysis_var_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index b75c01c48..15c2e01b6 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -74,8 +74,6 @@ struct BuildStrategy { bool fuse_elewise_add_act_ops_{false}; - bool enable_data_balance_{false}; - bool memory_optimize_{false}; bool memory_early_delete_{false}; @@ -84,6 +82,10 @@ struct BuildStrategy { bool fuse_broadcast_op_{false}; + // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, + // num_trainers is 1, so the current fields of build_strategy doesn't tell if + // it's distributed model. + bool is_distribution_{false}; int num_trainers_{1}; int trainer_id_{0}; std::vector trainers_endpoints_; @@ -104,6 +106,8 @@ struct BuildStrategy { bool IsFinalized() const { return is_finalized_; } + bool IsMultiDevPass(const std::string &pass_name) const; + // Apply the passes built by the pass_builder_. The passes will be // applied to the Program and output an ir::Graph. std::unique_ptr Apply(const ProgramDesc &main_program, diff --git a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc index c8ea18804..a4bb1e26d 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" @@ -21,68 +21,78 @@ namespace paddle { namespace framework { namespace details { -bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { - std::unordered_map pending_ops; - std::unordered_set pending_vars; - std::unordered_set ready_vars; - std::unordered_set ready_ops; +class SSAGraghBuilderWithChecker : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override { + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return graph; + } - auto insert_pending_var = [&](VarHandleBase *var) { - pending_vars.insert(var); - if (var->GeneratedOp() == nullptr) { - ready_vars.emplace(var); - } - }; + bool IsValidGraph(const ir::Graph *graph) const { + std::unordered_map pending_ops; + std::unordered_set pending_vars; + std::unordered_set ready_vars; + std::unordered_set ready_ops; - for (auto &var_map : graph->Get(kGraphVars)) { - for (auto &name_pair : var_map) { - for (auto &version_pair : name_pair.second) { - insert_pending_var(version_pair); + auto insert_pending_var = [&](VarHandleBase *var) { + pending_vars.insert(var); + if (var->GeneratedOp() == nullptr) { + ready_vars.emplace(var); } - } - } + }; - for (auto &var : graph->Get(kGraphDepVars)) { - insert_pending_var(var); - } + for (auto &var_map : graph->Get(kGraphVars)) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + insert_pending_var(version_pair); + } + } + } - for (OpHandleBase *op : ir::FilterByNodeWrapper(*graph)) { - if (op->Inputs().empty()) { - ready_ops.insert(op); - } else { - pending_ops.insert({op, op->NoDupInputSize()}); + for (auto &var : graph->Get(kGraphDepVars)) { + insert_pending_var(var); } - } - auto run_all_ops = [&](std::unordered_set &set) { - for (auto *op : set) { - for (auto out : op->Outputs()) { - ready_vars.emplace(out); + for (OpHandleBase *op : ir::FilterByNodeWrapper(*graph)) { + if (op->Inputs().empty()) { + ready_ops.insert(op); + } else { + pending_ops.insert({op, op->NoDupInputSize()}); } } - set.clear(); - }; - while (!pending_vars.empty()) { - run_all_ops(ready_ops); + auto run_all_ops = [&](std::unordered_set &set) { + for (auto *op : set) { + for (auto out : op->Outputs()) { + ready_vars.emplace(out); + } + } + set.clear(); + }; - if (ready_vars.empty()) { - return false; - } + while (!pending_vars.empty()) { + run_all_ops(ready_ops); - for (auto ready_var : ready_vars) { - pending_vars.erase(ready_var); - for (auto *op : ready_var->PendingOps()) { - auto &deps = --pending_ops[op]; - if (deps == 0) { - ready_ops.insert(op); + if (ready_vars.empty()) { + return false; + } + + for (auto ready_var : ready_vars) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->PendingOps()) { + auto &deps = --pending_ops[op]; + if (deps == 0) { + ready_ops.insert(op); + } } } + ready_vars.clear(); } - ready_vars.clear(); + return true; } - return true; -} +}; + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_check_pass.h b/paddle/fluid/framework/details/multi_devices_graph_check_pass.h deleted file mode 100644 index 1e2b1867c..000000000 --- a/paddle/fluid/framework/details/multi_devices_graph_check_pass.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/details/multi_devices_helper.h" - -#include - -namespace paddle { -namespace framework { -namespace details { - -class SSAGraghBuilderWithChecker : public ir::Pass { - protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override { - PADDLE_ENFORCE(IsValidGraph(graph.get())); - return graph; - } - - bool IsValidGraph(const ir::Graph* graph) const; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 761c9ab90..d91993bd4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -134,15 +134,8 @@ void AddOutputToLeafOps(ir::Graph *graph) { } } // namespace -static const char kLossVarName[] = "loss_var_name"; -static const char kPlaces[] = "places"; -static const char kLocalScopes[] = "local_scopes"; -static const char kStrategy[] = "strategy"; -static const char kNRanks[] = "nranks"; - -void MultiDevSSAGraphBuilder::Init() const { +void MultiDevSSAGraphBuilderBase::Init() const { all_vars_.clear(); - balance_vars_.clear(); loss_var_name_ = Get(kLossVarName); places_ = Get>(kPlaces); @@ -151,31 +144,16 @@ void MultiDevSSAGraphBuilder::Init() const { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) nccl_ctxs_ = &Get("nccl_ctxs"); #endif - - balance_vars_.resize(places_.size(), 0); - - if (strategy_.enable_data_balance_ && places_.size() == 1) { - LOG(WARNING) << "It is no need to enable data balance when there is only " - "one place. enable_data_balance is set to False."; - strategy_.enable_data_balance_ = false; - } } -std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( +std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( std::unique_ptr graph) const { Init(); - // Give the topology sort order and rebuild the graph structure. - std::vector sorted_ops = ir::TopologySortOperations(*graph); - - if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { - sorted_ops = SortForReduceMode(sorted_ops); - } + std::vector sorted_ops = SortOperations(*graph); auto nodes = graph->ReleaseNodes(); ir::Graph &result = *graph; - size_t nranks = Get(kNRanks); - for (auto &node : nodes) { if (node->IsVar() && node->Var()) { all_vars_.emplace(node->Name(), node->Var()); @@ -187,146 +165,61 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphOps, new GraphOps); - std::vector> bcast_var_name_set; - bcast_var_name_set.resize(places_.size()); - bool is_forwarding = true; - bool is_dist_train = false; - - std::unordered_map sharded_var_device; + bool insert_collection_ops = NeedCollectiveOps(); for (ir::Node *node : sorted_ops) { - if (OpHaveRole(*node, OpRole::kRPC)) { - int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device); - PADDLE_ENFORCE(op_dev_id != -1, - "Can not schedule the RPC operator to the right place."); - if (node->Op()->Type() == "recv") { - auto recv_vars_attr = - boost::get>(node->Op()->GetNullableAttr( - OpProtoAndCheckerMaker::OpRoleVarAttrName())); - PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient] - if (recv_vars_attr[0].find(".block") == std::string::npos) { - bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]); - } - } - is_dist_train = true; - } else if (OpHaveRole(*node, OpRole::kDist)) { - int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device); - if (node->Op()->Type() == "concat") { - auto origin_param_name = node->Op()->OutputArgumentNames()[0]; - bcast_var_name_set[op_dev_id].emplace(origin_param_name); - } - } else if (IsScaleLossOp(node)) { - // user can customize loss@grad if not use_default_grad_scale_ - if (strategy_.gradient_scale_ != - BuildStrategy::GradientScaleStrategy::kCustomized) { - // TODO(paddle-dev): Why is there no input for this op_handle? - auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; - auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType(); - CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0], - out_dtype); - } - // This assumes the backward generating code will ensure IsScaleLossOp - // is true only for the op that scale the final scalar loss. - // It also assumes backward op will always follow the forward op in - // the block. - is_forwarding = false; + if (DealWithSpecialOp(&result, node)) { + continue; } else { - int op_dev_id = GetOpDeviceID(node, sharded_var_device); - if (op_dev_id != -1) { // This op only runs on one specific device. - CreateComputationalOp(&result, node, op_dev_id); - for (ir::Node *n : node->outputs) { - sharded_var_device.emplace(n->Name(), op_dev_id); - } + // This op runs on all devices + if (IsScaleLossOp(node)) { + // user can customize loss@grad if not use_default_grad_scale_ + InsertScaleLossGradOp(&result, node); + // This assumes the backward generating code will ensure IsScaleLossOp + // is true only for the op that scale the final scalar loss. + // It also assumes backward op will always follow the forward op in + // the block. + is_forwarding = false; } else { - // This op runs on all devices, and its output may have parameter's - // gradients. - // TODO(paddle-dev): Why is so special about "read" op? - if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { - node->Op()->SetAttr("throw_eof_exp", false); - CreateComputationalOps(&result, node, places_.size()); - const auto &data_var_names = node->Op()->Output("Out"); - InsertDataBalanceOp(&result, data_var_names); - } else { - CreateComputationalOps(&result, node, places_.size()); - } + CreateComputationalOps(&result, node, places_.size()); + } - if (!is_forwarding && nranks > 1UL) { + // Insert collection ops + if (!is_forwarding && insert_collection_ops) { + try { bool is_bk_op = static_cast(boost::get(node->Op()->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward)); if (!is_bk_op) continue; + // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. - try { - auto backward_vars = boost::get>( - node->Op()->GetNullableAttr( - OpProtoAndCheckerMaker::OpRoleVarAttrName())); - - PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); - - for (size_t i = 0; i < backward_vars.size(); i += 2) { - auto &p_name = backward_vars[i]; - auto &g_name = backward_vars[i + 1]; - VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; - size_t cur_device_id = -1; - switch (strategy_.reduce_) { - case BuildStrategy::ReduceStrategy::kReduce: - cur_device_id = GetAppropriateDeviceID({g_name}); - CreateReduceOp(&result, g_name, cur_device_id); - sharded_var_device.emplace(g_name, cur_device_id); - if (!is_dist_train) { - bcast_var_name_set[cur_device_id].emplace(p_name); - } - break; - case BuildStrategy::ReduceStrategy::kAllReduce: - if (IsSparseGradient(g_name)) { - CreateReduceOp(&result, g_name, 0); - CreateBroadcastOp(&result, g_name, 0); - } else { - InsertAllReduceOp(&result, g_name); - } - break; - default: - LOG(FATAL) << "Unknown reduce strategy "; - break; - } - } - } catch (boost::bad_get e) { + auto backward_vars = + boost::get>(node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); + + for (size_t i = 0; i < backward_vars.size(); i += 2) { + auto &p_name = backward_vars[i]; + auto &g_name = backward_vars[i + 1]; + VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; + + InsertCollectiveOp(&result, p_name, g_name); } + } catch (boost::bad_get e) { } } } } - bool use_gpu = false; -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - use_gpu = nccl_ctxs_ != nullptr; -#endif - // Insert broadcast operators principle: - // 1. Broadcast optimized parameters in Reduce strategy; - // 2. No need broadcast optimized parameters in AllReduce strategy because of - // the optimization sub-graph would be run on every GPU; - // 3. Allways broadcast received parameters in Distribute Training. - if ((use_gpu && - strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || - is_dist_train) { - if (strategy_.fuse_broadcast_op_) { - CreateFusedBroadcastOp(&result, bcast_var_name_set); - } else { - for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { - auto &to_bcast_set = bcast_var_name_set[dev_id]; - for (auto &bcast_name : to_bcast_set) { - CreateBroadcastOp(&result, bcast_name, dev_id); - } - } - } - } + InsertPostprocessOps(&result); + /* Dependency graph has been constructed. However, there are still data hazards need to be handled. - */ + */ PolishGraphToSupportDataHazards(&result); /* @@ -337,67 +230,54 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( return graph; } -std::vector MultiDevSSAGraphBuilder::SortForReduceMode( - const std::vector &topo_ops) const { - std::unordered_map sharded_var_device; - std::vector sorted_ops; - std::unordered_map> delayed_op; - sorted_ops.reserve(topo_ops.size()); - - auto insert_delayed_op = [&](const std::string &var_name, int dev_id) { - sharded_var_device.emplace(var_name, dev_id); - if (delayed_op.count(var_name)) { - auto &ops = delayed_op.at(var_name); - sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end()); - delayed_op.at(var_name).clear(); - } - }; +void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp( + ir::Graph *result, const ir::Node *node) const { + // user can customize loss@grad if not use_default_grad_scale_ + size_t loss_scale = 0; + switch (this->strategy_.gradient_scale_) { + case BuildStrategy::GradientScaleStrategy::kOne: + loss_scale = 1; + break; + case BuildStrategy::GradientScaleStrategy::kCoeffNumDevice: + loss_scale = Get(kNRanks); + break; + case BuildStrategy::GradientScaleStrategy::kCustomized: + loss_scale = 0; + break; + default: + LOG(FATAL) << "Unknown gradient scale strategy."; + break; + } + + if (loss_scale) { + // TODO(paddle-dev): Why is there no input for this op_handle? + auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; + auto out_dtype = this->all_vars_.at(loss_grad_name)->GetDataType(); + this->CreateScaleLossGradOp(result, loss_grad_name, node->outputs[0], + loss_scale, out_dtype); + } +} - for (ir::Node *node : topo_ops) { - int op_dev_id = GetOpDeviceID(node, sharded_var_device, &delayed_op); - if (op_dev_id > -1) { - // This op only runs on one specific device. - sorted_ops.emplace_back(node); - for (ir::Node *n : node->outputs) { - insert_delayed_op(n->Name(), op_dev_id); - } - } else if (op_dev_id == -1) { - // This op runs on all devices, and its output may have parameter's - // gradients. - sorted_ops.emplace_back(node); - bool is_bk_op = - static_cast(boost::get(node->Op()->GetAttr( - OpProtoAndCheckerMaker::OpRoleAttrName())) & - static_cast(OpRole::kBackward)); - if (!is_bk_op) continue; - // Currently, we assume that once gradient is generated, it can be - // broadcast, and each gradient is only broadcast once. - std::vector backward_vars; - try { - backward_vars = - boost::get>(node->Op()->GetNullableAttr( - OpProtoAndCheckerMaker::OpRoleVarAttrName())); - } catch (boost::bad_get e) { - } - PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); +std::vector MultiDevSSAGraphBuilderBase::SortOperations( + const ir::Graph &graph) const { + return ir::TopologySortOperations(graph); +} - for (size_t i = 0; i < backward_vars.size(); i += 2) { - auto &g_name = backward_vars[i + 1]; - size_t cur_device_id = GetAppropriateDeviceID({g_name}); - insert_delayed_op(g_name, static_cast(cur_device_id)); - } - } else if (op_dev_id == -2) { - // The Op on which the Op depends has not yet been generated. - } - } +bool MultiDevSSAGraphBuilderBase::UseGPU() const { + bool use_gpu = false; +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + use_gpu = nccl_ctxs_ != nullptr; +#endif + return use_gpu; +} - PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size()); - return sorted_ops; +bool MultiDevSSAGraphBuilderBase::NeedCollectiveOps() const { + return Get(kNRanks) > 1; } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, - ir::Node *node, - size_t place_id) const { +void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result, + ir::Node *node, + size_t place_id) const { auto p = places_[place_id]; auto *op_handle = result->Get(kGraphOps).back(); op_handle->SetDeviceContext(p, @@ -420,28 +300,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, } } -size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( - const std::vector &var_names) const { - int64_t numel_sum = 0; - for (auto var_name : var_names) { - if (all_vars_.find(var_name) == all_vars_.end()) continue; - auto var_desc = all_vars_.at(var_name); - PADDLE_ENFORCE_NOT_NULL(var_desc); - auto dim = framework::make_ddim(var_desc->GetShape()); - int64_t numel = framework::product(dim); - PADDLE_ENFORCE_GT(numel, 0); - numel_sum += numel; - } - - auto smallest = - std::min_element(std::begin(balance_vars_), std::end(balance_vars_)); - size_t dev_id = - static_cast(std::distance(std::begin(balance_vars_), smallest)); - balance_vars_[dev_id] += numel_sum; - return dev_id; -} - -void MultiDevSSAGraphBuilder::SetCommunicationContext( +void MultiDevSSAGraphBuilderBase::SetCommunicationContext( OpHandleBase *op_handle, const platform::Place &p) const { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) if (nccl_ctxs_ == nullptr) { @@ -454,9 +313,9 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( #endif } -void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, - const std::string &p_name, - size_t src_dev_id) const { +void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result, + const std::string &p_name, + size_t src_dev_id) const { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto *op_handle = new BroadcastOpHandle( result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), @@ -484,7 +343,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, } } -void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( +void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp( ir::Graph *result, const std::vector> &bcast_varnames) const { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) @@ -522,17 +381,17 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( } } -void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, - ir::Node *node, - int dev_id) const { +void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, + ir::Node *node, + int dev_id) const { result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), local_scopes_[dev_id], places_[dev_id], dev_id)); CreateOpHandleIOs(result, node, dev_id); } -void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, - const std::string &og) const { +void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( + ir::Graph *result, const std::string &og) const { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), @@ -560,102 +419,15 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, } } -void MultiDevSSAGraphBuilder::InsertDataBalanceOp( - ir::Graph *result, const std::vector &datas) const { -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( - result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), - local_scopes_, places_, nccl_ctxs_)); -#else - result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( - result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), - local_scopes_, places_)); -#endif - auto *op_handle = result->Get(kGraphOps).back(); - for (size_t i = 0; i < places_.size(); ++i) { - auto &p = places_[i]; - SetCommunicationContext(op_handle, p); - for (const std::string &d_name : datas) { - auto &vars = result->Get(kGraphVars)[i][d_name]; - PADDLE_ENFORCE(!vars.empty()); - op_handle->AddInput(vars.back()); - auto var = new VarHandle( - result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), - vars.size(), i, d_name, p); - vars.emplace_back(var); - op_handle->AddOutput(var); - } - } -} - -int MultiDevSSAGraphBuilder::GetOpDeviceID( - ir::Node *node, - const std::unordered_map &sharded_var_device, - std::unordered_map> *delay_ops) const { - if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { - return -1; - } - - if (!OpHaveRole(*node, framework::OpRole::kOptimize)) { - return -1; - } - - auto param_grad = boost::get>( - node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); - - PADDLE_ENFORCE_EQ(param_grad.size(), 2U); - int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device); - - if (dev_id == -1) { - (*delay_ops)[param_grad[1]].push_back(node); - return -2; - } - return dev_id; -} - -int MultiDevSSAGraphBuilder::GetOpDeviceID( - ir::Node *node, - const std::unordered_map &sharded_var_device) const { - if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { - return -1; - } - - if (!OpHaveRole(*node, framework::OpRole::kOptimize)) { - return -1; - } - auto param_grad = boost::get>( - node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); - - PADDLE_ENFORCE_EQ(param_grad.size(), 2U); - int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device); - PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", - node->Op()->Type(), param_grad[0], param_grad[1]); - return dev_id; -} - -int MultiDevSSAGraphBuilder::GetVarDeviceID( - const std::string &varname, - const std::unordered_map &sharded_var_device) const { - auto got = sharded_var_device.find(varname); - if (got == sharded_var_device.end()) { - auto pos = varname.find(framework::kNewGradSuffix); - if (pos != std::string::npos) { - got = sharded_var_device.find(varname.substr(0, pos)); - } - } - return got == sharded_var_device.end() ? -1 : got->second; -} - -void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( +void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp( ir::Graph *result, const std::string &loss_grad_name, - ir::Node *out_var_node, proto::VarType::Type dtype) const { - size_t nranks = Get("nranks"); + ir::Node *out_var_node, size_t loss_scale, + proto::VarType::Type dtype) const { for (size_t i = 0; i < places_.size(); ++i) { - // Insert ScaleCost OpHandle auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *op_handle = new ScaleLossGradOpHandle( result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), - nranks, local_scopes_[i], places_[i], dev_ctx, dtype); + loss_scale, local_scopes_[i], places_[i], dev_ctx, dtype); result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -669,9 +441,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( } } -void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, - ir::Node *node, - size_t num_places) const { +void MultiDevSSAGraphBuilderBase::CreateComputationalOps( + ir::Graph *result, ir::Node *node, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; @@ -681,9 +452,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, } } -VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, - const std::string &og, - int dst_dev_id) const { +VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result, + const std::string &og, + int dst_dev_id) const { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), @@ -712,51 +483,273 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, return var; } -int MultiDevSSAGraphBuilder::CreateDistTrainOp( - ir::Graph *result, ir::Node *node, - std::unordered_map *sharded_var_device) const { - int op_dev_id = -1; - std::vector input_var_names; - std::vector output_var_names; - for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Name()); +bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const { + return boost::get( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss)) && + !loss_var_name_.empty(); // If loss_var is empty. This is test mode +} + +bool MultiDevSSAGraphBuilderBase::IsSparseGradient( + const std::string &og) const { + PADDLE_ENFORCE(all_vars_.count(og) != 0); + if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { + return true; } - for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Name()); + return false; +} + +void AllReduceSSAGraphBuilder::InsertCollectiveOp( + ir::Graph *result, const std::string &p_name, + const std::string &g_name) const { + if (IsSparseGradient(g_name)) { + CreateReduceOp(result, g_name, 0); + CreateBroadcastOp(result, g_name, 0); + } else { + CreateAllReduceOp(result, g_name); } +} - if (node->Op()->Type() == "split_byref" || - node->Op()->Type() == "split_selected_rows" || - node->Op()->Type() == "split_ids") { - // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device); - if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { - op_dev_id = GetAppropriateDeviceID(input_var_names); - for (auto &varname : input_var_names) { - sharded_var_device->emplace(varname, op_dev_id); +int BalanceVarSSAGraphBuilder::GetVarDeviceID( + const std::string &varname) const { + auto got = sharded_var_device_.find(varname); + if (got == sharded_var_device_.end()) { + auto pos = varname.find(framework::kNewGradSuffix); + if (pos != std::string::npos) { + got = sharded_var_device_.find(varname.substr(0, pos)); + } + } + return got == sharded_var_device_.end() ? -1 : got->second; +} + +int BalanceVarSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { + if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { + return -1; + } + if (!OpHaveRole(*node, framework::OpRole::kOptimize)) { + return -1; + } + auto param_grad = boost::get>( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + + PADDLE_ENFORCE_EQ(param_grad.size(), 2U); + int dev_id = GetVarDeviceID(param_grad[1]); + PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", + node->Op()->Type(), param_grad[0], param_grad[1]); + return dev_id; +} + +size_t BalanceVarSSAGraphBuilder::GetAppropriateDeviceID( + const std::vector &var_names) const { + int64_t numel_sum = 0; + for (auto var_name : var_names) { + if (all_vars_.find(var_name) == all_vars_.end()) continue; + auto var_desc = all_vars_.at(var_name); + PADDLE_ENFORCE_NOT_NULL(var_desc); + auto dim = framework::make_ddim(var_desc->GetShape()); + int64_t numel = framework::product(dim); + PADDLE_ENFORCE_GT(numel, 0); + numel_sum += numel; + } + + auto smallest = + std::min_element(std::begin(balance_vars_), std::end(balance_vars_)); + size_t dev_id = + static_cast(std::distance(std::begin(balance_vars_), smallest)); + balance_vars_[dev_id] += numel_sum; + return dev_id; +} + +void BalanceVarSSAGraphBuilder::ResetState() const { + balance_vars_.clear(); + sharded_var_device_.clear(); + + balance_vars_.resize(places_.size(), 0); +} + +void ReduceSSAGraphBuilder::Init() const { + MultiDevSSAGraphBuilderBase::Init(); + ResetState(); +} + +void ReduceSSAGraphBuilder::ResetState() const { + BalanceVarSSAGraphBuilder::ResetState(); + bcast_var_name_set_.clear(); + bcast_var_name_set_.resize(places_.size()); +} + +void ReduceSSAGraphBuilder::InsertCollectiveOp( + ir::Graph *result, const std::string &p_name, + const std::string &g_name) const { + size_t cur_device_id = GetAppropriateDeviceID({g_name}); + CreateReduceOp(result, g_name, cur_device_id); + sharded_var_device_.emplace(g_name, cur_device_id); + bcast_var_name_set_[cur_device_id].emplace(p_name); +} + +bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, + ir::Node *node) const { + int op_dev_id = BalanceVarSSAGraphBuilder::GetOpDeviceID(node); + if (op_dev_id != -1) { + // This op only runs on one specific device. + CreateComputationalOp(result, node, op_dev_id); + for (ir::Node *n : node->outputs) { + sharded_var_device_.emplace(n->Name(), op_dev_id); + } + return true; + } + return false; +} + +void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { + if (UseGPU()) { + if (strategy_.fuse_broadcast_op_) { + CreateFusedBroadcastOp(result, bcast_var_name_set_); + } else { + for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { + auto &to_bcast_set = bcast_var_name_set_[dev_id]; + for (auto &bcast_name : to_bcast_set) { + CreateBroadcastOp(result, bcast_name, dev_id); + } } } - for (auto &varname : output_var_names) { - sharded_var_device->emplace(varname, op_dev_id); + } +} + +int ReduceSSAGraphBuilder::GetOpDeviceID( + ir::Node *node, + std::unordered_map> *delay_ops) const { + if (!OpHaveRole(*node, framework::OpRole::kOptimize)) { + return -1; + } + + auto param_grad = boost::get>( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + + PADDLE_ENFORCE_EQ(param_grad.size(), 2U); + int dev_id = GetVarDeviceID(param_grad[1]); + + if (dev_id == -1) { + (*delay_ops)[param_grad[1]].push_back(node); + return -2; + } + return dev_id; +} + +std::vector ReduceSSAGraphBuilder::SortOperations( + const ir::Graph &graph) const { + std::vector sorted_ops = ir::TopologySortOperations(graph); + return SortForReduceMode(sorted_ops); +} + +std::vector ReduceSSAGraphBuilder::SortForReduceMode( + const std::vector &topo_ops) const { + std::vector sorted_ops; + std::unordered_map> delayed_op; + sorted_ops.reserve(topo_ops.size()); + ResetState(); + + auto insert_delayed_op = [&](const std::string &var_name, int dev_id) { + sharded_var_device_.emplace(var_name, dev_id); + if (delayed_op.count(var_name)) { + auto &ops = delayed_op.at(var_name); + sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end()); + delayed_op.at(var_name).clear(); } - } else if (node->Op()->Type() == "concat") { - op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device); - for (auto &varname : output_var_names) { - sharded_var_device->emplace(varname, op_dev_id); + }; + + for (ir::Node *node : topo_ops) { + int op_dev_id = GetOpDeviceID(node, &delayed_op); + if (op_dev_id > -1) { + // This op only runs on one specific device. + sorted_ops.emplace_back(node); + for (ir::Node *n : node->outputs) { + insert_delayed_op(n->Name(), op_dev_id); + } + } else if (op_dev_id == -1) { + // This op runs on all devices, and its output may have parameter's + // gradients. + sorted_ops.emplace_back(node); + bool is_bk_op = + static_cast(boost::get(node->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) & + static_cast(OpRole::kBackward)); + if (!is_bk_op) continue; + // Currently, we assume that once gradient is generated, it can be + // broadcast, and each gradient is only broadcast once. + std::vector backward_vars; + try { + backward_vars = + boost::get>(node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + } catch (boost::bad_get e) { + } + PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); + + for (size_t i = 0; i < backward_vars.size(); i += 2) { + auto &g_name = backward_vars[i + 1]; + size_t cur_device_id = GetAppropriateDeviceID({g_name}); + insert_delayed_op(g_name, static_cast(cur_device_id)); + } + } else if (op_dev_id == -2) { + // The Op on which the Op depends has not yet been generated. } - } else { - LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); - PADDLE_THROW( - "the distribute training related op should be in [split_byref, " - "concat]."); } - PADDLE_ENFORCE(op_dev_id != -1, - "can not find right place for distributed op: %s", - node->Op()->Type()); + PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size()); - CreateComputationalOp(result, node, op_dev_id); - return op_dev_id; + ResetState(); + return sorted_ops; +} + +void DistSSAGraphBuilder::Init() const { + MultiDevSSAGraphBuilderBase::Init(); + ResetState(); +} + +void DistSSAGraphBuilder::ResetState() const { + BalanceVarSSAGraphBuilder::ResetState(); + bcast_var_name_set_.clear(); + bcast_var_name_set_.resize(places_.size()); +} + +bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, + ir::Node *node) const { + bool insert_op = false; + if (OpHaveRole(*node, OpRole::kRPC)) { + int op_dev_id = CreateRPCOp(result, node); + PADDLE_ENFORCE(op_dev_id != -1, + "Can not schedule the RPC operator to the right place."); + if (node->Op()->Type() == "recv") { + auto recv_vars_attr = + boost::get>(node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient] + if (recv_vars_attr[0].find(".block") == std::string::npos) { + bcast_var_name_set_[op_dev_id].emplace(recv_vars_attr[0]); + } + } + insert_op = true; + need_broadcast_var_ = true; + } else if (OpHaveRole(*node, OpRole::kDist)) { + int op_dev_id = CreateDistTrainOp(result, node); + if (node->Op()->Type() == "concat") { + auto origin_param_name = node->Op()->OutputArgumentNames()[0]; + bcast_var_name_set_[op_dev_id].emplace(origin_param_name); + } + insert_op = true; + } else { + int op_dev_id = GetOpDeviceID(node); + if (op_dev_id != -1) { // This op only runs on one specific device. + CreateComputationalOp(result, node, op_dev_id); + for (ir::Node *n : node->outputs) { + sharded_var_device_.emplace(n->Name(), op_dev_id); + } + insert_op = true; + } + } + return insert_op; } void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { @@ -775,13 +768,11 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { } // Create RPC related op handles that connects its in ops and out ops. -int MultiDevSSAGraphBuilder::CreateRPCOp( - ir::Graph *result, ir::Node *node, - std::unordered_map *sharded_var_device) const { +int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = GetVarDeviceID(node->inputs[0]->Name(), *sharded_var_device); + op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by @@ -799,9 +790,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( VLOG(10) << "send grad " << input_var_names[0] << " origin " << send_param_grad[1] << " place: " << op_dev_id; for (auto &varname : input_var_names) { - sharded_var_device->emplace(varname, op_dev_id); + sharded_var_device_.emplace(varname, op_dev_id); } - sharded_var_device->emplace(send_param_grad[1], op_dev_id); + sharded_var_device_.emplace(send_param_grad[1], op_dev_id); } } else if (node->Op()->Type() == "recv") { std::vector output_var_names; @@ -811,7 +802,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( auto recv_param_grad = boost::get>( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); if (recv_param_grad.size() == 2U) { - op_dev_id = GetVarDeviceID(recv_param_grad[1], *sharded_var_device); + op_dev_id = GetVarDeviceID(recv_param_grad[1]); VLOG(10) << "recv param " << recv_param_grad[0] << " get grad place: " << recv_param_grad[1] << " place: " << op_dev_id; @@ -819,7 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( op_dev_id = GetAppropriateDeviceID(output_var_names); } for (auto &varname : output_var_names) { - sharded_var_device->emplace(varname, op_dev_id); + sharded_var_device_.emplace(varname, op_dev_id); } } else { // send_barrier, fetch_barrier will run on place 0; @@ -846,7 +837,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( for (ir::Node *output : node->outputs) { int outvar_dev_id = op_dev_id; if (node->Op()->Type() == "fetch_barrier") { - outvar_dev_id = GetVarDeviceID(output->Name(), *sharded_var_device); + outvar_dev_id = GetVarDeviceID(output->Name()); PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name()); } p = places_[outvar_dev_id]; @@ -863,29 +854,124 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( return op_dev_id; } -bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { - PADDLE_ENFORCE(all_vars_.count(og) != 0); - if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { - return true; +int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, + ir::Node *node) const { + int op_dev_id = -1; + std::vector input_var_names; + std::vector output_var_names; + for (ir::Node *input : node->inputs) { + input_var_names.push_back(input->Name()); } - return false; + for (ir::Node *output : node->outputs) { + output_var_names.push_back(output->Name()); + } + + if (node->Op()->Type() == "split_byref" || + node->Op()->Type() == "split_selected_rows" || + node->Op()->Type() == "split_ids") { + // TODO(paddle-dev): getting the first var is not safe. + op_dev_id = GetVarDeviceID(input_var_names[0]); + if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { + op_dev_id = GetAppropriateDeviceID(input_var_names); + for (auto &varname : input_var_names) { + sharded_var_device_.emplace(varname, op_dev_id); + } + } + for (auto &varname : output_var_names) { + sharded_var_device_.emplace(varname, op_dev_id); + } + } else if (node->Op()->Type() == "concat") { + op_dev_id = GetVarDeviceID(input_var_names[0]); + for (auto &varname : output_var_names) { + sharded_var_device_.emplace(varname, op_dev_id); + } + } else { + LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); + PADDLE_THROW( + "the distribute training related op should be in [split_byref, " + "concat]."); + } + + PADDLE_ENFORCE(op_dev_id != -1, + "can not find right place for distributed op: %s", + node->Op()->Type()); + + CreateComputationalOp(result, node, op_dev_id); + return op_dev_id; } -bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { - return boost::get( - node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - (static_cast(OpRole::kBackward) | - static_cast(OpRole::kLoss)) && - !loss_var_name_.empty(); // If loss_var is empty. This is test mode +void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, + const std::string &p_name, + const std::string &g_name) const { + size_t cur_device_id = 0; + switch (strategy_.reduce_) { + case BuildStrategy::ReduceStrategy::kReduce: + cur_device_id = GetAppropriateDeviceID({g_name}); + CreateReduceOp(result, g_name, cur_device_id); + sharded_var_device_.emplace(g_name, cur_device_id); + break; + case BuildStrategy::ReduceStrategy::kAllReduce: + if (IsSparseGradient(g_name)) { + CreateReduceOp(result, g_name, 0); + CreateBroadcastOp(result, g_name, 0); + } else { + CreateAllReduceOp(result, g_name); + } + break; + default: + LOG(FATAL) << "Unknown reduce strategy."; + break; + } +} + +void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { + if (need_broadcast_var_ || + (UseGPU() && + strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce)) { + if (strategy_.fuse_broadcast_op_) { + CreateFusedBroadcastOp(result, bcast_var_name_set_); + } else { + for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { + auto &to_bcast_set = bcast_var_name_set_[dev_id]; + for (auto &bcast_name : to_bcast_set) { + CreateBroadcastOp(result, bcast_name, dev_id); + } + } + } + } +} + +std::unordered_set &MultiDevSSAGraphBuilder() { + static std::unordered_set regs; + return regs; } + +static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) { + MultiDevSSAGraphBuilder().insert(builder_mode); + return 0; +} + } // namespace details } // namespace framework } // namespace paddle -REGISTER_PASS(multi_devices_pass, - paddle::framework::details::MultiDevSSAGraphBuilder) - .RequirePassAttr(paddle::framework::details::kLossVarName) - .RequirePassAttr(paddle::framework::details::kPlaces) - .RequirePassAttr(paddle::framework::details::kLocalScopes) - .RequirePassAttr(paddle::framework::details::kStrategy) - .RequirePassAttr(paddle::framework::details::kNRanks); +#define REGISTER_MULTI_DEVICES_PASS(pass_name, pass_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + _reg_ssa_graph_builder_##pass_name, \ + "REGISTER_MULTI_DEVICES_PASS must be called in global namespace."); \ + int _reg_ssa_graph_builder_entry_##pass_name = \ + paddle::framework::details::MultiDevSSAGraphBuilderRegister(#pass_name); \ + REGISTER_PASS(pass_name, pass_class) \ + .RequirePassAttr(paddle::framework::details::kLossVarName) \ + .RequirePassAttr(paddle::framework::details::kPlaces) \ + .RequirePassAttr(paddle::framework::details::kLocalScopes) \ + .RequirePassAttr(paddle::framework::details::kStrategy) \ + .RequirePassAttr(paddle::framework::details::kNRanks) + +REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass, + paddle::framework::details::ReduceSSAGraphBuilder); +REGISTER_MULTI_DEVICES_PASS( + allreduce_mode_multi_devices_pass, + paddle::framework::details::AllReduceSSAGraphBuilder); +REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass, + paddle::framework::details::DistSSAGraphBuilder); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 7029e9dc1..6d4386538 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include #include @@ -30,78 +31,70 @@ namespace framework { class Scope; namespace details { -class MultiDevSSAGraphBuilder : public ir::Pass { +constexpr char kLossVarName[] = "loss_var_name"; +constexpr char kPlaces[] = "places"; +constexpr char kLocalScopes[] = "local_scopes"; +constexpr char kStrategy[] = "strategy"; +constexpr char kNRanks[] = "nranks"; + +class MultiDevSSAGraphBuilderBase : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; - private: - void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, - size_t device_id) const; - void Init() const; + virtual void Init() const; -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - mutable platform::NCCLContextMap *nccl_ctxs_; -#endif + virtual std::vector SortOperations(const ir::Graph &graph) const; - int GetVarDeviceID( - const std::string &varname, - const std::unordered_map &sharded_var_device) const; + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, + const std::string &g_name) const = 0; - bool IsScaleLossOp(ir::Node *node) const; + virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0; + + virtual void InsertPostprocessOps(ir::Graph *result) const = 0; - int CreateRPCOp( - ir::Graph *result, ir::Node *node, - std::unordered_map *sharded_var_device) const; - int CreateDistTrainOp( - ir::Graph *result, ir::Node *node, - std::unordered_map *sharded_var_device) const; + bool UseGPU() const; + + bool NeedCollectiveOps() const; + + bool IsScaleLossOp(ir::Node *node) const; void CreateComputationalOps(ir::Graph *result, ir::Node *node, size_t num_places) const; void CreateScaleLossGradOp(ir::Graph *result, const std::string &loss_grad_name, - ir::Node *out_var_node, + ir::Node *out_var_node, size_t loss_scale, proto::VarType::Type dtype) const; VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const; + void CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const; - int GetOpDeviceID( - ir::Node *node, - const std::unordered_map &sharded_var_device) const; - - void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; + bool IsSparseGradient(const std::string &og) const; - void InsertDataBalanceOp(ir::Graph *result, - const std::vector &datas) const; + void CreateAllReduceOp(ir::Graph *result, const std::string &og) const; void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const; + void InsertScaleLossGradOp(ir::Graph *result, const ir::Node *node) const; + void CreateFusedBroadcastOp( ir::Graph *result, const std::vector> &bcast_varnames) const; - bool IsSparseGradient(const std::string &og) const; - - size_t GetAppropriateDeviceID( - const std::vector &var_names) const; - void SetCommunicationContext(OpHandleBase *op_handle, const platform::Place &p) const; - std::vector SortForReduceMode( - const std::vector &) const; + void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, + size_t device_id) const; - int GetOpDeviceID( - ir::Node *node, - const std::unordered_map &shared_var_device, - std::unordered_map> *delay_ops) - const; +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + mutable platform::NCCLContextMap *nccl_ctxs_; +#endif mutable std::string loss_var_name_; mutable std::vector places_; @@ -109,8 +102,83 @@ class MultiDevSSAGraphBuilder : public ir::Pass { mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; +}; + +class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { + protected: + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, + const std::string &g_name) const; + + virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { + return false; + } + + virtual void InsertPostprocessOps(ir::Graph *result) const {} +}; + +class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { + protected: + int GetVarDeviceID(const std::string &varname) const; + + int GetOpDeviceID(ir::Node *node) const; + + size_t GetAppropriateDeviceID( + const std::vector &var_names) const; + + virtual void ResetState() const; + + mutable std::unordered_map sharded_var_device_; mutable std::vector balance_vars_; }; + +class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder { + protected: + virtual void Init() const; + + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, + const std::string &g_name) const; + + virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; + + virtual void InsertPostprocessOps(ir::Graph *result) const; + + virtual std::vector SortOperations(const ir::Graph &graph) const; + + virtual void ResetState() const; + + int GetOpDeviceID(ir::Node *node, + std::unordered_map> + *delay_ops) const; + + std::vector SortForReduceMode( + const std::vector &topo_ops) const; + + mutable std::vector> bcast_var_name_set_; +}; + +class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { + protected: + virtual void Init() const; + + virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; + + virtual void InsertPostprocessOps(ir::Graph *result) const; + + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, + const std::string &g_name) const; + + virtual void ResetState() const; + + int CreateRPCOp(ir::Graph *result, ir::Node *node) const; + + int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; + + mutable std::vector> bcast_var_name_set_; + mutable bool need_broadcast_var_{false}; +}; + +std::unordered_set &MultiDevSSAGraphBuilder(); + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3b81d59ad..dce755c91 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -946,13 +946,6 @@ All parameter, weight, gradient are variables in Paddle. R"DOC(The type is STR, debug_graphviz_path indicate the path that writing the SSA Graph to file in the form of graphviz, you. It is useful for debugging. Default "")DOC") - .def_property( - "enable_data_balance", - [](const BuildStrategy &self) { return self.enable_data_balance_; }, - [](BuildStrategy &self, bool b) { - PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); - self.enable_data_balance_ = b; - }) // FIXME(chengudo): enable_data_balance seems not important .def_property( "enable_sequential_execution", [](const BuildStrategy &self) { @@ -1007,6 +1000,10 @@ All parameter, weight, gradient are variables in Paddle. "memory_optimize", [](const BuildStrategy &self) { return self.memory_optimize_; }, [](BuildStrategy &self, bool b) { self.memory_optimize_ = b; }) + .def_property( + "is_distribution", + [](const BuildStrategy &self) { return self.is_distribution_; }, + [](BuildStrategy &self, bool b) { self.is_distribution_ = b; }) .def_property( "memory_early_delete", [](const BuildStrategy &self) { return self.memory_early_delete_; }, diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index c97a93ec3..3b066eda1 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -29,6 +29,15 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy BuildStrategy = core.ParallelExecutor.BuildStrategy +def _is_pserver_mode(main_program): + main = main_program if main_program \ + else framework.default_main_program() + for op in main.global_block().ops: + if op.type in ["send", "recv"]: + return True + return False + + class ParallelExecutor(object): """ ParallelExecutor is designed for data parallelism, which focuses on distributing @@ -128,6 +137,11 @@ class ParallelExecutor(object): build_strategy = BuildStrategy() build_strategy.num_trainers = num_trainers build_strategy.trainer_id = trainer_id + # FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, + # num_trainers is 1, so the current fields of build_strategy doesn't tell if + # it's distributed model. + build_strategy.is_distribution = _is_pserver_mode( + main_program) or num_trainers > 1 # step4: get main_program, scope, local_scopes main = main_program if main_program \ diff --git a/python/paddle/fluid/tests/unittests/test_reader_reset.py b/python/paddle/fluid/tests/unittests/test_reader_reset.py index e97a05b6f..7eeffa103 100644 --- a/python/paddle/fluid/tests/unittests/test_reader_reset.py +++ b/python/paddle/fluid/tests/unittests/test_reader_reset.py @@ -75,8 +75,6 @@ class TestReaderReset(unittest.TestCase): exe.run(startup_prog) build_strategy = fluid.BuildStrategy() - if with_double_buffer: - build_strategy.enable_data_balance = True exec_strategy = fluid.ExecutionStrategy() parallel_exe = fluid.ParallelExecutor( use_cuda=self.use_cuda, -- GitLab