lstm错误
Created by: 3wGTA
python 3.7 paddle 1.72-gpu
V = 32
h = 3
emb_size=5
max_len = 7 #最大为5,每个句子最长为5
hidden_size = 2
num_layers = 1
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
x = fluid.data(name='t', shape=[None], dtype='int64',lod_level=1)
y = fluid.data(name='h', shape=[None], dtype='int64',lod_level=1)
w = fluid.ParamAttr(name='emb_vec', initializer=fluid.initializer.NumpyArrayInitializer(weight), trainable=False)
emb_x = fluid.embedding(input=x, size=[32,5], param_attr=w)
emb_y = fluid.embedding(input=y, size=[32,5], param_attr=w)
pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32))
pad_x,info_x = fluid.layers.sequence_pad(emb_x,pad_value)
pad_y,info_y = fluid.layers.sequence_pad(emb_y,pad_value)
batch_size=5
init_h = fluid.layers.fill_constant([num_layers, batch_size, hidden_size], 'float32', 0)
init_c = fluid.layers.fill_constant([num_layers, batch_size, hidden_size], 'float32', 0)
# lstm 网络
# 返回的形状是 batch_size, seq_len, hiddensize
lstm_x, x_last_h, x_last_c = fluid.layers.lstm(pad_x, init_h, init_c, max_len, hidden_size, num_layers,is_bidirec=True)
use_gpu = True
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
main_program = fluid.default_main_program()
feeder = fluid.DataFeeder(feed_list=['t', 'h','label'], place=place)
exe.run(fluid.default_startup_program())
fetch_var = [x, y, emb_x, emb_y, pad_x, pad_y, lstm_x, x_last_h]
for i, data in enumerate(train_reader()):
print(data)
result = exe.run(
main_program,
feeder.feed(data),
fetch_list=fetch_var,
return_numpy=False
)
break
经过多次测试,发现经过lstm网络之后输出的形状是[batch_size, seq_len, hidden_size] 并且lstm中的双向是没有效果的,是否使用双向,得到的结果都是上述的形状
本次输入数据 [([9, 1, 3, 8], [8, 5], 0), ([0, 3, 2, 5], [8, 3], 0), ([9, 5], [4, 5, 9, 3, 2, 3, 7], 1), ([4, 5, 5, 4], [7, 5, 2], 1), ([6, 7], [5, 4, 6, 0, 4], 1)]
batch_size=5 fetch的结果如下 t [[0, 4, 8, 10, 14, 16]] [9 1 3 8 0 3 2 5 9 5 4 5 5 4 6 7] (16,)
h [[0, 2, 4, 11, 14, 19]] [8 5 8 3 4 5 9 3 2 3 7 7 5 2 5 4 6 0 4] (19,)
embedding_0.tmp_0 [[0, 4, 8, 10, 14, 16]] [[0. 1. 1. 0. 1.] [1. 0. 0. 1. 0.] [0. 1. 0. 1. 0.] [0. 0. 0. 0. 1.] [0. 0. 0. 0. 0.] [0. 1. 0. 1. 0.] [1. 0. 1. 0. 0.] [0. 1. 0. 0. 0.] [0. 1. 1. 0. 1.] [0. 1. 0. 0. 0.] [0. 1. 1. 1. 0.] [0. 1. 0. 0. 0.] [0. 1. 0. 0. 0.] [0. 1. 1. 1. 0.] [0. 0. 1. 1. 1.] [0. 0. 1. 1. 0.]] (16, 5)
embedding_1.tmp_0 [[0, 2, 4, 11, 14, 19]] [[0. 0. 0. 0. 1.] [0. 1. 0. 0. 0.] [0. 0. 0. 0. 1.] [0. 1. 0. 1. 0.] [0. 1. 1. 1. 0.] [0. 1. 0. 0. 0.] [0. 1. 1. 0. 1.] [0. 1. 0. 1. 0.] [1. 0. 1. 0. 0.] [0. 1. 0. 1. 0.] [0. 0. 1. 1. 0.] [0. 0. 1. 1. 0.] [0. 1. 0. 0. 0.] [1. 0. 1. 0. 0.] [0. 1. 0. 0. 0.] [0. 1. 1. 1. 0.] [0. 0. 1. 1. 1.] [0. 0. 0. 0. 0.] [0. 1. 1. 1. 0.]] (19, 5)
sequence_pad_0.tmp_0 [] [[[0. 1. 1. 0. 1.] [1. 0. 0. 1. 0.] [0. 1. 0. 1. 0.] [0. 0. 0. 0. 1.]]
[[0. 0. 0. 0. 0.] [0. 1. 0. 1. 0.] [1. 0. 1. 0. 0.] [0. 1. 0. 0. 0.]]
[[0. 1. 1. 0. 1.] [0. 1. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]]
[[0. 1. 1. 1. 0.] [0. 1. 0. 0. 0.] [0. 1. 0. 0. 0.] [0. 1. 1. 1. 0.]]
[[0. 0. 1. 1. 1.] [0. 0. 1. 1. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]]] (5, 4, 5)
sequence_pad_1.tmp_0 [] [[[0. 0. 0. 0. 1.] [0. 1. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 1.] [0. 1. 0. 1. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]]
[[0. 1. 1. 1. 0.] [0. 1. 0. 0. 0.] [0. 1. 1. 0. 1.] [0. 1. 0. 1. 0.] [1. 0. 1. 0. 0.] [0. 1. 0. 1. 0.] [0. 0. 1. 1. 0.]]
[[0. 0. 1. 1. 0.] [0. 1. 0. 0. 0.] [1. 0. 1. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]]
[[0. 1. 0. 0. 0.] [0. 1. 1. 1. 0.] [0. 0. 1. 1. 1.] [0. 0. 0. 0. 0.] [0. 1. 1. 1. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]]] (5, 7, 5)
cudnn_lstm_0.tmp_0 [] [[[-0.01163395 -0.03404206] [-0.03346499 0.09141013] [-0.01192071 -0.00806206] [ 0.01852721 0.02037505]]
[[-0.02112382 0.0178064 ] [ 0.00036139 0.02811656] [-0.04091306 -0.03953137] [ 0.01994075 0.04983816]]
[[-0.03010507 -0.02586269] [-0.01206493 0.04903913] [-0.02649166 0.01360786] [-0.02587583 0.01543753]]
[[ 0.00416536 -0.02817909] [ 0.04569617 0.04708511] [-0.03898652 -0.01187311] [ 0.00342977 0.01862298]]
[[-0.02776574 -0.04738818] [-0.06142154 0.09363435] [-0.03319343 0.01389923] [-0.00538392 0.0187942 ]]] (5, 4, 2)