未验证 提交 db6ce5e9 编写于 作者: W wanghuancoder 提交者: GitHub

fix language model time print (#4865)

* fix language_model timecost algorithm

* fix dataloader time calc, test=develop
上级 295c16b6
......@@ -49,6 +49,19 @@ import pickle
SEED = 123
class TimeCostAverage(object):
def __init__(self):
self.reset()
def reset(self):
self.cnt = 0
self.total_time = 0
def record(self, usetime):
self.cnt += 1
self.total_time += usetime
def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt
@contextlib.contextmanager
def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'):
......@@ -293,8 +306,10 @@ def main():
total_loss = 0
iters = 0
batch_cost_avg = TimeCostAverage()
init_hidden, init_cell = generate_init_data()
batch_start_time = time.time()
for batch_id, batch in enumerate(train_data_iter):
input_data_feed = prepare_input(
batch,
......@@ -303,7 +318,6 @@ def main():
epoch_id=epoch_id,
with_lr=True,
device_count=device_count)
batch_start_time = time.time()
fetch_outs = exe.run(train_program,
feed=input_data_feed,
fetch_list=[
......@@ -313,6 +327,7 @@ def main():
use_program_cache=True)
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
batch_cost_avg.record(batch_time)
cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1])
......@@ -324,13 +339,17 @@ def main():
ppl = np.exp(total_loss / iters)
print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
% (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
batch_cost_avg.reset()
# profiler tools for benchmark
if args.profile and batch_id == log_interval:
profiler.reset_profiler()
elif args.profile and batch_id == (log_interval + 5):
break
batch_start_time = time.time()
ppl = np.exp(total_loss / iters)
return ppl
......@@ -342,6 +361,7 @@ def main():
total_loss = 0
iters = 0
batch_cost_avg = TimeCostAverage()
dataloader.start()
batch_id = 0
......@@ -355,6 +375,7 @@ def main():
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
batch_start_time = time.time()
batch_cost_avg.record(batch_time)
new_lr = generate_new_lr(epoch_id, device_count)
data_feeds['learning_rate'] = new_lr
......@@ -381,7 +402,8 @@ def main():
ppl = np.exp(total_loss / iters)
print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
% (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
batch_cost_avg.reset()
batch_id += 1
# profiler tools for benchmark
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册