diff --git a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py index 4d968c83d9c9bf9d947204d73f4460e62039cdda..842865933f2b4741aea034b19952d4c59344ba06 100644 --- a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py +++ b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py @@ -1,5 +1,4 @@ # copyright (c) 2019 paddlepaddle authors. all rights reserved. -# # licensed under the apache license, version 2.0 (the "license"); # you may not use this file except in compliance with the license. # you may obtain a copy of the license at @@ -11,6 +10,7 @@ # without warranties or conditions of any kind, either express or implied. # see the license for the specific language governing permissions and # limitations under the license. +import hashlib import unittest import os import numpy as np @@ -21,16 +21,20 @@ import functools import contextlib from PIL import Image, ImageEnhance import math -from paddle.dataset.common import download +from paddle.dataset.common import download, md5file +import tarfile random.seed(0) np.random.seed(0) DATA_DIM = 224 - SIZE_FLOAT32 = 4 SIZE_INT64 = 8 - +FULL_SIZE_BYTES = 30106000008 +FULL_IMAGES = 50000 +DATA_DIR_NAME = 'ILSVRC2012' +IMG_DIR_NAME = 'var' +TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2' img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) @@ -70,19 +74,9 @@ def process_image(img_path, mode, color_jitter, rotate): return img -def download_unzip(): - int8_download = 'int8/download' - - target_name = 'data' - - cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + - int8_download) - - target_folder = os.path.join(cache_folder, target_name) - +def download_concat(cache_folder, zip_path): data_urls = [] data_md5s = [] - data_urls.append( 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa' ) @@ -91,72 +85,138 @@ def download_unzip(): 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab' ) data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5') - file_names = [] - + print("Downloading full ImageNet Validation dataset ...") for i in range(0, len(data_urls)): download(data_urls[i], cache_folder, data_md5s[i]) - file_names.append(data_urls[i].split('/')[-1]) - - zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') - + file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1]) + file_names.append(file_name) + print("Downloaded part {0}\n".format(file_name)) if not os.path.exists(zip_path): - cat_command = 'cat' - for file_name in file_names: - cat_command += ' ' + os.path.join(cache_folder, file_name) - cat_command += ' > ' + zip_path - os.system(cat_command) - print('Data is downloaded at {0}\n').format(zip_path) - - if not os.path.exists(target_folder): - cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, zip_path) - os.system(cmd) - print('Data is unzipped at {0}\n'.format(target_folder)) - - data_dir = os.path.join(target_folder, 'ILSVRC2012') - print('ILSVRC2012 full val set at {0}\n'.format(data_dir)) - return data_dir + with open(zip_path, "w+") as outfile: + for fname in file_names: + with open(fname) as infile: + outfile.write(infile.read()) + + +def extract(zip_path, extract_folder): + data_dir = os.path.join(extract_folder, DATA_DIR_NAME) + img_dir = os.path.join(data_dir, IMG_DIR_NAME) + print("Extracting...\n") + + if not (os.path.exists(img_dir) and + len(os.listdir(img_dir)) == FULL_IMAGES): + tar = tarfile.open(zip_path) + tar.extractall(path=extract_folder) + tar.close() + print('Extracted. Full Imagenet Validation dataset is located at {0}\n'. + format(data_dir)) + + +def print_processbar(done, total): + done_filled = done * '=' + empty_filled = (total - done) * ' ' + percentage_done = done * 100 / total + sys.stdout.write("\r[%s%s]%d%%" % + (done_filled, empty_filled, percentage_done)) + sys.stdout.flush() + + +def check_integrity(filename, target_hash): + print('\nThe binary file exists. Checking file integrity...\n') + md = hashlib.md5() + count = 0 + total_parts = 50 + chunk_size = 8192 + onepart = FULL_SIZE_BYTES / chunk_size / total_parts + with open(filename) as ifs: + while True: + buf = ifs.read(8192) + if count % onepart == 0: + done = count / onepart + print_processbar(done, total_parts) + count = count + 1 + if not buf: + break + md.update(buf) + hash1 = md.hexdigest() + if hash1 == target_hash: + return True + else: + return False -def reader(): - data_dir = download_unzip() - file_list = os.path.join(data_dir, 'val_list.txt') - output_file = os.path.join(data_dir, 'int8_full_val.bin') +def convert(file_list, data_dir, output_file): + print('Converting 50000 images to binary file ...\n') with open(file_list) as flist: lines = [line.strip() for line in flist] num_images = len(lines) - if not os.path.exists(output_file): - print( - 'Preprocessing to binary file......\n' - ) - with open(output_file, "w+b") as of: - #save num_images(int64_t) to file - of.seek(0) - num = np.array(int(num_images)).astype('int64') - of.write(num.tobytes()) - for idx, line in enumerate(lines): - img_path, label = line.split() - img_path = os.path.join(data_dir, img_path) - if not os.path.exists(img_path): - continue - - #save image(float32) to file - img = process_image( - img_path, 'val', color_jitter=False, rotate=False) - np_img = np.array(img) - of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 - * idx) - of.write(np_img.astype('float32').tobytes()) - - #save label(int64_t) to file - label_int = (int)(label) - np_label = np.array(label_int) - of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 - * num_images + idx * SIZE_INT64) - of.write(np_label.astype('int64').tobytes()) - - print('The preprocessed binary file path {}\n'.format(output_file)) + with open(output_file, "w+b") as ofs: + #save num_images(int64_t) to file + ofs.seek(0) + num = np.array(int(num_images)).astype('int64') + ofs.write(num.tobytes()) + per_parts = 1000 + full_parts = FULL_IMAGES / per_parts + print_processbar(0, full_parts) + for idx, line in enumerate(lines): + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + if not os.path.exists(img_path): + continue + + #save image(float32) to file + img = process_image( + img_path, 'val', color_jitter=False, rotate=False) + np_img = np.array(img) + ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * + idx) + ofs.write(np_img.astype('float32').tobytes()) + ofs.flush() + + #save label(int64_t) to file + label_int = (int)(label) + np_label = np.array(label_int) + ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * + num_images + idx * SIZE_INT64) + ofs.write(np_label.astype('int64').tobytes()) + ofs.flush() + if (idx + 1) % per_parts == 0: + done = (idx + 1) / per_parts + print_processbar(done, full_parts) + print("Conversion finished.") + + +def run_convert(): + print('Start to download and convert 50000 images to binary file...') + cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download') + extract_folder = os.path.join(cache_folder, 'full_data') + data_dir = os.path.join(extract_folder, DATA_DIR_NAME) + file_list = os.path.join(data_dir, 'val_list.txt') + zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') + output_file = os.path.join(cache_folder, 'int8_full_val.bin') + retry = 0 + try_limit = 3 + + while not (os.path.exists(output_file) and + os.path.getsize(output_file) == FULL_SIZE_BYTES and + check_integrity(output_file, TARGET_HASH)): + if os.path.exists(output_file): + sys.stderr.write( + "\n\nThe existing binary file is broken. Start to generate new one...\n\n". + format(output_file)) + os.remove(output_file) + if retry < try_limit: + retry = retry + 1 + else: + raise RuntimeError( + "Can not convert the dataset to binary file with try limit {0}". + format(try_limit)) + download_concat(cache_folder, zip_path) + extract(zip_path, extract_folder) + convert(file_list, data_dir, output_file) + print("\nSuccess! The binary file can be found at {0}".format(output_file)) if __name__ == '__main__': - reader() + run_convert()