diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index f8bf43bcb48226b4d1317a78ade7179741097378..645dd421d48b89c26c3489a7de02371736556b2d 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -142,6 +142,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("memory_optimize_pass"); } + // runtime_context_cache pass should be the last pass to enable the attr of + // all original and fused operators. But no operators can be enabled this + // attr if putting it after MultiDevPass. + if (strategy_.cache_runtime_context_) { + VLOG(10) << "Add runtime_context_cache_pass"; + AppendPass("runtime_context_cache_pass"); + } + AppendMultiDevPass(strategy_); if (strategy_.fuse_all_reduce_ops_) { @@ -328,3 +336,4 @@ USE_PASS(graph_to_program_pass); USE_PASS(fuse_adam_op_pass); USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_all_reduce_op_pass); +USE_PASS(runtime_context_cache_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index cc48c51e924039d93b2e1e18bea752611e7bef92..8aa444a30c0f7f1f5c19d54cf248f86c3e3b3cf3 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -107,6 +107,8 @@ struct BuildStrategy { std::vector trainers_endpoints_; bool remove_unnecessary_lock_{true}; + bool cache_runtime_context_{false}; + // NOTE: // Before you add new options, think if it's a general strategy that works // with other strategy. If not, the strategy should be created through diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc index c7cf9b0dc342bbfaa80b622d7dcd0f6348f78d42..566b654f237cbd71e1983c971374ee13d7b36805 100644 --- a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc @@ -23,7 +23,7 @@ namespace ir { void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Applies Runtime Context Cache strategy."; for (const Node* n : graph->Nodes()) { - if (n->IsOp()) { + if (n->IsOp() && n->Op()) { n->Op()->SetAttr(kEnableCacheRuntimeContext, true); } } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8c34e3efe2a07cadde5aa06669fda88be7661db1..0d1a1ca07771b991df78743501907401cf56f933 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1356,6 +1356,10 @@ All parameter, weight, gradient are variables in Paddle. "fuse_all_reduce_ops", [](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; }, [](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; }) + .def_property( + "cache_runtime_context", + [](const BuildStrategy &self) { return self.cache_runtime_context_; }, + [](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true);