From b123e43bf99fa84b68c91e16d92a8aac5508e88e Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 24 Mar 2018 12:28:14 +0800 Subject: [PATCH] extract multi devices graph builder --- paddle/fluid/framework/CMakeLists.txt | 9 +- paddle/fluid/framework/details/CMakeLists.txt | 3 + .../details/multi_devices_graph_builder.cc | 140 ++++++++++ .../details/multi_devices_graph_builder.h | 46 ++++ .../framework/details/ssa_graph_builder.cc | 88 ++++++ .../framework/details/ssa_graph_builder.h | 56 ++++ paddle/fluid/framework/parallel_executor.cc | 254 ++---------------- 7 files changed, 354 insertions(+), 242 deletions(-) create mode 100644 paddle/fluid/framework/details/multi_devices_graph_builder.cc create mode 100644 paddle/fluid/framework/details/multi_devices_graph_builder.h create mode 100644 paddle/fluid/framework/details/ssa_graph_builder.cc create mode 100644 paddle/fluid/framework/details/ssa_graph_builder.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index f1d19efa97d..d3f69ee9d84 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -88,14 +88,9 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table feed_fetch_method) -if(WITH_GPU) - set(parallel_executor_cuda_deps nccl_all_reduce_op_handle) -else() - set(parallel_executor_cuda_deps) -endif() + cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope - backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle - fetch_op_handle computation_op_handle ssa_graph ${parallel_executor_cuda_deps}) + backward glog lod_rank_table simple_threadpool multi_devices_graph_builder fetch_op_handle) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 9ed41ab94c3..4432bc0245e 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -7,3 +7,6 @@ nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_h cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) +cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) +cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle + nccl_all_reduce_op_handle scale_loss_grad_op_handle) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc new file mode 100644 index 00000000000..3fab6adf0f8 --- /dev/null +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace framework { +namespace details { +MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( + const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes, + platform::NCCLContextMap *nccl_ctxs) + : loss_var_name_(loss_var_name), + places_(places), + local_scopes_(local_scopes), + nccl_ctxs_(nccl_ctxs) { + for (auto &p : params) { + grad_names_.insert(GradVarName(p)); + } +} + +void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program, + SSAGraph *graph) const { + SSAGraph &result = *graph; + result.vars_.resize(places_.size()); + + bool is_forwarding = true; + for (auto *op : program.Block(0).AllOps()) { + bool change_forward = false; + if (!is_forwarding) { + // FIXME(yy): Do not hard code like this + if (op->OutputArgumentNames().size() == 1 && + op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) { + continue; // Drop fill 1. for backward coeff; + } + } + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + auto *s = local_scopes_[i]; + + result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); + auto *op_handle = result.ops_.back().get(); + op_handle->dev_ctx_[p] = const_cast( + platform::DeviceContextPool::Instance().Get(p)); + + auto var_names = op->InputArgumentNames(); + + for (auto &each_var_name : var_names) { + VarHandle *var = + CreateOrGetLatestVarHandle(&result, each_var_name, p, i); + op_handle->AddInput(var); + } + var_names = op->OutputArgumentNames(); + + for (auto &each_var_name : var_names) { + CreateOpOutput(&result, op_handle, each_var_name, p, i); + } + + if (is_forwarding) { + if (var_names.size() == 1 && var_names[0] == loss_var_name_) { + // Insert ScaleCost OpHandle + op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p, + nccl_ctxs_->DevCtx(p)); + result.ops_.emplace_back(op_handle); + + // FIXME: Currently ScaleLossGradOp only use device_count as scale + // factor. So it does not depend on any other operators. + // VarHandle *loss = GetVarHandle(loss_var_name, place); + // loss->pending_ops_.emplace_back(op_handle); + // op_handle->inputs_.emplace_back(loss); + + CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p, i); + change_forward = true; + } + } + } + + if (change_forward) { + is_forwarding = false; + } + + if (!is_forwarding) { + auto var_names = op->OutputArgumentNames(); + for (auto &og : var_names) { + if (grad_names_.count(og) != 0) { // is param grad + // Insert NCCL AllReduce Op + result.ops_.emplace_back( + new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); + auto *op_handle = result.ops_.back().get(); + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + auto &vars = result.vars_[i][og]; + + if (vars.empty()) { // This device has no data. continue. + continue; + } + auto *prev_grad = &vars[vars.size() - 1]; + op_handle->AddInput(prev_grad); + + auto &var = vars[vars.size()]; + var.place_ = p; + var.name_ = og; + var.version_ = vars.size() - 1; + + op_handle->AddOutput(&var); + } + } + } + } + } + + /* + Dependency graph has been constructed. However, there are still data + harzaeds need to be handled. + */ + PolishGraphToSupportDataHazards(&result); +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h new file mode 100644 index 00000000000..510f85bc877 --- /dev/null +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -0,0 +1,46 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + +namespace paddle { +namespace platform { +class NCCLContextMap; +} + +namespace framework { +class Scope; +namespace details { +class MultiDevSSAGraphBuilder : public SSAGraphBuilder { + public: + MultiDevSSAGraphBuilder(const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes, + platform::NCCLContextMap *nccl_ctxs); + + void Build(const ProgramDesc &program, SSAGraph *graph) const override; + + private: + std::string loss_var_name_; + const std::vector &places_; + const std::vector &local_scopes_; + platform::NCCLContextMap *nccl_ctxs_; + std::unordered_set grad_names_; +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc new file mode 100644 index 00000000000..7a80a4b1e73 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + +namespace paddle { +namespace framework { +namespace details { +void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + if (name_pair.second.size() <= 1) { + return; + } + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + auto *write_op = it_new->second.generated_op_; + auto &read_ops = it_old->second.pending_ops_; + auto *ex_write_op = it_old->second.generated_op_; + + if (ex_write_op == nullptr) { // Nobody write this var. + continue; + } + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + + auto *dep_var = new DummyVarHandle(); + read_op->AddOutput(dep_var); + write_op->AddInput(dep_var); + graph->dep_vars_.emplace(dep_var); + } + } + } + } +} + +VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( + SSAGraph *graph, const std::string &each_var_name, + const platform::Place &place, size_t place_offset) { + auto &var_holders = graph->vars_[place_offset]; + auto &var_holder = var_holders[each_var_name]; + VarHandle *var = nullptr; + if (var_holder.empty()) { + auto &init_var = var_holder[0]; + init_var.place_ = place; + init_var.name_ = each_var_name; + init_var.generated_op_ = nullptr; + init_var.version_ = 0; + var = &init_var; + } else { + var = &var_holder.rbegin()->second; + } + return var; +} + +void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, + const std::string &each_var_name, + const platform::Place &place, + size_t place_offset) { + auto &vars = graph->vars_[place_offset][each_var_name]; + size_t version = vars.size(); + auto &var = vars[version]; + var.version_ = version; + var.name_ = each_var_name; + var.place_ = place; + op_handle->AddOutput(&var); +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h new file mode 100644 index 00000000000..848b90293a3 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -0,0 +1,56 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/place.h" + +#include + +namespace paddle { +namespace framework { +namespace details { + +class SSAGraphBuilder { + public: + SSAGraphBuilder() {} + virtual ~SSAGraphBuilder() {} + virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0; + + DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); + + protected: + /** + * We only handle write after read(WAR), since it should not have a write + * after write in program. If there are write after write operators, we need + * prune them. + * + * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) + */ + static void PolishGraphToSupportDataHazards(SSAGraph *graph); + + static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, + const std::string &each_var_name, + const platform::Place &place, + size_t place_offset); + + static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, + const std::string &each_var_name, + const platform::Place &place, size_t place_offset); +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 5c10595db9c..4ebb89181cd 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -16,231 +16,14 @@ limitations under the License. */ #include "ThreadPool.h" #include "lod_tensor.h" #include "op_registry.h" -#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" -#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" -#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/platform/nccl_helper.h" namespace paddle { namespace framework { -using details::ComputationOpHandle; -using details::DummyVarHandle; -using details::FetchOpHandle; -using details::NCCLAllReduceOpHandle; -using details::OpHandleBase; -using details::ScaleLossGradOpHandle; -using details::SSAGraph; -using details::VarHandle; -using details::VarHandleBase; - -class SSAGraphBuilder { - public: - virtual ~SSAGraphBuilder() {} - virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0; - - protected: - /** - * We only handle write after read(WAR), since it should not have a write - * after write in program. If there are write after write operators, we need - * prune them. - * - * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) - */ - static void PolishGraphToSupportDataHazards(SSAGraph *graph) { - for (auto &var_map : graph->vars_) { - for (auto &name_pair : var_map) { - if (name_pair.second.size() <= 1) { - return; - } - auto it_new = name_pair.second.rbegin(); - auto it_old = name_pair.second.rbegin(); - ++it_old; - for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - auto *write_op = it_new->second.generated_op_; - auto &read_ops = it_old->second.pending_ops_; - auto *ex_write_op = it_old->second.generated_op_; - - if (ex_write_op == nullptr) { // Nobody write this var. - continue; - } - - for (auto *read_op : read_ops) { - // Manually add a dependency var from read_op to write_op; - if (read_op == write_op) { - // Read Write is the same op. - continue; - } - - auto *dep_var = new DummyVarHandle(); - read_op->AddOutput(dep_var); - write_op->AddInput(dep_var); - graph->dep_vars_.emplace(dep_var); - } - } - } - } - } - - static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, - const std::string &each_var_name, - const platform::Place &place, - size_t place_offset) { - auto &var_holders = graph->vars_[place_offset]; - auto &var_holder = var_holders[each_var_name]; - VarHandle *var = nullptr; - if (var_holder.empty()) { - auto &init_var = var_holder[0]; - init_var.place_ = place; - init_var.name_ = each_var_name; - init_var.generated_op_ = nullptr; - init_var.version_ = 0; - var = &init_var; - } else { - var = &var_holder.rbegin()->second; - } - return var; - } - - static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, - const std::string &each_var_name, - const platform::Place &place, - size_t place_offset) { - auto &vars = graph->vars_[place_offset][each_var_name]; - size_t version = vars.size(); - auto &var = vars[version]; - var.version_ = version; - var.name_ = each_var_name; - var.place_ = place; - op_handle->AddOutput(&var); - } -}; - -class MultiDevSSAGraphBuilder : public SSAGraphBuilder { - public: - MultiDevSSAGraphBuilder(const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs) - : loss_var_name_(loss_var_name), - places_(places), - local_scopes_(local_scopes), - nccl_ctxs_(nccl_ctxs) { - for (auto &p : params) { - grad_names_.insert(GradVarName(p)); - } - } - - void Build(const ProgramDesc &program, SSAGraph *graph) const override { - SSAGraph &result = *graph; - result.vars_.resize(places_.size()); - - bool is_forwarding = true; - for (auto *op : program.Block(0).AllOps()) { - bool change_forward = false; - if (!is_forwarding) { - // FIXME(yy): Do not hard code like this - if (op->OutputArgumentNames().size() == 1 && - op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) { - continue; // Drop fill 1. for backward coeff; - } - } - - for (size_t i = 0; i < places_.size(); ++i) { - auto &p = places_[i]; - auto *s = local_scopes_[i]; - - result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); - auto *op_handle = result.ops_.back().get(); - op_handle->dev_ctx_[p] = const_cast( - platform::DeviceContextPool::Instance().Get(p)); - - auto var_names = op->InputArgumentNames(); - - for (auto &each_var_name : var_names) { - VarHandle *var = - CreateOrGetLatestVarHandle(&result, each_var_name, p, i); - op_handle->AddInput(var); - } - var_names = op->OutputArgumentNames(); - - for (auto &each_var_name : var_names) { - CreateOpOutput(&result, op_handle, each_var_name, p, i); - } - - if (is_forwarding) { - if (var_names.size() == 1 && var_names[0] == loss_var_name_) { - // Insert ScaleCost OpHandle - op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p, - nccl_ctxs_->DevCtx(p)); - result.ops_.emplace_back(op_handle); - - // FIXME: Currently ScaleLossGradOp only use device_count as scale - // factor. So it does not depend on any other operators. - // VarHandle *loss = GetVarHandle(loss_var_name, place); - // loss->pending_ops_.emplace_back(op_handle); - // op_handle->inputs_.emplace_back(loss); - - CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p, - i); - change_forward = true; - } - } - } - - if (change_forward) { - is_forwarding = false; - } - - if (!is_forwarding) { - auto var_names = op->OutputArgumentNames(); - for (auto &og : var_names) { - if (grad_names_.count(og) != 0) { // is param grad - // Insert NCCL AllReduce Op - result.ops_.emplace_back( - new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); - auto *op_handle = result.ops_.back().get(); - - for (size_t i = 0; i < places_.size(); ++i) { - auto &p = places_[i]; - auto &vars = result.vars_[i][og]; - - if (vars.empty()) { // This device has no data. continue. - continue; - } - auto *prev_grad = &vars[vars.size() - 1]; - op_handle->AddInput(prev_grad); - - auto &var = vars[vars.size()]; - var.place_ = p; - var.name_ = og; - var.version_ = vars.size() - 1; - - op_handle->AddOutput(&var); - } - } - } - } - } - - /* - Dependency graph has been constructed. However, there are still data - harzaeds need to be handled. - */ - PolishGraphToSupportDataHazards(&result); - } - - private: - std::string loss_var_name_; - const std::vector &places_; - const std::vector &local_scopes_; - platform::NCCLContextMap *nccl_ctxs_; - - std::unordered_set grad_names_; -}; - class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(size_t num_threads, @@ -256,17 +39,17 @@ class ParallelExecutorPrivate { std::unique_ptr nccl_ctxs_; - SSAGraph graph_; + details::SSAGraph graph_; // Use a simpler thread pool, might be faster. std::unique_ptr pool_; std::unique_ptr exception_; - void RunOp( - bool use_event, - std::unordered_map> &pending_vars, - OpHandleBase *op) { + void RunOp(bool use_event, + std::unordered_map> + &pending_vars, + details::OpHandleBase *op) { std::vector *> *ready_buffer = new std::vector *>(); for (auto *var : op->outputs_) { @@ -321,9 +104,9 @@ ParallelExecutor::ParallelExecutor( // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp - MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, params, - member_->local_scopes_, - member_->nccl_ctxs_.get()); + details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, + params, member_->local_scopes_, + member_->nccl_ctxs_.get()); builder.Build(main_program, &member_->graph_); // Step 3. Create vars in each scope; @@ -389,9 +172,9 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, FeedFetchList fetched_data(fetch_tensors.size()); // Version --> VarHandle member_->exception_.reset(); - std::unordered_map> pending_vars; - std::unordered_map pending_ops; - std::vector dummy_vars; + std::unordered_map> pending_vars; + std::unordered_map pending_ops; + std::vector dummy_vars; for (auto &var_map : member_->graph_.vars_) { for (auto &name_pair : var_map) { @@ -406,7 +189,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, pending_vars[var.get()] = var->generated_op_ == nullptr; } - std::vector to_run; + std::vector to_run; for (auto &op : member_->graph_.ops_) { if (op->inputs_.empty()) { // Special case, Op has no input. @@ -416,7 +199,8 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } } - std::unordered_map> fetched_vars; + std::unordered_map> + fetched_vars; for (auto &fetch_var_name : fetch_tensors) { for (auto &var_map : member_->graph_.vars_) { @@ -427,13 +211,13 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } } - std::vector fetch_ops; + std::vector fetch_ops; for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars[var_name]; fetch_ops.emplace_back(&fetched_data, i, &member_->local_scopes_); - FetchOpHandle *op = &fetch_ops.back(); + details::FetchOpHandle *op = &fetch_ops.back(); // FIXME: Use new device context for (auto &p : member_->places_) { @@ -457,7 +241,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } while (!pending_vars.empty()) { - VarHandleBase *ready_var = nullptr; + details::VarHandleBase *ready_var = nullptr; for (auto &pair : pending_vars) { if (pair.second.load(std::memory_order_acquire)) { ready_var = pair.first; -- GitLab