From 3fe8cb0dd793228ce238bd6d631499a0d72256fc Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Mon, 8 Apr 2019 14:10:02 +0800 Subject: [PATCH] Enable the runtime_context_cache pass in train phase (#16640) * Try to enable the runtime_context_cache pass in train phase. * Put the append of runtime_context_cache pass ahead of multi_dev passes. test=develop --- paddle/fluid/framework/details/build_strategy.cc | 9 +++++++++ paddle/fluid/framework/details/build_strategy.h | 2 ++ paddle/fluid/framework/ir/runtime_context_cache_pass.cc | 2 +- paddle/fluid/pybind/pybind.cc | 4 ++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index f8bf43bcb48..645dd421d48 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -142,6 +142,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("memory_optimize_pass"); } + // runtime_context_cache pass should be the last pass to enable the attr of + // all original and fused operators. But no operators can be enabled this + // attr if putting it after MultiDevPass. + if (strategy_.cache_runtime_context_) { + VLOG(10) << "Add runtime_context_cache_pass"; + AppendPass("runtime_context_cache_pass"); + } + AppendMultiDevPass(strategy_); if (strategy_.fuse_all_reduce_ops_) { @@ -328,3 +336,4 @@ USE_PASS(graph_to_program_pass); USE_PASS(fuse_adam_op_pass); USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_all_reduce_op_pass); +USE_PASS(runtime_context_cache_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index cc48c51e924..8aa444a30c0 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -107,6 +107,8 @@ struct BuildStrategy { std::vector trainers_endpoints_; bool remove_unnecessary_lock_{true}; + bool cache_runtime_context_{false}; + // NOTE: // Before you add new options, think if it's a general strategy that works // with other strategy. If not, the strategy should be created through diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc index c7cf9b0dc34..566b654f237 100644 --- a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc @@ -23,7 +23,7 @@ namespace ir { void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Applies Runtime Context Cache strategy."; for (const Node* n : graph->Nodes()) { - if (n->IsOp()) { + if (n->IsOp() && n->Op()) { n->Op()->SetAttr(kEnableCacheRuntimeContext, true); } } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8c34e3efe2a..0d1a1ca0777 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1356,6 +1356,10 @@ All parameter, weight, gradient are variables in Paddle. "fuse_all_reduce_ops", [](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; }, [](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; }) + .def_property( + "cache_runtime_context", + [](const BuildStrategy &self) { return self.cache_runtime_context_; }, + [](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true); -- GitLab