CMakeLists.txt 15.4 KB
Newer Older
W
WangZhen 已提交
1 2 3
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")

4
function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn)
5
    py_test(${target} SRCS ${filename}
6
        ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
7
             FLAGS_use_mkldnn=${use_mkldnn}
8
        ARGS --infer_model ${model_dir}/model
9
             --infer_data ${data_path}
10
             --int8_model_save_path int8_models/${target}
11
             --warmup_batch_size ${WARMUP_BATCH_SIZE}
12 13 14
             --batch_size 50)
endfunction()

15 16
function(inference_analysis_python_api_int8_test target model_dir data_path filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False)
17 18
endfunction()

19 20 21 22 23
function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size)
    set(WARMUP_BATCH_SIZE ${warmup_batch_size})
    inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename})
endfunction()

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True)
endfunction()

function(download_qat_data install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
    endif()
endfunction()

function(download_qat_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
    endif()
endfunction()

40 41 42 43 44 45 46
function(download_qat_fp32_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file})
    endif()
endfunction()

function(inference_qat_int8_image_classification_test target qat_model_dir dataset_path)
47
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py"
48 49
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
50
                 FLAGS_use_mkldnn=true
51
            ARGS --qat_model ${qat_model_dir}
52
                 --infer_data ${dataset_path}
53 54 55 56 57
                 --batch_size 25
                 --batch_num 2
                 --acc_diff_threshold 0.1)
endfunction()

58 59

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 
60
function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path ops_to_quantize)
61
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_image_classification_comparison.py"
62 63
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
64
                 FLAGS_use_mkldnn=true
65 66 67
            ARGS --qat_model ${qat_model_dir}
                 --fp32_model ${fp32_model_dir}
                 --infer_data ${dataset_path}
68
                 --batch_size 10
69 70
                 --batch_num 2
                 --acc_diff_threshold 0.1
71
                 --ops_to_quantize ${ops_to_quantize})
72 73 74
endfunction()

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 
75
function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path)
76
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_nlp_comparison.py"
77 78 79
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 FLAGS_use_mkldnn=true
80 81 82
            ARGS --qat_model ${qat_model_dir}
		 --fp32_model ${fp32_model_dir}
                 --infer_data ${dataset_path}
83 84 85
		 --labels ${labels_path}
                 --batch_size 10
                 --batch_num 2
86
                 --acc_diff_threshold 0.1)
87 88 89 90
endfunction()

function(download_qat_data install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
91
           inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
92 93 94 95 96
    endif()
endfunction()

function(download_qat_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
97
           inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
98
    endif()
99 100
endfunction()

101
function(save_qat_ic_model_test target qat_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize)
102
    py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
103
            ARGS --qat_model_path ${qat_model_dir}
104 105
	         --fp32_model_save_path ${fp32_model_save_path}
	         --int8_model_save_path ${int8_model_save_path}
106 107 108 109 110 111 112 113
		 --ops_to_quantize ${ops_to_quantize})
endfunction()

function(save_qat_nlp_model_test target qat_model_dir fp32_model_save_path int8_model_save_path)
    py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
            ARGS --qat_model_path ${qat_model_dir}
	         --fp32_model_save_path ${fp32_model_save_path}
	         --int8_model_save_path ${int8_model_save_path})
114
endfunction()
115

116 117 118 119 120 121 122
function(convert_model2dot_test target model_path save_graph_dir save_graph_name)
    py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/convert_model2dot.py
            ARGS --model_path ${model_path}
	         --save_graph_dir ${save_graph_dir}
	         --save_graph_name ${save_graph_name})
endfunction()

W
whs 已提交
123
if(WIN32)
124 125
	list(REMOVE_ITEM TEST_OPS test_light_nas)
	list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
126
    list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
127
    list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
W
whs 已提交
128 129
endif()

130 131
if(LINUX AND WITH_MKLDNN)

132 133 134
	#### Image classification dataset: ImageNet (small)
	# The dataset should already be downloaded for INT8v2 unit tests
	set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin")
