提交 b46e467a 编写于 作者: L lidanqing

add wget and unzip part and change data_dir

test=develop
上级 894aa9b2
...@@ -21,6 +21,7 @@ import functools ...@@ -21,6 +21,7 @@ import functools
import contextlib import contextlib
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
import math import math
from paddle.dataset.common import download
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -30,8 +31,6 @@ DATA_DIM = 224 ...@@ -30,8 +31,6 @@ DATA_DIM = 224
SIZE_FLOAT32 = 4 SIZE_FLOAT32 = 4
SIZE_INT64 = 8 SIZE_INT64 = 8
DATA_DIR = './data/ILSVRC2012/data.bin'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
...@@ -71,15 +70,60 @@ def process_image(img_path, mode, color_jitter, rotate): ...@@ -71,15 +70,60 @@ def process_image(img_path, mode, color_jitter, rotate):
return img return img
def download_unzip():
tmp_folder = 'int8/download'
cache_folder = os.path.expanduser('~/.cache/' + tmp_folder)
data_urls = []
data_md5s = []
data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa'
)
data_md5s.append('60f6525b0e1d127f345641d75d41f0a8')
data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab'
)
data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5')
file_names = []
for i in range(0, len(data_urls)):
download(data_urls[i], tmp_folder, data_md5s[i])
file_names.append(data_urls[i].split('/')[-1])
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
if not os.path.exists(zip_path):
cat_command = 'cat'
for file_name in file_names:
cat_command += ' ' + os.path.join(cache_folder, file_name)
cat_command += ' > ' + zip_path
os.system(cat_command)
if not os.path.exists(cache_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(cache_folder, zip_path)
cmd = 'rm -rf {3} && ln -s {1} {0}'.format("data", cache_folder, zip_path)
os.system(cmd)
data_dir = os.path.expanduser(cache_folder + 'data')
return data_dir
def reader(): def reader():
data_dir = DATA_DIR data_dir = download_unzip()
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
bin_file = os.path.join(data_dir, 'data.bin') output_file = os.path.join(data_dir, 'int8_full_val.bin')
with open(file_list) as flist: with open(file_list) as flist:
lines = [line.strip() for line in flist] lines = [line.strip() for line in flist]
num_images = len(lines) num_images = len(lines)
with open(bin_file, "w+b") as of: with open(output_file, "w+b") as of:
#save num_images(int64_t) to file
of.seek(0) of.seek(0)
num = np.array(int(num_images)).astype('int64') num = np.array(int(num_images)).astype('int64')
of.write(num.tobytes()) of.write(num.tobytes())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册