未验证 提交 f04c97a0 编写于 作者: F fengjiayi 提交者: GitHub

refine test_understand_sentiment_lstm (#5781)

* fix

* Fix a bug
上级 3e9ea348
......@@ -54,17 +54,17 @@ def to_lodtensor(data, place):
return res
def chop_data(data, chop_len=80, batch_len=50):
def chop_data(data, chop_len=80, batch_size=50):
data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len]
return data[:batch_len]
return data[:batch_size]
def prepare_feed_data(data, place):
tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
label = np.array(map(lambda x: x[1], data)).astype("int64")
label = label.reshape([50, 1])
label = label.reshape([len(label), 1])
tensor_label = core.LoDTensor()
tensor_label.set(label, place)
......@@ -72,33 +72,41 @@ def prepare_feed_data(data, place):
def main():
word_dict = paddle.dataset.imdb.word_dict()
cost, acc = lstm_net(dict_dim=len(word_dict), class_dim=2)
BATCH_SIZE = 100
PASS_NUM = 5
batch_size = 100
train_data = paddle.batch(
paddle.reader.buffered(
paddle.dataset.imdb.train(word_dict), size=batch_size * 10),
batch_size=batch_size)
word_dict = paddle.dataset.imdb.word_dict()
print "load word dict successfully"
dict_dim = len(word_dict)
class_dim = 2
data = chop_data(next(train_data()))
cost, acc = lstm_net(dict_dim=dict_dim, class_dim=class_dim)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=BATCH_SIZE * 10),
batch_size=BATCH_SIZE)
place = core.CPUPlace()
tensor_words, tensor_label = prepare_feed_data(data, place)
exe = Executor(place)
exe.run(framework.default_startup_program())
while True:
outs = exe.run(framework.default_main_program(),
feed={"words": tensor_words,
"label": tensor_label},
fetch_list=[cost, acc])
cost_val = np.array(outs[0])
acc_val = np.array(outs[1])
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
if acc_val > 0.9:
break
for pass_id in xrange(PASS_NUM):
for data in train_data():
chopped_data = chop_data(data)
tensor_words, tensor_label = prepare_feed_data(chopped_data, place)
outs = exe.run(framework.default_main_program(),
feed={"words": tensor_words,
"label": tensor_label},
fetch_list=[cost, acc])
cost_val = np.array(outs[0])
acc_val = np.array(outs[1])
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
if acc_val > 0.7:
exit(0)
exit(1)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册