From 5b84c9b59ce55bfeef6e474905b41475dfb07b36 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 16 Apr 2018 13:12:20 +0800 Subject: [PATCH] CreateOpHandleIOs --- .../details/multi_devices_graph_builder.cc | 14 +++++++------- .../details/multi_devices_graph_builder.h | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index e0dd9e6068..5a95cbc536 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -55,21 +55,21 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, + const OpDesc &op, const platform::Place &p, const size_t &i) const { auto *op_handle = result->ops_.back().get(); - op_handle->dev_ctxes_[p] = const_cast( - platform::DeviceContextPool::Instance().Get(p)); + op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p); - auto var_names = op->InputArgumentNames(); + auto var_names = op.InputArgumentNames(); for (auto &each_var_name : var_names) { VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); op_handle->AddInput(var); } - var_names = op->OutputArgumentNames(); + var_names = op.OutputArgumentNames(); for (auto &each_var_name : var_names) { CreateOpOutput(result, op_handle, each_var_name, p, i); @@ -107,7 +107,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( result.ops_.emplace_back(new SendOpHandle(*op, s, p)); // Create inputs for output on original place and no ssa output // is created for send op. - CreateOpHandleIOs(&result, op, p, 0); + CreateOpHandleIOs(&result, *op, p, 0); continue; } @@ -117,7 +117,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); auto *op_handle = result.ops_.back().get(); - CreateOpHandleIOs(&result, op, p, i); + CreateOpHandleIOs(&result, *op, p, i); auto var_names = op->OutputArgumentNames(); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index de34caab1b..f1518d75b4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -45,8 +45,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::unique_ptr Build(const ProgramDesc &program) const override; private: - void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i) const; + void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, + const platform::Place &p, const size_t &i) const; private: std::string loss_var_name_; -- GitLab