未验证 提交 112f1614 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add an option to enable the cache of expected kernel in train phase. (#16724)

* Add an option to enable the cache of expected kernel in train phase.
test=develop

* Change the default value of cache_expected_kernel to true.
上级 2e07c19a
...@@ -150,6 +150,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -150,6 +150,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("runtime_context_cache_pass"); AppendPass("runtime_context_cache_pass");
} }
if (strategy_.cache_expected_kernel_) {
VLOG(10) << "Add expected_kernel_cache_pass";
AppendPass("expected_kernel_cache_pass");
}
AppendMultiDevPass(strategy_); AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
...@@ -337,3 +342,4 @@ USE_PASS(fuse_adam_op_pass); ...@@ -337,3 +342,4 @@ USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass);
...@@ -108,6 +108,7 @@ struct BuildStrategy { ...@@ -108,6 +108,7 @@ struct BuildStrategy {
bool remove_unnecessary_lock_{true}; bool remove_unnecessary_lock_{true};
bool cache_runtime_context_{false}; bool cache_runtime_context_{false};
bool cache_expected_kernel_{true};
// NOTE: // NOTE:
// Before you add new options, think if it's a general strategy that works // Before you add new options, think if it's a general strategy that works
......
...@@ -23,7 +23,7 @@ namespace ir { ...@@ -23,7 +23,7 @@ namespace ir {
void ExpectedKernelCachePass::ApplyImpl(ir::Graph* graph) const { void ExpectedKernelCachePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies Expected Kernel Cache strategy."; VLOG(3) << "Applies Expected Kernel Cache strategy.";
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp() && n->Op()) {
n->Op()->SetAttr(kEnableCacheExpectedKernel, true); n->Op()->SetAttr(kEnableCacheExpectedKernel, true);
} }
} }
......
...@@ -1366,6 +1366,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1366,6 +1366,10 @@ All parameter, weight, gradient are variables in Paddle.
"cache_runtime_context", "cache_runtime_context",
[](const BuildStrategy &self) { return self.cache_runtime_context_; }, [](const BuildStrategy &self) { return self.cache_runtime_context_; },
[](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; }) [](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; })
.def_property(
"cache_expected_kernel",
[](const BuildStrategy &self) { return self.cache_expected_kernel_; },
[](BuildStrategy &self, bool b) { self.cache_expected_kernel_ = b; })
.def("_finalize_strategy_and_create_passes", .def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> { [](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true); return self.CreatePassesFromStrategy(true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册