提交 2ca0de3c 编写于 作者: L lidanqing 提交者: Tao Luo

fix preprocess script with processbar, integrity check and logs (#16608)

* fix preprocess script with processbar, integrity check and logs
test=develop

* delete unnecessary empty lines, change function name
test=develop
上级 c797aed8
# copyright (c) 2019 paddlepaddle authors. all rights reserved. # copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license"); # licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license. # you may not use this file except in compliance with the license.
# you may obtain a copy of the license at # you may obtain a copy of the license at
...@@ -11,6 +10,7 @@ ...@@ -11,6 +10,7 @@
# without warranties or conditions of any kind, either express or implied. # without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and # see the license for the specific language governing permissions and
# limitations under the license. # limitations under the license.
import hashlib
import unittest import unittest
import os import os
import numpy as np import numpy as np
...@@ -21,16 +21,20 @@ import functools ...@@ -21,16 +21,20 @@ import functools
import contextlib import contextlib
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
import math import math
from paddle.dataset.common import download from paddle.dataset.common import download, md5file
import tarfile
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
DATA_DIM = 224 DATA_DIM = 224
SIZE_FLOAT32 = 4 SIZE_FLOAT32 = 4
SIZE_INT64 = 8 SIZE_INT64 = 8
FULL_SIZE_BYTES = 30106000008
FULL_IMAGES = 50000
DATA_DIR_NAME = 'ILSVRC2012'
IMG_DIR_NAME = 'var'
TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
...@@ -70,19 +74,9 @@ def process_image(img_path, mode, color_jitter, rotate): ...@@ -70,19 +74,9 @@ def process_image(img_path, mode, color_jitter, rotate):
return img return img
def download_unzip(): def download_concat(cache_folder, zip_path):
int8_download = 'int8/download'
target_name = 'data'
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
int8_download)
target_folder = os.path.join(cache_folder, target_name)
data_urls = [] data_urls = []
data_md5s = [] data_md5s = []
data_urls.append( data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa' 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa'
) )
...@@ -91,72 +85,138 @@ def download_unzip(): ...@@ -91,72 +85,138 @@ def download_unzip():
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab' 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab'
) )
data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5') data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5')
file_names = [] file_names = []
print("Downloading full ImageNet Validation dataset ...")
for i in range(0, len(data_urls)): for i in range(0, len(data_urls)):
download(data_urls[i], cache_folder, data_md5s[i]) download(data_urls[i], cache_folder, data_md5s[i])
file_names.append(data_urls[i].split('/')[-1]) file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1])
file_names.append(file_name)
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') print("Downloaded part {0}\n".format(file_name))
if not os.path.exists(zip_path): if not os.path.exists(zip_path):
cat_command = 'cat' with open(zip_path, "w+") as outfile:
for file_name in file_names: for fname in file_names:
cat_command += ' ' + os.path.join(cache_folder, file_name) with open(fname) as infile:
cat_command += ' > ' + zip_path outfile.write(infile.read())
os.system(cat_command)
print('Data is downloaded at {0}\n').format(zip_path)
def extract(zip_path, extract_folder):
if not os.path.exists(target_folder): data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, zip_path) img_dir = os.path.join(data_dir, IMG_DIR_NAME)
os.system(cmd) print("Extracting...\n")
print('Data is unzipped at {0}\n'.format(target_folder))
if not (os.path.exists(img_dir) and
data_dir = os.path.join(target_folder, 'ILSVRC2012') len(os.listdir(img_dir)) == FULL_IMAGES):
print('ILSVRC2012 full val set at {0}\n'.format(data_dir)) tar = tarfile.open(zip_path)
return data_dir tar.extractall(path=extract_folder)
tar.close()
print('Extracted. Full Imagenet Validation dataset is located at {0}\n'.
format(data_dir))
def print_processbar(done, total):
done_filled = done * '='
empty_filled = (total - done) * ' '
percentage_done = done * 100 / total
sys.stdout.write("\r[%s%s]%d%%" %
(done_filled, empty_filled, percentage_done))
sys.stdout.flush()
def check_integrity(filename, target_hash):
print('\nThe binary file exists. Checking file integrity...\n')
md = hashlib.md5()
count = 0
total_parts = 50
chunk_size = 8192
onepart = FULL_SIZE_BYTES / chunk_size / total_parts
with open(filename) as ifs:
while True:
buf = ifs.read(8192)
if count % onepart == 0:
done = count / onepart
print_processbar(done, total_parts)
count = count + 1
if not buf:
break
md.update(buf)
hash1 = md.hexdigest()
if hash1 == target_hash:
return True
else:
return False
def reader(): def convert(file_list, data_dir, output_file):
data_dir = download_unzip() print('Converting 50000 images to binary file ...\n')
file_list = os.path.join(data_dir, 'val_list.txt')
output_file = os.path.join(data_dir, 'int8_full_val.bin')
with open(file_list) as flist: with open(file_list) as flist:
lines = [line.strip() for line in flist] lines = [line.strip() for line in flist]
num_images = len(lines) num_images = len(lines)
if not os.path.exists(output_file): with open(output_file, "w+b") as ofs:
print( #save num_images(int64_t) to file
'Preprocessing to binary file...<num_images><all images><all labels>...\n' ofs.seek(0)
) num = np.array(int(num_images)).astype('int64')
with open(output_file, "w+b") as of: ofs.write(num.tobytes())
#save num_images(int64_t) to file per_parts = 1000
of.seek(0) full_parts = FULL_IMAGES / per_parts
num = np.array(int(num_images)).astype('int64') print_processbar(0, full_parts)
of.write(num.tobytes()) for idx, line in enumerate(lines):
for idx, line in enumerate(lines): img_path, label = line.split()
img_path, label = line.split() img_path = os.path.join(data_dir, img_path)
img_path = os.path.join(data_dir, img_path) if not os.path.exists(img_path):
if not os.path.exists(img_path): continue
continue
#save image(float32) to file
#save image(float32) to file img = process_image(
img = process_image( img_path, 'val', color_jitter=False, rotate=False)
img_path, 'val', color_jitter=False, rotate=False) np_img = np.array(img)
np_img = np.array(img) ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 idx)
* idx) ofs.write(np_img.astype('float32').tobytes())
of.write(np_img.astype('float32').tobytes()) ofs.flush()
#save label(int64_t) to file #save label(int64_t) to file
label_int = (int)(label) label_int = (int)(label)
np_label = np.array(label_int) np_label = np.array(label_int)
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
* num_images + idx * SIZE_INT64) num_images + idx * SIZE_INT64)
of.write(np_label.astype('int64').tobytes()) ofs.write(np_label.astype('int64').tobytes())
ofs.flush()
print('The preprocessed binary file path {}\n'.format(output_file)) if (idx + 1) % per_parts == 0:
done = (idx + 1) / per_parts
print_processbar(done, full_parts)
print("Conversion finished.")
def run_convert():
print('Start to download and convert 50000 images to binary file...')
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download')
extract_folder = os.path.join(cache_folder, 'full_data')
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
file_list = os.path.join(data_dir, 'val_list.txt')
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
output_file = os.path.join(cache_folder, 'int8_full_val.bin')
retry = 0
try_limit = 3
while not (os.path.exists(output_file) and
os.path.getsize(output_file) == FULL_SIZE_BYTES and
check_integrity(output_file, TARGET_HASH)):
if os.path.exists(output_file):
sys.stderr.write(
"\n\nThe existing binary file is broken. Start to generate new one...\n\n".
format(output_file))
os.remove(output_file)
if retry < try_limit:
retry = retry + 1
else:
raise RuntimeError(
"Can not convert the dataset to binary file with try limit {0}".
format(try_limit))
download_concat(cache_folder, zip_path)
extract(zip_path, extract_folder)
convert(file_list, data_dir, output_file)
print("\nSuccess! The binary file can be found at {0}".format(output_file))
if __name__ == '__main__': if __name__ == '__main__':
reader() run_convert()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册