diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 64fee61a1c0a36653126ea2c8ffdc09b60c5c39a..6de371cffbb9b9e2b3ee8fbabed98dbb01605e76 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -23,11 +23,11 @@ function(inference_analysis_api_test target install_dir filename) ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt) endfunction() -function(inference_analysis_api_int8_test target model_dir data_dir filename) +function(inference_analysis_api_int8_test target model_dir data_path filename) inference_analysis_test(${target} SRCS ${filename} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark ARGS --infer_model=${model_dir}/model - --infer_data=${data_dir}/data.bin + --infer_data=${data_path} --warmup_batch_size=100 --batch_size=50 --paddle_num_threads=${CPU_NUM_THREADS_ON_CI} @@ -159,55 +159,70 @@ if(WITH_MKLDNN) if (NOT EXISTS ${INT8_DATA_DIR}) inference_download_and_uncompress(${INT8_DATA_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz") endif() + + if (NOT EXISTS ${INT8_DATA_DIR}/pascalvoc_data.bin) + inference_download_and_uncompress(${INT8_DATA_DIR} "${INFERENCE_URL}/int8" "pascalvoc_val_200_head.tar.gz") + endif() + set(IMAGENET_DATA_PATH "${INT8_DATA_DIR}/data.bin") + set(PASCALVOC_DATA_PATH "${INT8_DATA_DIR}/pascalvoc_data.bin") #resnet50 int8 set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50") if (NOT EXISTS ${INT8_RESNET50_MODEL_DIR}) inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} "${INFERENCE_URL}/int8" "resnet50_int8_model.tar.gz" ) endif() - inference_analysis_api_int8_test(test_analyzer_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + inference_analysis_api_int8_test(test_analyzer_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH} analyzer_int8_image_classification_tester.cc) #mobilenet int8 set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet") if (NOT EXISTS ${INT8_MOBILENET_MODEL_DIR}) inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenetv1_int8_model.tar.gz" ) endif() - inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} analyzer_int8_image_classification_tester.cc) #mobilenetv2 int8 set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2") if (NOT EXISTS ${INT8_MOBILENETV2_MODEL_DIR}) inference_download_and_uncompress(${INT8_MOBILENETV2_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenet_v2_int8_model.tar.gz" ) endif() - inference_analysis_api_int8_test(test_analyzer_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + inference_analysis_api_int8_test(test_analyzer_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH} analyzer_int8_image_classification_tester.cc) #resnet101 int8 set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101") if (NOT EXISTS ${INT8_RESNET101_MODEL_DIR}) inference_download_and_uncompress(${INT8_RESNET101_MODEL_DIR} "${INFERENCE_URL}/int8" "Res101_int8_model.tar.gz" ) endif() - inference_analysis_api_int8_test(test_analyzer_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + inference_analysis_api_int8_test(test_analyzer_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH} analyzer_int8_image_classification_tester.cc) #vgg16 int8 set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16") if (NOT EXISTS ${INT8_VGG16_MODEL_DIR}) inference_download_and_uncompress(${INT8_VGG16_MODEL_DIR} "${INFERENCE_URL}/int8" "VGG16_int8_model.tar.gz" ) endif() - inference_analysis_api_int8_test(test_analyzer_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + inference_analysis_api_int8_test(test_analyzer_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH} analyzer_int8_image_classification_tester.cc) #vgg19 int8 set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19") if (NOT EXISTS ${INT8_VGG19_MODEL_DIR}) inference_download_and_uncompress(${INT8_VGG19_MODEL_DIR} "${INFERENCE_URL}/int8" "VGG19_int8_model.tar.gz" ) endif() - inference_analysis_api_int8_test(test_analyzer_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + inference_analysis_api_int8_test(test_analyzer_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} analyzer_int8_image_classification_tester.cc) #googlenet int8 set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet") if (NOT EXISTS ${INT8_GOOGLENET_MODEL_DIR}) inference_download_and_uncompress(${INT8_GOOGLENET_MODEL_DIR} "${INFERENCE_URL}/int8" "GoogleNet_int8_model.tar.gz" ) endif() - inference_analysis_api_int8_test(test_analyzer_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL) + inference_analysis_api_int8_test(test_analyzer_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} analyzer_int8_image_classification_tester.cc SERIAL) + + #mobilenet-ssd int8 model + set(INT8_MOBILENET_SSD_MODEL_DIR "${INT8_DATA_DIR}/mobilenet-ssd") + if (NOT EXISTS ${INT8_MOBILENET_SSD_MODEL_DIR}) + inference_download_and_uncompress(${INT8_MOBILENET_SSD_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenet_ssd_int8_model.tar.gz" ) + endif() + inference_analysis_api_int8_test(test_analyzer_int8_mobilenet_ssd ${INT8_MOBILENET_SSD_MODEL_DIR} ${PASCALVOC_DATA_PATH} analyzer_int8_object_detection_tester.cc) + + endif() # bert, max_len=20, embedding_dim=128 diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c86f32bf7fc5139d09d57851c07901ef53ec306 --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc @@ -0,0 +1,278 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/inference/api/paddle_analysis_config.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void SetConfig(AnalysisConfig *cfg) { + cfg->SetModel(FLAGS_infer_model); + cfg->DisableGpu(); + cfg->SwitchIrOptim(true); + cfg->SwitchSpecifyInputNames(false); + cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); + cfg->EnableMKLDNN(); +} + +std::vector ReadObjectsNum(std::ifstream &file, size_t offset, + int64_t total_images) { + std::vector num_objects; + num_objects.resize(total_images); + + file.clear(); + file.seekg(offset); + file.read(reinterpret_cast(num_objects.data()), + total_images * sizeof(size_t)); + + if (file.eof()) LOG(ERROR) << "Reached end of stream"; + if (file.fail()) throw std::runtime_error("Failed reading file."); + return num_objects; +} + +template +class TensorReader { + public: + TensorReader(std::ifstream &file, size_t beginning_offset, std::string name) + : file_(file), position(beginning_offset), name_(name) {} + + PaddleTensor NextBatch(std::vector shape, std::vector lod) { + int numel = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + PaddleTensor tensor; + tensor.name = name_; + tensor.shape = shape; + tensor.dtype = GetPaddleDType(); + tensor.data.Resize(numel * sizeof(T)); + if (lod.empty() == false) { + tensor.lod.clear(); + tensor.lod.push_back(lod); + } + file_.seekg(position); + file_.read(reinterpret_cast(tensor.data.data()), numel * sizeof(T)); + position = file_.tellg(); + if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream"; + if (file_.fail()) + throw std::runtime_error(name_ + ": failed reading file."); + return tensor; + } + + protected: + std::ifstream &file_; + size_t position; + std::string name_; +}; + +void SetInput(std::vector> *inputs, + int32_t batch_size = FLAGS_batch_size, int process_images = 0) { + std::ifstream file(FLAGS_infer_data, std::ios::binary); + if (!file) { + FAIL() << "Couldn't open file: " << FLAGS_infer_data; + } + + int64_t total_images{0}; + file.read(reinterpret_cast(&total_images), sizeof(int64_t)); + LOG(INFO) << "Total images in file: " << total_images; + + size_t image_beginning_offset = static_cast(file.tellg()); + auto lod_offset_in_file = + image_beginning_offset + sizeof(float) * total_images * 3 * 300 * 300; + auto labels_beginning_offset = + lod_offset_in_file + sizeof(size_t) * total_images; + + std::vector lod_full = + ReadObjectsNum(file, lod_offset_in_file, total_images); + size_t sum_objects_num = + std::accumulate(lod_full.begin(), lod_full.end(), 0UL); + + auto bbox_beginning_offset = + labels_beginning_offset + sizeof(int64_t) * sum_objects_num; + auto difficult_beginning_offset = + bbox_beginning_offset + sizeof(float) * sum_objects_num * 4; + + TensorReader image_reader(file, image_beginning_offset, "image"); + TensorReader label_reader(file, labels_beginning_offset, "gt_label"); + TensorReader bbox_reader(file, bbox_beginning_offset, "gt_bbox"); + TensorReader difficult_reader(file, difficult_beginning_offset, + "gt_difficult"); + if (process_images == 0) process_images = total_images; + auto iterations_max = process_images / batch_size; + for (auto i = 0; i < iterations_max; i++) { + auto images_tensor = image_reader.NextBatch({batch_size, 3, 300, 300}, {}); + std::vector batch_lod(lod_full.begin() + i * batch_size, + lod_full.begin() + batch_size * (i + 1)); + size_t batch_num_objects = + std::accumulate(batch_lod.begin(), batch_lod.end(), 0UL); + batch_lod.insert(batch_lod.begin(), 0UL); + for (auto it = batch_lod.begin() + 1; it != batch_lod.end(); it++) { + *it = *it + *(it - 1); + } + auto labels_tensor = label_reader.NextBatch( + {static_cast(batch_num_objects), 1}, batch_lod); + auto bbox_tensor = bbox_reader.NextBatch( + {static_cast(batch_num_objects), 4}, batch_lod); + auto difficult_tensor = difficult_reader.NextBatch( + {static_cast(batch_num_objects), 1}, batch_lod); + + inputs->emplace_back(std::vector{ + std::move(images_tensor), std::move(bbox_tensor), + std::move(labels_tensor), std::move(difficult_tensor)}); + } +} + +std::shared_ptr> GetWarmupData( + const std::vector> &test_data, + int32_t num_images = FLAGS_warmup_batch_size) { + int test_data_batch_size = test_data[0][0].shape[0]; + auto iterations_max = test_data.size(); + PADDLE_ENFORCE( + static_cast(num_images) <= iterations_max * test_data_batch_size, + "The requested quantization warmup data size " + + std::to_string(num_images) + " is bigger than all test data size."); + + PaddleTensor images; + images.name = "image"; + images.shape = {num_images, 3, 300, 300}; + images.dtype = PaddleDType::FLOAT32; + images.data.Resize(sizeof(float) * num_images * 3 * 300 * 300); + + int batches = num_images / test_data_batch_size; + int batch_remain = num_images % test_data_batch_size; + size_t num_objects = 0UL; + std::vector accum_lod; + accum_lod.push_back(0UL); + for (int i = 0; i < batches; i++) { + std::transform(test_data[i][1].lod[0].begin() + 1, + test_data[i][1].lod[0].end(), std::back_inserter(accum_lod), + [&num_objects](size_t lodtemp) -> size_t { + return lodtemp + num_objects; + }); + num_objects += test_data[i][1].lod[0][test_data_batch_size]; + } + if (batch_remain > 0) { + std::transform(test_data[batches][1].lod[0].begin() + 1, + test_data[batches][1].lod[0].begin() + batch_remain + 1, + std::back_inserter(accum_lod), + [&num_objects](size_t lodtemp) -> size_t { + return lodtemp + num_objects; + }); + num_objects = num_objects + test_data[batches][1].lod[0][batch_remain]; + } + + PaddleTensor labels; + labels.name = "gt_label"; + labels.shape = {static_cast(num_objects), 1}; + labels.dtype = PaddleDType::INT64; + labels.data.Resize(sizeof(int64_t) * num_objects); + labels.lod.push_back(accum_lod); + + PaddleTensor bbox; + bbox.name = "gt_bbox"; + bbox.shape = {static_cast(num_objects), 4}; + bbox.dtype = PaddleDType::FLOAT32; + bbox.data.Resize(sizeof(float) * num_objects * 4); + bbox.lod.push_back(accum_lod); + + PaddleTensor difficult; + difficult.name = "gt_difficult"; + difficult.shape = {static_cast(num_objects), 1}; + difficult.dtype = PaddleDType::INT64; + difficult.data.Resize(sizeof(int64_t) * num_objects); + difficult.lod.push_back(accum_lod); + + size_t objects_accum = 0; + size_t objects_in_batch = 0; + for (int i = 0; i < batches; i++) { + objects_in_batch = test_data[i][1].lod[0][test_data_batch_size]; + std::copy_n(static_cast(test_data[i][0].data.data()), + test_data_batch_size * 3 * 300 * 300, + static_cast(images.data.data()) + + i * test_data_batch_size * 3 * 300 * 300); + std::copy_n(static_cast(test_data[i][1].data.data()), + objects_in_batch, + static_cast(labels.data.data()) + objects_accum); + std::copy_n(static_cast(test_data[i][2].data.data()), + objects_in_batch * 4, + static_cast(bbox.data.data()) + objects_accum * 4); + std::copy_n(static_cast(test_data[i][3].data.data()), + objects_in_batch, + static_cast(difficult.data.data()) + objects_accum); + objects_accum = objects_accum + objects_in_batch; + } + + size_t objects_remain = test_data[batches][1].lod[0][batch_remain]; + std::copy_n( + static_cast(test_data[batches][0].data.data()), + batch_remain * 3 * 300 * 300, + static_cast(images.data.data()) + objects_accum * 3 * 300 * 300); + std::copy_n(static_cast(test_data[batches][1].data.data()), + objects_remain, + static_cast(labels.data.data()) + objects_accum); + std::copy_n(static_cast(test_data[batches][2].data.data()), + objects_remain * 4, + static_cast(bbox.data.data()) + objects_accum * 4); + std::copy_n(static_cast(test_data[batches][3].data.data()), + objects_remain, + static_cast(difficult.data.data()) + objects_accum); + + objects_accum = objects_accum + objects_remain; + PADDLE_ENFORCE( + static_cast(num_objects) == static_cast(objects_accum), + "The requested num of objects " + std::to_string(num_objects) + + " is the same as objects_accum."); + + auto warmup_data = std::make_shared>(4); + (*warmup_data)[0] = std::move(images); + (*warmup_data)[1] = std::move(bbox); + (*warmup_data)[2] = std::move(labels); + (*warmup_data)[3] = std::move(difficult); + + return warmup_data; +} + +TEST(Analyzer_int8_mobilenet_ssd, quantization) { + AnalysisConfig cfg; + SetConfig(&cfg); + + AnalysisConfig q_cfg; + SetConfig(&q_cfg); + + // read data from file and prepare batches with test data + std::vector> input_slots_all; + SetInput(&input_slots_all); + + // prepare warmup batch from input data read earlier + // warmup batch size can be different than batch size + std::shared_ptr> warmup_data = + GetWarmupData(input_slots_all); + + // configure quantizer + q_cfg.EnableMkldnnQuantizer(); + q_cfg.mkldnn_quantizer_config(); + std::unordered_set quantize_operators( + {"conv2d", "depthwise_conv2d", "prior_box"}); + q_cfg.mkldnn_quantizer_config()->SetEnabledOpTypes(quantize_operators); + q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); + q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_warmup_batch_size); + + CompareQuantizedAndAnalysis(&cfg, &q_cfg, input_slots_all); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/full_pascalvoc_test_preprocess.py b/paddle/fluid/inference/tests/api/full_pascalvoc_test_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca8e582f8cda55c27249e95092ec6ce6a1c40d0 --- /dev/null +++ b/paddle/fluid/inference/tests/api/full_pascalvoc_test_preprocess.py @@ -0,0 +1,187 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import xml.etree.ElementTree as ET +from PIL import Image +import numpy as np +import os +import sys +from paddle.dataset.common import download +import tarfile +import StringIO +import hashlib +import tarfile + +DATA_URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar" +DATA_DIR = os.path.expanduser("~/.cache/paddle/dataset/pascalvoc/") +TAR_FILE = "VOCtest_06-Nov-2007.tar" +TAR_PATH = os.path.join(DATA_DIR, TAR_FILE) +RESIZE_H = 300 +RESIZE_W = 300 +mean_value = [127.5, 127.5, 127.5] +ap_version = '11point' +DATA_OUT = 'pascalvoc_full.bin' +DATA_OUT_PATH = os.path.join(DATA_DIR, DATA_OUT) +BIN_TARGETHASH = "f6546cadc42f5ff13178b84ed29b740b" +TAR_TARGETHASH = "b6e924de25625d8de591ea690078ad9f" +TEST_LIST_KEY = "VOCdevkit/VOC2007/ImageSets/Main/test.txt" +BIN_FULLSIZE = 5348678856 + + +def preprocess(img): + img_width, img_height = img.size + + img = img.resize((RESIZE_W, RESIZE_H), Image.ANTIALIAS) + img = np.array(img) + + # HWC to CHW + if len(img.shape) == 3: + img = np.swapaxes(img, 1, 2) + img = np.swapaxes(img, 1, 0) + # RBG to BGR + img = img[[2, 1, 0], :, :] + img = img.astype('float32') + img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype('float32') + img -= img_mean + img = img * 0.007843 + return img + + +def print_processbar(done_percentage): + done_filled = done_percentage * '=' + empty_filled = (100 - done_percentage) * ' ' + sys.stdout.write("\r[%s%s]%d%%" % + (done_filled, empty_filled, done_percentage)) + sys.stdout.flush() + + +def convert_pascalvoc(tar_path, data_out_path): + print("Start converting ...\n") + images = {} + gt_labels = {} + boxes = [] + lbls = [] + difficults = [] + object_nums = [] + + # map label to number (index) + label_list = [ + "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", + "car", "cat", "chair", "cow", "diningtable", "dog", "horse", + "motorbike", "person", "pottedplant", "sheep", "sofa", "train", + "tvmonitor" + ] + print_processbar(0) + #read from tar file and write to bin + tar = tarfile.open(tar_path, "r") + f_test = tar.extractfile(TEST_LIST_KEY).read() + lines = f_test.split('\n') + del lines[-1] + line_len = len(lines) + per_percentage = line_len / 100 + + f1 = open(data_out_path, "w+b") + f1.seek(0) + f1.write(np.array(line_len).astype('int64').tobytes()) + for tarInfo in tar: + if tarInfo.isfile(): + tmp_filename = tarInfo.name + name_arr = tmp_filename.split('/') + name_prefix = name_arr[-1].split('.')[0] + if name_arr[-2] == 'JPEGImages' and name_prefix in lines: + images[name_prefix] = tar.extractfile(tarInfo).read() + if name_arr[-2] == 'Annotations' and name_prefix in lines: + gt_labels[name_prefix] = tar.extractfile(tarInfo).read() + + for line_idx, name_prefix in enumerate(lines): + im = Image.open(StringIO.StringIO(images[name_prefix])) + if im.mode == 'L': + im = im.convert('RGB') + im_width, im_height = im.size + + im = preprocess(im) + np_im = np.array(im) + f1.write(np_im.astype('float32').tobytes()) + + # layout: label | xmin | ymin | xmax | ymax | difficult + bbox_labels = [] + root = ET.fromstring(gt_labels[name_prefix]) + + objects = root.findall('object') + objects_size = len(objects) + object_nums.append(objects_size) + + for object in objects: + bbox_sample = [] + bbox_sample.append( + float(label_list.index(object.find('name').text))) + bbox = object.find('bndbox') + difficult = float(object.find('difficult').text) + bbox_sample.append(float(bbox.find('xmin').text) / im_width) + bbox_sample.append(float(bbox.find('ymin').text) / im_height) + bbox_sample.append(float(bbox.find('xmax').text) / im_width) + bbox_sample.append(float(bbox.find('ymax').text) / im_height) + bbox_sample.append(difficult) + bbox_labels.append(bbox_sample) + + bbox_labels = np.array(bbox_labels) + if len(bbox_labels) == 0: continue + lbls.extend(bbox_labels[:, 0]) + boxes.extend(bbox_labels[:, 1:5]) + difficults.extend(bbox_labels[:, -1]) + + if line_idx % per_percentage: + print_processbar(line_idx / per_percentage) + + f1.write(np.array(object_nums).astype('uint64').tobytes()) + f1.write(np.array(lbls).astype('int64').tobytes()) + f1.write(np.array(boxes).astype('float32').tobytes()) + f1.write(np.array(difficults).astype('int64').tobytes()) + f1.close() + print_processbar(100) + print("Conversion finished!\n") + + +def download_pascalvoc(data_url, data_dir, tar_targethash, tar_path): + print("Downloading pascalvcoc test set...") + download(data_url, data_dir, tar_targethash) + if not os.path.exists(tar_path): + print("Failed in downloading pascalvoc test set. URL %s\n" % data_url) + else: + tmp_hash = hashlib.md5(open(tar_path, 'rb').read()).hexdigest() + if tmp_hash != tar_targethash: + print("Downloaded test set is broken, removing ...\n") + else: + print("Downloaded successfully. Path: %s\n" % tar_path) + + +def run_convert(): + try_limit = 2 + retry = 0 + while not (os.path.exists(DATA_OUT_PATH) and + os.path.getsize(DATA_OUT_PATH) == BIN_FULLSIZE and BIN_TARGETHASH + == hashlib.md5(open(DATA_OUT_PATH, 'rb').read()).hexdigest()): + if os.path.exists(DATA_OUT_PATH): + sys.stderr.write( + "The existing binary file is broken. It is being removed...\n") + os.remove(DATA_OUT_PATH) + if retry < try_limit: + retry = retry + 1 + else: + download_pascalvoc(DATA_URL, DATA_DIR, TAR_TARGETHASH, TAR_PATH) + convert_pascalvoc(TAR_PATH, DATA_OUT_PATH) + print("Success! \nThe binary file can be found at %s\n" % DATA_OUT_PATH) + + +if __name__ == "__main__": + run_convert()