提交 de02d40e 编写于 作者: L lidanqing

improve preprocess script and read from tar

test=develop
上级 a67fbffd
......@@ -19,10 +19,11 @@ import sys
import random
import functools
import contextlib
from PIL import Image, ImageEnhance
from PIL import Image
import math
from paddle.dataset.common import download, md5file
from paddle.dataset.common import download
import tarfile
import StringIO
random.seed(0)
np.random.seed(0)
......@@ -32,9 +33,11 @@ SIZE_FLOAT32 = 4
SIZE_INT64 = 8
FULL_SIZE_BYTES = 30106000008
FULL_IMAGES = 50000
DATA_DIR_NAME = 'ILSVRC2012'
IMG_DIR_NAME = 'var'
TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2'
TARGET_HASH = '22d2e0008dca693916d9595a5ea3ded8'
FOLDER_NAME = "ILSVRC2012/"
VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt"
CHUNK_SIZE = 8192
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
......@@ -62,8 +65,7 @@ def crop_image(img, target_size, center):
return img
def process_image(img_path, mode, color_jitter, rotate):
img = Image.open(img_path)
def process_image(img):
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
......@@ -99,26 +101,11 @@ def download_concat(cache_folder, zip_path):
outfile.write(infile.read())
def extract(zip_path, extract_folder):
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
img_dir = os.path.join(data_dir, IMG_DIR_NAME)
print("Extracting...\n")
if not (os.path.exists(img_dir) and
len(os.listdir(img_dir)) == FULL_IMAGES):
tar = tarfile.open(zip_path)
tar.extractall(path=extract_folder)
tar.close()
print('Extracted. Full Imagenet Validation dataset is located at {0}\n'.
format(data_dir))
def print_processbar(done, total):
done_filled = done * '='
empty_filled = (total - done) * ' '
percentage_done = done * 100 / total
def print_processbar(done_percentage):
done_filled = done_percentage * '='
empty_filled = (100 - done_percentage) * ' '
sys.stdout.write("\r[%s%s]%d%%" %
(done_filled, empty_filled, percentage_done))
(done_filled, empty_filled, done_percentage))
sys.stdout.flush()
......@@ -126,15 +113,13 @@ def check_integrity(filename, target_hash):
print('\nThe binary file exists. Checking file integrity...\n')
md = hashlib.md5()
count = 0
total_parts = 50
chunk_size = 8192
onepart = FULL_SIZE_BYTES / chunk_size / total_parts
onepart = FULL_SIZE_BYTES / CHUNK_SIZE / 100
with open(filename) as ifs:
while True:
buf = ifs.read(8192)
buf = ifs.read(CHUNK_SIZE)
if count % onepart == 0:
done = count / onepart
print_processbar(done, total_parts)
print_processbar(done)
count = count + 1
if not buf:
break
......@@ -146,54 +131,61 @@ def check_integrity(filename, target_hash):
return False
def convert(file_list, data_dir, output_file):
def convert(tar_file, output_file):
print('Converting 50000 images to binary file ...\n')
with open(file_list) as flist:
lines = [line.strip() for line in flist]
num_images = len(lines)
tar = tarfile.open(name=tar_file, mode='r:gz')
print_processbar(0)
dataset = {}
for tarInfo in tar:
if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME:
dataset[tarInfo.name] = tar.extractfile(tarInfo).read()
with open(output_file, "w+b") as ofs:
#save num_images(int64_t) to file
ofs.seek(0)
num = np.array(int(num_images)).astype('int64')
num = np.array(int(FULL_IMAGES)).astype('int64')
ofs.write(num.tobytes())
per_parts = 1000
full_parts = FULL_IMAGES / per_parts
print_processbar(0, full_parts)
for idx, line in enumerate(lines):
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
if not os.path.exists(img_path):
continue
#save image(float32) to file
img = process_image(
img_path, 'val', color_jitter=False, rotate=False)
per_percentage = FULL_IMAGES / 100
idx = 0
for imagedata in dataset.values():
img = Image.open(StringIO.StringIO(imagedata))
img = process_image(img)
np_img = np.array(img)
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
idx)
ofs.write(np_img.astype('float32').tobytes())
ofs.flush()
if idx % per_percentage == 0:
print_processbar(idx / per_percentage)
idx = idx + 1
val_info = tar.getmember(VALLIST_TAR_NAME)
val_list = tar.extractfile(val_info).read()
#save label(int64_t) to file
lines = val_list.split('\n')
val_dict = {}
for line_idx, line in enumerate(lines):
if line_idx == FULL_IMAGES:
break
name, label = line.split()
val_dict[name] = label
for img_name in dataset.keys():
remove_len = (len(FOLDER_NAME))
img_name_prim = img_name[remove_len:]
label = val_dict[img_name_prim]
label_int = (int)(label)
np_label = np.array(label_int)
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
num_images + idx * SIZE_INT64)
ofs.write(np_label.astype('int64').tobytes())
ofs.flush()
if (idx + 1) % per_parts == 0:
done = (idx + 1) / per_parts
print_processbar(done, full_parts)
print_processbar(100)
tar.close()
print("Conversion finished.")
def run_convert():
print('Start to download and convert 50000 images to binary file...')
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download')
extract_folder = os.path.join(cache_folder, 'full_data')
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
file_list = os.path.join(data_dir, 'val_list.txt')
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz.partaa')
output_file = os.path.join(cache_folder, 'int8_full_val.bin')
retry = 0
try_limit = 3
......@@ -213,8 +205,7 @@ def run_convert():
"Can not convert the dataset to binary file with try limit {0}".
format(try_limit))
download_concat(cache_folder, zip_path)
extract(zip_path, extract_folder)
convert(file_list, data_dir, output_file)
convert(zip_path, output_file)
print("\nSuccess! The binary file can be found at {0}".format(output_file))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册