diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0c4d369e889cf2cca7722dac14a5268fdacabeb4..97242ebf2af304e1498e2ef37cd87d1ef07fb6df 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -464,9 +464,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { - auto &p = places_[0]; - auto *s = local_scopes_[0]; - result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); + result->ops_.emplace_back( + new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); if (op.Type() == "send_barrier") { ConnectOp(result, result->ops_.back().get(), "send_vars"); diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index 7f4da4c01de1010467d839ee5490c5e0d02d8c24..586465f99fd94117c821be2952bffda385fbcf75 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -19,12 +19,12 @@ namespace framework { namespace details { RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope, const platform::Place &place, - const std::string &name) + const Scope *local_scope, const std::string &name, + const platform::Place &place) : op_(framework::OpRegistry::CreateOp(op_desc)), local_scope_(local_scope), - place_(place), - name_(name) {} + name_(name), + place_(place) {} void RPCOpHandle::RunImpl() { // TODO(wuyi): need further analysis whether wait VarDummyHandle. diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h index d28b7721720d808a8d81701c3811eae16121fb41..ae38c7fe19e102a330455d89a1068414a7835fab 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.h +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -29,7 +29,7 @@ namespace details { struct RPCOpHandle : public OpHandleBase { RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const platform::Place& place, const std::string& name); + const std::string& name, const platform::Place& place); std::string Name() const override; @@ -43,8 +43,8 @@ struct RPCOpHandle : public OpHandleBase { private: std::unique_ptr op_; const Scope* local_scope_; - const platform::Place& place_; const std::string name_; + platform::Place place_; }; } // namespace details