提交 89f1a4d6 编写于 作者: Z zhengya01

update reader

上级 179be88c
......@@ -13,15 +13,14 @@ import os
is_ce = int(os.environ.get('is_ce', 0))
#random.seed(0)
if is_ce:
random.seed(0)
def rotate_image(img):
""" rotate_image """
(h, w) = img.shape[:2]
center = (w // 2, h // 2)
angle = random.randint(-10, 10)
if is_ce:
aggle = 0
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img, M, (w, h))
return rotated
......@@ -32,8 +31,6 @@ def random_crop(img, size, scale=None, ratio=None):
ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
aspect_ratio = math.sqrt(random.uniform(*ratio))
if is_ce:
aspect_ratio = math.sqrt(1.)
w = 1. * aspect_ratio
h = 1. / aspect_ratio
......@@ -43,17 +40,12 @@ def random_crop(img, size, scale=None, ratio=None):
scale_min = min(scale[0], bound)
target_area = img.shape[0] * img.shape[1] * random.uniform(scale_min, scale_max)
if is_ce:
target_area = img.shape[0] * img.shape[1] * (scale_min + scale_max) / 2.
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img.shape[0] - h)
j = random.randint(0, img.shape[1] - w)
if is_ce:
i = int(img.shape[0] - h) // 2
j = int(img.shape[1] - w) // 2
img = img[i:i+h, j:j+w, :]
resized = cv2.resize(img, (size, size), interpolation=cv2.INTER_LANCZOS4)
......@@ -80,9 +72,6 @@ def crop_image(img, target_size, center):
else:
w_start = random.randint(0, width - size)
h_start = random.randint(0, height - size)
if is_ce:
w_start = (width - size) // 2
h_start = (height - size) // 2
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
......@@ -107,8 +96,6 @@ def process_image(sample, mode, color_jitter, rotate,
img = distort_color(img)
if random.randint(0, 1) == 1:
img = img[:, ::-1, :]
if is_ce:
img = img[:, ::-1, :]
else:
if crop_size > 0:
img = resize_short(img, crop_size)
......
......@@ -34,8 +34,7 @@ def init_sop(mode):
if label not in train_data:
train_data[label] = []
train_data[label].append(path)
if not is_ce:
random.shuffle(train_image_list)
random.shuffle(train_image_list)
print("{} dataset size: {}".format(mode, len(train_data)))
return train_data, train_image_list
else:
......@@ -70,15 +69,13 @@ def common_iterator(data, settings):
lab_num = len(labs)
ind = list(range(0, lab_num))
while True:
if not is_ce:
random.shuffle(ind)
random.shuffle(ind)
ind_sample = ind[:class_num]
for ind_i in ind_sample:
lab = labs[ind_i]
data_list = data[lab]
data_ind = list(range(0, len(data_list)))
if not is_ce:
random.shuffle(data_ind)
random.shuffle(data_ind)
anchor_ind = data_ind[:samples_each_class]
for anchor_ind_i in anchor_ind:
......@@ -95,21 +92,17 @@ def triplet_iterator(data, settings):
lab_num = len(labs)
ind = list(range(0, lab_num))
while True:
if not is_ce:
random.shuffle(ind)
random.shuffle(ind)
ind_pos, ind_neg = ind[:2]
lab_pos = labs[ind_pos]
pos_data_list = data[lab_pos]
data_ind = list(range(0, len(pos_data_list)))
if not is_ce:
random.shuffle(data_ind)
random.shuffle(data_ind)
anchor_ind, pos_ind = data_ind[:2]
lab_neg = labs[ind_neg]
neg_data_list = data[lab_neg]
neg_ind = random.randint(0, len(neg_data_list) - 1)
if is_ce:
neg_ind = 1
anchor_path = DATA_DIR + pos_data_list[anchor_ind]
yield anchor_path, lab_pos
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册