未验证 提交 0aa9546e 编写于 作者: Y Yancey 提交者: GitHub

fix dist train error (#11281)

* fix dist train error

* update by comment
上级 8fa457f9
......@@ -464,9 +464,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0];
auto *s = local_scopes_[0];
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
result->ops_.emplace_back(
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send_vars");
......
......@@ -19,12 +19,12 @@ namespace framework {
namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const platform::Place &place,
const std::string &name)
const Scope *local_scope, const std::string &name,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place),
name_(name) {}
name_(name),
place_(place) {}
void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
......
......@@ -29,7 +29,7 @@ namespace details {
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place, const std::string& name);
const std::string& name, const platform::Place& place);
std::string Name() const override;
......@@ -43,8 +43,8 @@ struct RPCOpHandle : public OpHandleBase {
private:
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
const std::string name_;
platform::Place place_;
};
} // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册