From d5090c892d609bf1d394d3c755cc4bafb80ba6f7 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 19 Feb 2019 15:22:25 +0800 Subject: [PATCH] polish code test=develop --- paddle/fluid/framework/details/build_strategy.cc | 2 +- .../details/multi_devices_graph_pass.cc | 16 +++++++--------- .../details/parallel_ssa_graph_executor.cc | 3 ++- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 45c2c734152..3a5e41ef3ca 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -34,7 +34,7 @@ namespace details { static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { // Should fix the allreduce op order if scheduling // them in multiple threads or processes to avoid hang. - // NOTE: ParallelExecutor would execute this pass on each graph, so + // NOTE: ParallelGraph would execute this pass on each graph, so // don't need to append it here. return (!strategy.enable_sequential_execution_ && strategy.num_trainers_ > 1) && diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 27bc7718147..3c0a8d7020a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -389,8 +389,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( OpHandleBase *op_handle = nullptr; auto append_allreduce_op = [&]( - std::vector &scopes, - std::vector &places) -> OpHandleBase * { + const std::vector &scopes, + const std::vector &places) -> OpHandleBase * { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), @@ -407,13 +407,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( op_handle = append_allreduce_op(local_scopes_, places_); for (size_t i = 0; i < places_.size(); ++i) { - auto p = places_[i]; - std::vector ss{local_scopes_[i]}; - std::vector ps{p}; - if (strategy_.enable_parallel_graph_) - op_handle = append_allreduce_op(ss, ps); + if (strategy_.enable_parallel_graph_) { + op_handle = append_allreduce_op({local_scopes_[i]}, {places_[i]}); + } - SetCommunicationContext(op_handle, p); + SetCommunicationContext(op_handle, places_[i]); auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); @@ -421,7 +419,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( auto var = new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), - vars.size(), i, og, p); + vars.size(), i, og, places_[i]); vars.emplace_back(var); op_handle->AddOutput(var); } diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index c36618016be..3740b795fa4 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -32,8 +32,9 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph( g->Set(kGraphDepVars, new GraphDepVars); g->Set(kGraphOps, new GraphOps); } + auto op_handles = ir::FilterByNodeWrapper(*graph); - for (auto &op : graph->Get(kGraphOps)) { + for (auto &op : op_handles) { auto &dev_ctx = op->DeviceContext(); auto &p = dev_ctx.begin()->first; int dev_id = boost::get(p).device; -- GitLab