diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 984358b2bd90daf768cea0a6e36b5805d81050d6..77b04bb6f5df24142837773ef4e54b2251630963 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -216,7 +216,7 @@ struct AnalysisConfig : public NativeConfig { bool enable_ir_optim = true; // Manually determine the IR passes to run. IrPassMode ir_mode{IrPassMode::kExclude}; - std::vector ir_passes; + std::vector ir_passes{"embedding_fc_lstm_fuse_pass"}; // NOTE this is just for internal development, please not use it. bool _use_mkldnn{false}; diff --git a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc index 340ef152f0b1a15a451f840b36ae845ef4984740..ca19475bda372398d425b0fa6f9a732cd79a8166 100644 --- a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc @@ -104,5 +104,18 @@ TEST(Analyzer_Text_Classification, compare) { CompareNativeAndAnalysis(cfg, input_slots_all); } +TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) { + AnalysisConfig cfg; + SetConfig(&cfg); + // Enable embedding_fc_lstm_fuse_pass (disabled by default) + auto it = std::find(cfg.ir_passes.begin(), cfg.ir_passes.end(), + "embedding_fc_lstm_fuse_pass"); + if (it != cfg.ir_passes.end()) cfg.ir_passes.erase(it); + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareNativeAndAnalysis(cfg, input_slots_all); +} + } // namespace inference } // namespace paddle