未验证 提交 3fe8cb0d 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable the runtime_context_cache pass in train phase (#16640)

* Try to enable the runtime_context_cache pass in train phase.

* Put the append of runtime_context_cache pass ahead of multi_dev passes.
test=develop
上级 4048a268
...@@ -142,6 +142,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -142,6 +142,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("memory_optimize_pass"); AppendPass("memory_optimize_pass");
} }
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(10) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}
AppendMultiDevPass(strategy_); AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
...@@ -328,3 +336,4 @@ USE_PASS(graph_to_program_pass); ...@@ -328,3 +336,4 @@ USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass); USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
...@@ -107,6 +107,8 @@ struct BuildStrategy { ...@@ -107,6 +107,8 @@ struct BuildStrategy {
std::vector<std::string> trainers_endpoints_; std::vector<std::string> trainers_endpoints_;
bool remove_unnecessary_lock_{true}; bool remove_unnecessary_lock_{true};
bool cache_runtime_context_{false};
// NOTE: // NOTE:
// Before you add new options, think if it's a general strategy that works // Before you add new options, think if it's a general strategy that works
// with other strategy. If not, the strategy should be created through // with other strategy. If not, the strategy should be created through
......
...@@ -23,7 +23,7 @@ namespace ir { ...@@ -23,7 +23,7 @@ namespace ir {
void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const { void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy."; VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp() && n->Op()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true); n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
} }
} }
......
...@@ -1356,6 +1356,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1356,6 +1356,10 @@ All parameter, weight, gradient are variables in Paddle.
"fuse_all_reduce_ops", "fuse_all_reduce_ops",
[](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; }, [](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; },
[](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; }) [](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; })
.def_property(
"cache_runtime_context",
[](const BuildStrategy &self) { return self.cache_runtime_context_; },
[](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; })
.def("_finalize_strategy_and_create_passes", .def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> { [](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true); return self.CreatePassesFromStrategy(true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册