未验证 提交 fab8bbf2 编写于 作者: L LoveAn 提交者: GitHub

Modify data download function and support unittests of inference APIs on windows (#26988)

* Modify data download function, and support unittests of inference APIs on windows, test=develop

* The import error compatible with py2 and py3, and fix unittests problems of inference APIs on Windows, test=develop
上级 4ff16eb2
......@@ -125,7 +125,7 @@ endfunction()
if(NOT APPLE AND WITH_MKLML)
# RNN1
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz")
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1/model.tar.gz" "rnn1/data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc)
# seq_pool1
......@@ -210,7 +210,7 @@ inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} ana
# transformer, the dataset only works on batch_size=8 now
set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer")
download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar.gz" "temp%2Ftransformer_data.txt.tar.gz")
download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp/transformer_model.tar.gz" "temp/transformer_data.txt.tar.gz")
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
......@@ -219,7 +219,7 @@ inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_test
# ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR}/ocr.tar.gz)
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Focr.tar.gz")
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos/ocr.tar.gz")
endif()
inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
......@@ -235,7 +235,7 @@ set_property(TEST test_analyzer_detect PROPERTY ENVIRONMENT GLOG_vmodule=analysi
# mobilenet with transpose op
set(MOBILENET_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet")
if (NOT EXISTS ${MOBILENET_INSTALL_DIR}/mobilenet.tar.gz)
inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz")
inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos/mobilenet.tar.gz")
endif()
inference_analysis_api_test(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc)
......@@ -363,9 +363,9 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build(${QUANT_IMG_CLASS_TEST_APP} ${QUANT_IMG_CLASS_TEST_APP_SRC})
# MobileNetV1 FP32 vs. Quant INT8
# The FP32 model should already be downloaded for slim Quant unit tests
set(QUANT2_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2")
set(QUANT2_INT8_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2_int8")
download_quant_data(${QUANT2_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf.tar.gz")
download_quant_data(${QUANT2_INT8_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
......
......@@ -44,7 +44,7 @@ void zero_copy_run() {
const int channels = 3;
const int height = 318;
const int width = 318;
float input[batch_size * channels * height * width] = {0};
float *input = new float[batch_size * channels * height * width]();
int shape[4] = {batch_size, channels, height, width};
int shape_size = 4;
......@@ -65,6 +65,7 @@ void zero_copy_run() {
PD_PredictorZeroCopyRun(config, inputs, in_size, &outputs, &out_size);
delete[] input;
delete[] inputs;
delete[] outputs;
}
......
......@@ -112,7 +112,11 @@ TEST(Analyzer_resnet50, compare_determine) {
TEST(Analyzer_resnet50, save_optim_model) {
AnalysisConfig cfg;
std::string optimModelPath = FLAGS_infer_model + "/saved_optim_model";
#ifdef _WIN32
_mkdir(optimModelPath.c_str());
#else
mkdir(optimModelPath.c_str(), 0777);
#endif
SetConfig(&cfg);
SaveOptimModel(&cfg, optimModelPath);
}
......
......@@ -123,7 +123,7 @@ void profile(bool memory_load = false) {
size_t size = GetSize(output[0]);
PADDLE_ENFORCE_GT(size, 0);
int64_t *result = static_cast<int64_t *>(output[0].data.data());
for (size_t i = 0; i < std::min(11UL, size); i++) {
for (size_t i = 0; i < std::min<size_t>(11, size); i++) {
EXPECT_EQ(result[i], chinese_ner_result_data[i]);
}
}
......
......@@ -23,7 +23,7 @@ from PIL import Image
import math
from paddle.dataset.common import download
import tarfile
import StringIO
from six.moves import StringIO
import argparse
random.seed(0)
......@@ -152,7 +152,7 @@ def convert_Imagenet_tar2bin(tar_file, output_file):
idx = 0
for imagedata in dataset.values():
img = Image.open(StringIO.StringIO(imagedata))
img = Image.open(StringIO(imagedata))
img = process_image(img)
np_img = np.array(img)
ofs.write(np_img.astype('float32').tobytes())
......
......@@ -19,7 +19,7 @@ import os
import sys
from paddle.dataset.common import download
import tarfile
import StringIO
from six.moves import StringIO
import hashlib
import tarfile
import argparse
......@@ -191,7 +191,7 @@ def convert_pascalvoc_tar2bin(tar_path, data_out_path):
gt_labels[name_prefix] = tar.extractfile(tarInfo).read()
for line_idx, name_prefix in enumerate(lines):
im = Image.open(StringIO.StringIO(images[name_prefix]))
im = Image.open(StringIO(images[name_prefix]))
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
......
......@@ -25,7 +25,8 @@ endfunction()
function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
message(STATUS "Download inference test stuff from ${URL}/${FILENAME}")
string(REGEX REPLACE "[-%.]" "_" FILENAME_EX ${FILENAME})
string(REGEX REPLACE "[-%./\\]" "_" FILENAME_EX ${FILENAME})
string(REGEX MATCH "[^/\\]+$" DOWNLOAD_NAME ${FILENAME})
set(EXTERNAL_PROJECT_NAME "extern_inference_download_${FILENAME_EX}")
set(UNPACK_DIR "${INSTALL_DIR}/src/${EXTERNAL_PROJECT_NAME}")
ExternalProject_Add(
......@@ -38,7 +39,7 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ${CMAKE_COMMAND} -E chdir ${INSTALL_DIR}
${CMAKE_COMMAND} -E tar xzf ${FILENAME}
${CMAKE_COMMAND} -E tar xzf ${DOWNLOAD_NAME}
UPDATE_COMMAND ""
INSTALL_COMMAND ""
)
......
......@@ -58,7 +58,7 @@ if not defined WITH_AVX set WITH_AVX=ON
if not defined WITH_TESTING set WITH_TESTING=ON
if not defined WITH_PYTHON set WITH_PYTHON=ON
if not defined ON_INFER set ON_INFER=ON
if not defined WITH_INFERENCE_API_TEST set WITH_INFERENCE_API_TEST=OFF
if not defined WITH_INFERENCE_API_TEST set WITH_INFERENCE_API_TEST=ON
if not defined WITH_TPCACHE set WITH_TPCACHE=ON
rem ------set cache third_party------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册