CMakeLists.txt 13.2 KB
Newer Older
W
WangZhen 已提交
1 2 3
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")

4
function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn)
5
    py_test(${target} SRCS ${filename}
6
        ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
7
             FLAGS_use_mkldnn=${use_mkldnn}
8
        ARGS --infer_model ${model_dir}/model
9
             --infer_data ${data_path}
10
             --int8_model_save_path int8_models/${target}
11
             --warmup_batch_size ${WARMUP_BATCH_SIZE}
12 13 14
             --batch_size 50)
endfunction()

15 16
function(inference_analysis_python_api_int8_test target model_dir data_path filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False)
17 18
endfunction()

19 20 21 22 23
function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size)
    set(WARMUP_BATCH_SIZE ${warmup_batch_size})
    inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename})
endfunction()

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True)
endfunction()

function(download_qat_data install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
    endif()
endfunction()

function(download_qat_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
    endif()
endfunction()

40 41 42 43 44 45 46
function(download_qat_fp32_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file})
    endif()
endfunction()

function(inference_qat_int8_image_classification_test target qat_model_dir dataset_path)
47
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py"
48 49
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
50
                 FLAGS_use_mkldnn=true
51
            ARGS --qat_model ${qat_model_dir}
52
                 --infer_data ${dataset_path}
53 54 55 56 57
                 --batch_size 25
                 --batch_num 2
                 --acc_diff_threshold 0.1)
endfunction()

58 59

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 
60 61
function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path quantized_ops)
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_image_classification_comparison.py"
62 63
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
64
                 FLAGS_use_mkldnn=true
65 66 67
            ARGS --qat_model ${qat_model_dir}
                 --fp32_model ${fp32_model_dir}
                 --infer_data ${dataset_path}
68
                 --batch_size 10
69 70
                 --batch_num 2
                 --acc_diff_threshold 0.1
71
		 --quantized_ops ${quantized_ops})
72 73 74
endfunction()

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 
75 76
function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path quantized_ops)
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_nlp_comparison.py"
77 78 79
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 FLAGS_use_mkldnn=true
80 81 82
            ARGS --qat_model ${qat_model_dir}
		 --fp32_model ${fp32_model_dir}
                 --infer_data ${dataset_path}
83 84 85
		 --labels ${labels_path}
                 --batch_size 10
                 --batch_num 2
86 87
                 --acc_diff_threshold 0.1
		 --quantized_ops ${quantized_ops})
88 89 90 91
endfunction()

function(download_qat_data install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
92
           inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
93 94 95 96 97
    endif()
endfunction()

function(download_qat_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
98
           inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
99
    endif()
100 101
endfunction()

102 103
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path quantized_ops)
    py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
104
            ARGS --qat_model_path ${qat_model_dir}
105 106 107
	         --fp32_model_save_path ${fp32_model_save_path}
	         --int8_model_save_path ${int8_model_save_path}
		 --quantized_ops ${quantized_ops})
108
endfunction()
109

W
whs 已提交
110
if(WIN32)
111 112
	list(REMOVE_ITEM TEST_OPS test_light_nas)
	list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
113
    list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
114
    list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
W
whs 已提交
115 116
endif()

117
# Disable unittest for random error temporary
118 119
list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass)

120 121
if(LINUX AND WITH_MKLDNN)

122 123 124
	#### Image classification dataset: ImageNet (small)
	# The dataset should already be downloaded for INT8v2 unit tests
	set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin")
