提交 c1379bf2 编写于 作者: M mozga-intel 提交者: tensor-tang

[NGraph] Bert model for a capi, ngraph's support test=develop (#17844)

上级 83e51ded
......@@ -146,7 +146,7 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); }
void profile(bool use_mkldnn = false) {
void profile(bool use_mkldnn = false, bool use_ngraph = false) {
AnalysisConfig config;
SetConfig(&config);
......@@ -155,6 +155,10 @@ void profile(bool use_mkldnn = false) {
config.pass_builder()->AppendPass("fc_mkldnn_pass");
}
if (use_ngraph) {
config.EnableNgraph();
}
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
......@@ -164,7 +168,11 @@ void profile(bool use_mkldnn = false) {
TEST(Analyzer_bert, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_bert, profile_mkldnn) { profile(true); }
TEST(Analyzer_bert, profile_mkldnn) { profile(true, false); }
#endif
#ifdef PADDLE_WITH_NGRAPH
TEST(Analyzer_bert, profile_ngraph) { profile(false, true); }
#endif
// Check the fuse status
......@@ -179,7 +187,7 @@ TEST(Analyzer_bert, fuse_statis) {
}
// Compare result of NativeConfig and AnalysisConfig
void compare(bool use_mkldnn = false) {
void compare(bool use_mkldnn = false, bool use_ngraph = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
if (use_mkldnn) {
......@@ -187,6 +195,10 @@ void compare(bool use_mkldnn = false) {
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
if (use_ngraph) {
cfg.EnableNgraph();
}
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
CompareNativeAndAnalysis(
......@@ -195,7 +207,15 @@ void compare(bool use_mkldnn = false) {
TEST(Analyzer_bert, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_bert, compare_mkldnn) { compare(true /* use_mkldnn */); }
TEST(Analyzer_bert, compare_mkldnn) {
compare(true, false /* use_mkldnn, no use_ngraph */);
}
#endif
#ifdef PADDLE_WITH_NGRAPH
TEST(Analyzer_bert, compare_ngraph) {
compare(false, true /* no use_mkldnn, use_ngraph */);
}
#endif
// Compare Deterministic result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册