diff --git a/fluid/text_classification/README.md b/fluid/text_classification/README.md index 500ee6ae6db28e9d844d206a1cc894c36f1db09f..43c15934fa62af3db2261be37803ce21ba6bf946 100644 --- a/fluid/text_classification/README.md +++ b/fluid/text_classification/README.md @@ -1,16 +1,112 @@ -The minimum PaddlePaddle version needed for the code sample in this directory is the lastest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html). +# 文本分类 ---- +以下是本例的简要目录结构及说明: -# Text Classification - -## Data Preparation -``` -wget http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz -tar zxf aclImdb_v1.tar.gz +```text +. +├── nets.py # 模型定义 +├── README.md # 文档 +├── train.py # 训练脚本 +├── infer.py # 预测脚本 +└── utils.py # 定义通用函数,从外部获取 ``` -## Training + +## 简介,模型详解 + +在PaddlePaddle v2版本[文本分类](https://github.com/PaddlePaddle/models/blob/develop/text/README.md)中对于文本分类任务有较详细的介绍,在本例中不再重复介绍。 +在模型上,我们采用了bow, cnn, lstm, gru四种常见的文本分类模型。 + +## 训练 + +1. 运行命令 `python train.py bow` 开始训练模型。 + ```python + python train.py bow # bow指定网络结构,可替换成cnn, lstm, gru + ``` + +2. (可选)想自定义网络结构,需在[nets.py](./nets.py)中自行添加,并设置[train.py](./train.py)中的相应参数。 + ```python + def train(train_reader, # 训练数据 + word_dict, # 数据字典 + network, # 模型配置 + use_cuda, # 是否用GPU + parallel, # 是否并行 + save_dirname, # 保存模型路径 + lr=0.2, # 学习率大小 + batch_size=128, # 每个batch的样本数 + pass_num=30): # 训练的轮数 + ``` + +## 训练结果示例 +```text + pass_id: 0, avg_acc: 0.848040, avg_cost: 0.354073 + pass_id: 1, avg_acc: 0.914200, avg_cost: 0.217945 + pass_id: 2, avg_acc: 0.929800, avg_cost: 0.184302 + pass_id: 3, avg_acc: 0.938680, avg_cost: 0.164240 + pass_id: 4, avg_acc: 0.945120, avg_cost: 0.149150 + pass_id: 5, avg_acc: 0.951280, avg_cost: 0.137117 + pass_id: 6, avg_acc: 0.955360, avg_cost: 0.126434 + pass_id: 7, avg_acc: 0.961400, avg_cost: 0.117405 + pass_id: 8, avg_acc: 0.963560, avg_cost: 0.110070 + pass_id: 9, avg_acc: 0.965840, avg_cost: 0.103273 + pass_id: 10, avg_acc: 0.969800, avg_cost: 0.096314 + pass_id: 11, avg_acc: 0.971720, avg_cost: 0.090206 + pass_id: 12, avg_acc: 0.974800, avg_cost: 0.084970 + pass_id: 13, avg_acc: 0.977400, avg_cost: 0.078981 + pass_id: 14, avg_acc: 0.980000, avg_cost: 0.073685 + pass_id: 15, avg_acc: 0.981080, avg_cost: 0.069898 + pass_id: 16, avg_acc: 0.982080, avg_cost: 0.064923 + pass_id: 17, avg_acc: 0.984680, avg_cost: 0.060861 + pass_id: 18, avg_acc: 0.985840, avg_cost: 0.057095 + pass_id: 19, avg_acc: 0.988080, avg_cost: 0.052424 + pass_id: 20, avg_acc: 0.989160, avg_cost: 0.049059 + pass_id: 21, avg_acc: 0.990120, avg_cost: 0.045882 + pass_id: 22, avg_acc: 0.992080, avg_cost: 0.042140 + pass_id: 23, avg_acc: 0.992280, avg_cost: 0.039722 + pass_id: 24, avg_acc: 0.992840, avg_cost: 0.036607 + pass_id: 25, avg_acc: 0.994440, avg_cost: 0.034040 + pass_id: 26, avg_acc: 0.995000, avg_cost: 0.031501 + pass_id: 27, avg_acc: 0.995440, avg_cost: 0.028988 + pass_id: 28, avg_acc: 0.996240, avg_cost: 0.026639 + pass_id: 29, avg_acc: 0.996960, avg_cost: 0.024186 ``` -python train.py --dict_path 'aclImdb/imdb.vocab' + +## 预测 +1. 运行命令 `python infer.py bow_model`, 开始预测。 + ```python + python infer.py bow_model # bow_model指定需要导入的模型 + +## 预测结果示例 +```text + model_path: bow_model/epoch0, avg_acc: 0.882800 + model_path: bow_model/epoch1, avg_acc: 0.882360 + model_path: bow_model/epoch2, avg_acc: 0.881400 + model_path: bow_model/epoch3, avg_acc: 0.877800 + model_path: bow_model/epoch4, avg_acc: 0.872920 + model_path: bow_model/epoch5, avg_acc: 0.872640 + model_path: bow_model/epoch6, avg_acc: 0.869960 + model_path: bow_model/epoch7, avg_acc: 0.865160 + model_path: bow_model/epoch8, avg_acc: 0.863680 + model_path: bow_model/epoch9, avg_acc: 0.861200 + model_path: bow_model/epoch10, avg_acc: 0.853520 + model_path: bow_model/epoch11, avg_acc: 0.850400 + model_path: bow_model/epoch12, avg_acc: 0.855960 + model_path: bow_model/epoch13, avg_acc: 0.853480 + model_path: bow_model/epoch14, avg_acc: 0.855960 + model_path: bow_model/epoch15, avg_acc: 0.854120 + model_path: bow_model/epoch16, avg_acc: 0.854160 + model_path: bow_model/epoch17, avg_acc: 0.852240 + model_path: bow_model/epoch18, avg_acc: 0.852320 + model_path: bow_model/epoch19, avg_acc: 0.850280 + model_path: bow_model/epoch20, avg_acc: 0.849760 + model_path: bow_model/epoch21, avg_acc: 0.850160 + model_path: bow_model/epoch22, avg_acc: 0.846800 + model_path: bow_model/epoch23, avg_acc: 0.845440 + model_path: bow_model/epoch24, avg_acc: 0.845640 + model_path: bow_model/epoch25, avg_acc: 0.846200 + model_path: bow_model/epoch26, avg_acc: 0.845880 + model_path: bow_model/epoch27, avg_acc: 0.844880 + model_path: bow_model/epoch28, avg_acc: 0.844680 + model_path: bow_model/epoch29, avg_acc: 0.844960 ``` +注:过拟合导致acc持续下降,请忽略 diff --git a/fluid/text_classification/config.py b/fluid/text_classification/config.py deleted file mode 100644 index 2aba3247eb9033d959bbf4a7c3d475d5c8309058..0000000000000000000000000000000000000000 --- a/fluid/text_classification/config.py +++ /dev/null @@ -1,16 +0,0 @@ -class TrainConfig(object): - - # Whether to use GPU in training or not. - use_gpu = False - - # The training batch size. - batch_size = 4 - - # The epoch number. - num_passes = 30 - - # The global learning rate. - learning_rate = 0.01 - - # Training log will be printed every log_period. - log_period = 100 diff --git a/fluid/text_classification/infer.py b/fluid/text_classification/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a0363d786866a92195dba8b490287b3ca9bc9d --- /dev/null +++ b/fluid/text_classification/infer.py @@ -0,0 +1,50 @@ +import sys +import time +import unittest +import contextlib +import numpy as np + +import paddle.fluid as fluid +import paddle.v2 as paddle + +import utils + + +def infer(test_reader, use_cuda, model_path=None): + """ + inference function + """ + if model_path is None: + print(str(model_path) + " cannot be found") + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(model_path, exe) + + total_acc = 0.0 + total_count = 0 + for data in test_reader(): + acc = exe.run(inference_program, + feed=utils.data2tensor(data, place), + fetch_list=fetch_targets, + return_numpy=True) + total_acc += acc[0] * len(data) + total_count += len(data) + + avg_acc = total_acc / total_count + print("model_path: %s, avg_acc: %f" % (model_path, avg_acc)) + + +if __name__ == "__main__": + word_dict, train_reader, test_reader = utils.prepare_data( + "imdb", self_dict=False, batch_size=128, buf_size=50000) + + model_path = sys.argv[1] + for i in range(30): + epoch_path = model_path + "/" + "epoch" + str(i) + infer(test_reader, use_cuda=False, model_path=epoch_path) diff --git a/fluid/text_classification/nets.py b/fluid/text_classification/nets.py new file mode 100644 index 0000000000000000000000000000000000000000..a21742d22d0bd1676c8c5874899af746b5225636 --- /dev/null +++ b/fluid/text_classification/nets.py @@ -0,0 +1,124 @@ +import sys +import time +import numpy as np + +import paddle.fluid as fluid +import paddle.v2 as paddle + + +def bow_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2): + """ + bow net + """ + emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bow_tanh = fluid.layers.tanh(bow) + fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, prediction + + +def cnn_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2, + win_size=3): + """ + conv net + """ + emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) + + conv_3 = fluid.nets.sequence_conv_pool( + input=emb, + num_filters=hid_dim, + filter_size=win_size, + act="tanh", + pool_type="max") + + fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2) + + prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, prediction + + +def lstm_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2, + emb_lr=30.0): + """ + lstm net + """ + emb = fluid.layers.embedding( + input=data, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr(learning_rate=emb_lr)) + + fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh') + + lstm_h, c = fluid.layers.dynamic_lstm( + input=fc0, size=hid_dim * 4, is_reverse=False) + + lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max') + lstm_max_tanh = fluid.layers.tanh(lstm_max) + + fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh') + + prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, prediction + + +def gru_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2, + emb_lr=400.0): + """ + gru net + """ + emb = fluid.layers.embedding( + input=data, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr(learning_rate=emb_lr)) + + fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3) + gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False) + gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max') + gru_max_tanh = fluid.layers.tanh(gru_max) + fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh') + prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, prediction diff --git a/fluid/text_classification/train.py b/fluid/text_classification/train.py index d32e1c4c878f4d6ef554cc27e0fc5ffc99f96a4a..dc164671e785b758365885b98788fae71d5f8a87 100644 --- a/fluid/text_classification/train.py +++ b/fluid/text_classification/train.py @@ -1,164 +1,131 @@ -import numpy as np import sys -import os -import argparse import time +import unittest +import contextlib -import paddle.v2 as paddle import paddle.fluid as fluid +import paddle.v2 as paddle -from config import TrainConfig as conf - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--dict_path', - type=str, - required=True, - help="Path of the word dictionary.") - return parser.parse_args() - - -# Define to_lodtensor function to process the sequential data. -def to_lodtensor(data, place): - seq_lens = [len(seq) for seq in data] - cur_len = 0 - lod = [cur_len] - for l in seq_lens: - cur_len += l - lod.append(cur_len) - flattened_data = np.concatenate(data, axis=0).astype("int64") - flattened_data = flattened_data.reshape([len(flattened_data), 1]) - res = fluid.LoDTensor() - res.set(flattened_data, place) - res.set_lod([lod]) - return res - - -# Load the dictionary. -def load_vocab(filename): - vocab = {} - with open(filename) as f: - for idx, line in enumerate(f): - vocab[line.strip()] = idx - return vocab - - -# Define the convolution model. -def conv_net(dict_dim, - window_size=3, - emb_dim=128, - num_filters=128, - fc0_dim=96, - class_dim=2): - +import utils +from nets import bow_net +from nets import cnn_net +from nets import lstm_net +from nets import gru_net + + +def train(train_reader, + word_dict, + network, + use_cuda, + parallel, + save_dirname, + lr=0.2, + batch_size=128, + pass_num=30): + """ + train network + """ data = fluid.layers.data( name="words", shape=[1], dtype="int64", lod_level=1) label = fluid.layers.data(name="label", shape=[1], dtype="int64") - emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) - - conv_3 = fluid.nets.sequence_conv_pool( - input=emb, - num_filters=num_filters, - filter_size=window_size, - act="tanh", - pool_type="max") - - fc_0 = fluid.layers.fc(input=[conv_3], size=fc0_dim) - - prediction = fluid.layers.fc(input=[fc_0], size=class_dim, act="softmax") - - cost = fluid.layers.cross_entropy(input=prediction, label=label) - - avg_cost = fluid.layers.mean(x=cost) - - return data, label, prediction, avg_cost - - -def main(dict_path): - word_dict = load_vocab(dict_path) - word_dict[""] = len(word_dict) - dict_dim = len(word_dict) - print("The dictionary size is : %d" % dict_dim) - - data, label, prediction, avg_cost = conv_net(dict_dim) - - sgd_optimizer = fluid.optimizer.SGD(learning_rate=conf.learning_rate) - sgd_optimizer.minimize(avg_cost) - - batch_size_var = fluid.layers.create_tensor(dtype='int64') - batch_acc_var = fluid.layers.accuracy( - input=prediction, label=label, total=batch_size_var) - - inference_program = fluid.default_main_program().clone() - with fluid.program_guard(inference_program): - inference_program = fluid.io.get_inference_program( - target_vars=[batch_acc_var, batch_size_var]) + if not parallel: + cost, acc, prediction = network(data, label, len(word_dict)) + else: + places = fluid.layers.get_places(device_count=2) + pd = fluid.layers.ParallelDo(places) + with pd.do(): + cost, acc, prediction = network( + pd.read_input(data), pd.read_input(label), len(word_dict)) - # The training data set. - train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.imdb.train(word_dict), buf_size=51200), - batch_size=conf.batch_size) + pd.write_output(cost) + pd.write_output(acc) - # The testing data set. - test_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.imdb.test(word_dict), buf_size=51200), - batch_size=conf.batch_size) + cost, acc = pd() + cost = fluid.layers.mean(cost) + acc = fluid.layers.mean(acc) - if conf.use_gpu: - place = fluid.CUDAPlace(0) - else: - place = fluid.CPUPlace() + sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr) + sgd_optimizer.minimize(cost) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - feeder = fluid.DataFeeder(feed_list=[data, label], place=place) exe.run(fluid.default_startup_program()) + for pass_id in xrange(pass_num): + data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0 + for data in train_reader(): + avg_cost_np, avg_acc_np = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[cost, acc]) + data_size = len(data) + total_acc += data_size * avg_acc_np + total_cost += data_size * avg_cost_np + data_count += data_size + avg_cost = total_cost / data_count + + avg_acc = total_acc / data_count + print("pass_id: %d, avg_acc: %f, avg_cost: %f" % + (pass_id, avg_acc, avg_cost)) + + epoch_model = save_dirname + "/" + "epoch" + str(pass_id) + fluid.io.save_inference_model(epoch_model, ["words", "label"], acc, exe) + + +def train_net(): + word_dict, train_reader, test_reader = utils.prepare_data( + "imdb", self_dict=False, batch_size=128, buf_size=50000) + + if sys.argv[1] == "bow": + train( + train_reader, + word_dict, + bow_net, + use_cuda=False, + parallel=False, + save_dirname="bow_model", + lr=0.002, + pass_num=30, + batch_size=128) + elif sys.argv[1] == "cnn": + train( + train_reader, + word_dict, + cnn_net, + use_cuda=True, + parallel=False, + save_dirname="cnn_model", + lr=0.01, + pass_num=30, + batch_size=4) + elif sys.argv[1] == "lstm": + train( + train_reader, + word_dict, + lstm_net, + use_cuda=True, + parallel=False, + save_dirname="lstm_model", + lr=0.05, + pass_num=30, + batch_size=4) + elif sys.argv[1] == "gru": + train( + train_reader, + word_dict, + lstm_net, + use_cuda=True, + parallel=False, + save_dirname="gru_model", + lr=0.05, + pass_num=30, + batch_size=128) + else: + print("network name cannot be found!") + sys.exit(1) + - train_pass_acc_evaluator = fluid.average.WeightedAverage() - test_pass_acc_evaluator = fluid.average.WeightedAverage() - - def test(exe): - test_pass_acc_evaluator.reset() - for batch_id, data in enumerate(test_reader()): - input_seq = to_lodtensor(map(lambda x: x[0], data), place) - y_data = np.array(map(lambda x: x[1], data)).astype("int64") - y_data = y_data.reshape([-1, 1]) - b_acc, b_size = exe.run(inference_program, - feed={"words": input_seq, - "label": y_data}, - fetch_list=[batch_acc_var, batch_size_var]) - test_pass_acc_evaluator.add(value=b_acc, weight=b_size) - test_acc = test_pass_acc_evaluator.eval() - return test_acc - - total_time = 0. - for pass_id in xrange(conf.num_passes): - train_pass_acc_evaluator.reset() - start_time = time.time() - for batch_id, data in enumerate(train_reader()): - cost_val, acc_val, size_val = exe.run( - fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost, batch_acc_var, batch_size_var]) - train_pass_acc_evaluator.add(value=acc_val, weight=size_val) - if batch_id and batch_id % conf.log_period == 0: - print("Pass id: %d, batch id: %d, cost: %f, pass_acc: %f" % - (pass_id, batch_id, cost_val, - train_pass_acc_evaluator.eval())) - end_time = time.time() - total_time += (end_time - start_time) - pass_test_acc = test(exe) - print("Pass id: %d, test_acc: %f" % (pass_id, pass_test_acc)) - print("Total train time: %f" % (total_time)) - - -if __name__ == '__main__': - args = parse_args() - main(args.dict_path) +if __name__ == "__main__": + train_net() diff --git a/fluid/text_classification/utils.py b/fluid/text_classification/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a37d8d720e1f97d7c70464ff15ef75e3ba049c32 --- /dev/null +++ b/fluid/text_classification/utils.py @@ -0,0 +1,108 @@ +import sys +import time +import numpy as np + +import paddle.fluid as fluid +import paddle.v2 as paddle + +import light_imdb +import tiny_imdb + + +def to_lodtensor(data, place): + """ + convert to LODtensor + """ + seq_lens = [len(seq) for seq in data] + cur_len = 0 + lod = [cur_len] + for l in seq_lens: + cur_len += l + lod.append(cur_len) + flattened_data = np.concatenate(data, axis=0).astype("int64") + flattened_data = flattened_data.reshape([len(flattened_data), 1]) + res = fluid.LoDTensor() + res.set(flattened_data, place) + res.set_lod([lod]) + return res + + +def load_vocab(filename): + """ + load imdb vocabulary + """ + vocab = {} + with open(filename) as f: + wid = 0 + for line in f: + vocab[line.strip()] = wid + wid += 1 + vocab[""] = len(vocab) + return vocab + + +def data2tensor(data, place): + """ + data2tensor + """ + input_seq = to_lodtensor(map(lambda x: x[0], data), place) + y_data = np.array(map(lambda x: x[1], data)).astype("int64") + y_data = y_data.reshape([-1, 1]) + return {"words": input_seq, "label": y_data} + + +def prepare_data(data_type="imdb", + self_dict=False, + batch_size=128, + buf_size=50000): + """ + prepare data + """ + if self_dict: + word_dict = load_vocab(data_type + ".vocab") + else: + if data_type == "imdb": + word_dict = paddle.dataset.imdb.word_dict() + elif data_type == "light_imdb": + word_dict = light_imdb.word_dict() + elif data_type == "tiny_imdb": + word_dict = tiny_imdb.word_dict() + else: + raise RuntimeError("No such dataset") + + if data_type == "imdb": + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.imdb.train(word_dict), buf_size=buf_size), + batch_size=batch_size) + + test_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.imdb.test(word_dict), buf_size=buf_size), + batch_size=batch_size) + + elif data_type == "light_imdb": + train_reader = paddle.batch( + paddle.reader.shuffle( + light_imdb.train(word_dict), buf_size=buf_size), + batch_size=batch_size) + + test_reader = paddle.batch( + paddle.reader.shuffle( + light_imdb.test(word_dict), buf_size=buf_size), + batch_size=batch_size) + + elif data_type == "tiny_imdb": + train_reader = paddle.batch( + paddle.reader.shuffle( + tiny_imdb.train(word_dict), buf_size=buf_size), + batch_size=batch_size) + + test_reader = paddle.batch( + paddle.reader.shuffle( + tiny_imdb.test(word_dict), buf_size=buf_size), + batch_size=batch_size) + else: + raise RuntimeError("no such dataset") + + return word_dict, train_reader, test_reader