diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index fc79be0e83fb7e6606a57c6c48b82a80798e4e38..4efb10ad2fe156b5b3c6218a3e198297e92510b0 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -375,6 +375,15 @@ if(WITH_MKLDNN) # resnet50 bfloat16 inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_resnet50 ${BF16_IMG_CLASS_TEST_APP} ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH}) + + # googlenet bfloat16 + inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_googlenet ${BF16_IMG_CLASS_TEST_APP} ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH}) + + # mobilenetv1 bfloat16 + inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_mobilenetv1 ${BF16_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV1_MODEL_DIR} ${IMAGENET_DATA_PATH}) + + # mobilenetv2 bfloat16 + inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_mobilenetv2 ${BF16_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH}) ### Object detection models set(PASCALVOC_DATA_PATH "${INT8_DATA_DIR}/pascalvoc_val_head_300.bin") diff --git a/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc index 3621477148fffd343a67047247be846bb6ee652a..3b16b0d34fd4cb87879bb6ed585e72b48167ac2c 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc @@ -28,20 +28,18 @@ void SetConfig(AnalysisConfig *cfg) { cfg->EnableMKLDNN(); } -TEST(Analyzer_int8_image_classification, bfloat16) { +TEST(Analyzer_bfloat16_image_classification, bfloat16) { AnalysisConfig cfg; SetConfig(&cfg); - AnalysisConfig q_cfg; - SetConfig(&q_cfg); + AnalysisConfig b_cfg; + SetConfig(&b_cfg); // read data from file and prepare batches with test data std::vector> input_slots_all; SetInputs(&input_slots_all); - q_cfg.SwitchIrDebug(); - q_cfg.EnableMkldnnBfloat16(); - q_cfg.SetBfloat16Op({"conv2d"}); - CompareBFloat16AndAnalysis(&cfg, &q_cfg, input_slots_all); + b_cfg.EnableMkldnnBfloat16(); + CompareBFloat16AndAnalysis(&cfg, &b_cfg, input_slots_all); } } // namespace analysis