From 16a9dfe4805fa88670338b52bf898f60043fc16f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 11 Apr 2018 15:12:58 +0800 Subject: [PATCH] finish --- .../details/multi_devices_graph_builder.cc | 16 +++++++--------- .../details/multi_devices_graph_builder.h | 2 +- paddle/fluid/framework/details/send_op_handle.cc | 12 ++++++++---- paddle/fluid/framework/details/send_op_handle.h | 4 +++- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0ebcd627bd..e0dd9e6068 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -57,8 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i, - bool create_output) const { + const size_t &i) const { auto *op_handle = result->ops_.back().get(); op_handle->dev_ctxes_[p] = const_cast( platform::DeviceContextPool::Instance().Get(p)); @@ -69,12 +68,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); op_handle->AddInput(var); } - if (create_output) { - var_names = op->OutputArgumentNames(); - for (auto &each_var_name : var_names) { - CreateOpOutput(result, op_handle, each_var_name, p, i); - } + var_names = op->OutputArgumentNames(); + + for (auto &each_var_name : var_names) { + CreateOpOutput(result, op_handle, each_var_name, p, i); } } @@ -106,10 +104,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( auto &p = places_[0]; auto *s = local_scopes_[0]; // FIXME(wuyi): send op always copy from GPU 0 - result.ops_.emplace_back(new SendOpHandle(*op, s)); + result.ops_.emplace_back(new SendOpHandle(*op, s, p)); // Create inputs for output on original place and no ssa output // is created for send op. - CreateOpHandleIOs(&result, op, p, 0, false); + CreateOpHandleIOs(&result, op, p, 0); continue; } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 137c817fde..de34caab1b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i, bool create_output = true) const; + const size_t &i) const; private: std::string loss_var_name_; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index caacfa6b1e..d181607e86 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -19,18 +19,22 @@ namespace framework { namespace details { SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope) + const Scope *local_scope, + const platform::Place &place) : op_(framework::OpRegistry::CreateOp(op_desc)), - local_scope_(local_scope) {} + local_scope_(local_scope), + place_(place) {} void SendOpHandle::RunImpl() { // Wait input done for (auto *in : inputs_) { auto &p = static_cast(in)->place_; + if (in->DebugString() == "dummy") { // HACK + continue; + } in->generated_op_->Wait(dev_ctxes_[p]); } - platform::CPUPlace cpu; - op_->Run(*local_scope_, cpu); + op_->Run(*local_scope_, place_); } std::string SendOpHandle::Name() const { return "send"; } diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index 8a7b62ba1c..e7857c1f23 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -31,8 +31,10 @@ namespace details { struct SendOpHandle : public OpHandleBase { std::unique_ptr op_; const Scope* local_scope_; + const platform::Place& place_; - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope); + SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place); std::string Name() const override; -- GitLab