未验证 提交 662e396d 编写于 作者: H hong 提交者: GitHub

add reader cost, and change speed to ips (#4933)

* add reader cost, and change speed to ips; test=develop

* fix inf bug; test=develop
上级 5338dde4
......@@ -196,7 +196,7 @@ class BaseModel(object):
max_tar_seq_len = layers.shape(self.tar)[1]
tar_mask = layers.sequence_mask(
self.tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32')
loss = loss * tar_mask
loss = layers.elementwise_mul(loss, tar_mask, axis=0)
loss = layers.reduce_mean(loss, dim=[0])
loss = layers.reduce_sum(loss)
return loss
......
......@@ -192,10 +192,13 @@ def main():
time_interval = 0.0
batch_start_time = time.time()
epoch_word_count = 0.0
total_reader_cost = 0.0
batch_read_start = time.time()
for batch_id, batch in enumerate(train_data_iter):
input_data_feed, word_num = prepare_input(
batch, epoch_id=epoch_id)
word_count += word_num
total_reader_cost += time.time() - batch_read_start
fetch_outs = exe.run(program=CompiledProgram,
feed=input_data_feed,
fetch_list=[loss.name],
......@@ -212,14 +215,15 @@ def main():
if batch_id > 0 and batch_id % 100 == 0:
print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f; speed: %0.5f tokens/sec"
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f; reader cost: %0.5f s; ips: %0.5f tokens/sec"
% (epoch_id, batch_id, batch_time,
np.exp(total_loss / word_count),
word_count / time_interval))
total_reader_cost / 100, word_count / time_interval))
ce_ppl.append(np.exp(total_loss / word_count))
total_loss = 0.0
word_count = 0.0
time_interval = 0.0
total_reader_cost = 0.0
# profiler tools
if args.profile and epoch_id == 0 and batch_id == 100:
......@@ -227,12 +231,13 @@ def main():
elif args.profile and epoch_id == 0 and batch_id == 105:
return
batch_start_time = time.time()
batch_read_start = time.time()
end_time = time.time()
epoch_time = end_time - start_time
ce_time.append(epoch_time)
print(
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step; speed: %0.5f tokens/sec\n"
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step; ips: %0.5f tokens/sec\n"
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times),
epoch_word_count / sum(batch_times)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册