提交 c1379bf2 编写于 作者: M mozga-intel 提交者: tensor-tang

[NGraph] Bert model for a capi, ngraph's support test=develop (#17844)

上级 83e51ded
...@@ -146,7 +146,7 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) { ...@@ -146,7 +146,7 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); } void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); }
void profile(bool use_mkldnn = false) { void profile(bool use_mkldnn = false, bool use_ngraph = false) {
AnalysisConfig config; AnalysisConfig config;
SetConfig(&config); SetConfig(&config);
...@@ -155,6 +155,10 @@ void profile(bool use_mkldnn = false) { ...@@ -155,6 +155,10 @@ void profile(bool use_mkldnn = false) {
config.pass_builder()->AppendPass("fc_mkldnn_pass"); config.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
if (use_ngraph) {
config.EnableNgraph();
}
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> inputs; std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs); LoadInputData(&inputs);
...@@ -164,7 +168,11 @@ void profile(bool use_mkldnn = false) { ...@@ -164,7 +168,11 @@ void profile(bool use_mkldnn = false) {
TEST(Analyzer_bert, profile) { profile(); } TEST(Analyzer_bert, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_bert, profile_mkldnn) { profile(true); } TEST(Analyzer_bert, profile_mkldnn) { profile(true, false); }
#endif
#ifdef PADDLE_WITH_NGRAPH
TEST(Analyzer_bert, profile_ngraph) { profile(false, true); }
#endif #endif
// Check the fuse status // Check the fuse status
...@@ -179,7 +187,7 @@ TEST(Analyzer_bert, fuse_statis) { ...@@ -179,7 +187,7 @@ TEST(Analyzer_bert, fuse_statis) {
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
void compare(bool use_mkldnn = false) { void compare(bool use_mkldnn = false, bool use_ngraph = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
...@@ -187,6 +195,10 @@ void compare(bool use_mkldnn = false) { ...@@ -187,6 +195,10 @@ void compare(bool use_mkldnn = false) {
cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
if (use_ngraph) {
cfg.EnableNgraph();
}
std::vector<std::vector<PaddleTensor>> inputs; std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs); LoadInputData(&inputs);
CompareNativeAndAnalysis( CompareNativeAndAnalysis(
...@@ -195,7 +207,15 @@ void compare(bool use_mkldnn = false) { ...@@ -195,7 +207,15 @@ void compare(bool use_mkldnn = false) {
TEST(Analyzer_bert, compare) { compare(); } TEST(Analyzer_bert, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_bert, compare_mkldnn) { compare(true /* use_mkldnn */); } TEST(Analyzer_bert, compare_mkldnn) {
compare(true, false /* use_mkldnn, no use_ngraph */);
}
#endif
#ifdef PADDLE_WITH_NGRAPH
TEST(Analyzer_bert, compare_ngraph) {
compare(false, true /* no use_mkldnn, use_ngraph */);
}
#endif #endif
// Compare Deterministic result // Compare Deterministic result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册