未验证 提交 5b584500 编写于 作者: W whs 提交者: GitHub

Fix reader of slim. (#2477)

* Fix reader of slim.

* Remove unused comments.
上级 2aed7f63
......@@ -133,33 +133,36 @@ def _reader_creator(file_list,
data_dir=DATA_DIR,
batch_size=1):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
else:
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
# img_path = img_path.replace("JPEG", "jpeg")
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
img_path = os.path.join(data_dir, line)
yield [img_path]
try:
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(
trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
else:
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
img_path = os.path.join(data_dir, line)
yield [img_path]
except Exception as e:
print("Reader failed!\n{}".format(str(e)))
os._exit(1)
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册