未验证 提交 38ada7ff 编写于 作者: T Tao Luo 提交者: GitHub

fix resnet dygraph model time print (#4868)

上级 ba9a787d
......@@ -144,6 +144,24 @@ args = parse_args()
batch_size = args.batch_size
class TimeCostAverage(object):
def __init__(self):
self.reset()
def reset(self):
self.cnt = 0
self.total_time = 0
def record(self, usetime):
self.cnt += 1
self.total_time += usetime
def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt
def optimizer_setting(parameter_list=None):
total_images = IMAGENET1000
......@@ -433,6 +451,8 @@ def train_resnet():
total_acc5 = 0.0
total_sample = 0
train_batch_cost_avg = TimeCostAverage()
train_reader_cost_avg = TimeCostAverage()
batch_start = time.time()
for batch_id, data in enumerate(train_loader()):
#NOTE: used in benchmark
......@@ -469,13 +489,19 @@ def train_resnet():
total_sample += 1
train_batch_cost = time.time() - batch_start
train_batch_cost_avg.record(train_batch_cost)
train_reader_cost_avg.record(train_reader_cost)
total_batch_num = total_batch_num + 1 #this is for benchmark
if batch_id % 10 == 0:
print(
"[Epoch %d, batch %d] loss %.5f, acc1 %.5f, acc5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s"
% (eop, batch_id, total_loss / total_sample,
total_acc1 / total_sample, total_acc5 / total_sample,
train_batch_cost, train_reader_cost))
train_batch_cost_avg.get_average(),
train_reader_cost_avg.get_average()))
train_batch_cost_avg.reset()
train_reader_cost_avg.reset()
batch_start = time.time()
if args.ce:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册