CMakeLists.txt 14.6 KB
Newer Older
W
WangZhen 已提交
1 2 3
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")

4
function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn)
5
    py_test(${target} SRCS ${filename}
6
        ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
7
             FLAGS_use_mkldnn=${use_mkldnn}
8
        ARGS --infer_model ${model_dir}/model
9
             --infer_data ${data_path}
10
             --int8_model_save_path int8_models/${target}
11
             --warmup_batch_size ${WARMUP_BATCH_SIZE}
12 13 14
             --batch_size 50)
endfunction()

15 16
function(inference_analysis_python_api_int8_test target model_dir data_path filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False)
17 18
endfunction()

19 20 21 22 23
function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size)
    set(WARMUP_BATCH_SIZE ${warmup_batch_size})
    inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename})
endfunction()

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True)
endfunction()

function(download_qat_data install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
    endif()
endfunction()

function(download_qat_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
    endif()
endfunction()

40 41 42 43 44 45 46
function(download_qat_fp32_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file})
    endif()
endfunction()

function(inference_qat_int8_image_classification_test target qat_model_dir dataset_path)
47
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py"
48 49
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
50
                 FLAGS_use_mkldnn=true
51
            ARGS --qat_model ${qat_model_dir}
52
                 --infer_data ${dataset_path}
53 54 55 56 57
                 --batch_size 25
                 --batch_num 2
                 --acc_diff_threshold 0.1)
endfunction()

58 59

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 
60 61
function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path quantized_ops)
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_image_classification_comparison.py"
62 63
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
64
                 FLAGS_use_mkldnn=true
65 66 67
            ARGS --qat_model ${qat_model_dir}
                 --fp32_model ${fp32_model_dir}
                 --infer_data ${dataset_path}
68
                 --batch_size 10
69 70
                 --batch_num 2
                 --acc_diff_threshold 0.1
71
                 --quantized_ops ${quantized_ops})
72 73 74
endfunction()

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 
75 76
function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path quantized_ops)
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_nlp_comparison.py"
77 78 79
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 FLAGS_use_mkldnn=true
80 81 82
            ARGS --qat_model ${qat_model_dir}
		 --fp32_model ${fp32_model_dir}
                 --infer_data ${dataset_path}
83 84 85
		 --labels ${labels_path}
                 --batch_size 10
                 --batch_num 2
86 87
                 --acc_diff_threshold 0.1
		 --quantized_ops ${quantized_ops})
88 89 90 91
endfunction()

function(download_qat_data install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
92
           inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
93 94 95 96 97
    endif()
endfunction()

function(download_qat_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
98
           inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
99
    endif()
100 101
endfunction()

102 103
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path quantized_ops)
    py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
104
            ARGS --qat_model_path ${qat_model_dir}
105 106 107
	         --fp32_model_save_path ${fp32_model_save_path}
	         --int8_model_save_path ${int8_model_save_path}
		 --quantized_ops ${quantized_ops})
108
endfunction()
109

W
whs 已提交
110
if(WIN32)
111 112
	list(REMOVE_ITEM TEST_OPS test_light_nas)
	list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
113
    list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
114
    list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
W
whs 已提交
115 116
endif()

117 118
if(LINUX AND WITH_MKLDNN)

119 120 121
	#### Image classification dataset: ImageNet (small)
	# The dataset should already be downloaded for INT8v2 unit tests
	set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin")
