未验证 提交 f09c4423 编写于 作者: T Tao Luo 提交者: GitHub

add ips for dygraph mobilenet and resnet models (#4883)

test=develop
上级 fa73c7f5
...@@ -38,20 +38,25 @@ args = parse_args() ...@@ -38,20 +38,25 @@ args = parse_args()
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
print_arguments(args) print_arguments(args)
class TimeCostAverage(object): class TimeCostAverage(object):
def __init__(self): def __init__(self):
self.reset() self.reset()
def reset(self): def reset(self):
self.cnt = 0 self.cnt = 0
self.total_time = 0 self.total_time = 0
def record(self, usetime): def record(self, usetime):
self.cnt += 1 self.cnt += 1
self.total_time += usetime self.total_time += usetime
def get_average(self): def get_average(self):
if self.cnt == 0: if self.cnt == 0:
return 0 return 0
return self.total_time / self.cnt return self.total_time / self.cnt
def eval(net, test_data_loader, eop): def eval(net, test_data_loader, eop):
total_loss = 0.0 total_loss = 0.0
total_acc1 = 0.0 total_acc1 = 0.0
...@@ -232,13 +237,14 @@ def train_mobilenet(): ...@@ -232,13 +237,14 @@ def train_mobilenet():
total_batch_num = total_batch_num + 1 total_batch_num = total_batch_num + 1
if batch_id % args.print_step == 0: if batch_id % args.print_step == 0:
ips = float(args.batch_size) / batch_cost_avg.get_average()
print( print(
"[Epoch %d, batch %d], avg_loss %.5f, acc_top1 %.5f, acc_top5 %.5f, batch_cost: %.5f s, net_t: %.5f s, backward_t: %.5f s, reader_t: %.5f s" "[Epoch %d, batch %d], avg_loss %.5f, acc_top1 %.5f, acc_top5 %.5f, batch_cost: %.5f s, net_t: %.5f s, backward_t: %.5f s, reader_t: %.5f s, ips: %.5f images/sec"
% (eop, batch_id, avg_loss.numpy(), acc_top1.numpy(), % (eop, batch_id, avg_loss.numpy(), acc_top1.numpy(),
acc_top5.numpy(), batch_cost_avg.get_average(), acc_top5.numpy(), batch_cost_avg.get_average(),
batch_net_avg.get_average(), batch_net_avg.get_average(),
batch_backward_avg.get_average(), batch_backward_avg.get_average(),
batch_reader_avg.get_average())) batch_reader_avg.get_average(), ips))
sys.stdout.flush() sys.stdout.flush()
batch_cost_avg.reset() batch_cost_avg.reset()
batch_net_avg.reset() batch_net_avg.reset()
......
...@@ -494,12 +494,14 @@ def train_resnet(): ...@@ -494,12 +494,14 @@ def train_resnet():
total_batch_num = total_batch_num + 1 #this is for benchmark total_batch_num = total_batch_num + 1 #this is for benchmark
if batch_id % 10 == 0: if batch_id % 10 == 0:
ips = float(
args.batch_size) / train_batch_cost_avg.get_average()
print( print(
"[Epoch %d, batch %d] loss %.5f, acc1 %.5f, acc5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s" "[Epoch %d, batch %d] loss %.5f, acc1 %.5f, acc5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s, ips: %.5f images/sec"
% (eop, batch_id, total_loss / total_sample, % (eop, batch_id, total_loss / total_sample,
total_acc1 / total_sample, total_acc5 / total_sample, total_acc1 / total_sample, total_acc5 / total_sample,
train_batch_cost_avg.get_average(), train_batch_cost_avg.get_average(),
train_reader_cost_avg.get_average())) train_reader_cost_avg.get_average(), ips))
train_batch_cost_avg.reset() train_batch_cost_avg.reset()
train_reader_cost_avg.reset() train_reader_cost_avg.reset()
batch_start = time.time() batch_start = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册