CMakeLists.txt 13.1 KB
Newer Older
W
WangZhen 已提交
1 2 3
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")

4
function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn)
5
    py_test(${target} SRCS ${filename}
6
        ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
7
             FLAGS_use_mkldnn=${use_mkldnn}
8
        ARGS --infer_model ${model_dir}/model
9
             --infer_data ${data_path}
10
             --int8_model_save_path int8_models/${target}
11
             --warmup_batch_size ${WARMUP_BATCH_SIZE}
12 13 14
             --batch_size 50)
endfunction()

15 16
function(inference_analysis_python_api_int8_test target model_dir data_path filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False)
17 18
endfunction()

19 20 21 22 23
function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size)
    set(WARMUP_BATCH_SIZE ${warmup_batch_size})
    inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename})
endfunction()

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True)
endfunction()

function(download_qat_data install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
    endif()
endfunction()

function(download_qat_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
    endif()
endfunction()

40 41 42 43 44 45 46
function(download_qat_fp32_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
	    inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file})
    endif()
endfunction()

function(inference_qat_int8_image_classification_test target qat_model_dir dataset_path)
47
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py"
48 49
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
50
                 FLAGS_use_mkldnn=true
51
            ARGS --qat_model ${qat_model_dir}
52
                 --infer_data ${dataset_path}
53 54 55 56 57
                 --batch_size 25
                 --batch_num 2
                 --acc_diff_threshold 0.1)
endfunction()

58 59

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 
60 61
function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path quantized_ops)
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_image_classification_comparison.py"
62 63
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
64
                 FLAGS_use_mkldnn=true
65 66 67
            ARGS --qat_model ${qat_model_dir}
                 --fp32_model ${fp32_model_dir}
                 --infer_data ${dataset_path}
68
                 --batch_size 10
69 70
                 --batch_num 2
                 --acc_diff_threshold 0.1
71
		 --quantized_ops ${quantized_ops})
72 73 74
endfunction()

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 
75 76
function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path quantized_ops)
    py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_nlp_comparison.py"
77 78 79
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 FLAGS_use_mkldnn=true
80 81 82
            ARGS --qat_model ${qat_model_dir}
		 --fp32_model ${fp32_model_dir}
                 --infer_data ${dataset_path}
83 84 85
		 --labels ${labels_path}
                 --batch_size 10
                 --batch_num 2
86 87
                 --acc_diff_threshold 0.1
		 --quantized_ops ${quantized_ops})
88 89 90 91
endfunction()

function(download_qat_data install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
92
           inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
93 94 95 96 97
    endif()
endfunction()

function(download_qat_model install_dir data_file)
    if (NOT EXISTS ${install_dir}/${data_file})
98
           inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
99
    endif()
100 101
endfunction()

102 103
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path quantized_ops)
    py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
104
            ARGS --qat_model_path ${qat_model_dir}
105 106 107
	         --fp32_model_save_path ${fp32_model_save_path}
	         --int8_model_save_path ${int8_model_save_path}
		 --quantized_ops ${quantized_ops})
108
endfunction()
109

110 111
# Disable the unittest temporary
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
W
whs 已提交
112 113
if(WIN32)
    list(REMOVE_ITEM TEST_OPS test_light_nas)
114
    list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
115
    list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
W
whs 已提交
116 117
endif()

118 119
if(LINUX AND WITH_MKLDNN)

120 121 122
	#### Image classification dataset: ImageNet (small)
	# The dataset should already be downloaded for INT8v2 unit tests
	set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin")
