提交 64410f01 编写于 作者: H hong 提交者: Leo Chen

update ptb to support remove build once; test=develop (#4100)

上级 54f64c66
...@@ -42,13 +42,12 @@ if sys.version[0] == '2': ...@@ -42,13 +42,12 @@ if sys.version[0] == '2':
class SimpleLSTMRNN(fluid.Layer): class SimpleLSTMRNN(fluid.Layer):
def __init__(self, def __init__(self,
name_scope,
hidden_size, hidden_size,
num_steps, num_steps,
num_layers=2, num_layers=2,
init_scale=0.1, init_scale=0.1,
dropout=None): dropout=None):
super(SimpleLSTMRNN, self).__init__(name_scope) super(SimpleLSTMRNN, self).__init__()
self._hidden_size = hidden_size self._hidden_size = hidden_size
self._num_layers = num_layers self._num_layers = num_layers
self._init_scale = init_scale self._init_scale = init_scale
...@@ -132,14 +131,13 @@ class SimpleLSTMRNN(fluid.Layer): ...@@ -132,14 +131,13 @@ class SimpleLSTMRNN(fluid.Layer):
class PtbModel(fluid.Layer): class PtbModel(fluid.Layer):
def __init__(self, def __init__(self,
name_scope,
hidden_size, hidden_size,
vocab_size, vocab_size,
num_layers=2, num_layers=2,
num_steps=20, num_steps=20,
init_scale=0.1, init_scale=0.1,
dropout=None): dropout=None):
super(PtbModel, self).__init__(name_scope) super(PtbModel, self).__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.init_scale = init_scale self.init_scale = init_scale
...@@ -147,14 +145,12 @@ class PtbModel(fluid.Layer): ...@@ -147,14 +145,12 @@ class PtbModel(fluid.Layer):
self.num_steps = num_steps self.num_steps = num_steps
self.dropout = dropout self.dropout = dropout
self.simple_lstm_rnn = SimpleLSTMRNN( self.simple_lstm_rnn = SimpleLSTMRNN(
self.full_name(),
hidden_size, hidden_size,
num_steps, num_steps,
num_layers=num_layers, num_layers=num_layers,
init_scale=init_scale, init_scale=init_scale,
dropout=dropout) dropout=dropout)
self.embedding = Embedding( self.embedding = Embedding(
self.full_name(),
size=[vocab_size, hidden_size], size=[vocab_size, hidden_size],
dtype='float32', dtype='float32',
is_sparse=False, is_sparse=False,
...@@ -286,7 +282,6 @@ def train_ptb_lm(): ...@@ -286,7 +282,6 @@ def train_ptb_lm():
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
max_epoch = 1 max_epoch = 1
ptb_model = PtbModel( ptb_model = PtbModel(
"ptb_model",
hidden_size=hidden_size, hidden_size=hidden_size,
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=num_layers, num_layers=num_layers,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册