未验证 提交 ab91046b 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #9934 from reyoung/feature/PolishCreateOpHandleIOs

CreateOpHandleIOs
...@@ -55,21 +55,21 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -55,21 +55,21 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
} }
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
const OpDesc &op,
const platform::Place &p, const platform::Place &p,
const size_t &i) const { const size_t &i) const {
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>( op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p);
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames(); auto var_names = op.InputArgumentNames();
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var); op_handle->AddInput(var);
} }
var_names = op->OutputArgumentNames(); var_names = op.OutputArgumentNames();
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i); CreateOpOutput(result, op_handle, each_var_name, p, i);
...@@ -107,7 +107,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -107,7 +107,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
result.ops_.emplace_back(new SendOpHandle(*op, s, p)); result.ops_.emplace_back(new SendOpHandle(*op, s, p));
// Create inputs for output on original place and no ssa output // Create inputs for output on original place and no ssa output
// is created for send op. // is created for send op.
CreateOpHandleIOs(&result, op, p, 0); CreateOpHandleIOs(&result, *op, p, 0);
continue; continue;
} }
...@@ -117,7 +117,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -117,7 +117,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get(); auto *op_handle = result.ops_.back().get();
CreateOpHandleIOs(&result, op, p, i); CreateOpHandleIOs(&result, *op, p, i);
auto var_names = op->OutputArgumentNames(); auto var_names = op->OutputArgumentNames();
......
...@@ -45,8 +45,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -45,8 +45,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private: private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
const size_t &i) const; const platform::Place &p, const size_t &i) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册