未验证 提交 9e8d372f 编写于 作者: Y Yan Chunwei 提交者: GitHub

hide attention lstm fuse (#13615)

上级 e9bc5faa
...@@ -257,6 +257,22 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl( ...@@ -257,6 +257,22 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
PDPattern external_pattern, subblock_pattern; PDPattern external_pattern, subblock_pattern;
// Use the following variables to tell whether this model is RNN1.
// This fuse can only works on the RNN1 model.
std::unordered_set<std::string> specified_vars({"data_lod_attention",
"cell_init", "hidden_init",
"data", "week", "minute"});
int count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsVar() && specified_vars.count(node->Name())) {
++count;
}
}
if (count < specified_vars.size()) {
return graph;
}
// Continue to fuse.
FindWhileOp(graph.get()); FindWhileOp(graph.get());
return graph; return graph;
} }
......
...@@ -212,10 +212,11 @@ struct AnalysisConfig : public NativeConfig { ...@@ -212,10 +212,11 @@ struct AnalysisConfig : public NativeConfig {
kExclude // Specify the disabled passes in `ir_passes`. kExclude // Specify the disabled passes in `ir_passes`.
}; };
// Determine whether to perform graph optimization.
bool enable_ir_optim = true; bool enable_ir_optim = true;
// Manually determine the IR passes to run.
IrPassMode ir_mode{IrPassMode::kExclude}; IrPassMode ir_mode{IrPassMode::kExclude};
// attention lstm fuse works only on some specific models, disable as default. std::vector<std::string> ir_passes;
std::vector<std::string> ir_passes{"attention_lstm_fuse_pass"};
// NOTE this is just for internal development, please not use it. // NOTE this is just for internal development, please not use it.
bool _use_mkldnn{false}; bool _use_mkldnn{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册