From de02d40e98ba55ca6dc9ea9b88ffb527dce60aab Mon Sep 17 00:00:00 2001 From: lidanqing Date: Mon, 15 Apr 2019 05:23:21 +0200 Subject: [PATCH] improve preprocess script and read from tar test=develop --- .../api/full_ILSVRC2012_val_preprocess.py | 139 ++++++++---------- 1 file changed, 65 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py index 842865933f2..826c45311f4 100644 --- a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py +++ b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py @@ -19,10 +19,11 @@ import sys import random import functools import contextlib -from PIL import Image, ImageEnhance +from PIL import Image import math -from paddle.dataset.common import download, md5file +from paddle.dataset.common import download import tarfile +import StringIO random.seed(0) np.random.seed(0) @@ -32,9 +33,11 @@ SIZE_FLOAT32 = 4 SIZE_INT64 = 8 FULL_SIZE_BYTES = 30106000008 FULL_IMAGES = 50000 -DATA_DIR_NAME = 'ILSVRC2012' -IMG_DIR_NAME = 'var' -TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2' +TARGET_HASH = '22d2e0008dca693916d9595a5ea3ded8' +FOLDER_NAME = "ILSVRC2012/" +VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt" +CHUNK_SIZE = 8192 + img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) @@ -62,8 +65,7 @@ def crop_image(img, target_size, center): return img -def process_image(img_path, mode, color_jitter, rotate): - img = Image.open(img_path) +def process_image(img): img = resize_short(img, target_size=256) img = crop_image(img, target_size=DATA_DIM, center=True) if img.mode != 'RGB': @@ -99,26 +101,11 @@ def download_concat(cache_folder, zip_path): outfile.write(infile.read()) -def extract(zip_path, extract_folder): - data_dir = os.path.join(extract_folder, DATA_DIR_NAME) - img_dir = os.path.join(data_dir, IMG_DIR_NAME) - print("Extracting...\n") - - if not (os.path.exists(img_dir) and - len(os.listdir(img_dir)) == FULL_IMAGES): - tar = tarfile.open(zip_path) - tar.extractall(path=extract_folder) - tar.close() - print('Extracted. Full Imagenet Validation dataset is located at {0}\n'. - format(data_dir)) - - -def print_processbar(done, total): - done_filled = done * '=' - empty_filled = (total - done) * ' ' - percentage_done = done * 100 / total +def print_processbar(done_percentage): + done_filled = done_percentage * '=' + empty_filled = (100 - done_percentage) * ' ' sys.stdout.write("\r[%s%s]%d%%" % - (done_filled, empty_filled, percentage_done)) + (done_filled, empty_filled, done_percentage)) sys.stdout.flush() @@ -126,15 +113,13 @@ def check_integrity(filename, target_hash): print('\nThe binary file exists. Checking file integrity...\n') md = hashlib.md5() count = 0 - total_parts = 50 - chunk_size = 8192 - onepart = FULL_SIZE_BYTES / chunk_size / total_parts + onepart = FULL_SIZE_BYTES / CHUNK_SIZE / 100 with open(filename) as ifs: while True: - buf = ifs.read(8192) + buf = ifs.read(CHUNK_SIZE) if count % onepart == 0: done = count / onepart - print_processbar(done, total_parts) + print_processbar(done) count = count + 1 if not buf: break @@ -146,54 +131,61 @@ def check_integrity(filename, target_hash): return False -def convert(file_list, data_dir, output_file): +def convert(tar_file, output_file): print('Converting 50000 images to binary file ...\n') - with open(file_list) as flist: - lines = [line.strip() for line in flist] - num_images = len(lines) - with open(output_file, "w+b") as ofs: - #save num_images(int64_t) to file - ofs.seek(0) - num = np.array(int(num_images)).astype('int64') - ofs.write(num.tobytes()) - per_parts = 1000 - full_parts = FULL_IMAGES / per_parts - print_processbar(0, full_parts) - for idx, line in enumerate(lines): - img_path, label = line.split() - img_path = os.path.join(data_dir, img_path) - if not os.path.exists(img_path): - continue - - #save image(float32) to file - img = process_image( - img_path, 'val', color_jitter=False, rotate=False) - np_img = np.array(img) - ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * - idx) - ofs.write(np_img.astype('float32').tobytes()) - ofs.flush() - - #save label(int64_t) to file - label_int = (int)(label) - np_label = np.array(label_int) - ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * - num_images + idx * SIZE_INT64) - ofs.write(np_label.astype('int64').tobytes()) - ofs.flush() - if (idx + 1) % per_parts == 0: - done = (idx + 1) / per_parts - print_processbar(done, full_parts) + tar = tarfile.open(name=tar_file, mode='r:gz') + + print_processbar(0) + + dataset = {} + for tarInfo in tar: + if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME: + dataset[tarInfo.name] = tar.extractfile(tarInfo).read() + + with open(output_file, "w+b") as ofs: + ofs.seek(0) + num = np.array(int(FULL_IMAGES)).astype('int64') + ofs.write(num.tobytes()) + + per_percentage = FULL_IMAGES / 100 + + idx = 0 + for imagedata in dataset.values(): + img = Image.open(StringIO.StringIO(imagedata)) + img = process_image(img) + np_img = np.array(img) + ofs.write(np_img.astype('float32').tobytes()) + if idx % per_percentage == 0: + print_processbar(idx / per_percentage) + idx = idx + 1 + + val_info = tar.getmember(VALLIST_TAR_NAME) + val_list = tar.extractfile(val_info).read() + + lines = val_list.split('\n') + val_dict = {} + for line_idx, line in enumerate(lines): + if line_idx == FULL_IMAGES: + break + name, label = line.split() + val_dict[name] = label + + for img_name in dataset.keys(): + remove_len = (len(FOLDER_NAME)) + img_name_prim = img_name[remove_len:] + label = val_dict[img_name_prim] + label_int = (int)(label) + np_label = np.array(label_int) + ofs.write(np_label.astype('int64').tobytes()) + print_processbar(100) + tar.close() print("Conversion finished.") def run_convert(): print('Start to download and convert 50000 images to binary file...') cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download') - extract_folder = os.path.join(cache_folder, 'full_data') - data_dir = os.path.join(extract_folder, DATA_DIR_NAME) - file_list = os.path.join(data_dir, 'val_list.txt') - zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') + zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz.partaa') output_file = os.path.join(cache_folder, 'int8_full_val.bin') retry = 0 try_limit = 3 @@ -213,8 +205,7 @@ def run_convert(): "Can not convert the dataset to binary file with try limit {0}". format(try_limit)) download_concat(cache_folder, zip_path) - extract(zip_path, extract_folder) - convert(file_list, data_dir, output_file) + convert(zip_path, output_file) print("\nSuccess! The binary file can be found at {0}".format(output_file)) -- GitLab