提交 7db51b20 编写于 作者: R root

add nets.py utils.py

上级 a4d00a37
"""
For http://wiki.baidu.com/display/LegoNet/Text+Classification
"""
import paddle.fluid as fluid
import paddle.v2 as paddle
import numpy as np
import sys
import time
def bow_net(data, label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
bow net
"""
emb = fluid.layers.embedding(input=data,
size=[dict_dim, emb_dim])
bow = fluid.layers.sequence_pool(
input=emb,
pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh,
size=hid_dim, act = "tanh")
fc_2 = fluid.layers.fc(input=fc_1,
size=hid_dim2, act = "tanh")
prediction = fluid.layers.fc(input=[fc_2],
size=class_dim,
act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
def conv_net(data, label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
win_size=3):
"""
conv net
"""
emb = fluid.layers.embedding(input=data,
size=[dict_dim, emb_dim])
conv_3 = fluid.nets.sequence_conv_pool(input=emb,
num_filters=hid_dim,
filter_size=win_size,
act="tanh",
pool_type="max")
fc_1 = fluid.layers.fc(input=[conv_3],
size=hid_dim2)
prediction = fluid.layers.fc(input=[fc_1],
size=class_dim,
act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
def lstm_net(data, label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=30.0):
"""
lstm net
"""
emb = fluid.layers.embedding(input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb,
size=hid_dim * 4,
act='tanh')
lstm_h, c = fluid.layers.dynamic_lstm(input=fc0,
size=hid_dim * 4,
is_reverse=False)
lstm_max = fluid.layers.sequence_pool(input=lstm_h,
pool_type='max')
lstm_max_tanh = fluid.layers.tanh(lstm_max)
fc1 = fluid.layers.fc(input=lstm_max_tanh,
size=hid_dim2,
act='tanh')
prediction = fluid.layers.fc(input=fc1,
size=class_dim,
act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
def gru_net(data, label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=400.0):
"""
gru net
"""
emb = fluid.layers.embedding(input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb,
size=hid_dim * 3)
gru_h = fluid.layers.dynamic_gru(input=fc0,
size=hid_dim,
is_reverse=False)
gru_max = fluid.layers.sequence_pool(input=gru_h,
pool_type='max')
gru_max_tanh = fluid.layers.tanh(gru_max)
fc1 = fluid.layers.fc(input=gru_max_tanh,
size=hid_dim2,
act='tanh')
prediction = fluid.layers.fc(input=fc1,
size=class_dim,
act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
"""
For http://wiki.baidu.com/display/LegoNet/Text+Classification
"""
import paddle.fluid as fluid
import paddle.v2 as paddle
import numpy as np
import sys
import time
import light_imdb
import tiny_imdb
def to_lodtensor(data, place):
"""
convert to LODtensor
"""
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def load_vocab(filename):
"""
load imdb vocabulary
"""
vocab = {}
with open(filename) as f:
wid = 0
for line in f:
vocab[line.strip()] = wid
wid += 1
vocab["<unk>"] = len(vocab)
return vocab
def data2tensor(data, place):
"""
data2tensor
"""
input_seq = to_lodtensor(map(lambda x:x[0], data), place)
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([-1, 1])
return {"words": input_seq, "label": y_data}
def prepare_data(data_type="imdb",
self_dict=False,
batch_size=128,
buf_size=50000):
"""
prepare data
"""
if self_dict:
word_dict = load_vocab(data_type + ".vocab")
else:
if data_type == "imdb":
word_dict = paddle.dataset.imdb.word_dict()
elif data_type == "light_imdb":
word_dict = light_imdb.word_dict()
elif data_type == "tiny_imdb":
word_dict = tiny_imdb.word_dict()
else:
raise RuntimeError("No such dataset")
if data_type == "imdb":
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict),
buf_size = buf_size),
batch_size = batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.test(word_dict),
buf_size = buf_size),
batch_size = batch_size)
elif data_type == "light_imdb":
train_reader = paddle.batch(
paddle.reader.shuffle(
light_imdb.train(word_dict),
buf_size = buf_size),
batch_size = batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
light_imdb.test(word_dict),
buf_size = buf_size),
batch_size = batch_size)
elif data_type == "tiny_imdb":
train_reader = paddle.batch(
paddle.reader.shuffle(
tiny_imdb.train(word_dict),
buf_size = buf_size),
batch_size = batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
tiny_imdb.test(word_dict),
buf_size = buf_size),
batch_size = batch_size)
else:
raise RuntimeError("no such dataset")
return word_dict, train_reader, test_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册