提交 614946cf 编写于 作者: L Liu Yiqun

Calculate and print the average time for dygraph of slowfast.

上级 0d7d65d2
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self._cnt = 0
self._total_time = 0
def record(self, usetime):
self._cnt += 1
self._total_time += usetime
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
......@@ -30,6 +30,7 @@ from model import *
from config_utils import *
from lr_policy import get_epoch_lr
from kinetics_dataset import KineticsDataset
from timer import TimeAverager
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
......@@ -345,6 +346,8 @@ def train(args):
+ str(local_rank))
# 4. train loop
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
for epoch in range(train_config.TRAIN.epoch):
epoch_start = time.time()
if args.resume and epoch <= args.resume_epoch:
......@@ -361,7 +364,8 @@ def train(args):
train_config.TRAIN.epoch))
batch_start = time.time()
for batch_id, data in enumerate(train_loader):
batch_reader_end = time.time()
reader_cost_averager.record(time.time() - batch_start)
y_data = data[2]
labels = to_variable(y_data)
labels.stop_gradient = True
......@@ -405,12 +409,18 @@ def train(args):
step=epoch * train_iter_num + batch_id,
value=1.0 - acc_top5.numpy())
train_batch_cost = time.time() - batch_start
train_reader_cost = batch_reader_end - batch_start
batch_start = time.time()
batch_cost_averager.record(time.time() - batch_start)
if batch_id % args.log_interval == 0:
print( "[Epoch %d, batch %d] loss %.5f, err1 %.5f, err5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s" % \
(epoch, batch_id, avg_loss.numpy(), 1.0 - acc_top1.numpy(), 1. - acc_top5.numpy(), train_batch_cost, train_reader_cost))
print(
"[Epoch %d, batch %d] loss %.5f, err1 %.5f, err5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s"
% (epoch, batch_id, avg_loss.numpy(),
1.0 - acc_top1.numpy(), 1. - acc_top5.numpy(),
batch_cost_averager.get_average(),
reader_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
batch_start = time.time()
train_epoch_cost = time.time() - epoch_start
print( '[Epoch %d end] avg_loss %.5f, avg_err1 %.5f, avg_err5= %.5f, epoch_cost: %.5f s' % \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册