未验证 提交 273e300f 编写于 作者: Q qingqing01 提交者: GitHub

Reduce buf size in image_classification/reader (#2639)

* Reduce buf size in image_classification/reader
* Fix format
上级 57003b84
......@@ -27,7 +27,7 @@ np.random.seed(0)
DATA_DIM = 224
THREAD = 8
BUF_SIZE = 1024
BUF_SIZE = 2048
DATA_DIR = 'data/ILSVRC2012'
......@@ -162,12 +162,12 @@ def _reader_creator(file_list,
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
lines = full_lines[trainer_id * per_node_lines:(
trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
% (trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
else:
lines = full_lines
......@@ -206,11 +206,9 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=1, infinite=False):
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False,
data_dir=data_dir)
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'test', shuffle=False,
data_dir=data_dir)
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
......@@ -28,13 +28,14 @@ np.random.seed(0)
DATA_DIM = 224
THREAD = 8
BUF_SIZE = 102400
BUF_SIZE = 2048
DATA_DIR = './data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def rotate_image(img):
""" rotate_image """
(h, w) = img.shape[:2]
......@@ -44,6 +45,7 @@ def rotate_image(img):
rotated = cv2.warpAffine(img, M, (w, h))
return rotated
def random_crop(img, size, settings, scale=None, ratio=None):
""" random_crop """
lower_scale = settings.lower_scale
......@@ -52,7 +54,6 @@ def random_crop(img, size, settings, scale=None, ratio=None):
scale = [lower_scale, 1.0] if scale is None else scale
ratio = [lower_ratio, upper_ratio] if ratio is None else ratio
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
......@@ -73,24 +74,31 @@ def random_crop(img, size, settings, scale=None, ratio=None):
img = img[i:i + h, j:j + w, :]
resized = cv2.resize(img, (size, size)
#, interpolation=cv2.INTER_LANCZOS4
)
resized = cv2.resize(
img,
(size, size)
#, interpolation=cv2.INTER_LANCZOS4
)
return resized
def distort_color(img):
return img
def resize_short(img, target_size):
""" resize_short """
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
resized = cv2.resize(img, (resized_width, resized_height),
#interpolation=cv2.INTER_LANCZOS4
)
resized = cv2.resize(
img,
(resized_width, resized_height),
#interpolation=cv2.INTER_LANCZOS4
)
return resized
def crop_image(img, target_size, center):
""" crop_image """
height, width = img.shape[:2]
......@@ -106,26 +114,28 @@ def crop_image(img, target_size, center):
img = img[h_start:h_end, w_start:w_end, :]
return img
def create_mixup_reader(settings, rd):
def create_mixup_reader(settings, rd):
class context:
tmp_mix = []
tmp_l1 = []
tmp_l2 = []
tmp_lam = []
batch_size = settings.batch_size
alpha = settings.mixup_alpha
def fetch_data():
data_list = []
for i, item in enumerate(rd()):
data_list.append(item)
if i % batch_size == batch_size - 1:
if i % batch_size == batch_size - 1:
yield data_list
data_list =[]
data_list = []
def mixup_data():
for data_list in fetch_data():
if alpha > 0.:
lam = np.random.beta(alpha, alpha)
......@@ -133,11 +143,13 @@ def create_mixup_reader(settings, rd):
lam = 1.
l1 = np.array(data_list)
l2 = np.random.permutation(l1)
mixed_l = [l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1))]
mixed_l = [
l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1))
]
yield mixed_l, l1, l2, lam
def mixup_reader():
for context.tmp_mix, context.tmp_l1, context.tmp_l2, context.tmp_lam in mixup_data():
for i in range(len(context.tmp_mix)):
mixed_l = context.tmp_mix[i]
......@@ -145,11 +157,11 @@ def create_mixup_reader(settings, rd):
l2 = context.tmp_l2[i]
lam = context.tmp_lam
yield mixed_l, l1[1], l2[1], lam
return mixup_reader
def process_image(
sample,
def process_image(sample,
settings,
mode,
color_jitter,
......@@ -169,7 +181,7 @@ def process_image(
if rotate:
img = rotate_image(img)
if crop_size > 0:
img = random_crop(img, crop_size,settings)
img = random_crop(img, crop_size, settings)
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
......@@ -235,8 +247,9 @@ def _reader_creator(settings,
elif mode == 'test':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield [img_path]
crop_size = int(settings.image_shape.split(",")[2])
image_mapper = functools.partial(
process_image,
......@@ -249,9 +262,10 @@ def _reader_creator(settings,
image_mapper, reader, THREAD, BUF_SIZE, order=False)
return reader
def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0):
file_list = os.path.join(data_dir, 'train_list.txt')
reader = _reader_creator(
reader = _reader_creator(
settings,
file_list,
'train',
......@@ -259,19 +273,19 @@ def train(settings, data_dir=DATA_DIR, pass_id_as_seed=0):
color_jitter=False,
rotate=False,
data_dir=data_dir,
pass_id_as_seed=pass_id_as_seed,
)
pass_id_as_seed=pass_id_as_seed, )
if settings.use_mixup == True:
reader = create_mixup_reader(settings, reader)
return reader
def val(settings,data_dir=DATA_DIR):
def val(settings, data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(settings ,file_list, 'val', shuffle=False,
data_dir=data_dir)
return _reader_creator(
settings, file_list, 'val', shuffle=False, data_dir=data_dir)
def test(settings,data_dir=DATA_DIR):
def test(settings, data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(settings, file_list, 'test', shuffle=False,
data_dir=data_dir)
return _reader_creator(
settings, file_list, 'test', shuffle=False, data_dir=data_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册