From ce08dc8751b5f605ce6aece70ce6f16af72f4759 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 10 Apr 2018 20:40:43 +0800 Subject: [PATCH] have stream removed error --- .../details/multi_devices_graph_builder.cc | 34 ++++++++----------- .../details/multi_devices_graph_builder.h | 2 +- .../fluid/framework/details/send_op_handle.cc | 10 +++--- .../fluid/framework/details/send_op_handle.h | 4 +-- python/paddle/fluid/distribute_transpiler.py | 1 + python/paddle/fluid/parallel_executor.py | 4 ++- 6 files changed, 24 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 8a53270110..0ebcd627bd 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -57,8 +57,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i) const { + const size_t &i, + bool create_output) const { auto *op_handle = result->ops_.back().get(); + op_handle->dev_ctxes_[p] = const_cast( + platform::DeviceContextPool::Instance().Get(p)); auto var_names = op->InputArgumentNames(); @@ -66,10 +69,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); op_handle->AddInput(var); } - var_names = op->OutputArgumentNames(); + if (create_output) { + var_names = op->OutputArgumentNames(); - for (auto &each_var_name : var_names) { - CreateOpOutput(result, op_handle, each_var_name, p, i); + for (auto &each_var_name : var_names) { + CreateOpOutput(result, op_handle, each_var_name, p, i); + } } } @@ -100,9 +105,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( if (!is_forwarding && op->Type() == "send") { auto &p = places_[0]; auto *s = local_scopes_[0]; - size_t i = 0; - result.ops_.emplace_back(new SendOpHandle(*op, s, p)); - CreateOpHandleIOs(&result, op, p, i); + // FIXME(wuyi): send op always copy from GPU 0 + result.ops_.emplace_back(new SendOpHandle(*op, s)); + // Create inputs for output on original place and no ssa output + // is created for send op. + CreateOpHandleIOs(&result, op, p, 0, false); continue; } @@ -112,23 +119,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); auto *op_handle = result.ops_.back().get(); - op_handle->dev_ctxes_[p] = const_cast( - platform::DeviceContextPool::Instance().Get(p)); - CreateOpHandleIOs(&result, op, p, i); - // auto var_names = op->InputArgumentNames(); - // for (auto &each_var_name : var_names) { - // VarHandle *var = - // CreateOrGetLatestVarHandle(&result, each_var_name, p, i); - // op_handle->AddInput(var); - // } auto var_names = op->OutputArgumentNames(); - // for (auto &each_var_name : var_names) { - // CreateOpOutput(&result, op_handle, each_var_name, p, i); - // } - if (is_forwarding) { if (var_names.size() == 1 && var_names[0] == loss_var_name_) { // Insert ScaleCost OpHandle diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index de34caab1b..137c817fde 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i) const; + const size_t &i, bool create_output = true) const; private: std::string loss_var_name_; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index ae5637b804..caacfa6b1e 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -19,11 +19,9 @@ namespace framework { namespace details { SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope, - const platform::Place &place) + const Scope *local_scope) : op_(framework::OpRegistry::CreateOp(op_desc)), - local_scope_(local_scope), - place_(place) {} + local_scope_(local_scope) {} void SendOpHandle::RunImpl() { // Wait input done @@ -31,8 +29,8 @@ void SendOpHandle::RunImpl() { auto &p = static_cast(in)->place_; in->generated_op_->Wait(dev_ctxes_[p]); } - - op_->Run(*local_scope_, place_); + platform::CPUPlace cpu; + op_->Run(*local_scope_, cpu); } std::string SendOpHandle::Name() const { return "send"; } diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index e7857c1f23..8a7b62ba1c 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -31,10 +31,8 @@ namespace details { struct SendOpHandle : public OpHandleBase { std::unique_ptr op_; const Scope* local_scope_; - const platform::Place& place_; - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const platform::Place& place); + SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope); std::string Name() const override; diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 0ec3ebc7e3..e18ace844e 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -255,6 +255,7 @@ class DistributeTranspiler: def get_trainer_program(self): # remove optimize ops and add a send op to main_program self.program.global_block().delete_ops(self.optimize_ops) + self.program.sync_with_cpp() # FIXME(typhoonzero): serialize once will fix error occurs when clone. self.program.__str__() return self.program diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index a23cc9b772..c709f364c1 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -101,7 +101,9 @@ class ParallelExecutor(object): self.persistable_vars = [ v.name - for v in filter(lambda var: var.persistable, main.list_vars()) + for v in filter(lambda var: \ + var.persistable and var.type != core.VarDesc.VarType.RAW, + main.list_vars()) ] self.executor = core.ParallelExecutor( -- GitLab