提交 0d7d65d2 编写于 作者: L Liu Yiqun

Calculate and print the average time for video models.

上级 93c4daa4
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self._cnt = 0
self._total_time = 0
def record(self, usetime):
self._cnt += 1
self._total_time += usetime
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
......@@ -19,6 +19,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import profiler
from utils.timer import TimeAverager
import logging
import shutil
......@@ -82,29 +83,40 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader
is_profiler = None, profiler_path = None):
if not train_dataloader:
logger.error("[TRAIN] get dataloader failed.")
epoch_periods = []
train_loss = 0
epoch_periods = []
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
for epoch in range(epochs):
log_lr_and_step()
train_iter = 0
epoch_periods = []
cur_time = time.time()
batch_start = time.time()
for data in train_dataloader():
reader_cost_averager.record(time.time() - batch_start)
train_outs = exe.run(compiled_train_prog,
fetch_list=train_fetch_list,
feed=data)
period = time.time() - cur_time
epoch_periods.append(period)
timeStamp = time.time()
localTime = time.localtime(timeStamp)
strTime = time.strftime("%Y-%m-%d %H:%M:%S", localTime)
batch_cost = time.time() - batch_start
epoch_periods.append(batch_cost)
batch_cost_averager.record(batch_cost)
local_time = time.localtime(time.time())
str_time = time.strftime("%Y-%m-%d %H:%M:%S", local_time)
if log_interval > 0 and (train_iter % log_interval == 0):
train_metrics.calculate_and_log_out(train_outs, \
info = '[TRAIN {}] Epoch {}, iter {}, time {}, '.format(strTime, epoch, train_iter, period))
info = '[TRAIN {}] Epoch {}, iter {}, batch_cost {:.5}, reader_cost {:.5}'.format(str_time, epoch, train_iter, batch_cost_averager.get_average(), reader_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
train_iter += 1
cur_time = time.time()
batch_start = time.time()
# NOTE: profiler tools, used for benchmark
if is_profiler and epoch == 0 and train_iter == log_interval:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册