From 9e8d372ff43ef7d9a0eae639161d4d64c7016062 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Thu, 27 Sep 2018 12:30:50 +0800 Subject: [PATCH] hide attention lstm fuse (#13615) --- .../framework/ir/attention_lstm_fuse_pass.cc | 16 ++++++++++++++++ .../fluid/inference/api/paddle_inference_api.h | 5 +++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc index bb52d7e498..1c75cb5a82 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -257,6 +257,22 @@ std::unique_ptr AttentionLSTMFusePass::ApplyImpl( std::unique_ptr graph) const { PDPattern external_pattern, subblock_pattern; + // Use the following variables to tell whether this model is RNN1. + // This fuse can only works on the RNN1 model. + std::unordered_set specified_vars({"data_lod_attention", + "cell_init", "hidden_init", + "data", "week", "minute"}); + int count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsVar() && specified_vars.count(node->Name())) { + ++count; + } + } + if (count < specified_vars.size()) { + return graph; + } + + // Continue to fuse. FindWhileOp(graph.get()); return graph; } diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 01ea0d9c3a..984358b2bd 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -212,10 +212,11 @@ struct AnalysisConfig : public NativeConfig { kExclude // Specify the disabled passes in `ir_passes`. }; + // Determine whether to perform graph optimization. bool enable_ir_optim = true; + // Manually determine the IR passes to run. IrPassMode ir_mode{IrPassMode::kExclude}; - // attention lstm fuse works only on some specific models, disable as default. - std::vector ir_passes{"attention_lstm_fuse_pass"}; + std::vector ir_passes; // NOTE this is just for internal development, please not use it. bool _use_mkldnn{false}; -- GitLab