提交 a28dc8af 编写于 作者: Z zhengya01

update ce

上级 0d004fa8
...@@ -147,9 +147,6 @@ def train(args): ...@@ -147,9 +147,6 @@ def train(args):
# Dataloader # Dataloader
train_reader = paddle.batch(reader.train(), batch_size=args.batch_size) train_reader = paddle.batch(reader.train(), batch_size=args.batch_size)
if args.enable_ce:
import lib.coco_reader_ce as reader_ce
train_reader = paddle.batch(reader_ce.train_ce(), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, target, target_weight]) feeder = fluid.DataFeeder(place=place, feed_list=[image, target, target_weight])
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
......
...@@ -9,6 +9,9 @@ import math ...@@ -9,6 +9,9 @@ import math
import random import random
import functools import functools
import numpy as np import numpy as np
import os
is_ce = int(os.environ.get('is_ce', 0))
#random.seed(0) #random.seed(0)
...@@ -17,6 +20,8 @@ def rotate_image(img): ...@@ -17,6 +20,8 @@ def rotate_image(img):
(h, w) = img.shape[:2] (h, w) = img.shape[:2]
center = (w // 2, h // 2) center = (w // 2, h // 2)
angle = random.randint(-10, 10) angle = random.randint(-10, 10)
if is_ce:
aggle = 0
M = cv2.getRotationMatrix2D(center, angle, 1.0) M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img, M, (w, h)) rotated = cv2.warpAffine(img, M, (w, h))
return rotated return rotated
...@@ -27,6 +32,8 @@ def random_crop(img, size, scale=None, ratio=None): ...@@ -27,6 +32,8 @@ def random_crop(img, size, scale=None, ratio=None):
ratio = [3. / 4., 4. / 3.] if ratio is None else ratio ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
aspect_ratio = math.sqrt(random.uniform(*ratio)) aspect_ratio = math.sqrt(random.uniform(*ratio))
if is_ce:
aspect_ratio = math.sqrt(1.)
w = 1. * aspect_ratio w = 1. * aspect_ratio
h = 1. / aspect_ratio h = 1. / aspect_ratio
...@@ -35,14 +42,18 @@ def random_crop(img, size, scale=None, ratio=None): ...@@ -35,14 +42,18 @@ def random_crop(img, size, scale=None, ratio=None):
scale_max = min(scale[1], bound) scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound) scale_min = min(scale[0], bound)
target_area = img.shape[0] * img.shape[1] * random.uniform(scale_min, target_area = img.shape[0] * img.shape[1] * random.uniform(scale_min, scale_max)
scale_max) if is_ce:
target_area = img.shape[0] * img.shape[1] * (scale_min + scale_max) / 2.
target_size = math.sqrt(target_area) target_size = math.sqrt(target_area)
w = int(target_size * w) w = int(target_size * w)
h = int(target_size * h) h = int(target_size * h)
i = random.randint(0, img.shape[0] - h) i = random.randint(0, img.shape[0] - h)
j = random.randint(0, img.shape[1] - w) j = random.randint(0, img.shape[1] - w)
if is_ce:
i = int(img.shape[0] - h) // 2
j = int(img.shape[1] - w) // 2
img = img[i:i+h, j:j+w, :] img = img[i:i+h, j:j+w, :]
resized = cv2.resize(img, (size, size), interpolation=cv2.INTER_LANCZOS4) resized = cv2.resize(img, (size, size), interpolation=cv2.INTER_LANCZOS4)
...@@ -69,6 +80,9 @@ def crop_image(img, target_size, center): ...@@ -69,6 +80,9 @@ def crop_image(img, target_size, center):
else: else:
w_start = random.randint(0, width - size) w_start = random.randint(0, width - size)
h_start = random.randint(0, height - size) h_start = random.randint(0, height - size)
if is_ce:
w_start = (width - size) // 2
h_start = (height - size) // 2
w_end = w_start + size w_end = w_start + size
h_end = h_start + size h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :] img = img[h_start:h_end, w_start:w_end, :]
...@@ -93,6 +107,8 @@ def process_image(sample, mode, color_jitter, rotate, ...@@ -93,6 +107,8 @@ def process_image(sample, mode, color_jitter, rotate,
img = distort_color(img) img = distort_color(img)
if random.randint(0, 1) == 1: if random.randint(0, 1) == 1:
img = img[:, ::-1, :] img = img[:, ::-1, :]
if is_ce:
img = img[:, ::-1, :]
else: else:
if crop_size > 0: if crop_size > 0:
img = resize_short(img, crop_size) img = resize_short(img, crop_size)
......
...@@ -10,6 +10,8 @@ import numpy as np ...@@ -10,6 +10,8 @@ import numpy as np
import paddle import paddle
from imgtool import process_image from imgtool import process_image
is_ce = int(os.environ.get('is_ce', 0))
random.seed(0) random.seed(0)
DATA_DIR = "./data/Stanford_Online_Products/" DATA_DIR = "./data/Stanford_Online_Products/"
...@@ -32,7 +34,8 @@ def init_sop(mode): ...@@ -32,7 +34,8 @@ def init_sop(mode):
if label not in train_data: if label not in train_data:
train_data[label] = [] train_data[label] = []
train_data[label].append(path) train_data[label].append(path)
random.shuffle(train_image_list) if not is_ce:
random.shuffle(train_image_list)
print("{} dataset size: {}".format(mode, len(train_data))) print("{} dataset size: {}".format(mode, len(train_data)))
return train_data, train_image_list return train_data, train_image_list
else: else:
...@@ -67,13 +70,15 @@ def common_iterator(data, settings): ...@@ -67,13 +70,15 @@ def common_iterator(data, settings):
lab_num = len(labs) lab_num = len(labs)
ind = list(range(0, lab_num)) ind = list(range(0, lab_num))
while True: while True:
random.shuffle(ind) if not is_ce:
random.shuffle(ind)
ind_sample = ind[:class_num] ind_sample = ind[:class_num]
for ind_i in ind_sample: for ind_i in ind_sample:
lab = labs[ind_i] lab = labs[ind_i]
data_list = data[lab] data_list = data[lab]
data_ind = list(range(0, len(data_list))) data_ind = list(range(0, len(data_list)))
random.shuffle(data_ind) if not is_ce:
random.shuffle(data_ind)
anchor_ind = data_ind[:samples_each_class] anchor_ind = data_ind[:samples_each_class]
for anchor_ind_i in anchor_ind: for anchor_ind_i in anchor_ind:
...@@ -90,17 +95,21 @@ def triplet_iterator(data, settings): ...@@ -90,17 +95,21 @@ def triplet_iterator(data, settings):
lab_num = len(labs) lab_num = len(labs)
ind = list(range(0, lab_num)) ind = list(range(0, lab_num))
while True: while True:
random.shuffle(ind) if not is_ce:
random.shuffle(ind)
ind_pos, ind_neg = ind[:2] ind_pos, ind_neg = ind[:2]
lab_pos = labs[ind_pos] lab_pos = labs[ind_pos]
pos_data_list = data[lab_pos] pos_data_list = data[lab_pos]
data_ind = list(range(0, len(pos_data_list))) data_ind = list(range(0, len(pos_data_list)))
random.shuffle(data_ind) if not is_ce:
random.shuffle(data_ind)
anchor_ind, pos_ind = data_ind[:2] anchor_ind, pos_ind = data_ind[:2]
lab_neg = labs[ind_neg] lab_neg = labs[ind_neg]
neg_data_list = data[lab_neg] neg_data_list = data[lab_neg]
neg_ind = random.randint(0, len(neg_data_list) - 1) neg_ind = random.randint(0, len(neg_data_list) - 1)
if is_ce:
neg_ind = 1
anchor_path = DATA_DIR + pos_data_list[anchor_ind] anchor_path = DATA_DIR + pos_data_list[anchor_ind]
yield anchor_path, lab_pos yield anchor_path, lab_pos
...@@ -158,6 +167,8 @@ def createreader(settings, mode): ...@@ -158,6 +167,8 @@ def createreader(settings, mode):
assert(image_shape[1] == image_shape[2]) assert(image_shape[1] == image_shape[2])
image_size = int(image_shape[2]) image_size = int(image_shape[2])
keep_order = False if mode != 'train' or settings.loss_name in ['softmax', 'arcmargin'] else True keep_order = False if mode != 'train' or settings.loss_name in ['softmax', 'arcmargin'] else True
if is_ce:
keep_order = True
image_mapper = functools.partial(process_image, image_mapper = functools.partial(process_image,
mode=mode, color_jitter=False, rotate=False, crop_size=image_size) mode=mode, color_jitter=False, rotate=False, crop_size=image_size)
reader = paddle.reader.xmap_readers( reader = paddle.reader.xmap_readers(
......
...@@ -194,10 +194,6 @@ def train_async(args): ...@@ -194,10 +194,6 @@ def train_async(args):
train_reader = paddle.batch(reader.train(args), batch_size=train_batch_size, drop_last=True) train_reader = paddle.batch(reader.train(args), batch_size=train_batch_size, drop_last=True)
test_reader = paddle.batch(reader.test(args), batch_size=test_batch_size, drop_last=False) test_reader = paddle.batch(reader.test(args), batch_size=test_batch_size, drop_last=False)
if args.enable_ce:
import reader_ce
train_reader = paddle.batch(reader_ce.train(args), batch_size=train_batch_size, drop_last=False)
test_reader = paddle.batch(reader_ce.test(args), batch_size=test_batch_size, drop_last=False)
test_feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) test_feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_py_reader.decorate_paddle_reader(train_reader) train_py_reader.decorate_paddle_reader(train_reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册