未验证 提交 0fcc1aab 编写于 作者: L Li Fuchen 提交者: GitHub

Fixed issue with train_dyn_rnn.py training report nan (#808)

Fixed issue with train_dyn_rnn.py training report nan
上级 cdb5d193
...@@ -44,35 +44,11 @@ def parse_args(): ...@@ -44,35 +44,11 @@ def parse_args():
def dynamic_rnn_lstm(data, input_dim, class_dim, emb_dim, lstm_size): def dynamic_rnn_lstm(data, input_dim, class_dim, emb_dim, lstm_size):
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
input=data, size=[input_dim, emb_dim], is_sparse=True) input=data, size=[input_dim, emb_dim], is_sparse=True)
sentence = fluid.layers.fc(input=emb, size=lstm_size, act='tanh') sentence = fluid.layers.fc(input=emb, size=lstm_size * 4, act='tanh')
rnn = fluid.layers.DynamicRNN() lstm, _ = fluid.layers.dynamic_lstm(sentence, size=lstm_size * 4)
with rnn.block():
word = rnn.step_input(sentence) last = fluid.layers.sequence_last_step(lstm)
prev_hidden = rnn.memory(value=0.0, shape=[lstm_size])
prev_cell = rnn.memory(value=0.0, shape=[lstm_size])
def gate_common(ipt, hidden, size):
gate0 = fluid.layers.fc(input=ipt, size=size, bias_attr=True)
gate1 = fluid.layers.fc(input=hidden, size=size, bias_attr=False)
return gate0 + gate1
forget_gate = fluid.layers.sigmoid(x=gate_common(word, prev_hidden,
lstm_size))
input_gate = fluid.layers.sigmoid(x=gate_common(word, prev_hidden,
lstm_size))
output_gate = fluid.layers.sigmoid(x=gate_common(word, prev_hidden,
lstm_size))
cell_gate = fluid.layers.sigmoid(x=gate_common(word, prev_hidden,
lstm_size))
cell = forget_gate * prev_cell + input_gate * cell_gate
hidden = output_gate * fluid.layers.tanh(x=cell)
rnn.update_memory(prev_cell, cell)
rnn.update_memory(prev_hidden, hidden)
rnn.output(hidden)
last = fluid.layers.sequence_last_step(rnn())
prediction = fluid.layers.fc(input=last, size=class_dim, act="softmax") prediction = fluid.layers.fc(input=last, size=class_dim, act="softmax")
return prediction return prediction
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册