未验证 提交 f09c4423 编写于 作者: T Tao Luo 提交者: GitHub

add ips for dygraph mobilenet and resnet models (#4883)

test=develop
上级 fa73c7f5
......@@ -38,20 +38,25 @@ args = parse_args()
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
print_arguments(args)
class TimeCostAverage(object):
def __init__(self):
self.reset()
def reset(self):
self.cnt = 0
self.total_time = 0
def record(self, usetime):
self.cnt += 1
self.total_time += usetime
def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt
def eval(net, test_data_loader, eop):
total_loss = 0.0
total_acc1 = 0.0
......@@ -232,13 +237,14 @@ def train_mobilenet():
total_batch_num = total_batch_num + 1
if batch_id % args.print_step == 0:
ips = float(args.batch_size) / batch_cost_avg.get_average()
print(
"[Epoch %d, batch %d], avg_loss %.5f, acc_top1 %.5f, acc_top5 %.5f, batch_cost: %.5f s, net_t: %.5f s, backward_t: %.5f s, reader_t: %.5f s"
"[Epoch %d, batch %d], avg_loss %.5f, acc_top1 %.5f, acc_top5 %.5f, batch_cost: %.5f s, net_t: %.5f s, backward_t: %.5f s, reader_t: %.5f s, ips: %.5f images/sec"
% (eop, batch_id, avg_loss.numpy(), acc_top1.numpy(),
acc_top5.numpy(), batch_cost_avg.get_average(),
batch_net_avg.get_average(),
batch_backward_avg.get_average(),
batch_reader_avg.get_average()))
batch_reader_avg.get_average(), ips))
sys.stdout.flush()
batch_cost_avg.reset()
batch_net_avg.reset()
......
......@@ -494,12 +494,14 @@ def train_resnet():
total_batch_num = total_batch_num + 1 #this is for benchmark
if batch_id % 10 == 0:
ips = float(
args.batch_size) / train_batch_cost_avg.get_average()
print(
"[Epoch %d, batch %d] loss %.5f, acc1 %.5f, acc5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s"
"[Epoch %d, batch %d] loss %.5f, acc1 %.5f, acc5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s, ips: %.5f images/sec"
% (eop, batch_id, total_loss / total_sample,
total_acc1 / total_sample, total_acc5 / total_sample,
train_batch_cost_avg.get_average(),
train_reader_cost_avg.get_average()))
train_reader_cost_avg.get_average(), ips))
train_batch_cost_avg.reset()
train_reader_cost_avg.reset()
batch_start = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册