CMakeLists.txt 10.2 KB
Newer Older
W
WangZhen 已提交
1 2 3
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")

4
function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn)
5
    py_test(${target} SRCS ${filename}
6
        ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
7
             FLAGS_use_mkldnn=${use_mkldnn}
8 9 10
        ARGS --infer_model ${model_dir}/model
             --infer_data ${data_dir}/data.bin
             --int8_model_save_path int8_models/${target}
11
             --warmup_batch_size ${WARMUP_BATCH_SIZE}
12 13 14
             --batch_size 50)
endfunction()

15 16 17 18 19 20 21 22
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False)
endfunction()

function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True)
endfunction()

23 24 25 26 27
function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size)
    set(WARMUP_BATCH_SIZE ${warmup_batch_size})
    inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename})
endfunction()

28 29 30 31 32 33 34 35 36 37 38 39
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn)
    py_test(${target} SRCS ${test_script}
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 FLAGS_use_mkldnn=${use_mkldnn}
            ARGS --qat_model ${model_dir}/model
                 --infer_data ${data_dir}/data.bin
                 --batch_size 25
                 --batch_num 2
                 --acc_diff_threshold 0.1)
endfunction()

40 41

# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 
42 43 44 45 46 47 48
function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn)
    py_test(${target} SRCS ${test_script}
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 FLAGS_use_mkldnn=${use_mkldnn}
            ARGS --qat_model ${model_dir}/float
                 --infer_data ${data_dir}/data.bin
49
                 --batch_size 10
50 51 52 53 54
                 --batch_num 2
                 --acc_diff_threshold 0.1
                 --qat2)
endfunction()

55 56 57 58 59 60
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path test_script)
    py_test(${target} SRCS ${test_script}
            ARGS --qat_model_path ${qat_model_dir}
	            --fp32_model_save_path ${fp32_model_save_path}
	            --int8_model_save_path ${int8_model_save_path})
endfunction()
61

W
whs 已提交
62 63
if(WIN32)
    list(REMOVE_ITEM TEST_OPS test_light_nas)
64 65
    list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
    list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
66
    list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
W
whs 已提交
67 68
endif()

69 70 71 72
# int8 image classification python api test
if(LINUX AND WITH_MKLDNN)
  set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
  set(MKLDNN_INT8_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
73
  set(MKLDNN_INT8_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_INT8_TEST_FILE}")
74 75 76

  # googlenet int8
  set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet")
77
  inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH} 10)
78 79

  # mobilenet int8
80
  set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1")
81 82
  inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
  inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
83

84
  # temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
85 86 87 88
  # since the following UTs cost too much time on CI test.
  if (WITH_SLIM_MKLDNN_FULL_TEST)
    # resnet50 int8
    set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50")
89
    inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
90 91 92

    # mobilenetv2 int8
    set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2")
93
    inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
94 95 96

    # resnet101 int8
    set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
97
    inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
98 99 100

    # vgg16 int8
    set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16")
101
    inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
102 103 104

    # vgg19 int8
    set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19")
105
    inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
106 107 108
  endif()
endif()

109
# Since test_mkldnn_int8_quantization_strategy only supports testing on Linux
110 111 112
# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems.
list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy)

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
# QAT FP32 & INT8 comparison python api tests
if(LINUX AND WITH_MKLDNN)
	set(DATASET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
	set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
	set(QAT_MODELS_BASE_URL "${INFERENCE_URL}/int8/QAT_models")
	set(MKLDNN_QAT_TEST_FILE "qat_int8_comparison.py")
	set(MKLDNN_QAT_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_QAT_TEST_FILE}")

	# ImageNet small dataset
	# May be already downloaded for INT8v2 unit tests
	if (NOT EXISTS ${DATASET_DIR})
		inference_download_and_uncompress(${DATASET_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz")
	endif()

	# QAT ResNet50
	set(QAT_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_QAT")
	if (NOT EXISTS ${QAT_RESNET50_MODEL_DIR})
		inference_download_and_uncompress(${QAT_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT ResNet101
	set(QAT_RESNET101_MODEL_DIR "${QAT_DATA_DIR}/ResNet101_QAT")
	if (NOT EXISTS ${QAT_RESNET101_MODEL_DIR})
		inference_download_and_uncompress(${QAT_RESNET101_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet101_qat_model.tar.gz" )
	endif()
139
	# inference_qat_int8_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

	# QAT GoogleNet
	set(QAT_GOOGLENET_MODEL_DIR "${QAT_DATA_DIR}/GoogleNet_QAT")
	if (NOT EXISTS ${QAT_GOOGLENET_MODEL_DIR})
		inference_download_and_uncompress(${QAT_GOOGLENET_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "GoogleNet_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT MobileNetV1
	set(QAT_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV1_QAT")
	if (NOT EXISTS ${QAT_MOBILENETV1_MODEL_DIR})
		inference_download_and_uncompress(${QAT_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV1_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT MobileNetV2
	set(QAT_MOBILENETV2_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV2_QAT")
	if (NOT EXISTS ${QAT_MOBILENETV2_MODEL_DIR})
		inference_download_and_uncompress(${QAT_MOBILENETV2_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV2_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT VGG16
	set(QAT_VGG16_MODEL_DIR "${QAT_DATA_DIR}/VGG16_QAT")
	if (NOT EXISTS ${QAT_VGG16_MODEL_DIR})
		inference_download_and_uncompress(${QAT_VGG16_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG16_qat_model.tar.gz" )
	endif()
167
	# inference_qat_int8_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
168 169 170 171 172 173

	# QAT VGG19
	set(QAT_VGG19_MODEL_DIR "${QAT_DATA_DIR}/VGG19_QAT")
	if (NOT EXISTS ${QAT_VGG19_MODEL_DIR})
		inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
	endif()
174
	# inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
175 176 177 178 179 180 181 182 183 184 185 186 187
  
        set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf")
        if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR})
                inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" )
        endif()
        inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

        set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
        if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR})
                inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" )
        endif()
        inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

188 189 190 191 192 193 194
        # Save qat2 fp32 model or qat2 int8 model
        
        set(QAT2_INT8_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_int8")
        set(QAT2_FP32_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_fp32")
        set(SAVE_QAT2_MODEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py")
        save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true)

195 196 197 198 199 200
endif()

# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux 
# with MKL-DNN, we remove it here to not test it on other systems.
list(REMOVE_ITEM TEST_OPS qat_int8_comparison.py)

W
WangZhen 已提交
201 202 203
foreach(src ${TEST_OPS})
    py_test(${src} SRCS ${src}.py)
endforeach()