提交 b123e43b 编写于 作者: Y Yu Yang

extract multi devices graph builder

上级 dd73d18b
...@@ -88,14 +88,9 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo ...@@ -88,14 +88,9 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method) framework_proto backward glog lod_rank_table feed_fetch_method)
if(WITH_GPU)
set(parallel_executor_cuda_deps nccl_all_reduce_op_handle)
else()
set(parallel_executor_cuda_deps)
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle backward glog lod_rank_table simple_threadpool multi_devices_graph_builder fetch_op_handle)
fetch_op_handle computation_op_handle ssa_graph ${parallel_executor_cuda_deps})
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -7,3 +7,6 @@ nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_h ...@@ -7,3 +7,6 @@ nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_h
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
nccl_all_reduce_op_handle scale_loss_grad_op_handle)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
namespace framework {
namespace details {
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs) {
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
}
void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program,
SSAGraph *graph) const {
SSAGraph &result = *graph;
result.vars_.resize(places_.size());
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
bool change_forward = false;
if (!is_forwarding) {
// FIXME(yy): Do not hard code like this
if (op->OutputArgumentNames().size() == 1 &&
op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) {
continue; // Drop fill 1. for backward coeff;
}
}
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var =
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(&result, op_handle, each_var_name, p, i);
}
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
// Insert ScaleCost OpHandle
op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
nccl_ctxs_->DevCtx(p));
result.ops_.emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p, i);
change_forward = true;
}
}
}
if (change_forward) {
is_forwarding = false;
}
if (!is_forwarding) {
auto var_names = op->OutputArgumentNames();
for (auto &og : var_names) {
if (grad_names_.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op
result.ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
auto *op_handle = result.ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto &vars = result.vars_[i][og];
if (vars.empty()) { // This device has no data. continue.
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->AddInput(prev_grad);
auto &var = vars[vars.size()];
var.place_ = p;
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->AddOutput(&var);
}
}
}
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards(&result);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
namespace platform {
class NCCLContextMap;
}
namespace framework {
class Scope;
namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs);
void Build(const ProgramDesc &program, SSAGraph *graph) const override;
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_;
platform::NCCLContextMap *nccl_ctxs_;
std::unordered_set<std::string> grad_names_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
for (auto &var_map : graph->vars_) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
return;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
auto *ex_write_op = it_old->second.generated_op_;
if (ex_write_op == nullptr) { // Nobody write this var.
continue;
}
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var);
}
}
}
}
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &var_holders = graph->vars_[place_offset];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.name_ = each_var_name;
var.place_ = place;
op_handle->AddOutput(&var);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
class SSAGraphBuilder {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(SSAGraph *graph);
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset);
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -16,231 +16,14 @@ limitations under the License. */ ...@@ -16,231 +16,14 @@ limitations under the License. */
#include "ThreadPool.h" #include "ThreadPool.h"
#include "lod_tensor.h" #include "lod_tensor.h"
#include "op_registry.h" #include "op_registry.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using details::ComputationOpHandle;
using details::DummyVarHandle;
using details::FetchOpHandle;
using details::NCCLAllReduceOpHandle;
using details::OpHandleBase;
using details::ScaleLossGradOpHandle;
using details::SSAGraph;
using details::VarHandle;
using details::VarHandleBase;
class SSAGraphBuilder {
public:
virtual ~SSAGraphBuilder() {}
virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0;
protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(SSAGraph *graph) {
for (auto &var_map : graph->vars_) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
return;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
auto *ex_write_op = it_old->second.generated_op_;
if (ex_write_op == nullptr) { // Nobody write this var.
continue;
}
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var);
}
}
}
}
}
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset) {
auto &var_holders = graph->vars_[place_offset];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.name_ = each_var_name;
var.place_ = place;
op_handle->AddOutput(&var);
}
};
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs) {
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
}
void Build(const ProgramDesc &program, SSAGraph *graph) const override {
SSAGraph &result = *graph;
result.vars_.resize(places_.size());
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
bool change_forward = false;
if (!is_forwarding) {
// FIXME(yy): Do not hard code like this
if (op->OutputArgumentNames().size() == 1 &&
op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) {
continue; // Drop fill 1. for backward coeff;
}
}
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var =
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(&result, op_handle, each_var_name, p, i);
}
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
// Insert ScaleCost OpHandle
op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
nccl_ctxs_->DevCtx(p));
result.ops_.emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p,
i);
change_forward = true;
}
}
}
if (change_forward) {
is_forwarding = false;
}
if (!is_forwarding) {
auto var_names = op->OutputArgumentNames();
for (auto &og : var_names) {
if (grad_names_.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op
result.ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
auto *op_handle = result.ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto &vars = result.vars_[i][og];
if (vars.empty()) { // This device has no data. continue.
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->AddInput(prev_grad);
auto &var = vars[vars.size()];
var.place_ = p;
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->AddOutput(&var);
}
}
}
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards(&result);
}
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_;
platform::NCCLContextMap *nccl_ctxs_;
std::unordered_set<std::string> grad_names_;
};
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(size_t num_threads, explicit ParallelExecutorPrivate(size_t num_threads,
...@@ -256,17 +39,17 @@ class ParallelExecutorPrivate { ...@@ -256,17 +39,17 @@ class ParallelExecutorPrivate {
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
SSAGraph graph_; details::SSAGraph graph_;
// Use a simpler thread pool, might be faster. // Use a simpler thread pool, might be faster.
std::unique_ptr<ThreadPool> pool_; std::unique_ptr<ThreadPool> pool_;
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
void RunOp( void RunOp(bool use_event,
bool use_event, std::unordered_map<details::VarHandleBase *, std::atomic<bool>>
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars, &pending_vars,
OpHandleBase *op) { details::OpHandleBase *op) {
std::vector<std::atomic<bool> *> *ready_buffer = std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>(); new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) { for (auto *var : op->outputs_) {
...@@ -321,9 +104,9 @@ ParallelExecutor::ParallelExecutor( ...@@ -321,9 +104,9 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, params, details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
member_->local_scopes_, params, member_->local_scopes_,
member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get());
builder.Build(main_program, &member_->graph_); builder.Build(main_program, &member_->graph_);
// Step 3. Create vars in each scope; // Step 3. Create vars in each scope;
...@@ -389,9 +172,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -389,9 +172,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
FeedFetchList fetched_data(fetch_tensors.size()); FeedFetchList fetched_data(fetch_tensors.size());
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset(); member_->exception_.reset();
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars; std::unordered_map<details::VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<details::OpHandleBase *, size_t> pending_ops;
std::vector<DummyVarHandle> dummy_vars; std::vector<details::DummyVarHandle> dummy_vars;
for (auto &var_map : member_->graph_.vars_) { for (auto &var_map : member_->graph_.vars_) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
...@@ -406,7 +189,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -406,7 +189,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
pending_vars[var.get()] = var->generated_op_ == nullptr; pending_vars[var.get()] = var->generated_op_ == nullptr;
} }
std::vector<OpHandleBase *> to_run; std::vector<details::OpHandleBase *> to_run;
for (auto &op : member_->graph_.ops_) { for (auto &op : member_->graph_.ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input. if (op->inputs_.empty()) { // Special case, Op has no input.
...@@ -416,7 +199,8 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -416,7 +199,8 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
} }
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<details::VarHandleBase *>>
fetched_vars;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : member_->graph_.vars_) { for (auto &var_map : member_->graph_.vars_) {
...@@ -427,13 +211,13 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -427,13 +211,13 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
} }
std::vector<FetchOpHandle> fetch_ops; std::vector<details::FetchOpHandle> fetch_ops;
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars[var_name]; auto &vars = fetched_vars[var_name];
fetch_ops.emplace_back(&fetched_data, i, &member_->local_scopes_); fetch_ops.emplace_back(&fetched_data, i, &member_->local_scopes_);
FetchOpHandle *op = &fetch_ops.back(); details::FetchOpHandle *op = &fetch_ops.back();
// FIXME: Use new device context // FIXME: Use new device context
for (auto &p : member_->places_) { for (auto &p : member_->places_) {
...@@ -457,7 +241,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -457,7 +241,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
VarHandleBase *ready_var = nullptr; details::VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) { for (auto &pair : pending_vars) {
if (pair.second.load(std::memory_order_acquire)) { if (pair.second.load(std::memory_order_acquire)) {
ready_var = pair.first; ready_var = pair.first;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册