diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0ebcd627bded9d91d8d3aca3be5400c2d7bb53fa..e0dd9e6068174a4b0348d503f4082bee6ff68dac 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -57,8 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i, - bool create_output) const { + const size_t &i) const { auto *op_handle = result->ops_.back().get(); op_handle->dev_ctxes_[p] = const_cast( platform::DeviceContextPool::Instance().Get(p)); @@ -69,12 +68,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); op_handle->AddInput(var); } - if (create_output) { - var_names = op->OutputArgumentNames(); - for (auto &each_var_name : var_names) { - CreateOpOutput(result, op_handle, each_var_name, p, i); - } + var_names = op->OutputArgumentNames(); + + for (auto &each_var_name : var_names) { + CreateOpOutput(result, op_handle, each_var_name, p, i); } } @@ -106,10 +104,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( auto &p = places_[0]; auto *s = local_scopes_[0]; // FIXME(wuyi): send op always copy from GPU 0 - result.ops_.emplace_back(new SendOpHandle(*op, s)); + result.ops_.emplace_back(new SendOpHandle(*op, s, p)); // Create inputs for output on original place and no ssa output // is created for send op. - CreateOpHandleIOs(&result, op, p, 0, false); + CreateOpHandleIOs(&result, op, p, 0); continue; } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 137c817fde0fa655e9071f3872b39c120afd8234..de34caab1be85eecb741a5003f026eb982e178ea 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i, bool create_output = true) const; + const size_t &i) const; private: std::string loss_var_name_; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index caacfa6b1ee1917743f566de092ba484c67b7f07..d181607e86372f4872c38bc35db786ac142ccc65 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -19,18 +19,22 @@ namespace framework { namespace details { SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope) + const Scope *local_scope, + const platform::Place &place) : op_(framework::OpRegistry::CreateOp(op_desc)), - local_scope_(local_scope) {} + local_scope_(local_scope), + place_(place) {} void SendOpHandle::RunImpl() { // Wait input done for (auto *in : inputs_) { auto &p = static_cast(in)->place_; + if (in->DebugString() == "dummy") { // HACK + continue; + } in->generated_op_->Wait(dev_ctxes_[p]); } - platform::CPUPlace cpu; - op_->Run(*local_scope_, cpu); + op_->Run(*local_scope_, place_); } std::string SendOpHandle::Name() const { return "send"; } diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index 8a7b62ba1c4a0c50e556719834d479a7cfdd2421..e7857c1f234fc4617462b8b065cfc4ea68e8c3aa 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -31,8 +31,10 @@ namespace details { struct SendOpHandle : public OpHandleBase { std::unique_ptr op_; const Scope* local_scope_; + const platform::Place& place_; - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope); + SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place); std::string Name() const override;