lstm 训练的时候cost nan了
Created by: 357589873
数据:一些网页的title提取后取汉语部分。每个分类10W+的量 class:三分类 网络结构:
def stacked_lstm_net(input_dim,
class_dim=2,
emb_dim=128,
hid_dim=512,
stacked_num=5):
assert stacked_num % 2 == 1
fc_para_attr = paddle.attr.Param(learning_rate=1e-3)
lstm_para_attr = paddle.attr.Param(initial_std=0., learning_rate=1.)
para_attr = [fc_para_attr, lstm_para_attr]
bias_attr = paddle.attr.Param(initial_std=0., l2_rate=0.)
relu = paddle.activation.Relu()
brelu = paddle.activation.BRelu()
tanh = paddle.activation.Tanh()
linear = paddle.activation.Linear()
data = paddle.layer.data("x",
paddle.data_type.integer_value_sequence(input_dim))
emb = paddle.layer.embedding(input=data, size=emb_dim)
fc1 = paddle.layer.fc(
input=emb, size=hid_dim, act=brelu, bias_attr=bias_attr)
lstm1 = paddle.layer.lstmemory(input=fc1, act=tanh, bias_attr=bias_attr)
inputs = [fc1, lstm1]
for i in range(2, stacked_num + 1):
fc = paddle.layer.fc(
input=inputs,
size=hid_dim,
act=brelu,
param_attr=para_attr,
bias_attr=bias_attr)
lstm = paddle.layer.lstmemory(
input=fc, reverse=(i % 2) == 0, act=tanh, bias_attr=bias_attr)
inputs = [fc, lstm]
fc_last = paddle.layer.pooling(
input=inputs[0], pooling_type=paddle.pooling.Max())
lstm_last = paddle.layer.pooling(
input=inputs[1], pooling_type=paddle.pooling.Max())
output = paddle.layer.fc(
input=[fc_last, lstm_last],
size=class_dim,
act=paddle.activation.Softmax(),
layer_attr=paddle.attr.ExtraLayerAttribute(
error_clipping_threshold=10.0
),
bias_attr=bias_attr,
param_attr=para_attr)
lbl = paddle.layer.data("y", paddle.data_type.integer_value(1))
cost = paddle.layer.classification_cost(input=output, label=lbl)
return cost, output, lbl
参数选择:
cost, prob, label = stacked_lstm_net(
input_dim=len(meta[0]),
class_dim=3,
emb_dim=word_dim,
hid_dim=512,
stacked_num=3,
)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
gradient_clipping_threshold=2,
regularization=paddle.optimizer.L2Regularization(rate=0.00001),
model_average=paddle.optimizer.ModelAverage(average_window=0.5)
)
word_dim = min(meta[-1], 200)
调整了learning_rate 和gradient_clipping_threshold 没起作用。