未验证 提交 0aa9546e 编写于 作者: Y Yancey 提交者: GitHub

fix dist train error (#11281)

* fix dist train error

* update by comment
上级 8fa457f9
...@@ -464,9 +464,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, ...@@ -464,9 +464,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
auto &p = places_[0]; result->ops_.emplace_back(
auto *s = local_scopes_[0]; new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send_vars"); ConnectOp(result, result->ops_.back().get(), "send_vars");
......
...@@ -19,12 +19,12 @@ namespace framework { ...@@ -19,12 +19,12 @@ namespace framework {
namespace details { namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const platform::Place &place, const Scope *local_scope, const std::string &name,
const std::string &name) const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope), local_scope_(local_scope),
place_(place), name_(name),
name_(name) {} place_(place) {}
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle. // TODO(wuyi): need further analysis whether wait VarDummyHandle.
......
...@@ -29,7 +29,7 @@ namespace details { ...@@ -29,7 +29,7 @@ namespace details {
struct RPCOpHandle : public OpHandleBase { struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place, const std::string& name); const std::string& name, const platform::Place& place);
std::string Name() const override; std::string Name() const override;
...@@ -43,8 +43,8 @@ struct RPCOpHandle : public OpHandleBase { ...@@ -43,8 +43,8 @@ struct RPCOpHandle : public OpHandleBase {
private: private:
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_; const Scope* local_scope_;
const platform::Place& place_;
const std::string name_; const std::string name_;
platform::Place place_;
}; };
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册