提交 32d5a160 编写于 作者: X Xin Pan

resolve conflicts

test=develop
上级 26e32e09
...@@ -206,8 +206,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -206,8 +206,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
graph->Erase(kAllOpDescs); graph->Erase(kAllOpDescs);
} }
graph->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, graph->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, &all_ops);
&all_ops); // take ownership
pass->Erase(kAllOpDescs); pass->Erase(kAllOpDescs);
pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, &all_ops); pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, &all_ops);
......
...@@ -20,7 +20,7 @@ namespace framework { ...@@ -20,7 +20,7 @@ namespace framework {
namespace details { namespace details {
std::vector<std::unique_ptr<ir::Graph>> std::vector<std::unique_ptr<ir::Graph>>
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph* graph) { ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places_.size()); graphs.reserve(places_.size());
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -76,13 +76,12 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph* graph) { ...@@ -76,13 +76,12 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph* graph) {
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places, ir::Graph *graph)
const framework::ProgramDesc &main_prog, ir::Graph* graph)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(std::move(places)),
main_prog_(main_prog), main_prog_(graph->OriginProgram()),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the // TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs. // attrs.
graphs_(SeparateMultiDevicesGraph(graph)) { graphs_(SeparateMultiDevicesGraph(graph)) {
......
...@@ -31,8 +31,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -31,8 +31,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const framework::ProgramDesc &main_prog, ir::Graph *graph);
ir::Graph* graph);
~ParallelSSAGraphExecutor() final = default; ~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
...@@ -41,7 +40,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -41,7 +40,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
private: private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ir::Graph* graph); ir::Graph *graph);
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
......
...@@ -195,22 +195,12 @@ class Graph { ...@@ -195,22 +195,12 @@ class Graph {
return nullptr; return nullptr;
} }
<<<<<<< HEAD
=======
// Returns reference to the original program. // Returns reference to the original program.
// WARN: After a series of passes, the current graph can be quite // WARN: After a series of passes, the current graph can be quite
// different from OriginProgram. Caller shouldn't assume much from // different from OriginProgram. Caller shouldn't assume much from
// the returned OriginProgram. // the returned OriginProgram.
const ProgramDesc &OriginProgram() const { return program_; } const ProgramDesc &OriginProgram() const { return program_; }
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
>>>>>>> polish
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
......
...@@ -184,9 +184,10 @@ std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { ...@@ -184,9 +184,10 @@ std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const std::vector<ir::Graph *> &graphs, const std::string &loss_var_name, const std::string &loss_var_name, Scope *scope,
Scope *scope, const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy) const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_; member_->use_cuda_ = exec_strategy.use_cuda_;
...@@ -216,34 +217,17 @@ ParallelExecutor::ParallelExecutor( ...@@ -216,34 +217,17 @@ ParallelExecutor::ParallelExecutor(
} }
} }
<<<<<<< HEAD
std::unique_ptr<ir::Graph> temp_owned_graph(graph); std::unique_ptr<ir::Graph> temp_owned_graph(graph);
// FIXME(Yancey1989): parallel graph mode get better performance // FIXME(Yancey1989): parallel graph mode get better performance
// in GPU allreduce distributed training. Need an elegant way to // in GPU allreduce distributed training. Need an elegant way to
// choice the execution strategy. // choice the execution strategy.
build_strategy.enable_parallel_graph_ = build_strategy.enable_parallel_graph_ = EnableParallelGraphExecution(
EnableParallelGraphExecution(*temp_owned_graph, exec_strategy, build_strategy); *temp_owned_graph, exec_strategy, build_strategy);
if (build_strategy.enable_parallel_graph_) if (build_strategy.enable_parallel_graph_)
VLOG(0) << "The Executor would execute the graph by ParallelGraph " VLOG(0) << "The Executor would execute the graph by ParallelGraph "
"Execution which can get better performance," "Execution which can get better performance,"
<< "you can force it off by env FLAGS_enable_parallel_graph=0"; << "you can force it off by env FLAGS_enable_parallel_graph=0";
=======
// TODO(panyx0718): Update pass interface so we don't need this here.
std::vector<std::unique_ptr<ir::Graph>> temp_owned_graphs;
for (ir::Graph *g : graphs) {
temp_owned_graphs.emplace_back(g);
}
<<<<<<< HEAD
>>>>>>> fix parallel graph mode program
=======
bool parallel_graphs = (temp_owned_graphs.size() > 1);
if (parallel_graphs) {
PADDLE_ENFORCE_EQ(temp_owned_graphs.size(), places.size());
}
VLOG(1) << "Enable ParallelGraph Execution: " << parallel_graphs;
>>>>>>> polish
if (member_->use_cuda_) { if (member_->use_cuda_) {
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
...@@ -255,7 +239,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -255,7 +239,7 @@ ParallelExecutor::ParallelExecutor(
if (nccl_id_var != nullptr) { if (nccl_id_var != nullptr) {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>(); nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
} }
if (parallel_graphs && member_->nranks_ > 1UL) { if (build_strategy.enable_parallel_graph_ && member_->nranks_ > 1UL) {
if (nccl_id == nullptr) { if (nccl_id == nullptr) {
local_nccl_id_.reset(new ncclUniqueId()); local_nccl_id_.reset(new ncclUniqueId());
platform::dynload::ncclGetUniqueId(local_nccl_id_.get()); platform::dynload::ncclGetUniqueId(local_nccl_id_.get());
...@@ -273,105 +257,54 @@ ParallelExecutor::ParallelExecutor( ...@@ -273,105 +257,54 @@ ParallelExecutor::ParallelExecutor(
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToDevices(bcast_vars); BCastParamsToDevices(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
<<<<<<< HEAD
std::unique_ptr<ir::Graph> graph;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), member_->places_, loss_var_name, temp_owned_graph = build_strategy.Apply(
member_->local_scopes_, member_->nranks_, std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->use_cuda_, member_->nccl_ctxs_.get()); member_->local_scopes_, member_->nranks_, member_->use_cuda_,
#else member_->nccl_ctxs_.get());
temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
=======
std::vector<ir::Graph *> compiled_graphs;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (parallel_graphs) {
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graphs[i]), {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
compiled_graphs.push_back(temp_owned_graph.release());
}
} else {
auto temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graphs[0]), member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
compiled_graphs.push_back(temp_owned_graph.release());
}
#else #else
auto temp_owned_graph = build_strategy.Apply( temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graphs[0]), member_->places_, loss_var_name, std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, member_->use_cuda_); member_->local_scopes_, member_->nranks_, member_->use_cuda_);
compiled_graphs.push_back(temp_owned_graph.release());
>>>>>>> fix parallel graph mode program
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold " VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30); << static_cast<float>(max_memory_size) / (1 << 30);
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
<<<<<<< HEAD graph = member_
graph = member_->PrepareGCAndRefCnts(std::move(graph), ->PrepareGCAndRefCnts(std::move(temp_owned_graph),
static_cast<size_t>(max_memory_size)).release(); static_cast<size_t>(max_memory_size))
======= .release();
for (size_t i = 0; i < graphs.size(); ++i) { } else {
compiled_graphs[i] = graph = temp_owned_graph.release();
member_
->PrepareGCAndRefCnts(
std::unique_ptr<ir::Graph>(compiled_graphs[i]),
static_cast<size_t>(max_memory_size))
.release();
}
>>>>>>> fix parallel graph mode program
} }
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars // skip control vars and empty vars
std::vector<details::VariableInfo> var_infos; std::vector<details::VariableInfo> var_infos;
<<<<<<< HEAD
for (auto &node : graph->Nodes()) { for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back(); var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name(); var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType(); var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable(); var_infos.back().persistable_ = node->Var()->Persistable();
=======
for (auto &graph : compiled_graphs) {
for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
}
>>>>>>> fix parallel graph mode program
} }
} }
// If the loss_var_name is given, the number of graph should be only one. // If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) { if (loss_var_name.size()) {
<<<<<<< HEAD
size_t graph_num = ir::GraphNum(*graph); size_t graph_num = ir::GraphNum(*graph);
=======
size_t graph_num = ir::GraphNum(*compiled_graphs[0]);
>>>>>>> fix parallel graph mode program
if (graph_num > 1) { if (graph_num > 1) {
LOG(WARNING) LOG(WARNING)
<< "The number of graph should be only one, " << "The number of graph should be only one, "
"but the current graph has " "but the current graph has "
<<<<<<< HEAD
<< ir::GraphNum(*graph) << ir::GraphNum(*graph)
=======
<< ir::GraphNum(*compiled_graphs[0])
>>>>>>> fix parallel graph mode program
<< " sub_graphs. If you want to see the nodes of the " << " sub_graphs. If you want to see the nodes of the "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' " "sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"to specify the output dir. NOTES: if you not do training, " "to specify the output dir. NOTES: if you not do training, "
...@@ -379,18 +312,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -379,18 +312,12 @@ ParallelExecutor::ParallelExecutor(
} }
} }
<<<<<<< HEAD
if (build_strategy.enable_parallel_graph_) { if (build_strategy.enable_parallel_graph_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
// TODO(Yancey1989): Remove passing in the main_program when // TODO(Yancey1989): Remove passing in the main_program when
// allreduce_seq_pass doesn't need it as the attr. // allreduce_seq_pass doesn't need it as the attr.
=======
if (parallel_graphs) {
>>>>>>> polish
member_->executor_.reset(new details::ParallelSSAGraphExecutor( member_->executor_.reset(new details::ParallelSSAGraphExecutor(
<<<<<<< HEAD exec_strategy, member_->local_scopes_, member_->places_, graph));
exec_strategy, member_->local_scopes_, member_->places_, main_program,
graph));
#else #else
PADDLE_THROW( PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution."); "Paddle should be compiled with CUDA for ParallelGraph Execution.");
...@@ -402,19 +329,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -402,19 +329,6 @@ ParallelExecutor::ParallelExecutor(
} else { } else {
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->places_, graph));
=======
exec_strategy, member_->local_scopes_, member_->places_,
compiled_graphs));
} else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
compiled_graphs[0]));
} else {
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
compiled_graphs[0]));
>>>>>>> fix parallel graph mode program
} }
} }
...@@ -551,9 +465,9 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -551,9 +465,9 @@ ParallelExecutor::~ParallelExecutor() {
delete member_; delete member_;
} }
bool EnableParallelGraphExecution(const ir::Graph &graph, bool ParallelExecutor::EnableParallelGraphExecution(
const ExecutionStrategy &exec_strategy, const ir::Graph &graph, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy) { const BuildStrategy &build_strategy) const {
if (!FLAGS_enable_parallel_graph) return false; if (!FLAGS_enable_parallel_graph) return false;
bool enable_parallel_graph = true; bool enable_parallel_graph = true;
......
...@@ -46,11 +46,11 @@ class ParallelExecutor { ...@@ -46,11 +46,11 @@ class ParallelExecutor {
public: public:
explicit ParallelExecutor(const std::vector<platform::Place> &places, explicit ParallelExecutor(const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const std::vector<ir::Graph *> &graphs,
const std::string &loss_var_name, Scope *scope, const std::string &loss_var_name, Scope *scope,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy); const BuildStrategy &build_strategy,
ir::Graph *graph);
~ParallelExecutor(); ~ParallelExecutor();
...@@ -71,6 +71,9 @@ class ParallelExecutor { ...@@ -71,6 +71,9 @@ class ParallelExecutor {
private: private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const; void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
bool EnableParallelGraphExecution(const ir::Graph &graph,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy) const;
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
...@@ -78,9 +81,5 @@ class ParallelExecutor { ...@@ -78,9 +81,5 @@ class ParallelExecutor {
#endif #endif
}; };
bool EnableParallelGraphExecution(const ir::Graph &graph,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -976,8 +976,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -976,8 +976,6 @@ All parameter, weight, gradient are variables in Paddle.
[](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); }); [](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); });
// -- python binds for parallel executor. // -- python binds for parallel executor.
m.def("_enable_parallel_graph_execution",
framework::EnableParallelGraphExecution);
py::class_<ParallelExecutor> pe(m, "ParallelExecutor"); py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC( py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC(
...@@ -1216,10 +1214,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1216,10 +1214,9 @@ All parameter, weight, gradient are variables in Paddle.
cannot be updated after being finalized.)DOC"); cannot be updated after being finalized.)DOC");
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const std::unordered_set<std::string> &, const std::string &,
const std::vector<ir::Graph *> &, const std::string &,
Scope *, std::vector<Scope *> &, const ExecutionStrategy &, Scope *, std::vector<Scope *> &, const ExecutionStrategy &,
const BuildStrategy &>()) const BuildStrategy &, ir::Graph *>())
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element // We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope* // of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
...@@ -198,7 +198,6 @@ class CompiledProgram(object): ...@@ -198,7 +198,6 @@ class CompiledProgram(object):
if self._build_strategy.enable_inplace is None: if self._build_strategy.enable_inplace is None:
self._build_strategy.enable_inplace = False if self._program and self._program._is_mem_optimized else True self._build_strategy.enable_inplace = False if self._program and self._program._is_mem_optimized else True
# TODO(wuyi): trainer endpoings should be passed in through # TODO(wuyi): trainer endpoings should be passed in through
# build_strategy, not program.xxx. # build_strategy, not program.xxx.
if self._program and self._build_strategy.num_trainers > 1 and \ if self._program and self._build_strategy.num_trainers > 1 and \
...@@ -219,26 +218,13 @@ class CompiledProgram(object): ...@@ -219,26 +218,13 @@ class CompiledProgram(object):
places = list(map(_place_obj, self._places)) places = list(map(_place_obj, self._places))
# FIXME(Yancey1989): parallel graph mode get better performance pe = core.ParallelExecutor(
# in GPU allreduce distributed training. Need an elegant way to
# choice the execution strategy.
enable_parallel_graph = \
core._enable_parallel_graph_execution(self._graph,
self._exec_strategy,
self._build_strategy) and \
self._program # only supported if compile program not graph.
self._pe_graphs = [self._graph]
if enable_parallel_graph:
for _ in range(len(places) - 1):
self._pe_graphs.append(core.Graph(self._program_desc))
return core.ParallelExecutor(
places, places,
set(self._persistable_vars), self._pe_graphs, set(self._persistable_vars),
cpt.to_text(self._loss_name) cpt.to_text(self._loss_name)
if self._loss_name else six.u(''), self._scope, self._local_scopes, if self._loss_name else six.u(''), self._scope, self._local_scopes,
self._exec_strategy, self._build_strategy) self._exec_strategy, self._build_strategy, self._graph)
return pe
def _compile_inference(self): def _compile_inference(self):
return core.create_paddle_predictor(self._infer_config) return core.create_paddle_predictor(self._infer_config)
......
...@@ -186,12 +186,12 @@ class ParallelExecutor(object): ...@@ -186,12 +186,12 @@ class ParallelExecutor(object):
# step7: init ParallelExecutor # step7: init ParallelExecutor
# ParallelExecutor API will be deprecated, don't support parallel graph. # ParallelExecutor API will be deprecated, don't support parallel graph.
self._graphs = [core.Graph(main.desc)] self._graph = core.Graph(main.desc)
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
places, persistable_vars, self._graphs, places, persistable_vars,
cpt.to_text(loss_name) if loss_name else six.u(''), scope, cpt.to_text(loss_name) if loss_name else six.u(''), scope,
local_scopes, exec_strategy, build_strategy) local_scopes, exec_strategy, build_strategy, self._graph)
self.scope = scope self.scope = scope
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册