diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index ae17b8df7553aa719cf38338eaca36ce791c3a09..7d2a081e3b1d0ae3501730bef052d483b23d5a62 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -119,8 +119,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Verify that the graph is correct for multi-device executor. auto multi_devices_pass = AppendPass("multi_devices_check_pass"); - multi_devices_pass->Set(kEnablePG, - new bool(strategy.enable_parallel_graph_)); if (SeqOnlyAllReduceOps(strategy)) { AppendPass("all_reduce_deps_pass"); @@ -194,8 +192,6 @@ std::unique_ptr BuildStrategy::Apply( &local_scopes); pass->Erase(kNRanks); pass->Set(kNRanks, new size_t(nranks)); - pass->Erase(kEnablePG); - pass->Set(kEnablePG, new bool(true)); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index dcceaa93d9e00cf6b7a0f5db6bfbd96ba895b699..4f856c6d9eb842add9eb5e6fe30639dc4170358d 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -201,7 +201,7 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( auto &g_name = backward_vars[i + 1]; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; - InsertCollectiveOp(&result, node, p_name, g_name); + InsertCollectiveOp(&result, p_name, g_name); } } catch (boost::bad_get e) { } @@ -386,7 +386,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, } void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( - ir::Graph *result, ir::Node *node, const std::string &og) const { + ir::Graph *result, const std::string &og) const { OpHandleBase *op_handle = nullptr; auto append_allreduce_op = [&]( @@ -510,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient( } void AllReduceSSAGraphBuilder::InsertCollectiveOp( - ir::Graph *result, ir::Node *node, const std::string &p_name, + ir::Graph *result, const std::string &p_name, const std::string &g_name) const { if (IsSparseGradient(g_name)) { CreateReduceOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0); } else { - CreateAllReduceOp(result, node, g_name); + CreateAllReduceOp(result, g_name); } } @@ -589,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const { } void ReduceSSAGraphBuilder::InsertCollectiveOp( - ir::Graph *result, ir::Node *node, const std::string &p_name, + ir::Graph *result, const std::string &p_name, const std::string &g_name) const { size_t cur_device_id = GetAppropriateDeviceID({g_name}); CreateReduceOp(result, g_name, cur_device_id); @@ -909,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, return op_dev_id; } -void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node, +void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const { size_t cur_device_id = 0; @@ -924,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node, CreateReduceOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0); } else { - CreateAllReduceOp(result, node, g_name); + CreateAllReduceOp(result, g_name); } break; default: @@ -975,8 +975,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) { .RequirePassAttr(paddle::framework::details::kPlaces) \ .RequirePassAttr(paddle::framework::details::kLocalScopes) \ .RequirePassAttr(paddle::framework::details::kStrategy) \ - .RequirePassAttr(paddle::framework::details::kNRanks) \ - .RequirePassAttr(paddle::framework::details::kEnablePG) + .RequirePassAttr(paddle::framework::details::kNRanks) REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass, paddle::framework::details::ReduceSSAGraphBuilder); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index e3c1fe711c1cbcb2edfca3d747835fb8705ead15..6d4386538ea7d0cc318647c92282af9d598fa699 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -36,7 +36,6 @@ constexpr char kPlaces[] = "places"; constexpr char kLocalScopes[] = "local_scopes"; constexpr char kStrategy[] = "strategy"; constexpr char kNRanks[] = "nranks"; -constexpr char kEnablePG[] = "enable_pg"; class MultiDevSSAGraphBuilderBase : public ir::Pass { protected: @@ -47,8 +46,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { virtual std::vector SortOperations(const ir::Graph &graph) const; - virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node, - const std::string &p_name, + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const = 0; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0; @@ -77,8 +75,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { bool IsSparseGradient(const std::string &og) const; - void CreateAllReduceOp(ir::Graph *result, ir::Node *node, - const std::string &og) const; + void CreateAllReduceOp(ir::Graph *result, const std::string &og) const; void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const; @@ -109,8 +106,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { protected: - virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node, - const std::string &p_name, + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { @@ -139,8 +135,7 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder { protected: virtual void Init() const; - virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node, - const std::string &p_name, + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; @@ -169,8 +164,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { virtual void InsertPostprocessOps(ir::Graph *result) const; - virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node, - const std::string &p_name, + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const; virtual void ResetState() const; diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h index e3abd237538dcd86bedca36de08e406c7bab68a1..c31bba17f6840019660376991145028b9c254933 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h @@ -45,8 +45,6 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { FeedFetchList Run(const std::vector &fetch_tensors) override; private: - // std::vector> SeparateMultiDevicesGraph(); - ExecutionStrategy strategy_; std::vector local_scopes_; std::unique_ptr<::ThreadPool> pool_{nullptr}; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index c0edad6f7400d1d83214a031fd815e5daf3b17a2..5bf414324f50700e4d812e325d71644a11dbd34c 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } } } - for (auto &var : graph_->Get(details::kGraphDepVars)) { InsertPendingVar(&pending_vars, ready_vars.get(), var); } + for (auto &op : ir::FilterByNodeWrapper(*graph_)) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 0d66043a739cf25df59b56c1bdd9b8096e7ffa57..40baae2ffdd6f2d31901bf2c44a3e3fb6d8ad329 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -176,12 +176,6 @@ class Graph { return ret; } - void RemoveNode(ir::Node *node) { - PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); - node_set_.erase(node); - nodes_.erase(node); - } - // NOTE low performance, but simple and secure. Node *RetrieveNode(int id) { for (auto &node : nodes_) { @@ -200,10 +194,6 @@ class Graph { return node; } - bool ContainNode(ir::Node *node) { - return node_set_.find(node) != node_set_.end(); - } - void ResolveHazard( const std::map> &var_nodes); diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 3b95aa7b86fad36889a646cd521c380d5a58a675..214de9ec7d85aee6021b18866295777e317aa79d 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -64,9 +64,7 @@ template std::vector FilterByNodeWrapper(const Graph &graph) { std::vector ret; for (ir::Node *n : graph.Nodes()) { - if (n->IsWrappedBy()) { - ret.push_back(&n->Wrapper()); - } + if (n->IsWrappedBy()) ret.push_back(&n->Wrapper()); } return ret; } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 91d1a9988657ac674dbf5cbfdcbf97e4e3c584ce..dca1a4e53016185d52530a7d68bb620662bde54c 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -478,12 +478,11 @@ bool ParallelExecutor::EnableParallelGraphExecution( } } - // if (!member_->use_all_reduce_ || !member_->use_cuda_) - if (!member_->use_all_reduce_) enable_parallel_graph = false; + if (!member_->use_all_reduce_ || !member_->use_cuda_) - if (build_strategy.enable_sequential_execution_ || - exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) - enable_parallel_graph = false; + if (build_strategy.enable_sequential_execution_ || + exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) + enable_parallel_graph = false; return enable_parallel_graph; }