122

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
	#### INT8 image classification python api test
	# Models should be already downloaded for INT8v2 unit tests

	set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
	set(INT8_IC_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
	set(INT8_IC_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${INT8_IC_TEST_FILE}")

	# googlenet int8
	set(INT8_GOOGLENET_MODEL_DIR "${INT8_INSTALL_DIR}/googlenet")
	inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH} 10)

	# mobilenet int8
	set(INT8_MOBILENET_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
	inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
	inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

	# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
	# since the following UTs cost too much time on CI test.
	if (WITH_SLIM_MKLDNN_FULL_TEST)
		# resnet50 int8
		set(INT8_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
		inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# mobilenetv2 int8
		set(INT8_MOBILENETV2_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv2")
		inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# resnet101 int8
		set(INT8_RESNET101_MODEL_DIR "${INT8_INSTALL_DIR}/resnet101")
		inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# vgg16 int8
		set(INT8_VGG16_MODEL_DIR "${INT8_INSTALL_DIR}/vgg16")
		inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# vgg19 int8
		set(INT8_VGG19_MODEL_DIR "${INT8_INSTALL_DIR}/vgg19")
		inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
161 162
	endif()

163 164 165 166 167 168
	#### QAT FP32 & INT8 comparison python api tests

	set(QAT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")

	### QATv1 for image classification

169
	# QAT ResNet50
170 171 172
	set(QAT_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_QAT")
	set(QAT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz")
	download_qat_model(${QAT_RESNET50_MODEL_DIR} ${QAT_RESNET50_MODEL_ARCHIVE})
173
	inference_qat_int8_image_classification_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
174 175

	# QAT ResNet101
176 177 178
	set(QAT_RESNET101_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet101_QAT")
	set(QAT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz")
	download_qat_model(${QAT_RESNET101_MODEL_DIR} ${QAT_RESNET101_MODEL_ARCHIVE})
179
	# inference_qat_int8_image_classification_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
180 181

	# QAT GoogleNet
182 183 184
	set(QAT_GOOGLENET_MODEL_DIR "${QAT_INSTALL_DIR}/GoogleNet_QAT")
	set(QAT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz")
	download_qat_model(${QAT_GOOGLENET_MODEL_DIR} ${QAT_GOOGLENET_MODEL_ARCHIVE})
185
	inference_qat_int8_image_classification_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
186 187

	# QAT MobileNetV1
188 189 190
	set(QAT_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV1_QAT")
	set(QAT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz")
	download_qat_model(${QAT_MOBILENETV1_MODEL_DIR} ${QAT_MOBILENETV1_MODEL_ARCHIVE})
191
	inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
192 193

	# QAT MobileNetV2
194 195 196
	set(QAT_MOBILENETV2_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV2_QAT")
	set(QAT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz")
	download_qat_model(${QAT_MOBILENETV2_MODEL_DIR} ${QAT_MOBILENETV2_MODEL_ARCHIVE})
197
	inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
198 199

	# QAT VGG16
200 201 202
	set(QAT_VGG16_MODEL_DIR "${QAT_INSTALL_DIR}/VGG16_QAT")
	set(QAT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz")
	download_qat_model(${QAT_VGG16_MODEL_DIR} ${QAT_VGG16_MODEL_ARCHIVE})
203
	# inference_qat_int8_image_classification_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
204 205

	# QAT VGG19
206 207 208
	set(QAT_VGG19_MODEL_DIR "${QAT_INSTALL_DIR}/VGG19_QAT")
	set(QAT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz")
	download_qat_model(${QAT_VGG19_MODEL_DIR} ${QAT_VGG19_MODEL_ARCHIVE})
209
	# inference_qat_int8_image_classification_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
210 211 212 213 214

	### QATv2 for image classification

	set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d")

215 216
	# QAT2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators,
	# with weight scales in `fake_dequantize_max_abs` operators
217
        set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf")
218
	set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
219 220
	set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
	download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
221
	inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
222

223 224 225 226 227 228 229 230 231 232 233 234 235 236
	# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
	# with weight scales in `fake_dequantize_max_abs` operators
	set(QAT2_RESNET50_RANGE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_range")
	set(QAT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz")
	download_qat_model(${QAT2_RESNET50_RANGE_MODEL_DIR} ${QAT2_RESNET50_RANGE_MODEL_ARCHIVE})
	inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})

	# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
	# with weight scales in `fake_channel_wise_dequantize_max_abs` operators
	set(QAT2_RESNET50_CHANNELWISE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_channelwise")
	set(QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz")
	download_qat_model(${QAT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE})
	inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})

237 238
	# QAT2 MobileNetV1
        set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf")
239
	set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
240 241
	set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
	download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE})
242
	inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
243 244 245
	
	### QATv2 for NLP

246
	set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2,matmul")
247 248 249 250 251 252 253 254 255 256 257

	set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
	set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
	set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
	set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
	download_qat_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})

	# QAT2 Ernie
	set(QAT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
	set(QAT2_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_qat")
	download_qat_model(${QAT2_ERNIE_MODEL_DIR} ${QAT2_ERNIE_MODEL_ARCHIVE})
258 259 260 261
	set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
	set(FP32_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_float")
	download_qat_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE})
	inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QAT2_NLP_QUANTIZED_OPS})
262

263
	### Save QAT2 FP32 model or QAT2 INT8 model
264
        
265 266 267 268
	set(QAT2_INT8_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_int8")
	set(QAT2_FP32_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_fp32")
	save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_QUANTIZED_OPS})

269 270 271
	set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8")
	set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32")
	save_qat_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH} ${QAT2_NLP_QUANTIZED_OPS})
272

273 274
endif()

275
# Since the tests for QAT FP32 & INT8 comparison support only testing on Linux 
276
# with MKL-DNN, we remove it here to not test it on other systems.
277 278 279 280
list(REMOVE_ITEM TEST_OPS
	test_mkldnn_int8_quantization_strategy
	qat_int8_image_classification_comparison
	qat_int8_nlp_comparison)
281

W
WangZhen 已提交
282 283 284
foreach(src ${TEST_OPS})
    py_test(${src} SRCS ${src}.py)
endforeach()