未验证 提交 46989e88 编写于 作者: L lidanqing 提交者: GitHub

Fix python3 incompatibility issues (#30698)

* solve python3 incompatibility issues

* update checksum
上级 a12b6bb9
......@@ -13,6 +13,7 @@
import hashlib
import unittest
import os
import io
import numpy as np
import time
import sys
......@@ -23,10 +24,9 @@ from PIL import Image
import math
from paddle.dataset.common import download
import tarfile
from six.moves import StringIO
import argparse
import shutil
random.seed(0)
np.random.seed(0)
DATA_DIM = 224
......@@ -34,7 +34,7 @@ SIZE_FLOAT32 = 4
SIZE_INT64 = 8
FULL_SIZE_BYTES = 30106000008
FULL_IMAGES = 50000
TARGET_HASH = '22d2e0008dca693916d9595a5ea3ded8'
TARGET_HASH = '0be07c2c23296b97dad83c626682c66a'
FOLDER_NAME = "ILSVRC2012/"
VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt"
CHUNK_SIZE = 8192
......@@ -55,8 +55,8 @@ def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
w_start = (width - size) // 2
h_start = (height - size) // 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
......@@ -95,11 +95,9 @@ def download_concat(cache_folder, zip_path):
file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1])
file_names.append(file_name)
print("Downloaded part {0}\n".format(file_name))
if not os.path.exists(zip_path):
with open(zip_path, "w+") as outfile:
with open(zip_path, "wb") as outfile:
for fname in file_names:
with open(fname) as infile:
outfile.write(infile.read())
shutil.copyfileobj(open(fname, 'rb'), outfile)
def print_processbar(done_percentage):
......@@ -114,12 +112,12 @@ def check_integrity(filename, target_hash):
print('\nThe binary file exists. Checking file integrity...\n')
md = hashlib.md5()
count = 0
onepart = FULL_SIZE_BYTES / CHUNK_SIZE / 100
with open(filename) as ifs:
onepart = FULL_SIZE_BYTES // CHUNK_SIZE // 100
with open(filename, 'rb') as ifs:
while True:
buf = ifs.read(CHUNK_SIZE)
if count % onepart == 0:
done = count / onepart
done = count // onepart
print_processbar(done)
count = count + 1
if not buf:
......@@ -142,28 +140,26 @@ def convert_Imagenet_tar2bin(tar_file, output_file):
for tarInfo in tar:
if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME:
dataset[tarInfo.name] = tar.extractfile(tarInfo).read()
with open(output_file, "w+b") as ofs:
ofs.seek(0)
num = np.array(int(FULL_IMAGES)).astype('int64')
ofs.write(num.tobytes())
per_percentage = FULL_IMAGES / 100
per_percentage = FULL_IMAGES // 100
val_info = tar.getmember(VALLIST_TAR_NAME)
val_list = tar.extractfile(val_info).read().decode("utf-8")
lines = val_list.splitlines()
idx = 0
for imagedata in dataset.values():
img = Image.open(StringIO(imagedata))
img = Image.open(io.BytesIO(imagedata))
img = process_image(img)
np_img = np.array(img)
ofs.write(np_img.astype('float32').tobytes())
if idx % per_percentage == 0:
print_processbar(idx / per_percentage)
print_processbar(idx // per_percentage)
idx = idx + 1
val_info = tar.getmember(VALLIST_TAR_NAME)
val_list = tar.extractfile(val_info).read()
lines = val_list.split('\n')
val_dict = {}
for line_idx, line in enumerate(lines):
if line_idx == FULL_IMAGES:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册