未验证 提交 179be88c 编写于 作者: Z zhengya01 提交者: GitHub

Delete reader_ce.py

上级 190528db
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import random
import functools
import numpy as np
import paddle
from imgtool_ce import process_image
random.seed(0)
DATA_DIR = "./data/Stanford_Online_Products/"
TRAIN_LIST = './data/Stanford_Online_Products/Ebay_train.txt'
VAL_LIST = './data/Stanford_Online_Products/Ebay_test.txt'
def init_sop(mode):
if mode == 'train':
train_data = {}
train_image_list = []
train_list = open(TRAIN_LIST, "r").readlines()
for i, item in enumerate(train_list):
items = item.strip().split()
if items[0] == 'image_id':
continue
path = items[3]
label = int(items[1]) - 1
train_image_list.append((path, label))
if label not in train_data:
train_data[label] = []
train_data[label].append(path)
#random.shuffle(train_image_list)
print("{} dataset size: {}".format(mode, len(train_data)))
return train_data, train_image_list
else:
val_data = {}
val_image_list = []
test_image_list = []
val_list = open(VAL_LIST, "r").readlines()
for i, item in enumerate(val_list):
items = item.strip().split()
if items[0] == 'image_id':
continue
path = items[3]
label = int(items[1])
val_image_list.append((path, label))
test_image_list.append(path)
if label not in val_data:
val_data[label] = []
val_data[label].append(path)
print("{} dataset size: {}".format(mode, len(val_data)))
if mode == 'val':
return val_data, val_image_list
else:
return test_image_list
def common_iterator(data, settings):
batch_size = settings.train_batch_size
samples_each_class = settings.samples_each_class
assert (batch_size % samples_each_class == 0)
class_num = batch_size // samples_each_class
def train_iterator():
labs = list(data.keys())
lab_num = len(labs)
ind = list(range(0, lab_num))
while True:
#random.shuffle(ind)
ind_sample = ind[:class_num]
for ind_i in ind_sample:
lab = labs[ind_i]
data_list = data[lab]
data_ind = list(range(0, len(data_list)))
#random.shuffle(data_ind)
anchor_ind = data_ind[:samples_each_class]
for anchor_ind_i in anchor_ind:
anchor_path = DATA_DIR + data_list[anchor_ind_i]
yield anchor_path, lab
return train_iterator
def triplet_iterator(data, settings):
batch_size = settings.train_batch_size
assert (batch_size % 3 == 0)
def train_iterator():
labs = list(data.keys())
lab_num = len(labs)
ind = list(range(0, lab_num))
while True:
#random.shuffle(ind)
ind_pos, ind_neg = ind[:2]
lab_pos = labs[ind_pos]
pos_data_list = data[lab_pos]
data_ind = list(range(0, len(pos_data_list)))
#random.shuffle(data_ind)
anchor_ind, pos_ind = data_ind[:2]
lab_neg = labs[ind_neg]
neg_data_list = data[lab_neg]
#neg_ind = random.randint(0, len(neg_data_list) - 1)
neg_ind = 1
anchor_path = DATA_DIR + pos_data_list[anchor_ind]
yield anchor_path, lab_pos
pos_path = DATA_DIR + pos_data_list[pos_ind]
yield pos_path, lab_pos
neg_path = DATA_DIR + neg_data_list[neg_ind]
yield neg_path, lab_neg
return train_iterator
def arcmargin_iterator(data, settings):
def train_iterator():
while True:
for items in data:
path, label = items
path = DATA_DIR + path
yield path, label
return train_iterator
def image_iterator(data, mode):
def val_iterator():
for items in data:
path, label = items
path = DATA_DIR + path
yield path, label
def test_iterator():
for item in data:
path = item
path = DATA_DIR + path
yield [path]
if mode == 'val':
return val_iterator
else:
return test_iterator
def createreader(settings, mode):
def metric_reader():
if mode == 'train':
train_data, train_image_list = init_sop('train')
loss_name = settings.loss_name
if loss_name in ["softmax", "arcmargin"]:
return arcmargin_iterator(train_image_list, settings)()
elif loss_name == 'triplet':
return triplet_iterator(train_data, settings)()
else:
return common_iterator(train_data, settings)()
elif mode == 'val':
val_data, val_image_list = init_sop('val')
return image_iterator(val_image_list, 'val')()
else:
test_image_list = init_sop('test')
return image_iterator(test_image_list, 'test')()
image_shape = settings.image_shape.split(',')
assert(image_shape[1] == image_shape[2])
image_size = int(image_shape[2])
#keep_order = False if mode != 'train' or settings.loss_name in ['softmax', 'arcmargin'] else True
keep_order = True
image_mapper = functools.partial(process_image,
mode=mode, color_jitter=False, rotate=False, crop_size=image_size)
reader = paddle.reader.xmap_readers(
image_mapper, metric_reader, 8, 1000, order=keep_order)
return reader
def train(settings):
return createreader(settings, "train")
def test(settings):
return createreader(settings, "val")
def infer(settings):
return createreader(settings, "test")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册