From e335f01826143452c8733495f02a60f7d668d3c7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Mar 2018 19:20:37 +0800 Subject: [PATCH] Add more logs --- paddle/fluid/framework/parallel_executor.cc | 54 ++++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 8ee2e57324..82df86bebd 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -125,30 +125,6 @@ struct OpHandle { virtual void RunImpl() = 0; }; -struct ComputationOpHandle : public OpHandle { - std::unique_ptr op_; - Scope *scope_; - platform::Place place_; - - explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, - platform::Place place) - : op_(framework::OpRegistry::CreateOp(op_desc)), - scope_(scope), - place_(place) {} - - protected: - void RunImpl() override { - auto *cur_ctx = dev_ctx_[place_]; - for (auto *in : inputs_) { - if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { - in->generated_op_->Wait(cur_ctx); - } - } - - op_->Run(*scope_, place_); - } -}; - struct ScaleLossGradOpHandle : public OpHandle { float coeff_; Scope *scope_; @@ -396,6 +372,36 @@ struct NCCLAllReduceOpHandle : public OpHandle { } }; +struct ComputationOpHandle : public OpHandle { + std::unique_ptr op_; + Scope *scope_; + platform::Place place_; + + explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, + platform::Place place) + : op_(framework::OpRegistry::CreateOp(op_desc)), + scope_(scope), + place_(place) {} + + protected: + void RunImpl() override { + auto *cur_ctx = dev_ctx_[place_]; + for (auto *in : inputs_) { + bool need_wait = + in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx; + if (dynamic_cast(in->generated_op_)) { + VLOG(3) << "Input is nccl all reduce, need to wait" << need_wait; + } + + if (need_wait) { + in->generated_op_->Wait(cur_ctx); + } + } + + op_->Run(*scope_, place_); + } +}; + ParallelExecutor::ParallelExecutor( size_t num_threads, const std::vector &places, const std::unordered_set ¶ms, -- GitLab