提交 de02d40e 编写于 作者: L lidanqing

improve preprocess script and read from tar

test=develop
上级 a67fbffd
...@@ -19,10 +19,11 @@ import sys ...@@ -19,10 +19,11 @@ import sys
import random import random
import functools import functools
import contextlib import contextlib
from PIL import Image, ImageEnhance from PIL import Image
import math import math
from paddle.dataset.common import download, md5file from paddle.dataset.common import download
import tarfile import tarfile
import StringIO
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -32,9 +33,11 @@ SIZE_FLOAT32 = 4 ...@@ -32,9 +33,11 @@ SIZE_FLOAT32 = 4
SIZE_INT64 = 8 SIZE_INT64 = 8
FULL_SIZE_BYTES = 30106000008 FULL_SIZE_BYTES = 30106000008
FULL_IMAGES = 50000 FULL_IMAGES = 50000
DATA_DIR_NAME = 'ILSVRC2012' TARGET_HASH = '22d2e0008dca693916d9595a5ea3ded8'
IMG_DIR_NAME = 'var' FOLDER_NAME = "ILSVRC2012/"
TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2' VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt"
CHUNK_SIZE = 8192
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
...@@ -62,8 +65,7 @@ def crop_image(img, target_size, center): ...@@ -62,8 +65,7 @@ def crop_image(img, target_size, center):
return img return img
def process_image(img_path, mode, color_jitter, rotate): def process_image(img):
img = Image.open(img_path)
img = resize_short(img, target_size=256) img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True) img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB': if img.mode != 'RGB':
...@@ -99,26 +101,11 @@ def download_concat(cache_folder, zip_path): ...@@ -99,26 +101,11 @@ def download_concat(cache_folder, zip_path):
outfile.write(infile.read()) outfile.write(infile.read())
def extract(zip_path, extract_folder): def print_processbar(done_percentage):
data_dir = os.path.join(extract_folder, DATA_DIR_NAME) done_filled = done_percentage * '='
img_dir = os.path.join(data_dir, IMG_DIR_NAME) empty_filled = (100 - done_percentage) * ' '
print("Extracting...\n")
if not (os.path.exists(img_dir) and
len(os.listdir(img_dir)) == FULL_IMAGES):
tar = tarfile.open(zip_path)
tar.extractall(path=extract_folder)
tar.close()
print('Extracted. Full Imagenet Validation dataset is located at {0}\n'.
format(data_dir))
def print_processbar(done, total):
done_filled = done * '='
empty_filled = (total - done) * ' '
percentage_done = done * 100 / total
sys.stdout.write("\r[%s%s]%d%%" % sys.stdout.write("\r[%s%s]%d%%" %
(done_filled, empty_filled, percentage_done)) (done_filled, empty_filled, done_percentage))
sys.stdout.flush() sys.stdout.flush()
...@@ -126,15 +113,13 @@ def check_integrity(filename, target_hash): ...@@ -126,15 +113,13 @@ def check_integrity(filename, target_hash):
print('\nThe binary file exists. Checking file integrity...\n') print('\nThe binary file exists. Checking file integrity...\n')
md = hashlib.md5() md = hashlib.md5()
count = 0 count = 0
total_parts = 50 onepart = FULL_SIZE_BYTES / CHUNK_SIZE / 100
chunk_size = 8192
onepart = FULL_SIZE_BYTES / chunk_size / total_parts
with open(filename) as ifs: with open(filename) as ifs:
while True: while True:
buf = ifs.read(8192) buf = ifs.read(CHUNK_SIZE)
if count % onepart == 0: if count % onepart == 0:
done = count / onepart done = count / onepart
print_processbar(done, total_parts) print_processbar(done)
count = count + 1 count = count + 1
if not buf: if not buf:
break break
...@@ -146,54 +131,61 @@ def check_integrity(filename, target_hash): ...@@ -146,54 +131,61 @@ def check_integrity(filename, target_hash):
return False return False
def convert(file_list, data_dir, output_file): def convert(tar_file, output_file):
print('Converting 50000 images to binary file ...\n') print('Converting 50000 images to binary file ...\n')
with open(file_list) as flist: tar = tarfile.open(name=tar_file, mode='r:gz')
lines = [line.strip() for line in flist]
num_images = len(lines) print_processbar(0)
with open(output_file, "w+b") as ofs:
#save num_images(int64_t) to file dataset = {}
ofs.seek(0) for tarInfo in tar:
num = np.array(int(num_images)).astype('int64') if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME:
ofs.write(num.tobytes()) dataset[tarInfo.name] = tar.extractfile(tarInfo).read()
per_parts = 1000
full_parts = FULL_IMAGES / per_parts with open(output_file, "w+b") as ofs:
print_processbar(0, full_parts) ofs.seek(0)
for idx, line in enumerate(lines): num = np.array(int(FULL_IMAGES)).astype('int64')
img_path, label = line.split() ofs.write(num.tobytes())
img_path = os.path.join(data_dir, img_path)
if not os.path.exists(img_path): per_percentage = FULL_IMAGES / 100
continue
idx = 0
#save image(float32) to file for imagedata in dataset.values():
img = process_image( img = Image.open(StringIO.StringIO(imagedata))
img_path, 'val', color_jitter=False, rotate=False) img = process_image(img)
np_img = np.array(img) np_img = np.array(img)
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * ofs.write(np_img.astype('float32').tobytes())
idx) if idx % per_percentage == 0:
ofs.write(np_img.astype('float32').tobytes()) print_processbar(idx / per_percentage)
ofs.flush() idx = idx + 1
#save label(int64_t) to file val_info = tar.getmember(VALLIST_TAR_NAME)
label_int = (int)(label) val_list = tar.extractfile(val_info).read()
np_label = np.array(label_int)
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * lines = val_list.split('\n')
num_images + idx * SIZE_INT64) val_dict = {}
ofs.write(np_label.astype('int64').tobytes()) for line_idx, line in enumerate(lines):
ofs.flush() if line_idx == FULL_IMAGES:
if (idx + 1) % per_parts == 0: break
done = (idx + 1) / per_parts name, label = line.split()
print_processbar(done, full_parts) val_dict[name] = label
for img_name in dataset.keys():
remove_len = (len(FOLDER_NAME))
img_name_prim = img_name[remove_len:]
label = val_dict[img_name_prim]
label_int = (int)(label)
np_label = np.array(label_int)
ofs.write(np_label.astype('int64').tobytes())
print_processbar(100)
tar.close()
print("Conversion finished.") print("Conversion finished.")
def run_convert(): def run_convert():
print('Start to download and convert 50000 images to binary file...') print('Start to download and convert 50000 images to binary file...')
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download') cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download')
extract_folder = os.path.join(cache_folder, 'full_data') zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz.partaa')
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
file_list = os.path.join(data_dir, 'val_list.txt')
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
output_file = os.path.join(cache_folder, 'int8_full_val.bin') output_file = os.path.join(cache_folder, 'int8_full_val.bin')
retry = 0 retry = 0
try_limit = 3 try_limit = 3
...@@ -213,8 +205,7 @@ def run_convert(): ...@@ -213,8 +205,7 @@ def run_convert():
"Can not convert the dataset to binary file with try limit {0}". "Can not convert the dataset to binary file with try limit {0}".
format(try_limit)) format(try_limit))
download_concat(cache_folder, zip_path) download_concat(cache_folder, zip_path)
extract(zip_path, extract_folder) convert(zip_path, output_file)
convert(file_list, data_dir, output_file)
print("\nSuccess! The binary file can be found at {0}".format(output_file)) print("\nSuccess! The binary file can be found at {0}".format(output_file))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册