未验证 提交 cf186f39 编写于 作者: W wanghuancoder 提交者: GitHub

fix ptb_dy time print for benchmark, test=develop (#4866)

上级 db6ce5e9
......@@ -37,6 +37,19 @@ if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
class TimeCostAverage(object):
def __init__(self):
self.reset()
def reset(self):
self.cnt = 0
self.total_time = 0
def record(self, usetime):
self.cnt += 1
self.total_time += usetime
def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt
class SimpleLSTMRNN(fluid.Layer):
def __init__(self,
......@@ -405,10 +418,17 @@ def train_ptb_lm():
init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data)
batch_cost_avg = TimeCostAverage()
reader_cost_avg = TimeCostAverage()
batch_start = time.time()
for batch_id, batch in enumerate(train_data_loader):
if args.max_iter and total_batch_num == args.max_iter:
return
train_reader_cost = time.time() - batch_start
reader_cost_avg.record(train_reader_cost)
x, y = batch
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
......@@ -426,13 +446,17 @@ def train_ptb_lm():
total_batch_num = total_batch_num + 1 #this is for benchmark
train_batch_cost = time.time() - batch_start
batch_cost_avg.record(train_batch_cost)
if batch_id > 0 and batch_id % log_interval == 0:
ppl = np.exp(total_loss / iters)
print(
"-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f, batch_cost: %.5f s"
"-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f, batch_cost: %.5f s, reader_cost: %.5f s"
% (epoch_id, batch_id, ppl[0],
sgd._global_learning_rate().numpy(), out_loss,
train_batch_cost))
batch_cost_avg.get_average(), reader_cost_avg.get_average()))
batch_cost_avg.reset()
reader_cost_avg.reset()
batch_start = time.time()
ppl = np.exp(total_loss / iters)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册