提交 910cd415 编写于 作者: J Jacek Czaja

- Disabled embedding_fc_lstm_fuse by defult and

  extended test_text_classification ot use new op
上级 d5114c60
...@@ -216,7 +216,7 @@ struct AnalysisConfig : public NativeConfig { ...@@ -216,7 +216,7 @@ struct AnalysisConfig : public NativeConfig {
bool enable_ir_optim = true; bool enable_ir_optim = true;
// Manually determine the IR passes to run. // Manually determine the IR passes to run.
IrPassMode ir_mode{IrPassMode::kExclude}; IrPassMode ir_mode{IrPassMode::kExclude};
std::vector<std::string> ir_passes; std::vector<std::string> ir_passes{"embedding_fc_lstm_fuse_pass"};
// NOTE this is just for internal development, please not use it. // NOTE this is just for internal development, please not use it.
bool _use_mkldnn{false}; bool _use_mkldnn{false};
......
...@@ -104,5 +104,18 @@ TEST(Analyzer_Text_Classification, compare) { ...@@ -104,5 +104,18 @@ TEST(Analyzer_Text_Classification, compare) {
CompareNativeAndAnalysis(cfg, input_slots_all); CompareNativeAndAnalysis(cfg, input_slots_all);
} }
TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
AnalysisConfig cfg;
SetConfig(&cfg);
// Enable embedding_fc_lstm_fuse_pass (disabled by default)
auto it = std::find(cfg.ir_passes.begin(), cfg.ir_passes.end(),
"embedding_fc_lstm_fuse_pass");
if (it != cfg.ir_passes.end()) cfg.ir_passes.erase(it);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册