From 0aa9546eed6afee30e8e168a509d3d32810d6b2f Mon Sep 17 00:00:00 2001 From: Yancey Date: Fri, 8 Jun 2018 09:28:43 +0800 Subject: [PATCH] fix dist train error (#11281) * fix dist train error * update by comment --- .../framework/details/multi_devices_graph_builder.cc | 5 ++--- paddle/fluid/framework/details/rpc_op_handle.cc | 8 ++++---- paddle/fluid/framework/details/rpc_op_handle.h | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0c4d369e8..97242ebf2 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -464,9 +464,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { - auto &p = places_[0]; - auto *s = local_scopes_[0]; - result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); + result->ops_.emplace_back( + new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); if (op.Type() == "send_barrier") { ConnectOp(result, result->ops_.back().get(), "send_vars"); diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index 7f4da4c01..586465f99 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -19,12 +19,12 @@ namespace framework { namespace details { RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope, const platform::Place &place, - const std::string &name) + const Scope *local_scope, const std::string &name, + const platform::Place &place) : op_(framework::OpRegistry::CreateOp(op_desc)), local_scope_(local_scope), - place_(place), - name_(name) {} + name_(name), + place_(place) {} void RPCOpHandle::RunImpl() { // TODO(wuyi): need further analysis whether wait VarDummyHandle. diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h index d28b77217..ae38c7fe1 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.h +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -29,7 +29,7 @@ namespace details { struct RPCOpHandle : public OpHandleBase { RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const platform::Place& place, const std::string& name); + const std::string& name, const platform::Place& place); std::string Name() const override; @@ -43,8 +43,8 @@ struct RPCOpHandle : public OpHandleBase { private: std::unique_ptr op_; const Scope* local_scope_; - const platform::Place& place_; const std::string name_; + platform::Place place_; }; } // namespace details -- GitLab