提交 6fdd5de2 编写于 作者: Y yi.wu

update

上级 8893cf12
......@@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
# FIXME(wuyi): For use_reader_op, if the current
# pass is not the last, the last batch of this pass
# is also equal to args.batch_size.
num_samples += len(args.batch_size)
if args.use_reader_op:
num_samples += args.batch_size
else:
num_samples += len(data)
train_losses.append(loss)
print("Pass: %d, Iter: %d, Loss: %f\n" %
(pass_id, iters, np.mean(train_losses)))
......@@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
# NOTE: if use reader ops, the input data is not splited to multiple cards
if args.use_reader_op and iters >= args.iterations / args.gpus:
break
if args.use_fake_data or args.use_reader_op:
try:
loss, = exe.run([avg_loss.name])
......@@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
if args.update_method == "pserver":
exe.bcast_params()
num_samples += len(data)
if args.use_reader_op:
num_samples += args.batch_size
else:
num_samples += len(data)
iters += 1
if batch_id % 1 == 0:
print("Pass %d, batch %d, loss %s" %
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册