From 830532213a120ce9a5645cb3bba4797b4447b50f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 9 May 2018 14:44:02 +0800 Subject: [PATCH] extract method from broadcast::RunImpl --- .../framework/details/broadcast_op_handle.cc | 56 ++++++++++++------- .../framework/details/broadcast_op_handle.h | 3 + .../fluid/framework/details/op_handle_base.h | 8 ++- .../framework/details/reduce_op_handle.cc | 20 ++++--- .../fluid/framework/details/send_op_handle.cc | 2 +- 5 files changed, 56 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index f176e4e159..d5ca061944 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -48,29 +48,9 @@ void BroadcastOpHandle::RunImpl() { auto *in_var = var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_); PADDLE_ENFORCE_NOT_NULL(in_var); - Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); - // NOTE: The tensors' Place of input and output must be all on GPU or all on - // CPU. - for (auto *out_var_handle : out_var_handles) { - if (out_var_handle->IsTheSameVar(*in_var_handle)) { - continue; - } - auto t_out_p = out_var_handle->place_; - auto *out_var = var_scopes.at(out_var_handle->scope_idx_) - ->FindVar(out_var_handle->name_); - PADDLE_ENFORCE_NOT_NULL(out_var); - if (platform::is_gpu_place(in_tensor.place())) { - PADDLE_ENFORCE(platform::is_gpu_place(t_out_p), - "Places of input and output must be all on GPU."); - } else { - t_out_p = platform::CPUPlace(); - } - VariableVisitor::ShareDimsAndLoD(*in_var, out_var); - VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p, - in_tensor.type()); - } + InitOutputValue(*in_var_handle, out_var_handles); if (platform::is_cpu_place(in_tensor.place())) { for (auto *out_var_handle : out_var_handles) { @@ -145,6 +125,40 @@ void BroadcastOpHandle::RunImpl() { } } +void BroadcastOpHandle::InitOutputValue( + const VarHandle &in_var_handle, + const std::vector &out_var_handles) const { + std::vector var_scopes; + for (auto *s : local_scopes_) { + var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); + } + auto *in_var = + var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); + + Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); + + // NOTE: The tensors' Place of input and output must be all on GPU or all on + // CPU. + for (auto *out_var_handle : out_var_handles) { + if (out_var_handle->IsTheSameVar(in_var_handle)) { + continue; + } + auto t_out_p = out_var_handle->place_; + auto *out_var = var_scopes.at(out_var_handle->scope_idx_) + ->FindVar(out_var_handle->name_); + PADDLE_ENFORCE_NOT_NULL(out_var); + if (is_gpu_place(in_tensor.place())) { + PADDLE_ENFORCE(platform::is_gpu_place(t_out_p), + "Places of input and output must be all on GPU."); + } else { + t_out_p = platform::CPUPlace(); + } + VariableVisitor::ShareDimsAndLoD(*in_var, out_var); + VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p, + in_tensor.type()); + } +} + std::string BroadcastOpHandle::Name() const { return "broadcast"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 48e356af4b..629aa00cb8 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -64,6 +64,9 @@ struct BroadcastOpHandle : public OpHandleBase { #ifdef PADDLE_WITH_CUDA const platform::NCCLContextMap *nccl_ctxs_; #endif + + void InitOutputValue(const VarHandle &in_var_handle, + const std::vector &out_var_handles) const; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 3d4a093125..fe1735d05d 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -44,10 +44,14 @@ class OpHandleBase { void AddOutput(VarHandleBase *out); - // Wait inputs are generated, this Wait is asynchronous operation. + // This method adds the wait events of all the input on all the device + // context. + // NODE: This Wait is asynchronous operation. virtual void WaitInputVarGenerated(); - // Wait inputs are generated, this Wait is asynchronous operation. + // This method adds the wait events of all the input on the specified device + // context. + // NODE: This Wait is asynchronous operation. virtual void WaitInputVarGenerated(const platform::Place &place); virtual bool NeedWait(VarHandleBase *in_var); diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index f653064ade..7160e346da 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -80,19 +80,21 @@ void ReduceOpHandle::RunImpl() { } if (pre_in_var->IsType()) { - std::vector in_selected_rows = - GetInputValues(in_var_handles, var_scopes); - - GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p, - out_var->GetMutable()); + this->RunAndRecordEvent([&] { + std::vector in_selected_rows = + GetInputValues(in_var_handles, var_scopes); + GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p, + out_var->GetMutable()); + }); } else { std::vector lod_tensors = GetInputValues(in_var_handles, var_scopes); - if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) { - ReduceLoDTensor func(lod_tensors, - out_var->GetMutable()); - VisitDataType(ToDataType(lod_tensors[0]->type()), func); + this->RunAndRecordEvent([&] { + ReduceLoDTensor func(lod_tensors, + out_var->GetMutable()); + VisitDataType(ToDataType(lod_tensors[0]->type()), func); + }); } else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) { #ifdef PADDLE_WITH_CUDA auto pre_in = pre_in_var->Get(); diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index ee4beb5f9b..01f3a9df76 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -27,7 +27,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, void SendOpHandle::RunImpl() { // Wait input done - WaitInputVarGenerated(); + WaitInputVarGenerated(place_); auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead // lock. -- GitLab