关于lstm构建问题
Created by: ARDUJS
环境
- paddle = 1.7.2
- python = 3.7.5
问题
尝试的解决方案
class DecoderCell(layers.RNNCell):
def __init__(self, hidden_size):
self.hidden_size = hidden_size
self.lstm_cell = layers.LSTMCell(hidden_size)
def call(self, step_input, hidden):
h, c, t = hidden
# layers.Print(h)
i_gate = layers.fc(input=[step_input, t, h], size=self.hidden_size, act='sigmoid')
f_gate = layers.fc(input=[step_input, t, h], size=self.hidden_size, act='sigmoid')
z = layers.fc(input=[step_input, t, h], size=self.hidden_size, act='tanh')
ct = layers.elementwise_mul(f_gate, c) + layers.elementwise_mul(i_gate, z)
o_gate = layers.fc([step_input, t, h], size=self.hidden_size, act='sigmoid')
h = layers.elementwise_mul(o_gate, layers.tanh(ct))
t = layers.fc(h, size=self.hidden_size)
y = layers.fc(t, size=class_num)
return y, [h, c, t]
h = layers.fc(pad_feature, size=128, act="tanh")
c = layers.fc(pad_feature, size=128, act="tanh")
t = layers.fc(pad_feature, size=128, act="tanh")
decoder_cell = DecoderCell(128)
decoder_output, _ = layers.rnn(cell=decoder_cell, inputs=pad_feature, initial_states=[h, c, t], time_major=False, sequence_length=seq_len_used)
> 这样导致训练巨慢,可有其他的解决方案,help,大佬!!!