From 0d656996bf8768a11e1c3cb796b895dbab00fadb Mon Sep 17 00:00:00 2001 From: lidanqing Date: Thu, 28 Mar 2019 17:06:36 +0100 Subject: [PATCH] fix some bugs of unzip and reading val list test=develop --- .../api/full_ILSVRC2012_val_preprocess.py | 83 ++++++++++--------- 1 file changed, 46 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py index 99b892ed9..4d968c83d 100644 --- a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py +++ b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py @@ -71,10 +71,14 @@ def process_image(img_path, mode, color_jitter, rotate): def download_unzip(): + int8_download = 'int8/download' - tmp_folder = 'int8/download' + target_name = 'data' - cache_folder = os.path.expanduser('~/.cache/' + tmp_folder) + cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + + int8_download) + + target_folder = os.path.join(cache_folder, target_name) data_urls = [] data_md5s = [] @@ -89,8 +93,9 @@ def download_unzip(): data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5') file_names = [] + for i in range(0, len(data_urls)): - download(data_urls[i], tmp_folder, data_md5s[i]) + download(data_urls[i], cache_folder, data_md5s[i]) file_names.append(data_urls[i].split('/')[-1]) zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') @@ -101,16 +106,15 @@ def download_unzip(): cat_command += ' ' + os.path.join(cache_folder, file_name) cat_command += ' > ' + zip_path os.system(cat_command) + print('Data is downloaded at {0}\n').format(zip_path) - if not os.path.exists(cache_folder): - cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(cache_folder, zip_path) - - cmd = 'rm -rf {3} && ln -s {1} {0}'.format("data", cache_folder, zip_path) - - os.system(cmd) - - data_dir = os.path.expanduser(cache_folder + 'data') + if not os.path.exists(target_folder): + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, zip_path) + os.system(cmd) + print('Data is unzipped at {0}\n'.format(target_folder)) + data_dir = os.path.join(target_folder, 'ILSVRC2012') + print('ILSVRC2012 full val set at {0}\n'.format(data_dir)) return data_dir @@ -121,32 +125,37 @@ def reader(): with open(file_list) as flist: lines = [line.strip() for line in flist] num_images = len(lines) - - with open(output_file, "w+b") as of: - #save num_images(int64_t) to file - of.seek(0) - num = np.array(int(num_images)).astype('int64') - of.write(num.tobytes()) - for idx, line in enumerate(lines): - img_path, label = line.split() - img_path = os.path.join(data_dir, img_path) - if not os.path.exists(img_path): - continue - - #save image(float32) to file - img = process_image( - img_path, 'val', color_jitter=False, rotate=False) - np_img = np.array(img) - of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * - idx) - of.write(np_img.astype('float32').tobytes()) - - #save label(int64_t) to file - label_int = (int)(label) - np_label = np.array(label_int) - of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * - num_images + idx * SIZE_INT64) - of.write(np_label.astype('int64').tobytes()) + if not os.path.exists(output_file): + print( + 'Preprocessing to binary file......\n' + ) + with open(output_file, "w+b") as of: + #save num_images(int64_t) to file + of.seek(0) + num = np.array(int(num_images)).astype('int64') + of.write(num.tobytes()) + for idx, line in enumerate(lines): + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + if not os.path.exists(img_path): + continue + + #save image(float32) to file + img = process_image( + img_path, 'val', color_jitter=False, rotate=False) + np_img = np.array(img) + of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 + * idx) + of.write(np_img.astype('float32').tobytes()) + + #save label(int64_t) to file + label_int = (int)(label) + np_label = np.array(label_int) + of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 + * num_images + idx * SIZE_INT64) + of.write(np_label.astype('int64').tobytes()) + + print('The preprocessed binary file path {}\n'.format(output_file)) if __name__ == '__main__': -- GitLab