未验证 提交 f04c97a0 编写于 作者: F fengjiayi 提交者: GitHub

refine test_understand_sentiment_lstm (#5781)

* fix

* Fix a bug
上级 3e9ea348
...@@ -54,17 +54,17 @@ def to_lodtensor(data, place): ...@@ -54,17 +54,17 @@ def to_lodtensor(data, place):
return res return res
def chop_data(data, chop_len=80, batch_len=50): def chop_data(data, chop_len=80, batch_size=50):
data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len] data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len]
return data[:batch_len] return data[:batch_size]
def prepare_feed_data(data, place): def prepare_feed_data(data, place):
tensor_words = to_lodtensor(map(lambda x: x[0], data), place) tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
label = np.array(map(lambda x: x[1], data)).astype("int64") label = np.array(map(lambda x: x[1], data)).astype("int64")
label = label.reshape([50, 1]) label = label.reshape([len(label), 1])
tensor_label = core.LoDTensor() tensor_label = core.LoDTensor()
tensor_label.set(label, place) tensor_label.set(label, place)
...@@ -72,33 +72,41 @@ def prepare_feed_data(data, place): ...@@ -72,33 +72,41 @@ def prepare_feed_data(data, place):
def main(): def main():
word_dict = paddle.dataset.imdb.word_dict() BATCH_SIZE = 100
cost, acc = lstm_net(dict_dim=len(word_dict), class_dim=2) PASS_NUM = 5
batch_size = 100 word_dict = paddle.dataset.imdb.word_dict()
train_data = paddle.batch( print "load word dict successfully"
paddle.reader.buffered( dict_dim = len(word_dict)
paddle.dataset.imdb.train(word_dict), size=batch_size * 10), class_dim = 2
batch_size=batch_size)
data = chop_data(next(train_data())) cost, acc = lstm_net(dict_dim=dict_dim, class_dim=class_dim)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=BATCH_SIZE * 10),
batch_size=BATCH_SIZE)
place = core.CPUPlace() place = core.CPUPlace()
tensor_words, tensor_label = prepare_feed_data(data, place)
exe = Executor(place) exe = Executor(place)
exe.run(framework.default_startup_program()) exe.run(framework.default_startup_program())
while True: for pass_id in xrange(PASS_NUM):
outs = exe.run(framework.default_main_program(), for data in train_data():
feed={"words": tensor_words, chopped_data = chop_data(data)
"label": tensor_label}, tensor_words, tensor_label = prepare_feed_data(chopped_data, place)
fetch_list=[cost, acc])
cost_val = np.array(outs[0]) outs = exe.run(framework.default_main_program(),
acc_val = np.array(outs[1]) feed={"words": tensor_words,
"label": tensor_label},
print("cost=" + str(cost_val) + " acc=" + str(acc_val)) fetch_list=[cost, acc])
if acc_val > 0.9: cost_val = np.array(outs[0])
break acc_val = np.array(outs[1])
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
if acc_val > 0.7:
exit(0)
exit(1)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册