未验证 提交 66f6039f 编写于 作者: L Li Fuchen 提交者: GitHub

fix a bug in language model (#4120)

上级 db713042
...@@ -125,7 +125,6 @@ def main(): ...@@ -125,7 +125,6 @@ def main():
res_vars = lm_model.lm_model( res_vars = lm_model.lm_model(
config.hidden_size, config.hidden_size,
config.vocab_size, config.vocab_size,
config.batch_size,
num_layers=config.num_layers, num_layers=config.num_layers,
num_steps=config.num_steps, num_steps=config.num_steps,
init_scale=config.init_scale, init_scale=config.init_scale,
...@@ -160,7 +159,6 @@ def main(): ...@@ -160,7 +159,6 @@ def main():
lm_model.lm_model( lm_model.lm_model(
config.hidden_size, config.hidden_size,
config.vocab_size, config.vocab_size,
config.batch_size,
num_layers=config.num_layers, num_layers=config.num_layers,
num_steps=config.num_steps, num_steps=config.num_steps,
init_scale=config.init_scale, init_scale=config.init_scale,
...@@ -319,7 +317,7 @@ def main(): ...@@ -319,7 +317,7 @@ def main():
print( print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_time, ppl[0], lr[0])) % (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
# profiler tools for benchmark # profiler tools for benchmark
if args.profile and batch_id == log_interval: if args.profile and batch_id == log_interval:
profiler.reset_profiler() profiler.reset_profiler()
......
...@@ -26,7 +26,6 @@ from paddle.fluid.contrib.layers import basic_lstm ...@@ -26,7 +26,6 @@ from paddle.fluid.contrib.layers import basic_lstm
def lm_model(hidden_size, def lm_model(hidden_size,
vocab_size, vocab_size,
batch_size,
num_layers=2, num_layers=2,
num_steps=20, num_steps=20,
init_scale=0.1, init_scale=0.1,
...@@ -253,7 +252,6 @@ def lm_model(hidden_size, ...@@ -253,7 +252,6 @@ def lm_model(hidden_size,
return real_res, last_hidden, last_cell return real_res, last_hidden, last_cell
batch_size_each = batch_size // fluid.core.get_cuda_device_count()
x = fluid.data(name="x", shape=[None, num_steps, 1], dtype='int64') x = fluid.data(name="x", shape=[None, num_steps, 1], dtype='int64')
y = fluid.data(name="y", shape=[None, 1], dtype='int64') y = fluid.data(name="y", shape=[None, 1], dtype='int64')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册