未验证 提交 b68bb428 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add possibility to test native config in mkldnn tests (#41562)

上级 51cae7f7
......@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/inference/tests/api/tester_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(enable_mkldnn, true, "Enable MKLDNN");
namespace paddle {
namespace inference {
namespace analysis {
......@@ -31,7 +33,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames();
cfg->SetCpuMathLibraryNumThreads(FLAGS_num_threads);
cfg->EnableMKLDNN();
if (FLAGS_enable_mkldnn) cfg->EnableMKLDNN();
}
TEST(Analyzer_bfloat16_image_classification, bfloat16) {
......@@ -44,7 +46,7 @@ TEST(Analyzer_bfloat16_image_classification, bfloat16) {
// read data from file and prepare batches with test data
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInputs(&input_slots_all);
if (FLAGS_enable_bf16 &&
if (FLAGS_enable_mkldnn && FLAGS_enable_bf16 &&
platform::MayIUse(platform::cpu_isa_t::avx512_bf16)) {
b_cfg.EnableMkldnnBfloat16();
} else {
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
DEFINE_bool(enable_mkldnn, true, "Enable MKLDNN");
namespace paddle {
namespace inference {
namespace analysis {
......@@ -32,7 +34,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames();
cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
cfg->EnableMKLDNN();
if (FLAGS_enable_mkldnn) cfg->EnableMKLDNN();
}
TEST(Analyzer_int8_image_classification, quantization) {
......@@ -46,7 +48,7 @@ TEST(Analyzer_int8_image_classification, quantization) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInputs(&input_slots_all);
if (FLAGS_enable_int8) {
if (FLAGS_enable_mkldnn && FLAGS_enable_int8) {
// prepare warmup batch from input data read earlier
// warmup batch size can be different than batch size
std::shared_ptr<std::vector<PaddleTensor>> warmup_data =
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
DEFINE_bool(enable_mkldnn, true, "Enable MKLDNN");
// setting iterations to 0 means processing the whole dataset
namespace paddle {
namespace inference {
......@@ -28,7 +30,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->SwitchIrOptim(true);
cfg->SwitchSpecifyInputNames(false);
cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
cfg->EnableMKLDNN();
if (FLAGS_enable_mkldnn) cfg->EnableMKLDNN();
}
std::vector<size_t> ReadObjectsNum(std::ifstream &file, size_t offset,
......@@ -268,13 +270,16 @@ TEST(Analyzer_int8_mobilenet_ssd, quantization) {
GetWarmupData(input_slots_all);
// configure quantizer
if (FLAGS_enable_mkldnn) {
q_cfg.EnableMkldnnQuantizer();
q_cfg.mkldnn_quantizer_config();
std::unordered_set<std::string> quantize_operators(
{"conv2d", "depthwise_conv2d", "prior_box", "transpose2", "reshape2"});
q_cfg.mkldnn_quantizer_config()->SetEnabledOpTypes(quantize_operators);
q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_warmup_batch_size);
q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(
FLAGS_warmup_batch_size);
}
// 0 is avg_cost, 1 is top1_acc, 2 is top5_acc or mAP
CompareQuantizedAndAnalysis(&cfg, &q_cfg, input_slots_all, 2);
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
DEFINE_bool(enable_mkldnn, true, "Enable MKLDNN");
namespace paddle {
namespace inference {
namespace analysis {
......@@ -27,7 +29,7 @@ void SetConfig(AnalysisConfig *cfg, std::string model_path) {
cfg->SwitchIrOptim(false);
cfg->SwitchSpecifyInputNames();
cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
cfg->EnableMKLDNN();
if (FLAGS_enable_mkldnn) cfg->EnableMKLDNN();
}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册