135

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
	#### INT8 image classification python api test
	# Models should be already downloaded for INT8v2 unit tests

	set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
	set(INT8_IC_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
	set(INT8_IC_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${INT8_IC_TEST_FILE}")

	# googlenet int8
	set(INT8_GOOGLENET_MODEL_DIR "${INT8_INSTALL_DIR}/googlenet")
	inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH} 10)

	# mobilenet int8
	set(INT8_MOBILENET_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
	inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
	inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

	# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
	# since the following UTs cost too much time on CI test.
	if (WITH_SLIM_MKLDNN_FULL_TEST)
		# resnet50 int8
		set(INT8_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
		inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# mobilenetv2 int8
		set(INT8_MOBILENETV2_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv2")
		inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# resnet101 int8
		set(INT8_RESNET101_MODEL_DIR "${INT8_INSTALL_DIR}/resnet101")
		inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# vgg16 int8
		set(INT8_VGG16_MODEL_DIR "${INT8_INSTALL_DIR}/vgg16")
		inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# vgg19 int8
		set(INT8_VGG19_MODEL_DIR "${INT8_INSTALL_DIR}/vgg19")
		inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
174 175
	endif()

176 177 178 179 180 181
	#### QAT FP32 & INT8 comparison python api tests

	set(QAT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")

	### QATv1 for image classification

182
	# QAT ResNet50
183 184 185
	set(QAT_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_QAT")
	set(QAT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz")
	download_qat_model(${QAT_RESNET50_MODEL_DIR} ${QAT_RESNET50_MODEL_ARCHIVE})
186
	inference_qat_int8_image_classification_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
187 188

	# QAT ResNet101
189 190 191
	set(QAT_RESNET101_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet101_QAT")
	set(QAT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz")
	download_qat_model(${QAT_RESNET101_MODEL_DIR} ${QAT_RESNET101_MODEL_ARCHIVE})
192
	# inference_qat_int8_image_classification_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
193 194

	# QAT GoogleNet
195 196 197
	set(QAT_GOOGLENET_MODEL_DIR "${QAT_INSTALL_DIR}/GoogleNet_QAT")
	set(QAT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz")
	download_qat_model(${QAT_GOOGLENET_MODEL_DIR} ${QAT_GOOGLENET_MODEL_ARCHIVE})
198
	inference_qat_int8_image_classification_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
199 200

	# QAT MobileNetV1
201 202 203
	set(QAT_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV1_QAT")
	set(QAT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz")
	download_qat_model(${QAT_MOBILENETV1_MODEL_DIR} ${QAT_MOBILENETV1_MODEL_ARCHIVE})
204
	inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
205 206

	# QAT MobileNetV2
207 208 209
	set(QAT_MOBILENETV2_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV2_QAT")
	set(QAT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz")
	download_qat_model(${QAT_MOBILENETV2_MODEL_DIR} ${QAT_MOBILENETV2_MODEL_ARCHIVE})
210
	inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
211 212

	# QAT VGG16
213 214 215
	set(QAT_VGG16_MODEL_DIR "${QAT_INSTALL_DIR}/VGG16_QAT")
	set(QAT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz")
	download_qat_model(${QAT_VGG16_MODEL_DIR} ${QAT_VGG16_MODEL_ARCHIVE})
216
	# inference_qat_int8_image_classification_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
217 218

	# QAT VGG19
219 220 221
	set(QAT_VGG19_MODEL_DIR "${QAT_INSTALL_DIR}/VGG19_QAT")
	set(QAT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz")
	download_qat_model(${QAT_VGG19_MODEL_DIR} ${QAT_VGG19_MODEL_ARCHIVE})
222
	# inference_qat_int8_image_classification_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
223 224 225

	### QATv2 for image classification

226
	set(QAT2_IC_OPS_TO_QUANTIZE "conv2d,pool2d")
227

228 229
	# QAT2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators,
	# with weight scales in `fake_dequantize_max_abs` operators
230
        set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf")
231
	set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
232 233
	set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
	download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
234
	inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
235

236 237 238 239 240
	# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
	# with weight scales in `fake_dequantize_max_abs` operators
	set(QAT2_RESNET50_RANGE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_range")
	set(QAT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz")
	download_qat_model(${QAT2_RESNET50_RANGE_MODEL_DIR} ${QAT2_RESNET50_RANGE_MODEL_ARCHIVE})
241
	inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
242 243 244 245 246 247

	# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
	# with weight scales in `fake_channel_wise_dequantize_max_abs` operators
	set(QAT2_RESNET50_CHANNELWISE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_channelwise")
	set(QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz")
	download_qat_model(${QAT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE})
248
	inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
249

250 251
	# QAT2 MobileNetV1
        set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf")
252
	set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
253 254
	set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
	download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE})
255
	inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
256 257 258 259 260 261 262 263 264 265 266 267 268
	
	### QATv2 for NLP

	set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
	set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
	set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
	set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
	download_qat_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})

	# QAT2 Ernie
	set(QAT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
	set(QAT2_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_qat")
	download_qat_model(${QAT2_ERNIE_MODEL_DIR} ${QAT2_ERNIE_MODEL_ARCHIVE})
269 270 271
	set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
	set(FP32_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_float")
	download_qat_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE})
272
	inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH})
273

274
	### Save QAT2 FP32 model or QAT2 INT8 model
275
        
276 277
	set(QAT2_INT8_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_int8")
	set(QAT2_FP32_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_fp32")
278
	save_qat_ic_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
279

280 281
	set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8")
	set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32")
282
	save_qat_nlp_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH})
283

284 285 286 287
	# Convert QAT2 model to dot and pdf files 
	set(QAT2_INT8_ERNIE_DOT_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8_dot_file")
	convert_model2dot_test(convert_model2dot_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_qat2_int8")

288 289
endif()

290
# Since the tests for QAT FP32 & INT8 comparison support only testing on Linux 
291
# with MKL-DNN, we remove it here to not test it on other systems.
292 293 294 295
list(REMOVE_ITEM TEST_OPS
	test_mkldnn_int8_quantization_strategy
	qat_int8_image_classification_comparison
	qat_int8_nlp_comparison)
296

W
WangZhen 已提交
297 298 299
foreach(src ${TEST_OPS})
    py_test(${src} SRCS ${src}.py)
endforeach()