提交 6fdd5de2 编写于 作者: Y yi.wu

update

上级 8893cf12
...@@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
# FIXME(wuyi): For use_reader_op, if the current # FIXME(wuyi): For use_reader_op, if the current
# pass is not the last, the last batch of this pass # pass is not the last, the last batch of this pass
# is also equal to args.batch_size. # is also equal to args.batch_size.
num_samples += len(args.batch_size) if args.use_reader_op:
num_samples += args.batch_size
else:
num_samples += len(data)
train_losses.append(loss) train_losses.append(loss)
print("Pass: %d, Iter: %d, Loss: %f\n" % print("Pass: %d, Iter: %d, Loss: %f\n" %
(pass_id, iters, np.mean(train_losses))) (pass_id, iters, np.mean(train_losses)))
...@@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
if iters == args.skip_batch_num: if iters == args.skip_batch_num:
start_time = time.time() start_time = time.time()
num_samples = 0 num_samples = 0
# NOTE: if use reader ops, the input data is not splited to multiple cards
if args.use_reader_op and iters >= args.iterations / args.gpus:
break
if args.use_fake_data or args.use_reader_op: if args.use_fake_data or args.use_reader_op:
try: try:
loss, = exe.run([avg_loss.name]) loss, = exe.run([avg_loss.name])
...@@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
loss, = exe.run([avg_loss.name], feed=feeder.feed(data)) loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
if args.update_method == "pserver": if args.update_method == "pserver":
exe.bcast_params() exe.bcast_params()
num_samples += len(data) if args.use_reader_op:
num_samples += args.batch_size
else:
num_samples += len(data)
iters += 1 iters += 1
if batch_id % 1 == 0: if batch_id % 1 == 0:
print("Pass %d, batch %d, loss %s" % print("Pass %d, batch %d, loss %s" %
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册