125

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
	#### INT8 image classification python api test
	# Models should be already downloaded for INT8v2 unit tests

	set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
	set(INT8_IC_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
	set(INT8_IC_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${INT8_IC_TEST_FILE}")

	# googlenet int8
	set(INT8_GOOGLENET_MODEL_DIR "${INT8_INSTALL_DIR}/googlenet")
	inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH} 10)

	# mobilenet int8
	set(INT8_MOBILENET_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
	inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
	inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

	# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
	# since the following UTs cost too much time on CI test.
	if (WITH_SLIM_MKLDNN_FULL_TEST)
		# resnet50 int8
		set(INT8_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
		inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# mobilenetv2 int8
		set(INT8_MOBILENETV2_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv2")
		inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# resnet101 int8
		set(INT8_RESNET101_MODEL_DIR "${INT8_INSTALL_DIR}/resnet101")
		inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# vgg16 int8
		set(INT8_VGG16_MODEL_DIR "${INT8_INSTALL_DIR}/vgg16")
		inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# vgg19 int8
		set(INT8_VGG19_MODEL_DIR "${INT8_INSTALL_DIR}/vgg19")
		inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
164 165
	endif()

166 167 168 169 170 171
	#### QAT FP32 & INT8 comparison python api tests

	set(QAT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")

	### QATv1 for image classification

172
	# QAT ResNet50
173 174 175
	set(QAT_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_QAT")
	set(QAT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz")
	download_qat_model(${QAT_RESNET50_MODEL_DIR} ${QAT_RESNET50_MODEL_ARCHIVE})
176
	inference_qat_int8_image_classification_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
177 178

	# QAT ResNet101
179 180 181
	set(QAT_RESNET101_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet101_QAT")
	set(QAT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz")
	download_qat_model(${QAT_RESNET101_MODEL_DIR} ${QAT_RESNET101_MODEL_ARCHIVE})
182
	# inference_qat_int8_image_classification_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
183 184

	# QAT GoogleNet
185 186 187
	set(QAT_GOOGLENET_MODEL_DIR "${QAT_INSTALL_DIR}/GoogleNet_QAT")
	set(QAT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz")
	download_qat_model(${QAT_GOOGLENET_MODEL_DIR} ${QAT_GOOGLENET_MODEL_ARCHIVE})
188
	inference_qat_int8_image_classification_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
189 190

	# QAT MobileNetV1
191 192 193
	set(QAT_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV1_QAT")
	set(QAT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz")
	download_qat_model(${QAT_MOBILENETV1_MODEL_DIR} ${QAT_MOBILENETV1_MODEL_ARCHIVE})
194
	inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
195 196

	# QAT MobileNetV2
197 198 199
	set(QAT_MOBILENETV2_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV2_QAT")
	set(QAT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz")
	download_qat_model(${QAT_MOBILENETV2_MODEL_DIR} ${QAT_MOBILENETV2_MODEL_ARCHIVE})
200
	inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
201 202

	# QAT VGG16
203 204 205
	set(QAT_VGG16_MODEL_DIR "${QAT_INSTALL_DIR}/VGG16_QAT")
	set(QAT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz")
	download_qat_model(${QAT_VGG16_MODEL_DIR} ${QAT_VGG16_MODEL_ARCHIVE})
206
	# inference_qat_int8_image_classification_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
207 208

	# QAT VGG19
209 210 211
	set(QAT_VGG19_MODEL_DIR "${QAT_INSTALL_DIR}/VGG19_QAT")
	set(QAT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz")
	download_qat_model(${QAT_VGG19_MODEL_DIR} ${QAT_VGG19_MODEL_ARCHIVE})
212
	# inference_qat_int8_image_classification_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
213 214 215 216 217 218 219

	### QATv2 for image classification

	set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d")

	# QAT2 ResNet50
        set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf")
220
	set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
221 222
	set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
	download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
223
	inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
224 225 226

	# QAT2 MobileNetV1
        set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf")
227
	set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
228 229
	set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
	download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE})
230
	inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
	
	### QATv2 for NLP

	set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2")

	set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
	set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
	set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
	set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
	download_qat_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})

	# QAT2 Ernie
	set(QAT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
	set(QAT2_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_qat")
	download_qat_model(${QAT2_ERNIE_MODEL_DIR} ${QAT2_ERNIE_MODEL_ARCHIVE})
246 247 248 249
	set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
	set(FP32_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_float")
	download_qat_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE})
	inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QAT2_NLP_QUANTIZED_OPS})
250

251
	### Save QAT2 FP32 model or QAT2 INT8 model
252
        
253 254 255 256
	set(QAT2_INT8_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_int8")
	set(QAT2_FP32_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_fp32")
	save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_QUANTIZED_OPS})

257 258 259
	set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8")
	set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32")
	save_qat_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH} ${QAT2_NLP_QUANTIZED_OPS})
260

261 262
endif()

263
# Since the tests for QAT FP32 & INT8 comparison support only testing on Linux 
264
# with MKL-DNN, we remove it here to not test it on other systems.
265 266 267 268
list(REMOVE_ITEM TEST_OPS
	test_mkldnn_int8_quantization_strategy
	qat_int8_image_classification_comparison
	qat_int8_nlp_comparison)
269

W
WangZhen 已提交
270 271 272
foreach(src ${TEST_OPS})
    py_test(${src} SRCS ${src}.py)
endforeach()