未验证 提交 46989e88 编写于 作者: L lidanqing 提交者: GitHub

Fix python3 incompatibility issues (#30698)

* solve python3 incompatibility issues

* update checksum
上级 a12b6bb9
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
import hashlib import hashlib
import unittest import unittest
import os import os
import io
import numpy as np import numpy as np
import time import time
import sys import sys
...@@ -23,10 +24,9 @@ from PIL import Image ...@@ -23,10 +24,9 @@ from PIL import Image
import math import math
from paddle.dataset.common import download from paddle.dataset.common import download
import tarfile import tarfile
from six.moves import StringIO
import argparse import argparse
import shutil
random.seed(0)
np.random.seed(0) np.random.seed(0)
DATA_DIM = 224 DATA_DIM = 224
...@@ -34,7 +34,7 @@ SIZE_FLOAT32 = 4 ...@@ -34,7 +34,7 @@ SIZE_FLOAT32 = 4
SIZE_INT64 = 8 SIZE_INT64 = 8
FULL_SIZE_BYTES = 30106000008 FULL_SIZE_BYTES = 30106000008
FULL_IMAGES = 50000 FULL_IMAGES = 50000
TARGET_HASH = '22d2e0008dca693916d9595a5ea3ded8' TARGET_HASH = '0be07c2c23296b97dad83c626682c66a'
FOLDER_NAME = "ILSVRC2012/" FOLDER_NAME = "ILSVRC2012/"
VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt" VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt"
CHUNK_SIZE = 8192 CHUNK_SIZE = 8192
...@@ -55,8 +55,8 @@ def crop_image(img, target_size, center): ...@@ -55,8 +55,8 @@ def crop_image(img, target_size, center):
width, height = img.size width, height = img.size
size = target_size size = target_size
if center == True: if center == True:
w_start = (width - size) / 2 w_start = (width - size) // 2
h_start = (height - size) / 2 h_start = (height - size) // 2
else: else:
w_start = np.random.randint(0, width - size + 1) w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1) h_start = np.random.randint(0, height - size + 1)
...@@ -95,11 +95,9 @@ def download_concat(cache_folder, zip_path): ...@@ -95,11 +95,9 @@ def download_concat(cache_folder, zip_path):
file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1]) file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1])
file_names.append(file_name) file_names.append(file_name)
print("Downloaded part {0}\n".format(file_name)) print("Downloaded part {0}\n".format(file_name))
if not os.path.exists(zip_path): with open(zip_path, "wb") as outfile:
with open(zip_path, "w+") as outfile: for fname in file_names:
for fname in file_names: shutil.copyfileobj(open(fname, 'rb'), outfile)
with open(fname) as infile:
outfile.write(infile.read())
def print_processbar(done_percentage): def print_processbar(done_percentage):
...@@ -114,12 +112,12 @@ def check_integrity(filename, target_hash): ...@@ -114,12 +112,12 @@ def check_integrity(filename, target_hash):
print('\nThe binary file exists. Checking file integrity...\n') print('\nThe binary file exists. Checking file integrity...\n')
md = hashlib.md5() md = hashlib.md5()
count = 0 count = 0
onepart = FULL_SIZE_BYTES / CHUNK_SIZE / 100 onepart = FULL_SIZE_BYTES // CHUNK_SIZE // 100
with open(filename) as ifs: with open(filename, 'rb') as ifs:
while True: while True:
buf = ifs.read(CHUNK_SIZE) buf = ifs.read(CHUNK_SIZE)
if count % onepart == 0: if count % onepart == 0:
done = count / onepart done = count // onepart
print_processbar(done) print_processbar(done)
count = count + 1 count = count + 1
if not buf: if not buf:
...@@ -142,28 +140,26 @@ def convert_Imagenet_tar2bin(tar_file, output_file): ...@@ -142,28 +140,26 @@ def convert_Imagenet_tar2bin(tar_file, output_file):
for tarInfo in tar: for tarInfo in tar:
if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME: if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME:
dataset[tarInfo.name] = tar.extractfile(tarInfo).read() dataset[tarInfo.name] = tar.extractfile(tarInfo).read()
with open(output_file, "w+b") as ofs: with open(output_file, "w+b") as ofs:
ofs.seek(0) ofs.seek(0)
num = np.array(int(FULL_IMAGES)).astype('int64') num = np.array(int(FULL_IMAGES)).astype('int64')
ofs.write(num.tobytes()) ofs.write(num.tobytes())
per_percentage = FULL_IMAGES / 100 per_percentage = FULL_IMAGES // 100
val_info = tar.getmember(VALLIST_TAR_NAME)
val_list = tar.extractfile(val_info).read().decode("utf-8")
lines = val_list.splitlines()
idx = 0 idx = 0
for imagedata in dataset.values(): for imagedata in dataset.values():
img = Image.open(StringIO(imagedata)) img = Image.open(io.BytesIO(imagedata))
img = process_image(img) img = process_image(img)
np_img = np.array(img) np_img = np.array(img)
ofs.write(np_img.astype('float32').tobytes()) ofs.write(np_img.astype('float32').tobytes())
if idx % per_percentage == 0: if idx % per_percentage == 0:
print_processbar(idx / per_percentage) print_processbar(idx // per_percentage)
idx = idx + 1 idx = idx + 1
val_info = tar.getmember(VALLIST_TAR_NAME)
val_list = tar.extractfile(val_info).read()
lines = val_list.split('\n')
val_dict = {} val_dict = {}
for line_idx, line in enumerate(lines): for line_idx, line in enumerate(lines):
if line_idx == FULL_IMAGES: if line_idx == FULL_IMAGES:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册