提交 cececbbf 编写于 作者: D danleifeng

add reader data_format

上级 b85e4841
......@@ -693,7 +693,8 @@ class Entry(object):
if self.predict_reader is None:
predict_reader = paddle.batch(reader.arc_train(self.dataset_dir,
self.num_classes),
self.num_classes,
data_format=self.data_format),
batch_size=self.train_batch_size)
else:
predict_reader = self.predict_reader
......@@ -925,7 +926,7 @@ class Entry(object):
if self.train_reader is None:
train_reader = paddle.batch(reader.arc_train(
self.dataset_dir, self.num_classes),
self.dataset_dir, self.num_classes, data_format=self.data_format),
batch_size=self.train_batch_size)
else:
train_reader = self.train_reader
......
......@@ -172,7 +172,8 @@ def process_image(sample,
color_jitter,
rotate,
rand_mirror,
normalize):
normalize,
data_format='NCHW'):
img_data = base64.b64decode(sample[0])
img = Image.open(StringIO(img_data))
......@@ -198,6 +199,9 @@ def process_image(sample,
assert sample[1] < class_dim, \
"label of train dataset should be less than the class_dim."
if data_format == 'NHWC':
img = img.transpose((1, 2, 0))
return img, sample[1]
......@@ -208,7 +212,8 @@ def arc_iterator(data_dir,
color_jitter=False,
rotate=False,
rand_mirror=False,
normalize=False):
normalize=False,
data_format='NCHW'):
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
......@@ -237,11 +242,12 @@ def arc_iterator(data_dir,
color_jitter=color_jitter,
rotate=rotate,
rand_mirror=rand_mirror,
normalize=normalize)
normalize=normalize,
data_format=data_format)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_bin(path, image_size):
def load_bin(path, image_size, data_format ='NCHW'):
if six.PY2:
bins, issame_list = pickle.load(open(path, 'rb'))
else:
......@@ -267,6 +273,8 @@ def load_bin(path, image_size):
img = np.array(img).astype('float32').transpose((2, 0, 1))
img -= img_mean
img /= img_std
if data_format == 'NHWC':
img = img.transpose((1, 2, 0))
data_list[flip][i][:] = img
if i % 1000 == 0:
print('loading bin', i)
......@@ -274,7 +282,7 @@ def load_bin(path, image_size):
return data_list, issame_list
def train(data_dir, num_classes):
def train(data_dir, num_classes, data_format ='NCHW'):
file_path = os.path.join(data_dir, 'file_list.txt')
return arc_iterator(data_dir,
file_path,
......@@ -282,16 +290,17 @@ def train(data_dir, num_classes):
color_jitter=False,
rotate=False,
rand_mirror=True,
normalize=True)
normalize=True,
data_format=data_format)
def test(data_dir, datasets):
def test(data_dir, datasets, data_format ='NCHW'):
test_list = []
test_name_list = []
for name in datasets.split(','):
path = os.path.join(data_dir, name+".bin")
if os.path.exists(path):
data_set = load_bin(path, (DATA_DIM, DATA_DIM))
data_set = load_bin(path, (DATA_DIM, DATA_DIM), data_format=data_format)
test_list.append(data_set)
test_name_list.append(name)
print('test', name)
......
......@@ -184,7 +184,8 @@ def process_image_imagepath(sample,
color_jitter,
rotate,
rand_mirror,
normalize):
normalize,
data_format='NCHW'):
imgpath = sample[0]
img = Image.open(imgpath)
......@@ -211,6 +212,9 @@ def process_image_imagepath(sample,
assert sample[1] < class_dim, \
"label of train dataset should be less than the class_dim."
if data_format == 'NHWC':
img = img.transpose((1, 2, 0))
return img, sample[1]
......@@ -221,7 +225,8 @@ def arc_iterator(data,
color_jitter=False,
rotate=False,
rand_mirror=False,
normalize=False):
normalize=False,
data_format ='NCHW'):
def reader():
if shuffle:
random.shuffle(data)
......@@ -235,11 +240,12 @@ def arc_iterator(data,
color_jitter=color_jitter,
rotate=rotate,
rand_mirror=rand_mirror,
normalize=normalize)
normalize=normalize,
data_format=data_format)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_bin(path, image_size):
def load_bin(path, image_size, data_format ='NCHW'):
if six.PY2:
bins, issame_list = pickle.load(open(path, 'rb'))
else:
......@@ -265,6 +271,8 @@ def load_bin(path, image_size):
img = np.array(img).astype('float32').transpose((2, 0, 1))
img -= img_mean
img /= img_std
if data_format == 'NHWC':
img = img.transpose((1, 2, 0))
data_list[flip][i][:] = img
if i % 1000 == 0:
print('loading bin', i)
......@@ -272,7 +280,7 @@ def load_bin(path, image_size):
return data_list, issame_list
def arc_train(data_dir, class_dim):
def arc_train(data_dir, class_dim, data_format ='NCHW'):
train_image_list = get_train_image_list(data_dir)
return arc_iterator(train_image_list,
shuffle=True,
......@@ -281,16 +289,17 @@ def arc_train(data_dir, class_dim):
color_jitter=False,
rotate=False,
rand_mirror=True,
normalize=True)
normalize=True,
data_format=data_format)
def test(data_dir, datasets):
def test(data_dir, datasets, data_format ='NCHW'):
test_list = []
test_name_list = []
for name in datasets.split(','):
path = os.path.join(data_dir, name+".bin")
if os.path.exists(path):
data_set = load_bin(path, (DATA_DIM, DATA_DIM))
data_set = load_bin(path, (DATA_DIM, DATA_DIM), data_format=data_format)
test_list.append(data_set)
test_name_list.append(name)
print('test', name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册