提交 0d7d65d2 编写于 作者: L Liu Yiqun

Calculate and print the average time for video models.

上级 93c4daa4
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self._cnt = 0
self._total_time = 0
def record(self, usetime):
self._cnt += 1
self._total_time += usetime
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler from paddle.fluid import profiler
from utils.timer import TimeAverager
import logging import logging
import shutil import shutil
...@@ -82,29 +83,40 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader ...@@ -82,29 +83,40 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader
is_profiler = None, profiler_path = None): is_profiler = None, profiler_path = None):
if not train_dataloader: if not train_dataloader:
logger.error("[TRAIN] get dataloader failed.") logger.error("[TRAIN] get dataloader failed.")
epoch_periods = []
train_loss = 0 train_loss = 0
epoch_periods = []
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
for epoch in range(epochs): for epoch in range(epochs):
log_lr_and_step() log_lr_and_step()
train_iter = 0 train_iter = 0
epoch_periods = [] epoch_periods = []
cur_time = time.time() batch_start = time.time()
for data in train_dataloader(): for data in train_dataloader():
reader_cost_averager.record(time.time() - batch_start)
train_outs = exe.run(compiled_train_prog, train_outs = exe.run(compiled_train_prog,
fetch_list=train_fetch_list, fetch_list=train_fetch_list,
feed=data) feed=data)
period = time.time() - cur_time
epoch_periods.append(period) batch_cost = time.time() - batch_start
timeStamp = time.time() epoch_periods.append(batch_cost)
localTime = time.localtime(timeStamp) batch_cost_averager.record(batch_cost)
strTime = time.strftime("%Y-%m-%d %H:%M:%S", localTime)
local_time = time.localtime(time.time())
str_time = time.strftime("%Y-%m-%d %H:%M:%S", local_time)
if log_interval > 0 and (train_iter % log_interval == 0): if log_interval > 0 and (train_iter % log_interval == 0):
train_metrics.calculate_and_log_out(train_outs, \ train_metrics.calculate_and_log_out(train_outs, \
info = '[TRAIN {}] Epoch {}, iter {}, time {}, '.format(strTime, epoch, train_iter, period)) info = '[TRAIN {}] Epoch {}, iter {}, batch_cost {:.5}, reader_cost {:.5}'.format(str_time, epoch, train_iter, batch_cost_averager.get_average(), reader_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
train_iter += 1 train_iter += 1
cur_time = time.time() batch_start = time.time()
# NOTE: profiler tools, used for benchmark # NOTE: profiler tools, used for benchmark
if is_profiler and epoch == 0 and train_iter == log_interval: if is_profiler and epoch == 0 and train_iter == log_interval:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册