diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index 54b74a537c7a274c8e3ea5475ba81e9afe5b0670..45256234b83b804967cf3605fbc2acd6d3cc5ac6 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -146,7 +146,7 @@ bool LoadInputData(std::vector> *inputs) { void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); } -void profile(bool use_mkldnn = false) { +void profile(bool use_mkldnn = false, bool use_ngraph = false) { AnalysisConfig config; SetConfig(&config); @@ -155,6 +155,10 @@ void profile(bool use_mkldnn = false) { config.pass_builder()->AppendPass("fc_mkldnn_pass"); } + if (use_ngraph) { + config.EnableNgraph(); + } + std::vector> outputs; std::vector> inputs; LoadInputData(&inputs); @@ -164,7 +168,11 @@ void profile(bool use_mkldnn = false) { TEST(Analyzer_bert, profile) { profile(); } #ifdef PADDLE_WITH_MKLDNN -TEST(Analyzer_bert, profile_mkldnn) { profile(true); } +TEST(Analyzer_bert, profile_mkldnn) { profile(true, false); } +#endif + +#ifdef PADDLE_WITH_NGRAPH +TEST(Analyzer_bert, profile_ngraph) { profile(false, true); } #endif // Check the fuse status @@ -179,7 +187,7 @@ TEST(Analyzer_bert, fuse_statis) { } // Compare result of NativeConfig and AnalysisConfig -void compare(bool use_mkldnn = false) { +void compare(bool use_mkldnn = false, bool use_ngraph = false) { AnalysisConfig cfg; SetConfig(&cfg); if (use_mkldnn) { @@ -187,6 +195,10 @@ void compare(bool use_mkldnn = false) { cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } + if (use_ngraph) { + cfg.EnableNgraph(); + } + std::vector> inputs; LoadInputData(&inputs); CompareNativeAndAnalysis( @@ -195,7 +207,15 @@ void compare(bool use_mkldnn = false) { TEST(Analyzer_bert, compare) { compare(); } #ifdef PADDLE_WITH_MKLDNN -TEST(Analyzer_bert, compare_mkldnn) { compare(true /* use_mkldnn */); } +TEST(Analyzer_bert, compare_mkldnn) { + compare(true, false /* use_mkldnn, no use_ngraph */); +} +#endif + +#ifdef PADDLE_WITH_NGRAPH +TEST(Analyzer_bert, compare_ngraph) { + compare(false, true /* no use_mkldnn, use_ngraph */); +} #endif // Compare Deterministic result