提交 8071d264 编写于 作者: L Leo Chen 提交者: hong

Dev/update ptb (#4083)

* update ptb model to use auto generated op functions, test=develop

* refine model code, test=develop

* refine model code, test=develop

* update ptb model, test=develop
上级 1c52e005
...@@ -53,7 +53,6 @@ class SimpleLSTMRNN(fluid.Layer): ...@@ -53,7 +53,6 @@ class SimpleLSTMRNN(fluid.Layer):
self._num_layers = num_layers self._num_layers = num_layers
self._init_scale = init_scale self._init_scale = init_scale
self._dropout = dropout self._dropout = dropout
self._input = None
self._num_steps = num_steps self._num_steps = num_steps
self.cell_array = [] self.cell_array = []
self.hidden_array = [] self.hidden_array = []
...@@ -83,34 +82,23 @@ class SimpleLSTMRNN(fluid.Layer): ...@@ -83,34 +82,23 @@ class SimpleLSTMRNN(fluid.Layer):
self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1)) self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1))
def forward(self, input_embedding, init_hidden=None, init_cell=None): def forward(self, input_embedding, init_hidden=None, init_cell=None):
self.cell_array = [] cell_array = []
self.hidden_array = [] hidden_array = []
for i in range(self._num_layers): for i in range(self._num_layers):
pre_hidden = fluid.layers.slice( hidden_array.append(init_hidden[i])
init_hidden, axes=[0], starts=[i], ends=[i + 1]) cell_array.append(init_cell[i])
pre_cell = fluid.layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1])
pre_hidden = fluid.layers.reshape(
pre_hidden, shape=[-1, self._hidden_size])
pre_cell = fluid.layers.reshape(
pre_cell, shape=[-1, self._hidden_size])
self.hidden_array.append(pre_hidden)
self.cell_array.append(pre_cell)
res = [] res = []
for index in range(self._num_steps): for index in range(self._num_steps):
self._input = fluid.layers.slice( step_input = input_embedding[:,index,:]
input_embedding, axes=[1], starts=[index], ends=[index + 1])
self._input = fluid.layers.reshape(
self._input, shape=[-1, self._hidden_size])
for k in range(self._num_layers): for k in range(self._num_layers):
pre_hidden = self.hidden_array[k] pre_hidden = hidden_array[k]
pre_cell = self.cell_array[k] pre_cell = cell_array[k]
weight_1 = self.weight_1_arr[k] weight_1 = self.weight_1_arr[k]
bias = self.bias_arr[k] bias = self.bias_arr[k]
nn = fluid.layers.concat([self._input, pre_hidden], 1) nn = fluid.layers.concat([step_input, pre_hidden], 1)
gate_input = fluid.layers.matmul(x=nn, y=weight_1) gate_input = fluid.layers.matmul(x=nn, y=weight_1)
gate_input = fluid.layers.elementwise_add(gate_input, bias) gate_input = fluid.layers.elementwise_add(gate_input, bias)
...@@ -119,25 +107,23 @@ class SimpleLSTMRNN(fluid.Layer): ...@@ -119,25 +107,23 @@ class SimpleLSTMRNN(fluid.Layer):
c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid( c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid(
i) * fluid.layers.tanh(j) i) * fluid.layers.tanh(j)
m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o) m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o)
self.hidden_array[k] = m hidden_array[k] = m
self.cell_array[k] = c cell_array[k] = c
self._input = m step_input = m
if self._dropout is not None and self._dropout > 0.0: if self._dropout is not None and self._dropout > 0.0:
self._input = fluid.layers.dropout( step_input = fluid.layers.dropout(
self._input, step_input,
dropout_prob=self._dropout, dropout_prob=self._dropout,
dropout_implementation='upscale_in_train') dropout_implementation='upscale_in_train')
res.append( res.append(step_input)
fluid.layers.reshape( real_res = fluid.layers.concat(res, 1)
self._input, shape=[1, -1, self._hidden_size])) real_res = fluid.layers.reshape(real_res, [ -1, self._num_steps, self._hidden_size])
real_res = fluid.layers.concat(res, 0) last_hidden = fluid.layers.concat(hidden_array, 1)
real_res = fluid.layers.transpose(x=real_res, perm=[1, 0, 2])
last_hidden = fluid.layers.concat(self.hidden_array, 1)
last_hidden = fluid.layers.reshape( last_hidden = fluid.layers.reshape(
last_hidden, shape=[-1, self._num_layers, self._hidden_size]) last_hidden, shape=[-1, self._num_layers, self._hidden_size])
last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2]) last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2])
last_cell = fluid.layers.concat(self.cell_array, 1) last_cell = fluid.layers.concat(cell_array, 1)
last_cell = fluid.layers.reshape( last_cell = fluid.layers.reshape(
last_cell, shape=[-1, self._num_layers, self._hidden_size]) last_cell, shape=[-1, self._num_layers, self._hidden_size])
last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2]) last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2])
...@@ -212,12 +198,9 @@ class PtbModel(fluid.Layer): ...@@ -212,12 +198,9 @@ class PtbModel(fluid.Layer):
rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(x_emb, init_h, rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(x_emb, init_h,
init_c) init_c)
rnn_out = fluid.layers.reshape(
rnn_out, shape=[-1, self.num_steps, self.hidden_size])
projection = fluid.layers.matmul(rnn_out, self.softmax_weight) projection = fluid.layers.matmul(rnn_out, self.softmax_weight)
projection = fluid.layers.elementwise_add(projection, self.softmax_bias) projection = fluid.layers.elementwise_add(projection, self.softmax_bias)
projection = fluid.layers.reshape(
projection, shape=[-1, self.vocab_size])
loss = fluid.layers.softmax_with_cross_entropy( loss = fluid.layers.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=False) logits=projection, label=label, soft_label=False)
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
...@@ -334,7 +317,7 @@ def train_ptb_lm(): ...@@ -334,7 +317,7 @@ def train_ptb_lm():
batch_len = len(train_data) // batch_size batch_len = len(train_data) // batch_size
total_batch_size = (batch_len - 1) // num_steps total_batch_size = (batch_len - 1) // num_steps
log_interval = total_batch_size // 20 log_interval = 200
bd = [] bd = []
lr_arr = [1.0] lr_arr = [1.0]
...@@ -395,23 +378,24 @@ def train_ptb_lm(): ...@@ -395,23 +378,24 @@ def train_ptb_lm():
train_data_iter = reader.get_data_iter(train_data, batch_size, train_data_iter = reader.get_data_iter(train_data, batch_size,
num_steps) num_steps)
init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data)
start_time = time.time() start_time = time.time()
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
x_data, y_data = batch x_data, y_data = batch
x_data = x_data.reshape((-1, num_steps))
y_data = y_data.reshape((-1, 1)) x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, num_steps, 1))
x = to_variable(x_data) x = to_variable(x_data)
y = to_variable(y_data) y = to_variable(y_data)
init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data)
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden, dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell) init_cell)
init_hidden = last_hidden
init_cell = last_cell
out_loss = dy_loss.numpy() out_loss = dy_loss.numpy()
init_hidden_data = last_hidden.numpy()
init_cell_data = last_cell.numpy()
dy_loss.backward() dy_loss.backward()
sgd.minimize(dy_loss, grad_clip=grad_clip) sgd.minimize(dy_loss, grad_clip=grad_clip)
...@@ -421,9 +405,9 @@ def train_ptb_lm(): ...@@ -421,9 +405,9 @@ def train_ptb_lm():
if batch_id > 0 and batch_id % log_interval == 0: if batch_id > 0 and batch_id % log_interval == 0:
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
print("-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f" % print("-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f" %
(epoch_id, batch_id, ppl[0], (epoch_id, batch_id, ppl[0],
sgd._global_learning_rate().numpy())) sgd._global_learning_rate().numpy(), out_loss))
print("one ecpoh finished", epoch_id) print("one ecpoh finished", epoch_id)
print("time cost ", time.time() - start_time) print("time cost ", time.time() - start_time)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册