diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index afe5078bf80d00b595789a5f45d91a5e7a8dfce6..20cfa75292cf52a01bf794a2714deaac1e821f50 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -150,6 +150,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("runtime_context_cache_pass"); } + if (strategy_.cache_expected_kernel_) { + VLOG(10) << "Add expected_kernel_cache_pass"; + AppendPass("expected_kernel_cache_pass"); + } + AppendMultiDevPass(strategy_); if (strategy_.fuse_all_reduce_ops_) { @@ -337,3 +342,4 @@ USE_PASS(fuse_adam_op_pass); USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_all_reduce_op_pass); USE_PASS(runtime_context_cache_pass); +USE_PASS(expected_kernel_cache_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 121d4a27cd30abf88134986009208cc7d6399f16..b1601cfbcd5e9c66f1bbecd1f6fe10bc279cea26 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -108,6 +108,7 @@ struct BuildStrategy { bool remove_unnecessary_lock_{true}; bool cache_runtime_context_{false}; + bool cache_expected_kernel_{true}; // NOTE: // Before you add new options, think if it's a general strategy that works diff --git a/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc b/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc index ee67af0aff5c90a9da0ece8f197d9a0c0a8e5b9c..4a99d4c1a9c0f0bd973097d281e380341fe88515 100644 --- a/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc +++ b/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc @@ -23,7 +23,7 @@ namespace ir { void ExpectedKernelCachePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Applies Expected Kernel Cache strategy."; for (const Node* n : graph->Nodes()) { - if (n->IsOp()) { + if (n->IsOp() && n->Op()) { n->Op()->SetAttr(kEnableCacheExpectedKernel, true); } } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f0ea6d9b0a751c86e3911c35d9403a32604056d7..a8a2a94d473b18fdcd78771063ef4565c7fe0e42 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1366,6 +1366,10 @@ All parameter, weight, gradient are variables in Paddle. "cache_runtime_context", [](const BuildStrategy &self) { return self.cache_runtime_context_; }, [](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; }) + .def_property( + "cache_expected_kernel", + [](const BuildStrategy &self) { return self.cache_expected_kernel_; }, + [](BuildStrategy &self, bool b) { self.cache_expected_kernel_ = b; }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true);