未验证 提交 0952a2de 编写于 作者: H hysunflower 提交者: GitHub

add_maxiter_for_ptb (#4605)

上级 1bff563e
...@@ -57,6 +57,12 @@ def parse_args(): ...@@ -57,6 +57,12 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help='dir to init model.') help='dir to init model.')
# NOTE: used for benchmark
parser.add_argument(
'--max_iter',
type=int,
default=0,
help='the max iters for train, used for benchmark.')
parser.add_argument('--ce', action='store_true', help="run ce") parser.add_argument('--ce', action='store_true', help="run ce")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -374,6 +374,8 @@ def train_ptb_lm(): ...@@ -374,6 +374,8 @@ def train_ptb_lm():
ce_time = [] ce_time = []
ce_ppl = [] ce_ppl = []
total_batch_num = 0 #this is for benchmark
for epoch_id in range(max_epoch): for epoch_id in range(max_epoch):
ptb_model.train() ptb_model.train()
total_loss = 0.0 total_loss = 0.0
...@@ -389,6 +391,9 @@ def train_ptb_lm(): ...@@ -389,6 +391,9 @@ def train_ptb_lm():
init_cell = to_variable(init_cell_data) init_cell = to_variable(init_cell_data)
start_time = time.time() start_time = time.time()
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
if args.max_iter and total_batch_num == args.max_iter:
return
batch_start = time.time()
x_data, y_data = batch x_data, y_data = batch
x_data = x_data.reshape((-1, num_steps, 1)) x_data = x_data.reshape((-1, num_steps, 1))
...@@ -408,13 +413,16 @@ def train_ptb_lm(): ...@@ -408,13 +413,16 @@ def train_ptb_lm():
ptb_model.clear_gradients() ptb_model.clear_gradients()
total_loss += out_loss total_loss += out_loss
batch_end = time.time()
train_batch_cost = batch_end - batch_start
iters += num_steps iters += num_steps
total_batch_num = total_batch_num + 1 #this is for benchmark
if batch_id > 0 and batch_id % log_interval == 0: if batch_id > 0 and batch_id % log_interval == 0:
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
print("-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f" % print("-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f, batch cost: %.5f" %
(epoch_id, batch_id, ppl[0], (epoch_id, batch_id, ppl[0],
sgd._global_learning_rate().numpy(), out_loss)) sgd._global_learning_rate().numpy(), out_loss, train_batch_cost))
print("one epoch finished", epoch_id) print("one epoch finished", epoch_id)
print("time cost ", time.time() - start_time) print("time cost ", time.time() - start_time)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册