CMakeLists.txt 9.7 KB
Newer Older
W
WangZhen 已提交
1 2 3
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")

4
function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn)
5
    py_test(${target} SRCS ${filename}
6
        ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
7
             FLAGS_use_mkldnn=${use_mkldnn}
8 9 10 11 12 13 14
        ARGS --infer_model ${model_dir}/model
             --infer_data ${data_dir}/data.bin
             --int8_model_save_path int8_models/${target}
             --warmup_batch_size 100
             --batch_size 50)
endfunction()

15 16 17 18 19 20 21 22
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False)
endfunction()

function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename)
    _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True)
endfunction()

23 24 25 26 27 28 29 30 31 32 33 34
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn)
    py_test(${target} SRCS ${test_script}
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 FLAGS_use_mkldnn=${use_mkldnn}
            ARGS --qat_model ${model_dir}/model
                 --infer_data ${data_dir}/data.bin
                 --batch_size 25
                 --batch_num 2
                 --acc_diff_threshold 0.1)
endfunction()

35 36 37 38 39 40 41 42 43 44 45 46 47
function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn)
    py_test(${target} SRCS ${test_script}
            ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
                 FLAGS_use_mkldnn=${use_mkldnn}
            ARGS --qat_model ${model_dir}/float
                 --infer_data ${data_dir}/data.bin
                 --batch_size 25
                 --batch_num 2
                 --acc_diff_threshold 0.1
                 --qat2)
endfunction()

48 49 50 51 52 53
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path test_script)
    py_test(${target} SRCS ${test_script}
            ARGS --qat_model_path ${qat_model_dir}
	            --fp32_model_save_path ${fp32_model_save_path}
	            --int8_model_save_path ${int8_model_save_path})
endfunction()
54

W
whs 已提交
55 56
if(WIN32)
    list(REMOVE_ITEM TEST_OPS test_light_nas)
57 58
    list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
    list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
W
whs 已提交
59 60
endif()

61 62 63 64
# int8 image classification python api test
if(LINUX AND WITH_MKLDNN)
  set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
  set(MKLDNN_INT8_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
65
  set(MKLDNN_INT8_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_INT8_TEST_FILE}")
66 67 68

  # googlenet int8
  set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet")
69
  inference_analysis_python_api_int8_test(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
70 71

  # mobilenet int8
72
  set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1")
73 74
  inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
  inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
75

76
  # temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
77 78 79 80
  # since the following UTs cost too much time on CI test.
  if (WITH_SLIM_MKLDNN_FULL_TEST)
    # resnet50 int8
    set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50")
81
    inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
82 83 84

    # mobilenetv2 int8
    set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2")
85
    inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
86 87 88

    # resnet101 int8
    set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
89
    inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
90 91 92

    # vgg16 int8
    set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16")
93
    inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
94 95 96

    # vgg19 int8
    set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19")
97
    inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
98 99 100
  endif()
endif()

101
# Since test_mkldnn_int8_quantization_strategy only supports testing on Linux
102 103 104
# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems.
list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy)

105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
# QAT FP32 & INT8 comparison python api tests
if(LINUX AND WITH_MKLDNN)
	set(DATASET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
	set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
	set(QAT_MODELS_BASE_URL "${INFERENCE_URL}/int8/QAT_models")
	set(MKLDNN_QAT_TEST_FILE "qat_int8_comparison.py")
	set(MKLDNN_QAT_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_QAT_TEST_FILE}")

	# ImageNet small dataset
	# May be already downloaded for INT8v2 unit tests
	if (NOT EXISTS ${DATASET_DIR})
		inference_download_and_uncompress(${DATASET_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz")
	endif()

	# QAT ResNet50
	set(QAT_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_QAT")
	if (NOT EXISTS ${QAT_RESNET50_MODEL_DIR})
		inference_download_and_uncompress(${QAT_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT ResNet101
	set(QAT_RESNET101_MODEL_DIR "${QAT_DATA_DIR}/ResNet101_QAT")
	if (NOT EXISTS ${QAT_RESNET101_MODEL_DIR})
		inference_download_and_uncompress(${QAT_RESNET101_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet101_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT GoogleNet
	set(QAT_GOOGLENET_MODEL_DIR "${QAT_DATA_DIR}/GoogleNet_QAT")
	if (NOT EXISTS ${QAT_GOOGLENET_MODEL_DIR})
		inference_download_and_uncompress(${QAT_GOOGLENET_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "GoogleNet_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT MobileNetV1
	set(QAT_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV1_QAT")
	if (NOT EXISTS ${QAT_MOBILENETV1_MODEL_DIR})
		inference_download_and_uncompress(${QAT_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV1_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT MobileNetV2
	set(QAT_MOBILENETV2_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV2_QAT")
	if (NOT EXISTS ${QAT_MOBILENETV2_MODEL_DIR})
		inference_download_and_uncompress(${QAT_MOBILENETV2_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV2_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT VGG16
	set(QAT_VGG16_MODEL_DIR "${QAT_DATA_DIR}/VGG16_QAT")
	if (NOT EXISTS ${QAT_VGG16_MODEL_DIR})
		inference_download_and_uncompress(${QAT_VGG16_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG16_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

	# QAT VGG19
	set(QAT_VGG19_MODEL_DIR "${QAT_DATA_DIR}/VGG19_QAT")
	if (NOT EXISTS ${QAT_VGG19_MODEL_DIR})
		inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
	endif()
	inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
167 168 169 170 171 172 173 174 175 176 177 178 179
  
        set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf")
        if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR})
                inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" )
        endif()
        inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

        set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
        if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR})
                inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" )
        endif()
        inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)

180 181 182 183 184 185 186
        # Save qat2 fp32 model or qat2 int8 model
        
        set(QAT2_INT8_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_int8")
        set(QAT2_FP32_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_fp32")
        set(SAVE_QAT2_MODEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py")
        save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true)

187 188 189 190 191 192
endif()

# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux 
# with MKL-DNN, we remove it here to not test it on other systems.
list(REMOVE_ITEM TEST_OPS qat_int8_comparison.py)

W
WangZhen 已提交
193 194 195
foreach(src ${TEST_OPS})
    py_test(${src} SRCS ${src}.py)
endforeach()