未验证 提交 0952a2de 编写于 作者: H hysunflower 提交者: GitHub

add_maxiter_for_ptb (#4605)

上级 1bff563e
......@@ -57,6 +57,12 @@ def parse_args():
type=str,
default=None,
help='dir to init model.')
# NOTE: used for benchmark
parser.add_argument(
'--max_iter',
type=int,
default=0,
help='the max iters for train, used for benchmark.')
parser.add_argument('--ce', action='store_true', help="run ce")
args = parser.parse_args()
return args
......@@ -374,6 +374,8 @@ def train_ptb_lm():
ce_time = []
ce_ppl = []
total_batch_num = 0 #this is for benchmark
for epoch_id in range(max_epoch):
ptb_model.train()
total_loss = 0.0
......@@ -389,6 +391,9 @@ def train_ptb_lm():
init_cell = to_variable(init_cell_data)
start_time = time.time()
for batch_id, batch in enumerate(train_data_iter):
if args.max_iter and total_batch_num == args.max_iter:
return
batch_start = time.time()
x_data, y_data = batch
x_data = x_data.reshape((-1, num_steps, 1))
......@@ -408,13 +413,16 @@ def train_ptb_lm():
ptb_model.clear_gradients()
total_loss += out_loss
batch_end = time.time()
train_batch_cost = batch_end - batch_start
iters += num_steps
total_batch_num = total_batch_num + 1 #this is for benchmark
if batch_id > 0 and batch_id % log_interval == 0:
ppl = np.exp(total_loss / iters)
print("-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f" %
print("-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f, batch cost: %.5f" %
(epoch_id, batch_id, ppl[0],
sgd._global_learning_rate().numpy(), out_loss))
sgd._global_learning_rate().numpy(), out_loss, train_batch_cost))
print("one epoch finished", epoch_id)
print("time cost ", time.time() - start_time)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册