未验证 提交 a6e3cd5e 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15425 from panyx0718/api

Pass graph to parallel executor instead of program
...@@ -43,7 +43,7 @@ paddle.fluid.AsyncExecutor.init_worker ArgSpec(args=['self', 'dist_desc', 'start ...@@ -43,7 +43,7 @@ paddle.fluid.AsyncExecutor.init_worker ArgSpec(args=['self', 'dist_desc', 'start
paddle.fluid.AsyncExecutor.run ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'mode', 'debug'], varargs=None, keywords=None, defaults=('', False)) paddle.fluid.AsyncExecutor.run ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'mode', 'debug'], varargs=None, keywords=None, defaults=('', False))
paddle.fluid.AsyncExecutor.save_model ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None) paddle.fluid.AsyncExecutor.save_model ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.AsyncExecutor.stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.AsyncExecutor.stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.CompiledProgram.__init__ ArgSpec(args=['self', 'program'], varargs=None, keywords=None, defaults=None) paddle.fluid.CompiledProgram.__init__ ArgSpec(args=['self', 'program_or_graph'], varargs=None, keywords=None, defaults=None)
paddle.fluid.CompiledProgram.with_data_parallel ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.CompiledProgram.with_data_parallel ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.CompiledProgram.with_inference_optimize ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=None) paddle.fluid.CompiledProgram.with_inference_optimize ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=None)
paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.ExecutionStrategy) -> None paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.ExecutionStrategy) -> None
......
...@@ -50,7 +50,7 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -50,7 +50,7 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
std::unordered_map<std::string, int> vars; std::unordered_map<std::string, int> vars;
// TODO(gongwb): use graph topology sort to find the order of operators. // TODO(gongwb): use graph topology sort to find the order of operators.
// Note that must assert topology sort is stable // Note that must assert topology sort is stable
auto& ops = Get<const std::vector<OpDesc*>>(kAllOpDescs); auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
for (auto* op_desc : ops) { for (auto* op_desc : ops) {
auto outputs = op_desc->Outputs(); auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) { for (auto& o_it : outputs) {
...@@ -120,4 +120,4 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -120,4 +120,4 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
REGISTER_PASS(all_reduce_deps_pass, REGISTER_PASS(all_reduce_deps_pass,
paddle::framework::details::AllReduceDepsPass) paddle::framework::details::AllReduceDepsPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs); .RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
...@@ -174,7 +174,8 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const { ...@@ -174,7 +174,8 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
} }
std::unique_ptr<ir::Graph> BuildStrategy::Apply( std::unique_ptr<ir::Graph> BuildStrategy::Apply(
const ProgramDesc &main_program, const std::vector<platform::Place> &places, std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes, const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
...@@ -185,7 +186,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -185,7 +186,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
...@@ -203,41 +203,12 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -203,41 +203,12 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("nccl_ctxs"); pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "memory_optimize_pass") {
if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
const std::vector<OpDesc *> *all_op_descs =
new std::vector<OpDesc *>(main_program.Block(0).AllOps());
graph->Set<const std::vector<OpDesc *>>(kAllOpDescs,
all_op_descs); // take ownership
pass->Erase(kAllOpDescs);
pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, all_op_descs);
} else if (pass->Type() == "sequential_execution_pass") { } else if (pass->Type() == "sequential_execution_pass") {
LOG(INFO) << "set enable_sequential_execution:" LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_; << enable_sequential_execution_;
pass->Erase(kAllOpDescs);
pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "all_reduce_deps_pass") { } else if (pass->Type() == "all_reduce_deps_pass") {
LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this) LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
<< ", num_trainers:" << num_trainers_; << ", num_trainers:" << num_trainers_;
pass->Erase(kAllOpDescs);
pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "inplace_pass") {
if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
graph->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") { } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
if (!use_cuda) { if (!use_cuda) {
LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on " LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
......
...@@ -114,7 +114,7 @@ struct BuildStrategy { ...@@ -114,7 +114,7 @@ struct BuildStrategy {
// Apply the passes built by the pass_builder_. The passes will be // Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph. // applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(const ProgramDesc &main_program, std::unique_ptr<ir::Graph> Apply(std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
......
...@@ -24,12 +24,11 @@ namespace details { ...@@ -24,12 +24,11 @@ namespace details {
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places, ir::Graph *graph)
std::unique_ptr<ir::Graph> &&graph)
: strategy_(strategy), : strategy_(strategy),
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
graph_(std::move(graph)), graph_(graph),
pool_(strategy.num_threads_), pool_(strategy.num_threads_),
prepare_pool_(1), // add one more thread for generate op_deps prepare_pool_(1), // add one more thread for generate op_deps
fetch_ctxs_(places) { fetch_ctxs_(places) {
...@@ -110,14 +109,14 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -110,14 +109,14 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
} }
} }
if (exception_.IsCaught()) { if (exception_.IsCaught()) {
ClearFetchOp(graph_.get(), &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
exception_.ReThrow(); exception_.ReThrow();
} }
} }
num_complete += num_comp; num_complete += num_comp;
} }
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_.get(), &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
return fetches; return fetches;
} }
......
...@@ -32,7 +32,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -32,7 +32,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph); ir::Graph *graph);
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
const ir::Graph &Graph() const override; const ir::Graph &Graph() const override;
...@@ -40,7 +40,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -40,7 +40,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::unique_ptr<ir::Graph> graph_; ir::Graph *graph_;
std::unordered_map<OpHandleBase *, int> op_deps_; std::unordered_map<OpHandleBase *, int> op_deps_;
std::vector<OpHandleBase *> bootstrap_ops_; std::vector<OpHandleBase *> bootstrap_ops_;
......
...@@ -33,10 +33,10 @@ namespace details { ...@@ -33,10 +33,10 @@ namespace details {
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) { std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kAllOpDescs), PADDLE_ENFORCE(graph.Has(kStaleProgramOpDescs),
"Graph has no attribute of kAllOpDescs."); "Graph has no attribute of kStaleProgramOpDescs.");
// 1. get op desc order // 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kAllOpDescs); auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
// 2. topology sort order // 2. topology sort order
auto nodes = graph.Nodes(); auto nodes = graph.Nodes();
......
...@@ -228,9 +228,6 @@ TEST(CFGGraph, IRGraph) { ...@@ -228,9 +228,6 @@ TEST(CFGGraph, IRGraph) {
// prepare ir graph // prepare ir graph
auto prog = FillProgramDesc(); auto prog = FillProgramDesc();
ir::Graph graph(prog); ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
ControlFlowGraph cfg(graph); ControlFlowGraph cfg(graph);
cfg.LiveVariableAnalysis(); cfg.LiveVariableAnalysis();
...@@ -256,9 +253,6 @@ TEST(CFGGraph, IRGraph) { ...@@ -256,9 +253,6 @@ TEST(CFGGraph, IRGraph) {
TEST(SortOpLikeDescOrder, NormalTest) { TEST(SortOpLikeDescOrder, NormalTest) {
auto prog = FillProgramDesc(); auto prog = FillProgramDesc();
ir::Graph graph(prog); ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = SortOpLikeDescOrder(graph); auto nodes = SortOpLikeDescOrder(graph);
auto op_descs = prog.Block(0).AllOps(); auto op_descs = prog.Block(0).AllOps();
...@@ -273,9 +267,6 @@ TEST(SortOpLikeDescOrder, NormalTest) { ...@@ -273,9 +267,6 @@ TEST(SortOpLikeDescOrder, NormalTest) {
TEST(SortOpLikeDescOrder, RemoveOpDesc) { TEST(SortOpLikeDescOrder, RemoveOpDesc) {
auto prog = FillProgramDesc(); auto prog = FillProgramDesc();
ir::Graph graph(prog); ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = graph.Nodes(); auto nodes = graph.Nodes();
auto op_descs = prog.Block(0).AllOps(); auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr; ir::Node* found_node = nullptr;
...@@ -324,8 +315,6 @@ TEST(SortOpLikeDescOrder, RemoveOpDesc) { ...@@ -324,8 +315,6 @@ TEST(SortOpLikeDescOrder, RemoveOpDesc) {
// 3. add some op_desc // 3. add some op_desc
TEST(SortOpLikeDescOrder, AddOpDesc) { TEST(SortOpLikeDescOrder, AddOpDesc) {
auto prog = FillProgramDesc(); auto prog = FillProgramDesc();
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
ir::Graph graph(prog); ir::Graph graph(prog);
auto find_node_in_graph = [&](std::string s) { auto find_node_in_graph = [&](std::string s) {
...@@ -342,9 +331,7 @@ TEST(SortOpLikeDescOrder, AddOpDesc) { ...@@ -342,9 +331,7 @@ TEST(SortOpLikeDescOrder, AddOpDesc) {
// cached desc different with real one // cached desc different with real one
// mimic the intermidiete pass modify the programdesc. // mimic the intermidiete pass modify the programdesc.
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership std::vector<OpDesc*> op_descs = graph.OriginProgram().Block(0).AllOps();
auto op_descs = prog.Block(0).AllOps();
auto op = prog.MutableBlock(0)->AppendOp(); auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
...@@ -376,9 +363,6 @@ TEST(SortOpLikeDescOrder, AddOpDesc) { ...@@ -376,9 +363,6 @@ TEST(SortOpLikeDescOrder, AddOpDesc) {
TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) { TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) {
auto prog = FillProgramDesc(); auto prog = FillProgramDesc();
ir::Graph graph(prog); ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) { auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr; ir::Node* ret = nullptr;
...@@ -392,8 +376,9 @@ TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) { ...@@ -392,8 +376,9 @@ TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) {
return ret; return ret;
}; };
std::vector<OpDesc*> op_descs = graph.OriginProgram().Block(0).AllOps();
// remove sum node // remove sum node
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr; ir::Node* found_node = nullptr;
auto nodes = graph.Nodes(); auto nodes = graph.Nodes();
for (auto node : nodes) { for (auto node : nodes) {
...@@ -454,9 +439,7 @@ TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) { ...@@ -454,9 +439,7 @@ TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) {
TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) { TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
auto prog = FillProgramDesc(); auto prog = FillProgramDesc();
ir::Graph graph(prog); ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs = std::vector<OpDesc*> op_descs = graph.OriginProgram().Block(0).AllOps();
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) { auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr; ir::Node* ret = nullptr;
...@@ -470,7 +453,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) { ...@@ -470,7 +453,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
return ret; return ret;
}; };
auto op_descs = prog.Block(0).AllOps();
// add node // add node
auto op = prog.MutableBlock(0)->AppendOp(); auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
......
...@@ -337,4 +337,4 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -337,4 +337,4 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
REGISTER_PASS(memory_optimize_pass, REGISTER_PASS(memory_optimize_pass,
paddle::framework::details::MemoryOptimizePass) paddle::framework::details::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kAllOpDescs); .RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
...@@ -20,8 +20,7 @@ namespace framework { ...@@ -20,8 +20,7 @@ namespace framework {
namespace details { namespace details {
std::vector<std::unique_ptr<ir::Graph>> std::vector<std::unique_ptr<ir::Graph>>
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph( ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
std::unique_ptr<ir::Graph> &&graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places_.size()); graphs.reserve(places_.size());
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -77,24 +76,18 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph( ...@@ -77,24 +76,18 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places, ir::Graph *graph)
const framework::ProgramDesc &main_prog, std::unique_ptr<ir::Graph> &&graph)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(std::move(places)),
main_prog_(main_prog),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the // TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs. // attrs.
graphs_(SeparateMultiDevicesGraph(std::move(graph))) { graphs_(SeparateMultiDevicesGraph(graph)) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
auto seq_allreduce_pass = auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_prog_.Block(0).AllOps()));
for (size_t i = 0; i < graphs_.size(); ++i) { for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i])); graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
} }
...@@ -107,7 +100,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -107,7 +100,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<< " to run the operators of the graph on each device."; << " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, local_scopes_, {places_[i]}, std::move(graphs_.at(i)))); strategy_, local_scopes_, {places_[i]}, graphs_.at(i).get()));
} }
} }
......
...@@ -31,8 +31,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -31,8 +31,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const framework::ProgramDesc &main_prog, ir::Graph *graph);
std::unique_ptr<ir::Graph> &&graph);
~ParallelSSAGraphExecutor() final = default; ~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
...@@ -41,13 +40,12 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -41,13 +40,12 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
private: private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> &&graph); ir::Graph *graph);
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
framework::ProgramDesc main_prog_;
std::vector<std::unique_ptr<ir::Graph>> graphs_; std::vector<std::unique_ptr<ir::Graph>> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
......
...@@ -40,7 +40,7 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( ...@@ -40,7 +40,7 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
static std::unordered_set<std::string> skip_dist_ops{ static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"}; "send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = Get<const std::vector<OpDesc *>>(kAllOpDescs); auto &ops = graph->Get<const std::vector<OpDesc *>>(kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list; std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size()); op_node_list.reserve(ops.size());
...@@ -107,4 +107,4 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( ...@@ -107,4 +107,4 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
REGISTER_PASS(sequential_execution_pass, REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass) paddle::framework::details::SequentialExecutionPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs); .RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
...@@ -23,9 +23,8 @@ namespace framework { ...@@ -23,9 +23,8 @@ namespace framework {
namespace details { namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places, ir::Graph *graph)
std::unique_ptr<ir::Graph> &&graph) : graph_(graph),
: graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr), : nullptr),
local_scopes_(local_scopes), local_scopes_(local_scopes),
...@@ -110,7 +109,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -110,7 +109,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
ClearFetchOp(graph_.get(), &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} else { } else {
continue; continue;
...@@ -135,7 +134,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -135,7 +134,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_.get(), &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
return fetch_data; return fetch_data;
} }
......
...@@ -41,7 +41,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -41,7 +41,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph); ir::Graph *graph);
const ir::Graph &Graph() const override { return *graph_; } const ir::Graph &Graph() const override { return *graph_; }
// Run a SSAGraph by a thread pool // Run a SSAGraph by a thread pool
...@@ -55,7 +55,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -55,7 +55,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
std::unique_ptr<ir::Graph> graph_; ir::Graph *graph_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
......
...@@ -76,6 +76,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -76,6 +76,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
var->inputs.push_back(node); var->inputs.push_back(node);
} }
} }
Set<const std::vector<OpDesc *>>(
details::kStaleProgramOpDescs,
new std::vector<OpDesc *>(program.Block(0).AllOps()));
return var_nodes; return var_nodes;
} }
......
...@@ -31,7 +31,7 @@ namespace details { ...@@ -31,7 +31,7 @@ namespace details {
// This attr is not recommended, because the graph should not dependence // This attr is not recommended, because the graph should not dependence
// the program once it is built. // the program once it is built.
constexpr char kAllOpDescs[] = "all_op_descs"; constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs";
} // namespace details } // namespace details
namespace ir { namespace ir {
...@@ -195,6 +195,12 @@ class Graph { ...@@ -195,6 +195,12 @@ class Graph {
return nullptr; return nullptr;
} }
// Returns reference to the original program.
// WARN: After a series of passes, the current graph can be quite
// different from OriginProgram. Caller shouldn't assume much from
// the returned OriginProgram.
const ProgramDesc &OriginProgram() const { return program_; }
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
......
...@@ -184,9 +184,10 @@ std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { ...@@ -184,9 +184,10 @@ std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const std::string &loss_var_name, Scope *scope,
Scope *scope, const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy) const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_; member_->use_cuda_ = exec_strategy.use_cuda_;
...@@ -216,11 +217,13 @@ ParallelExecutor::ParallelExecutor( ...@@ -216,11 +217,13 @@ ParallelExecutor::ParallelExecutor(
} }
} }
std::unique_ptr<ir::Graph> temp_owned_graph(graph);
// FIXME(Yancey1989): parallel graph mode get better performance // FIXME(Yancey1989): parallel graph mode get better performance
// in GPU allreduce distributed training. Need an elegant way to // in GPU allreduce distributed training. Need an elegant way to
// choice the execution strategy. // choice the execution strategy.
build_strategy.enable_parallel_graph_ = build_strategy.enable_parallel_graph_ = EnableParallelGraphExecution(
EnableParallelGraphExecution(main_program, exec_strategy, build_strategy); *temp_owned_graph, exec_strategy, build_strategy);
if (build_strategy.enable_parallel_graph_) if (build_strategy.enable_parallel_graph_)
VLOG(0) << "The Executor would execute the graph by ParallelGraph " VLOG(0) << "The Executor would execute the graph by ParallelGraph "
"Execution which can get better performance," "Execution which can get better performance,"
...@@ -254,26 +257,32 @@ ParallelExecutor::ParallelExecutor( ...@@ -254,26 +257,32 @@ ParallelExecutor::ParallelExecutor(
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToDevices(bcast_vars); BCastParamsToDevices(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
std::unique_ptr<ir::Graph> graph;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
graph = build_strategy.Apply(main_program, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, temp_owned_graph = build_strategy.Apply(
member_->use_cuda_, member_->nccl_ctxs_.get()); std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
#else #else
graph = build_strategy.Apply(main_program, member_->places_, loss_var_name, temp_owned_graph = build_strategy.Apply(
member_->local_scopes_, member_->nranks_, std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->use_cuda_); member_->local_scopes_, member_->nranks_, member_->use_cuda_);
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold " VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30); << static_cast<float>(max_memory_size) / (1 << 30);
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
graph = member_->PrepareGCAndRefCnts(std::move(graph), graph = member_
static_cast<size_t>(max_memory_size)); ->PrepareGCAndRefCnts(std::move(temp_owned_graph),
static_cast<size_t>(max_memory_size))
.release();
} else {
graph = temp_owned_graph.release();
} }
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
...@@ -308,8 +317,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -308,8 +317,7 @@ ParallelExecutor::ParallelExecutor(
// TODO(Yancey1989): Remove passing in the main_program when // TODO(Yancey1989): Remove passing in the main_program when
// allreduce_seq_pass doesn't need it as the attr. // allreduce_seq_pass doesn't need it as the attr.
member_->executor_.reset(new details::ParallelSSAGraphExecutor( member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, main_program, exec_strategy, member_->local_scopes_, member_->places_, graph));
std::move(graph)));
#else #else
PADDLE_THROW( PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution."); "Paddle should be compiled with CUDA for ParallelGraph Execution.");
...@@ -317,12 +325,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -317,12 +325,10 @@ ParallelExecutor::ParallelExecutor(
} else { } else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_, graph));
std::move(graph)));
} else { } else {
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_, graph));
std::move(graph)));
} }
} }
...@@ -452,24 +458,33 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -452,24 +458,33 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
} }
ParallelExecutor::~ParallelExecutor() {
for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
delete member_;
}
bool ParallelExecutor::EnableParallelGraphExecution( bool ParallelExecutor::EnableParallelGraphExecution(
const ProgramDesc &main_program, const ExecutionStrategy &exec_strategy, const ir::Graph &graph, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy) const { const BuildStrategy &build_strategy) const {
if (!FLAGS_enable_parallel_graph) return false; if (!FLAGS_enable_parallel_graph) return false;
bool enable_parallel_graph = true; bool enable_parallel_graph = true;
// TODO(Yancey1989): support sparse update in ParallelGraph mode.
for (auto &var_desc : main_program.Block(0).AllVars()) {
if (var_desc->GetType() == proto::VarType::SELECTED_ROWS) {
enable_parallel_graph = false;
}
}
// TODO(Yancey1989): support pserver mode for (ir::Node *node : graph.Nodes()) {
for (auto &op_desc : main_program.Block(0).AllOps()) { if (node->IsVar() && node->Var()) {
if (op_desc->Type() == "send" || op_desc->Type() == "recv") { // TODO(Yancey1989): support sparse update in ParallelGraph mode.
enable_parallel_graph = false; if (node->Var()->GetType() == proto::VarType::SELECTED_ROWS) {
break; enable_parallel_graph = false;
break;
}
} else if (node->IsOp() && node->Op()) {
// TODO(Yancey1989): support pserver mode
if (node->Op()->Type() == "send" || node->Op()->Type() == "recv") {
enable_parallel_graph = false;
break;
}
} }
} }
...@@ -481,13 +496,6 @@ bool ParallelExecutor::EnableParallelGraphExecution( ...@@ -481,13 +496,6 @@ bool ParallelExecutor::EnableParallelGraphExecution(
return enable_parallel_graph; return enable_parallel_graph;
} }
ParallelExecutor::~ParallelExecutor() {
for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
delete member_;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -46,11 +46,11 @@ class ParallelExecutor { ...@@ -46,11 +46,11 @@ class ParallelExecutor {
public: public:
explicit ParallelExecutor(const std::vector<platform::Place> &places, explicit ParallelExecutor(const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope, const std::string &loss_var_name, Scope *scope,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy); const BuildStrategy &build_strategy,
ir::Graph *graph);
~ParallelExecutor(); ~ParallelExecutor();
...@@ -71,7 +71,7 @@ class ParallelExecutor { ...@@ -71,7 +71,7 @@ class ParallelExecutor {
private: private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const; void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
bool EnableParallelGraphExecution(const ProgramDesc &main_program, bool EnableParallelGraphExecution(const ir::Graph &graph,
const ExecutionStrategy &exec_strategy, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy) const; const BuildStrategy &build_strategy) const;
......
...@@ -101,7 +101,8 @@ void BindGraph(py::module *m) { ...@@ -101,7 +101,8 @@ void BindGraph(py::module *m) {
[](Graph &self, Node &node) { return self.RemoveNode(&node); }) [](Graph &self, Node &node) { return self.RemoveNode(&node); })
.def("retrieve_node", &Graph::RetrieveNode, .def("retrieve_node", &Graph::RetrieveNode,
return_value_policy::reference) return_value_policy::reference)
.def("resolve_hazard", &Graph::ResolveHazard); .def("resolve_hazard", &Graph::ResolveHazard)
.def("origin_program_desc", &Graph::OriginProgram);
} }
void BindNode(py::module *m) { void BindNode(py::module *m) {
......
...@@ -976,6 +976,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -976,6 +976,7 @@ All parameter, weight, gradient are variables in Paddle.
[](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); }); [](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); });
// -- python binds for parallel executor. // -- python binds for parallel executor.
py::class_<ParallelExecutor> pe(m, "ParallelExecutor"); py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC( py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC(
ExecutionStrategy allows the user to more preciously control how to run ExecutionStrategy allows the user to more preciously control how to run
...@@ -1213,9 +1214,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1213,9 +1214,9 @@ All parameter, weight, gradient are variables in Paddle.
cannot be updated after being finalized.)DOC"); cannot be updated after being finalized.)DOC");
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const ProgramDesc &, const std::unordered_set<std::string> &, const std::string &,
const std::string &, Scope *, std::vector<Scope *> &, Scope *, std::vector<Scope *> &, const ExecutionStrategy &,
const ExecutionStrategy &, const BuildStrategy &>()) const BuildStrategy &, ir::Graph *>())
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element // We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope* // of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
...@@ -444,6 +444,7 @@ function assert_api_spec_approvals() { ...@@ -444,6 +444,7 @@ function assert_api_spec_approvals() {
"paddle/fluid/framework/ir/node.h" "paddle/fluid/framework/ir/node.h"
"paddle/fluid/framework/ir/graph.h" "paddle/fluid/framework/ir/graph.h"
"paddle/fluid/framework/framework.proto" "paddle/fluid/framework/framework.proto"
"python/paddle/fluid/compiler.py"
"paddle/fluid/operators/distributed/send_recv.proto.in") "paddle/fluid/operators/distributed/send_recv.proto.in")
for API_FILE in ${API_FILES[*]}; do for API_FILE in ${API_FILES[*]}; do
API_CHANGE=`git diff --name-only upstream/$BRANCH | grep "${API_FILE}" || true` API_CHANGE=`git diff --name-only upstream/$BRANCH | grep "${API_FILE}" || true`
......
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
import six import six
import sys import sys
from .. import compat as cpt from .. import compat as cpt
from . import framework
from . import core from . import core
from . import framework from . import framework
...@@ -37,7 +38,7 @@ def _place_obj(place): ...@@ -37,7 +38,7 @@ def _place_obj(place):
class CompiledProgram(object): class CompiledProgram(object):
""" """
Compiles a Program for execution. Compiles to Graph for execution.
1. Users first create the program with layers. 1. Users first create the program with layers.
2. Optionally, users use CompiledProgram to optimize the program before run. 2. Optionally, users use CompiledProgram to optimize the program before run.
...@@ -52,7 +53,7 @@ class CompiledProgram(object): ...@@ -52,7 +53,7 @@ class CompiledProgram(object):
Example: Example:
.. code-block:: python .. code-block:: python
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
compiled_prog = compiler.CompiledProgram(main).with_data_parallel( compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
...@@ -63,11 +64,25 @@ class CompiledProgram(object): ...@@ -63,11 +64,25 @@ class CompiledProgram(object):
fetch_list=[loss.name]) fetch_list=[loss.name])
Args: Args:
program: Program instance that contains the model logic. program_or_graph (Graph|Program): If it's Program, it will be first
lowered to a graph for further optimizations. If it's a graph
(potentially optimized before), it will be directly used for
further optimizations. Note: graph is only supported when compiled
with with_data_parallel option.
""" """
def __init__(self, program): def __init__(self, program_or_graph):
self._program = program if isinstance(program_or_graph, core.Graph):
self._graph = program_or_graph
self._program = None
elif isinstance(program_or_graph, framework.Program):
self._graph = core.Graph(program_or_graph.desc)
self._program = program_or_graph
else:
raise ValueError("Wrong program_to_graph type: %s" %
type(program_or_graph))
self._program_desc = self._graph.origin_program_desc()
self._scope = None self._scope = None
self._place = None self._place = None
self._executor = None self._executor = None
...@@ -102,6 +117,7 @@ class CompiledProgram(object): ...@@ -102,6 +117,7 @@ class CompiledProgram(object):
self self
""" """
assert not self._is_data_parallel, "Already compiled with parallel." assert not self._is_data_parallel, "Already compiled with parallel."
assert not self._is_inference, "Cannot compile both data parallel and inference"
self._is_data_parallel = True self._is_data_parallel = True
self._build_strategy = build_strategy self._build_strategy = build_strategy
self._exec_strategy = exec_strategy self._exec_strategy = exec_strategy
...@@ -123,11 +139,13 @@ class CompiledProgram(object): ...@@ -123,11 +139,13 @@ class CompiledProgram(object):
Returns: Returns:
self self
""" """
assert not self._is_data_parallel, "Cannot compile both data parallel and inference"
assert not self._is_inference, "Already compiled with inference"
assert any([ assert any([
isinstance(config, InferNativeConfig), isinstance(config, InferNativeConfig),
isinstance(config, InferAnalysisConfig) isinstance(config, InferAnalysisConfig)
]) ])
self._is_data_parallel = False
self._is_inference = True self._is_inference = True
self._infer_config = config self._infer_config = config
return self return self
...@@ -176,37 +194,41 @@ class CompiledProgram(object): ...@@ -176,37 +194,41 @@ class CompiledProgram(object):
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
self._exec_strategy.num_threads = cpu_num * 2 self._exec_strategy.num_threads = cpu_num * 2
trainers_endpoints = self._program._trainers_endpoints
# FIXME(dzhwinter): enable_inplace should be after memory_optimize # FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass. # if turn on python memory optimize, turn off the inplace_pass.
if self._build_strategy.memory_optimize is None: if self._build_strategy.memory_optimize is None:
self._build_strategy.memory_optimize = False if self._program._is_mem_optimized else True self._build_strategy.memory_optimize = False if self._program and self._program._is_mem_optimized else True
if self._build_strategy.enable_inplace is None: if self._build_strategy.enable_inplace is None:
self._build_strategy.enable_inplace = False if self._program._is_mem_optimized else True self._build_strategy.enable_inplace = False if self._program and self._program._is_mem_optimized else True
# TODO(wuyi): trainer endpoings should be passed in through
# build_strategy, not program.xxx.
if self._program and self._build_strategy.num_trainers > 1 and \
self._program._trainers_endpoints:
tps = self._program._trainers_endpoints
if self._build_strategy.num_trainers > 1 and trainers_endpoints:
assert self._build_strategy.num_trainers == len( assert self._build_strategy.num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)" tps), "num_trainers == len(end_points)"
self._build_strategy.trainers_endpoints = trainers_endpoints self._build_strategy.trainers_endpoints = tps
self._persistable_vars = set([ self._persistable_vars = []
cpt.to_text(v.name) for block_id in range(self._program_desc.num_blocks()):
for v in [ bdesc = self._program_desc.block(block_id)
var for var in self._program.list_vars() self._persistable_vars.extend([
if var.persistable and var.type != core.VarDesc.VarType.RAW cpt.to_text(v.name()) for v in bdesc.all_vars()
] if v.persistable() and v.type() != core.VarDesc.VarType.RAW
]) ])
places = list(map(_place_obj, self._places)) places = list(map(_place_obj, self._places))
return core.ParallelExecutor( return core.ParallelExecutor(
places, self._persistable_vars, self._program.desc, places,
set(self._persistable_vars),
cpt.to_text(self._loss_name) cpt.to_text(self._loss_name)
if self._loss_name else six.u(''), self._scope, self._local_scopes, if self._loss_name else six.u(''), self._scope, self._local_scopes,
self._exec_strategy, self._build_strategy) self._exec_strategy, self._build_strategy, self._graph)
def _compile_inference(self): def _compile_inference(self):
assert self._is_data_parallel is False
return core.create_paddle_predictor(self._infer_config) return core.create_paddle_predictor(self._infer_config)
def _compile(self, scope, place): def _compile(self, scope, place):
......
...@@ -538,6 +538,8 @@ class Executor(object): ...@@ -538,6 +538,8 @@ class Executor(object):
else: else:
# TODO(panyx0718): Can compile program to optimize executor # TODO(panyx0718): Can compile program to optimize executor
# performance. # performance.
# TODO(panyx0718): executor should be able to run graph.
assert program._program, "CompiledProgram is compiled from graph, can only run with_data_parallel."
return self._run( return self._run(
program._program, program._program,
self._default_executor, self._default_executor,
......
...@@ -176,10 +176,13 @@ class ParallelExecutor(object): ...@@ -176,10 +176,13 @@ class ParallelExecutor(object):
places = list(map(place_obj, self._places)) places = list(map(place_obj, self._places))
# step7: init ParallelExecutor # step7: init ParallelExecutor
# ParallelExecutor API will be deprecated, don't support parallel graph.
self._graph = core.Graph(main.desc)
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
places, persistable_vars, main.desc, places, persistable_vars,
cpt.to_text(loss_name) if loss_name else six.u(''), scope, cpt.to_text(loss_name) if loss_name else six.u(''), scope,
local_scopes, exec_strategy, build_strategy) local_scopes, exec_strategy, build_strategy, self._graph)
self.scope = scope self.scope = scope
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册