未验证 提交 273e300f 编写于 作者: Q qingqing01 提交者: GitHub

Reduce buf size in image_classification/reader (#2639)

* Reduce buf size in image_classification/reader
* Fix format
上级 57003b84
...@@ -27,7 +27,7 @@ np.random.seed(0) ...@@ -27,7 +27,7 @@ np.random.seed(0)
DATA_DIM = 224 DATA_DIM = 224
THREAD = 8 THREAD = 8
BUF_SIZE = 1024 BUF_SIZE = 2048
DATA_DIR = 'data/ILSVRC2012' DATA_DIR = 'data/ILSVRC2012'
...@@ -162,12 +162,12 @@ def _reader_creator(file_list, ...@@ -162,12 +162,12 @@ def _reader_creator(file_list,
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) lines = full_lines[trainer_id * per_node_lines:(
* per_node_lines] trainer_id + 1) * per_node_lines]
print( print(
"read images from %d, length: %d, lines length: %d, total: %d" "read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines), % (trainer_id * per_node_lines, per_node_lines,
len(full_lines))) len(lines), len(full_lines)))
else: else:
lines = full_lines lines = full_lines
...@@ -206,11 +206,9 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=1, infinite=False): ...@@ -206,11 +206,9 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=1, infinite=False):
def val(data_dir=DATA_DIR): def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
data_dir=data_dir)
def test(data_dir=DATA_DIR): def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
data_dir=data_dir)
...@@ -28,13 +28,14 @@ np.random.seed(0) ...@@ -28,13 +28,14 @@ np.random.seed(0)
DATA_DIM = 224 DATA_DIM = 224
THREAD = 8 THREAD = 8
BUF_SIZE = 102400 BUF_SIZE = 2048
DATA_DIR = './data/ILSVRC2012' DATA_DIR = './data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def rotate_image(img): def rotate_image(img):
""" rotate_image """ """ rotate_image """
(h, w) = img.shape[:2] (h, w) = img.shape[:2]
...@@ -44,6 +45,7 @@ def rotate_image(img): ...@@ -44,6 +45,7 @@ def rotate_image(img):
rotated = cv2.warpAffine(img, M, (w, h)) rotated = cv2.warpAffine(img, M, (w, h))
return rotated return rotated
def random_crop(img, size, settings, scale=None, ratio=None): def random_crop(img, size, settings, scale=None, ratio=None):
""" random_crop """ """ random_crop """
lower_scale = settings.lower_scale lower_scale = settings.lower_scale
...@@ -52,7 +54,6 @@ def random_crop(img, size, settings, scale=None, ratio=None): ...@@ -52,7 +54,6 @@ def random_crop(img, size, settings, scale=None, ratio=None):
scale = [lower_scale, 1.0] if scale is None else scale scale = [lower_scale, 1.0] if scale is None else scale
ratio = [lower_ratio, upper_ratio] if ratio is None else ratio ratio = [lower_ratio, upper_ratio] if ratio is None else ratio
aspect_ratio = math.sqrt(np.random.uniform(*ratio)) aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio w = 1. * aspect_ratio
h = 1. / aspect_ratio h = 1. / aspect_ratio
...@@ -73,24 +74,31 @@ def random_crop(img, size, settings, scale=None, ratio=None): ...@@ -73,24 +74,31 @@ def random_crop(img, size, settings, scale=None, ratio=None):
img = img[i:i + h, j:j + w, :] img = img[i:i + h, j:j + w, :]
resized = cv2.resize(img, (size, size) resized = cv2.resize(
#, interpolation=cv2.INTER_LANCZOS4 img,
) (size, size)
#, interpolation=cv2.INTER_LANCZOS4
)
return resized return resized
def distort_color(img): def distort_color(img):
return img return img
def resize_short(img, target_size): def resize_short(img, target_size):
""" resize_short """ """ resize_short """
percent = float(target_size) / min(img.shape[0], img.shape[1]) percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent)) resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent)) resized_height = int(round(img.shape[0] * percent))
resized = cv2.resize(img, (resized_width, resized_height), resized = cv2.resize(
#interpolation=cv2.INTER_LANCZOS4 img,
) (resized_width, resized_height),
#interpolation=cv2.INTER_LANCZOS4
)
return resized return resized
def crop_image(img, target_size, center): def crop_image(img, target_size, center):
""" crop_image """ """ crop_image """
height, width = img.shape[:2] height, width = img.shape[:2]
...@@ -106,26 +114,28 @@ def crop_image(img, target_size, center): ...@@ -106,26 +114,28 @@ def crop_image(img, target_size, center):
img = img[h_start:h_end, w_start:w_end, :] img = img[h_start:h_end, w_start:w_end, :]
return img return img
def create_mixup_reader(settings, rd):
def create_mixup_reader(settings, rd):
class context: class context:
tmp_mix = [] tmp_mix = []
tmp_l1 = [] tmp_l1 = []
tmp_l2 = [] tmp_l2 = []
tmp_lam = [] tmp_lam = []
batch_size = settings.batch_size batch_size = settings.batch_size
alpha = settings.mixup_alpha alpha = settings.mixup_alpha
def fetch_data(): def fetch_data():
data_list = [] data_list = []
for i, item in enumerate(rd()): for i, item in enumerate(rd()):
data_list.append(item) data_list.append(item)
if i % batch_size == batch_size - 1: if i % batch_size == batch_size - 1:
yield data_list yield data_list
data_list =[] data_list = []
def mixup_data(): def mixup_data():
for data_list in fetch_data(): for data_list in fetch_data():
if alpha > 0.: if alpha > 0.:
lam = np.random.beta(alpha, alpha) lam = np.random.beta(alpha, alpha)
...@@ -133,11 +143,13 @@ def create_mixup_reader(settings, rd): ...@@ -133,11 +143,13 @@ def create_mixup_reader(settings, rd):
lam = 1. lam = 1.
l1 = np.array(data_list) l1 = np.array(data_list)
l2 = np.random.permutation(l1) l2 = np.random.permutation(l1)
mixed_l = [l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1))] mixed_l = [
l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1))
]
yield mixed_l, l1, l2, lam yield mixed_l, l1, l2, lam
def mixup_reader(): def mixup_reader():
for context.tmp_mix, context.tmp_l1, context.tmp_l2, context.tmp_lam in mixup_data(): for context.tmp_mix, context.tmp_l1, context.tmp_l2, context.tmp_lam in mixup_data():
for i in range(len(context.tmp_mix)): for i in range(len(context.tmp_mix)):
mixed_l = context.tmp_mix[i] mixed_l = context.tmp_mix[i]
...@@ -145,11 +157,11 @@ def create_mixup_reader(settings, rd): ...@@ -145,11 +157,11 @@ def create_mixup_reader(settings, rd):
l2 = context.tmp_l2[i] l2 = context.tmp_l2[i]
lam = context.tmp_lam lam = context.tmp_lam
yield mixed_l, l1[1], l2[1], lam yield mixed_l, l1[1], l2[1], lam
return mixup_reader return mixup_reader
def process_image(
sample, def process_image(sample,
settings, settings,
mode, mode,
color_jitter, color_jitter,
...@@ -169,7 +181,7 @@ def process_image( ...@@ -169,7 +181,7 @@ def process_image(
if rotate: if rotate:
img = rotate_image(img) img = rotate_image(img)
if crop_size > 0: if crop_size > 0:
img = random_crop(img, crop_size,settings) img = random_crop(img, crop_size, settings)
if color_jitter: if color_jitter:
img = distort_color(img) img = distort_color(img)
if np.random.randint(0, 2) == 1: if np.random.randint(0, 2) == 1:
...@@ -235,8 +247,9 @@ def _reader_creator(settings, ...@@ -235,8 +247,9 @@ def _reader_creator(settings,
elif mode == 'test': elif mode == 'test':
img_path, label = line.split() img_path, label = line.split()
img_path = os.path.join(data_dir, img_path) img_path = os.path.join(data_dir, img_path)
yield [img_path] yield [img_path]
crop_size = int(settings.image_shape.split(",")[2]) crop_size = int(settings.image_shape.split(",")[2])
image_mapper = functools.partial( image_mapper = functools.partial(
process_image, process_image,
...@@ -249,9 +262,10 @@ def _reader_creator(settings, ...@@ -249,9 +262,10 @@ def _reader_creator(settings,
image_mapper, reader, THREAD, BUF_SIZE, order=False) image_mapper, reader, THREAD, BUF_SIZE, order=False)
return reader return reader
def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0): def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0):
file_list = os.path.join(data_dir, 'train_list.txt') file_list = os.path.join(data_dir, 'train_list.txt')
reader = _reader_creator( reader = _reader_creator(
settings, settings,
file_list, file_list,
'train', 'train',
...@@ -259,19 +273,19 @@ def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0): ...@@ -259,19 +273,19 @@ def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0):
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=data_dir, data_dir=data_dir,
pass_id_as_seed=pass_id_as_seed, pass_id_as_seed=pass_id_as_seed, )
)
if settings.use_mixup == True: if settings.use_mixup == True:
reader = create_mixup_reader(settings, reader) reader = create_mixup_reader(settings, reader)
return reader return reader
def val(settings,data_dir=DATA_DIR):
def val(settings, data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(settings ,file_list, 'val', shuffle=False, return _reader_creator(
data_dir=data_dir) settings, file_list, 'val', shuffle=False, data_dir=data_dir)
def test(settings,data_dir=DATA_DIR): def test(settings, data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt') file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(settings, file_list, 'test', shuffle=False, return _reader_creator(
data_dir=data_dir) settings, file_list, 'test', shuffle=False, data_dir=data_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册