未验证 提交 273e300f 编写于 作者: Q qingqing01 提交者: GitHub

Reduce buf size in image_classification/reader (#2639)

* Reduce buf size in image_classification/reader
* Fix format
上级 57003b84
...@@ -27,7 +27,7 @@ np.random.seed(0) ...@@ -27,7 +27,7 @@ np.random.seed(0)
DATA_DIM = 224 DATA_DIM = 224
THREAD = 8 THREAD = 8
BUF_SIZE = 1024 BUF_SIZE = 2048
DATA_DIR = 'data/ILSVRC2012' DATA_DIR = 'data/ILSVRC2012'
...@@ -162,12 +162,12 @@ def _reader_creator(file_list, ...@@ -162,12 +162,12 @@ def _reader_creator(file_list,
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) lines = full_lines[trainer_id * per_node_lines:(
* per_node_lines] trainer_id + 1) * per_node_lines]
print( print(
"read images from %d, length: %d, lines length: %d, total: %d" "read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines), % (trainer_id * per_node_lines, per_node_lines,
len(full_lines))) len(lines), len(full_lines)))
else: else:
lines = full_lines lines = full_lines
...@@ -206,11 +206,9 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=1, infinite=False): ...@@ -206,11 +206,9 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=1, infinite=False):
def val(data_dir=DATA_DIR): def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
data_dir=data_dir)
def test(data_dir=DATA_DIR): def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
data_dir=data_dir)
...@@ -28,13 +28,14 @@ np.random.seed(0) ...@@ -28,13 +28,14 @@ np.random.seed(0)
DATA_DIM = 224 DATA_DIM = 224
THREAD = 8 THREAD = 8
BUF_SIZE = 102400 BUF_SIZE = 2048
DATA_DIR = './data/ILSVRC2012' DATA_DIR = './data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def rotate_image(img): def rotate_image(img):
""" rotate_image """ """ rotate_image """
(h, w) = img.shape[:2] (h, w) = img.shape[:2]
...@@ -44,6 +45,7 @@ def rotate_image(img): ...@@ -44,6 +45,7 @@ def rotate_image(img):
rotated = cv2.warpAffine(img, M, (w, h)) rotated = cv2.warpAffine(img, M, (w, h))
return rotated return rotated
def random_crop(img, size, settings, scale=None, ratio=None): def random_crop(img, size, settings, scale=None, ratio=None):
""" random_crop """ """ random_crop """
lower_scale = settings.lower_scale lower_scale = settings.lower_scale
...@@ -52,7 +54,6 @@ def random_crop(img, size, settings, scale=None, ratio=None): ...@@ -52,7 +54,6 @@ def random_crop(img, size, settings, scale=None, ratio=None):
scale = [lower_scale, 1.0] if scale is None else scale scale = [lower_scale, 1.0] if scale is None else scale
ratio = [lower_ratio, upper_ratio] if ratio is None else ratio ratio = [lower_ratio, upper_ratio] if ratio is None else ratio
aspect_ratio = math.sqrt(np.random.uniform(*ratio)) aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio w = 1. * aspect_ratio
h = 1. / aspect_ratio h = 1. / aspect_ratio
...@@ -73,24 +74,31 @@ def random_crop(img, size, settings, scale=None, ratio=None): ...@@ -73,24 +74,31 @@ def random_crop(img, size, settings, scale=None, ratio=None):
img = img[i:i + h, j:j + w, :] img = img[i:i + h, j:j + w, :]
resized = cv2.resize(img, (size, size) resized = cv2.resize(
img,
(size, size)
#, interpolation=cv2.INTER_LANCZOS4 #, interpolation=cv2.INTER_LANCZOS4
) )
return resized return resized
def distort_color(img): def distort_color(img):
return img return img
def resize_short(img, target_size): def resize_short(img, target_size):
""" resize_short """ """ resize_short """
percent = float(target_size) / min(img.shape[0], img.shape[1]) percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent)) resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent)) resized_height = int(round(img.shape[0] * percent))
resized = cv2.resize(img, (resized_width, resized_height), resized = cv2.resize(
img,
(resized_width, resized_height),
#interpolation=cv2.INTER_LANCZOS4 #interpolation=cv2.INTER_LANCZOS4
) )
return resized return resized
def crop_image(img, target_size, center): def crop_image(img, target_size, center):
""" crop_image """ """ crop_image """
height, width = img.shape[:2] height, width = img.shape[:2]
...@@ -106,6 +114,7 @@ def crop_image(img, target_size, center): ...@@ -106,6 +114,7 @@ def crop_image(img, target_size, center):
img = img[h_start:h_end, w_start:w_end, :] img = img[h_start:h_end, w_start:w_end, :]
return img return img
def create_mixup_reader(settings, rd): def create_mixup_reader(settings, rd):
class context: class context:
tmp_mix = [] tmp_mix = []
...@@ -115,6 +124,7 @@ def create_mixup_reader(settings, rd): ...@@ -115,6 +124,7 @@ def create_mixup_reader(settings, rd):
batch_size = settings.batch_size batch_size = settings.batch_size
alpha = settings.mixup_alpha alpha = settings.mixup_alpha
def fetch_data(): def fetch_data():
data_list = [] data_list = []
...@@ -122,7 +132,7 @@ def create_mixup_reader(settings, rd): ...@@ -122,7 +132,7 @@ def create_mixup_reader(settings, rd):
data_list.append(item) data_list.append(item)
if i % batch_size == batch_size - 1: if i % batch_size == batch_size - 1:
yield data_list yield data_list
data_list =[] data_list = []
def mixup_data(): def mixup_data():
...@@ -133,7 +143,9 @@ def create_mixup_reader(settings, rd): ...@@ -133,7 +143,9 @@ def create_mixup_reader(settings, rd):
lam = 1. lam = 1.
l1 = np.array(data_list) l1 = np.array(data_list)
l2 = np.random.permutation(l1) l2 = np.random.permutation(l1)
mixed_l = [l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1))] mixed_l = [
l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1))
]
yield mixed_l, l1, l2, lam yield mixed_l, l1, l2, lam
def mixup_reader(): def mixup_reader():
...@@ -148,8 +160,8 @@ def create_mixup_reader(settings, rd): ...@@ -148,8 +160,8 @@ def create_mixup_reader(settings, rd):
return mixup_reader return mixup_reader
def process_image(
sample, def process_image(sample,
settings, settings,
mode, mode,
color_jitter, color_jitter,
...@@ -169,7 +181,7 @@ def process_image( ...@@ -169,7 +181,7 @@ def process_image(
if rotate: if rotate:
img = rotate_image(img) img = rotate_image(img)
if crop_size > 0: if crop_size > 0:
img = random_crop(img, crop_size,settings) img = random_crop(img, crop_size, settings)
if color_jitter: if color_jitter:
img = distort_color(img) img = distort_color(img)
if np.random.randint(0, 2) == 1: if np.random.randint(0, 2) == 1:
...@@ -237,6 +249,7 @@ def _reader_creator(settings, ...@@ -237,6 +249,7 @@ def _reader_creator(settings,
img_path = os.path.join(data_dir, img_path) img_path = os.path.join(data_dir, img_path)
yield [img_path] yield [img_path]
crop_size = int(settings.image_shape.split(",")[2]) crop_size = int(settings.image_shape.split(",")[2])
image_mapper = functools.partial( image_mapper = functools.partial(
process_image, process_image,
...@@ -249,6 +262,7 @@ def _reader_creator(settings, ...@@ -249,6 +262,7 @@ def _reader_creator(settings,
image_mapper, reader, THREAD, BUF_SIZE, order=False) image_mapper, reader, THREAD, BUF_SIZE, order=False)
return reader return reader
def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0): def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0):
file_list = os.path.join(data_dir, 'train_list.txt') file_list = os.path.join(data_dir, 'train_list.txt')
reader = _reader_creator( reader = _reader_creator(
...@@ -259,19 +273,19 @@ def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0): ...@@ -259,19 +273,19 @@ def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0):
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=data_dir, data_dir=data_dir,
pass_id_as_seed=pass_id_as_seed, pass_id_as_seed=pass_id_as_seed, )
)
if settings.use_mixup == True: if settings.use_mixup == True:
reader = create_mixup_reader(settings, reader) reader = create_mixup_reader(settings, reader)
return reader return reader
def val(settings,data_dir=DATA_DIR):
def val(settings, data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(settings ,file_list, 'val', shuffle=False, return _reader_creator(
data_dir=data_dir) settings, file_list, 'val', shuffle=False, data_dir=data_dir)
def test(settings,data_dir=DATA_DIR): def test(settings, data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(settings, file_list, 'test', shuffle=False, return _reader_creator(
data_dir=data_dir) settings, file_list, 'test', shuffle=False, data_dir=data_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册