未验证 提交 a6f42259 编写于 作者: M minghaoBD 提交者: GitHub

fix imagenet reader, test=develop (#835)

* fix imagenet reader, test=develop

* update unstructuredPruner accordingly
上级 10e7c2da
......@@ -191,7 +191,7 @@ def compress(args):
opt.clear_grad()
pruner.step()
train_run_cost += time.time() - train_start
total_samples += args.batch_size * ParallelEnv().nranks
total_samples += args.batch_size
if batch_id % args.log_period == 0:
_logger.info(
......
......@@ -144,20 +144,7 @@ def _reader_creator(file_list,
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(
trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
else:
lines = full_lines
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册