未验证 提交 5b1565a7 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #16875 from lidanqing-intel/lidanqing/improve_preprocess_script

Improve preprocessing script and read from tar
...@@ -19,10 +19,11 @@ import sys ...@@ -19,10 +19,11 @@ import sys
import random import random
import functools import functools
import contextlib import contextlib
from PIL import Image, ImageEnhance from PIL import Image
import math import math
from paddle.dataset.common import download, md5file from paddle.dataset.common import download
import tarfile import tarfile
import StringIO
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -32,9 +33,11 @@ SIZE_FLOAT32 = 4 ...@@ -32,9 +33,11 @@ SIZE_FLOAT32 = 4
SIZE_INT64 = 8 SIZE_INT64 = 8
FULL_SIZE_BYTES = 30106000008 FULL_SIZE_BYTES = 30106000008
FULL_IMAGES = 50000 FULL_IMAGES = 50000
DATA_DIR_NAME = 'ILSVRC2012' TARGET_HASH = '22d2e0008dca693916d9595a5ea3ded8'
IMG_DIR_NAME = 'var' FOLDER_NAME = "ILSVRC2012/"
TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2' VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt"
CHUNK_SIZE = 8192
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
...@@ -62,8 +65,7 @@ def crop_image(img, target_size, center): ...@@ -62,8 +65,7 @@ def crop_image(img, target_size, center):
return img return img
def process_image(img_path, mode, color_jitter, rotate): def process_image(img):
img = Image.open(img_path)
img = resize_short(img, target_size=256) img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True) img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB': if img.mode != 'RGB':
...@@ -99,26 +101,11 @@ def download_concat(cache_folder, zip_path): ...@@ -99,26 +101,11 @@ def download_concat(cache_folder, zip_path):
outfile.write(infile.read()) outfile.write(infile.read())
def extract(zip_path, extract_folder): def print_processbar(done_percentage):
data_dir = os.path.join(extract_folder, DATA_DIR_NAME) done_filled = done_percentage * '='
img_dir = os.path.join(data_dir, IMG_DIR_NAME) empty_filled = (100 - done_percentage) * ' '
print("Extracting...\n")
if not (os.path.exists(img_dir) and
len(os.listdir(img_dir)) == FULL_IMAGES):
tar = tarfile.open(zip_path)
tar.extractall(path=extract_folder)
tar.close()
print('Extracted. Full Imagenet Validation dataset is located at {0}\n'.
format(data_dir))
def print_processbar(done, total):
done_filled = done * '='
empty_filled = (total - done) * ' '
percentage_done = done * 100 / total
sys.stdout.write("\r[%s%s]%d%%" % sys.stdout.write("\r[%s%s]%d%%" %
(done_filled, empty_filled, percentage_done)) (done_filled, empty_filled, done_percentage))
sys.stdout.flush() sys.stdout.flush()
...@@ -126,15 +113,13 @@ def check_integrity(filename, target_hash): ...@@ -126,15 +113,13 @@ def check_integrity(filename, target_hash):
print('\nThe binary file exists. Checking file integrity...\n') print('\nThe binary file exists. Checking file integrity...\n')
md = hashlib.md5() md = hashlib.md5()
count = 0 count = 0
total_parts = 50 onepart = FULL_SIZE_BYTES / CHUNK_SIZE / 100
chunk_size = 8192
onepart = FULL_SIZE_BYTES / chunk_size / total_parts
with open(filename) as ifs: with open(filename) as ifs:
while True: while True:
buf = ifs.read(8192) buf = ifs.read(CHUNK_SIZE)
if count % onepart == 0: if count % onepart == 0:
done = count / onepart done = count / onepart
print_processbar(done, total_parts) print_processbar(done)
count = count + 1 count = count + 1
if not buf: if not buf:
break break
...@@ -146,54 +131,61 @@ def check_integrity(filename, target_hash): ...@@ -146,54 +131,61 @@ def check_integrity(filename, target_hash):
return False return False
def convert(file_list, data_dir, output_file): def convert(tar_file, output_file):
print('Converting 50000 images to binary file ...\n') print('Converting 50000 images to binary file ...\n')
with open(file_list) as flist: tar = tarfile.open(name=tar_file, mode='r:gz')
lines = [line.strip() for line in flist]
num_images = len(lines) print_processbar(0)
dataset = {}
for tarInfo in tar:
if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME:
dataset[tarInfo.name] = tar.extractfile(tarInfo).read()
with open(output_file, "w+b") as ofs: with open(output_file, "w+b") as ofs:
#save num_images(int64_t) to file
ofs.seek(0) ofs.seek(0)
num = np.array(int(num_images)).astype('int64') num = np.array(int(FULL_IMAGES)).astype('int64')
ofs.write(num.tobytes()) ofs.write(num.tobytes())
per_parts = 1000
full_parts = FULL_IMAGES / per_parts per_percentage = FULL_IMAGES / 100
print_processbar(0, full_parts)
for idx, line in enumerate(lines): idx = 0
img_path, label = line.split() for imagedata in dataset.values():
img_path = os.path.join(data_dir, img_path) img = Image.open(StringIO.StringIO(imagedata))
if not os.path.exists(img_path): img = process_image(img)
continue
#save image(float32) to file
img = process_image(
img_path, 'val', color_jitter=False, rotate=False)
np_img = np.array(img) np_img = np.array(img)
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
idx)
ofs.write(np_img.astype('float32').tobytes()) ofs.write(np_img.astype('float32').tobytes())
ofs.flush() if idx % per_percentage == 0:
print_processbar(idx / per_percentage)
idx = idx + 1
val_info = tar.getmember(VALLIST_TAR_NAME)
val_list = tar.extractfile(val_info).read()
#save label(int64_t) to file lines = val_list.split('\n')
val_dict = {}
for line_idx, line in enumerate(lines):
if line_idx == FULL_IMAGES:
break
name, label = line.split()
val_dict[name] = label
for img_name in dataset.keys():
remove_len = (len(FOLDER_NAME))
img_name_prim = img_name[remove_len:]
label = val_dict[img_name_prim]
label_int = (int)(label) label_int = (int)(label)
np_label = np.array(label_int) np_label = np.array(label_int)
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
num_images + idx * SIZE_INT64)
ofs.write(np_label.astype('int64').tobytes()) ofs.write(np_label.astype('int64').tobytes())
ofs.flush() print_processbar(100)
if (idx + 1) % per_parts == 0: tar.close()
done = (idx + 1) / per_parts
print_processbar(done, full_parts)
print("Conversion finished.") print("Conversion finished.")
def run_convert(): def run_convert():
print('Start to download and convert 50000 images to binary file...') print('Start to download and convert 50000 images to binary file...')
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download') cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download')
extract_folder = os.path.join(cache_folder, 'full_data') zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz.partaa')
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
file_list = os.path.join(data_dir, 'val_list.txt')
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
output_file = os.path.join(cache_folder, 'int8_full_val.bin') output_file = os.path.join(cache_folder, 'int8_full_val.bin')
retry = 0 retry = 0
try_limit = 3 try_limit = 3
...@@ -213,8 +205,7 @@ def run_convert(): ...@@ -213,8 +205,7 @@ def run_convert():
"Can not convert the dataset to binary file with try limit {0}". "Can not convert the dataset to binary file with try limit {0}".
format(try_limit)) format(try_limit))
download_concat(cache_folder, zip_path) download_concat(cache_folder, zip_path)
extract(zip_path, extract_folder) convert(zip_path, output_file)
convert(file_list, data_dir, output_file)
print("\nSuccess! The binary file can be found at {0}".format(output_file)) print("\nSuccess! The binary file can be found at {0}".format(output_file))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册