123

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
	#### INT8 image classification python api test
	# Models should be already downloaded for INT8v2 unit tests

	set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
	set(INT8_IC_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
	set(INT8_IC_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${INT8_IC_TEST_FILE}")

	# googlenet int8
	set(INT8_GOOGLENET_MODEL_DIR "${INT8_INSTALL_DIR}/googlenet")
	inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH} 10)

	# mobilenet int8
	set(INT8_MOBILENET_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
	inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
	inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

	# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
	# since the following UTs cost too much time on CI test.
	if (WITH_SLIM_MKLDNN_FULL_TEST)
		# resnet50 int8
		set(INT8_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
		inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# mobilenetv2 int8
		set(INT8_MOBILENETV2_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv2")
		inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# resnet101 int8
		set(INT8_RESNET101_MODEL_DIR "${INT8_INSTALL_DIR}/resnet101")
		inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# vgg16 int8
		set(INT8_VGG16_MODEL_DIR "${INT8_INSTALL_DIR}/vgg16")
		inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})

		# vgg19 int8
		set(INT8_VGG19_MODEL_DIR "${INT8_INSTALL_DIR}/vgg19")
		inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
162 163
	endif()

164 165 166 167 168 169
	#### QAT FP32 & INT8 comparison python api tests

	set(QAT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")

	### QATv1 for image classification

170
	# QAT ResNet50
171 172 173
	set(QAT_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_QAT")
	set(QAT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz")
	download_qat_model(${QAT_RESNET50_MODEL_DIR} ${QAT_RESNET50_MODEL_ARCHIVE})
174
	inference_qat_int8_image_classification_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
175 176

	# QAT ResNet101
177 178 179
	set(QAT_RESNET101_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet101_QAT")
	set(QAT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz")
	download_qat_model(${QAT_RESNET101_MODEL_DIR} ${QAT_RESNET101_MODEL_ARCHIVE})
180
	# inference_qat_int8_image_classification_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
181 182

	# QAT GoogleNet
183 184 185
	set(QAT_GOOGLENET_MODEL_DIR "${QAT_INSTALL_DIR}/GoogleNet_QAT")
	set(QAT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz")
	download_qat_model(${QAT_GOOGLENET_MODEL_DIR} ${QAT_GOOGLENET_MODEL_ARCHIVE})
186
	inference_qat_int8_image_classification_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
187 188

	# QAT MobileNetV1
189 190 191
	set(QAT_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV1_QAT")
	set(QAT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz")
	download_qat_model(${QAT_MOBILENETV1_MODEL_DIR} ${QAT_MOBILENETV1_MODEL_ARCHIVE})
192
	inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
193 194

	# QAT MobileNetV2
195 196 197
	set(QAT_MOBILENETV2_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV2_QAT")
	set(QAT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz")
	download_qat_model(${QAT_MOBILENETV2_MODEL_DIR} ${QAT_MOBILENETV2_MODEL_ARCHIVE})
198
	inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
199 200

	# QAT VGG16
201 202 203
	set(QAT_VGG16_MODEL_DIR "${QAT_INSTALL_DIR}/VGG16_QAT")
	set(QAT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz")
	download_qat_model(${QAT_VGG16_MODEL_DIR} ${QAT_VGG16_MODEL_ARCHIVE})
204
	# inference_qat_int8_image_classification_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
205 206

	# QAT VGG19
207 208 209
	set(QAT_VGG19_MODEL_DIR "${QAT_INSTALL_DIR}/VGG19_QAT")
	set(QAT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz")
	download_qat_model(${QAT_VGG19_MODEL_DIR} ${QAT_VGG19_MODEL_ARCHIVE})
210
	# inference_qat_int8_image_classification_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
211 212 213 214 215 216 217

	### QATv2 for image classification

	set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d")

	# QAT2 ResNet50
        set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf")
218
	set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
219 220
	set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
	download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
221
	inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
222 223 224

	# QAT2 MobileNetV1
        set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf")
225
	set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
226 227
	set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
	download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE})
228
	inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
	
	### QATv2 for NLP

	set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2")

	set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
	set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
	set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
	set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
	download_qat_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})

	# QAT2 Ernie
	set(QAT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
	set(QAT2_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_qat")
	download_qat_model(${QAT2_ERNIE_MODEL_DIR} ${QAT2_ERNIE_MODEL_ARCHIVE})
244 245 246 247
	set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
	set(FP32_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_float")
	download_qat_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE})
	inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QAT2_NLP_QUANTIZED_OPS})
248

249
	### Save QAT2 FP32 model or QAT2 INT8 model
250
        
251 252 253 254
	set(QAT2_INT8_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_int8")
	set(QAT2_FP32_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_fp32")
	save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_QUANTIZED_OPS})

255 256 257
	set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8")
	set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32")
	save_qat_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH} ${QAT2_NLP_QUANTIZED_OPS})
258

259 260
endif()

261
# Since the tests for QAT FP32 & INT8 comparison support only testing on Linux 
262
# with MKL-DNN, we remove it here to not test it on other systems.
263 264 265 266
list(REMOVE_ITEM TEST_OPS
	test_mkldnn_int8_quantization_strategy
	qat_int8_image_classification_comparison
	qat_int8_nlp_comparison)
267

W
WangZhen 已提交
268 269 270
foreach(src ${TEST_OPS})
    py_test(${src} SRCS ${src}.py)
endforeach()