未验证 提交 db6ce5e9 编写于 作者: W wanghuancoder 提交者: GitHub

fix language model time print (#4865)

* fix language_model timecost algorithm

* fix dataloader time calc, test=develop
上级 295c16b6
...@@ -49,6 +49,19 @@ import pickle ...@@ -49,6 +49,19 @@ import pickle
SEED = 123 SEED = 123
class TimeCostAverage(object):
def __init__(self):
self.reset()
def reset(self):
self.cnt = 0
self.total_time = 0
def record(self, usetime):
self.cnt += 1
self.total_time += usetime
def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt
@contextlib.contextmanager @contextlib.contextmanager
def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'): def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'):
...@@ -293,8 +306,10 @@ def main(): ...@@ -293,8 +306,10 @@ def main():
total_loss = 0 total_loss = 0
iters = 0 iters = 0
batch_cost_avg = TimeCostAverage()
init_hidden, init_cell = generate_init_data() init_hidden, init_cell = generate_init_data()
batch_start_time = time.time()
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
input_data_feed = prepare_input( input_data_feed = prepare_input(
batch, batch,
...@@ -303,7 +318,6 @@ def main(): ...@@ -303,7 +318,6 @@ def main():
epoch_id=epoch_id, epoch_id=epoch_id,
with_lr=True, with_lr=True,
device_count=device_count) device_count=device_count)
batch_start_time = time.time()
fetch_outs = exe.run(train_program, fetch_outs = exe.run(train_program,
feed=input_data_feed, feed=input_data_feed,
fetch_list=[ fetch_list=[
...@@ -313,6 +327,7 @@ def main(): ...@@ -313,6 +327,7 @@ def main():
use_program_cache=True) use_program_cache=True)
batch_time = time.time() - batch_start_time batch_time = time.time() - batch_start_time
batch_times.append(batch_time) batch_times.append(batch_time)
batch_cost_avg.record(batch_time)
cost_train = np.array(fetch_outs[0]) cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1]) lr = np.array(fetch_outs[1])
...@@ -324,13 +339,17 @@ def main(): ...@@ -324,13 +339,17 @@ def main():
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
print( print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_time, ppl[0], lr[0])) % (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
batch_cost_avg.reset()
# profiler tools for benchmark # profiler tools for benchmark
if args.profile and batch_id == log_interval: if args.profile and batch_id == log_interval:
profiler.reset_profiler() profiler.reset_profiler()
elif args.profile and batch_id == (log_interval + 5): elif args.profile and batch_id == (log_interval + 5):
break break
batch_start_time = time.time()
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
return ppl return ppl
...@@ -342,6 +361,7 @@ def main(): ...@@ -342,6 +361,7 @@ def main():
total_loss = 0 total_loss = 0
iters = 0 iters = 0
batch_cost_avg = TimeCostAverage()
dataloader.start() dataloader.start()
batch_id = 0 batch_id = 0
...@@ -355,6 +375,7 @@ def main(): ...@@ -355,6 +375,7 @@ def main():
batch_time = time.time() - batch_start_time batch_time = time.time() - batch_start_time
batch_times.append(batch_time) batch_times.append(batch_time)
batch_start_time = time.time() batch_start_time = time.time()
batch_cost_avg.record(batch_time)
new_lr = generate_new_lr(epoch_id, device_count) new_lr = generate_new_lr(epoch_id, device_count)
data_feeds['learning_rate'] = new_lr data_feeds['learning_rate'] = new_lr
...@@ -381,7 +402,8 @@ def main(): ...@@ -381,7 +402,8 @@ def main():
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
print( print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_time, ppl[0], lr[0])) % (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
batch_cost_avg.reset()
batch_id += 1 batch_id += 1
# profiler tools for benchmark # profiler tools for benchmark
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册