From 63deaed1ca0d20093453e4556fda9905c0e470f4 Mon Sep 17 00:00:00 2001 From: wopeizl Date: Fri, 31 May 2019 14:05:27 +0800 Subject: [PATCH] add cnn sentiment model for dygraph (#2309) * add cnn sentiment model for dygraph --- dygraph/sentiment/README.md | 77 ++++++++++ dygraph/sentiment/main.py | 276 ++++++++++++++++++++++++++++++++++++ dygraph/sentiment/nets.py | 93 ++++++++++++ dygraph/sentiment/reader.py | 70 +++++++++ dygraph/sentiment/utils.py | 81 +++++++++++ 5 files changed, 597 insertions(+) create mode 100644 dygraph/sentiment/README.md create mode 100644 dygraph/sentiment/main.py create mode 100644 dygraph/sentiment/nets.py create mode 100644 dygraph/sentiment/reader.py create mode 100644 dygraph/sentiment/utils.py diff --git a/dygraph/sentiment/README.md b/dygraph/sentiment/README.md new file mode 100644 index 00000000..0e6dd392 --- /dev/null +++ b/dygraph/sentiment/README.md @@ -0,0 +1,77 @@ +## 简介 + + +情感是人类的一种高级智能行为,为了识别文本的情感倾向,需要深入的语义建模。另外,不同领域(如餐饮、体育)在情感的表达各不相同,因而需要有大规模覆盖各个领域的数据进行模型训练。为此,我们通过基于深度学习的语义模型和大规模数据挖掘解决上述两个问题。效果上,我们基于开源情感倾向分类数据集ChnSentiCorp进行评测。具体数据如下所示: + +| 模型 | dev | +| :------| :------ | +| CNN | 90.6% | + + +## 快速开始 + +本项目依赖于 Paddlepaddle 1.5.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 + +python版本依赖python 2.7或python 3.5及以上版本 + +#### 安装代码 + +克隆数据集代码库到本地 +```shell +git clone https://github.com/PaddlePaddle/models.git +cd models/dygraph/sentiment +``` + +#### 数据准备 + +下载经过预处理的数据,文件解压之后,senta_data目录下会存在训练数据(train.tsv)、开发集数据(dev.tsv)、测试集数据(test.tsv)以及对应的词典(word_dict.txt) +```shell +wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-dataset-1.0.0.tar.gz +tar -zxvf sentiment_classification-dataset-1.0.0.tar.gz +``` + +#### 模型训练 + +基于示例的数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证 +```shell +python main.py +``` + +#### 模型预测 + +利用已有模型,可以运行下面命令,对未知label的数据(test.tsv)进行预测 +```shell +python main.py --do_train false --do_infer true --checkpoints ./path_to_save_models +``` + +## 进阶使用 + +#### 任务定义 + +传统的情感分类主要基于词典或者特征工程的方式进行分类,这种方法需要繁琐的人工特征设计和先验知识,理解停留于浅层并且扩展泛化能力差。为了避免传统方法的局限,我们采用近年来飞速发展的深度学习技术。基于深度学习的情感分类不依赖于人工特征,它能够端到端的对输入文本进行语义理解,并基于语义表示进行情感倾向的判断。 +#### 模型原理介绍 + +本项目针对情感倾向性分类问题,: + ++ CNN(Convolutional Neural Networks),是一个基础的序列模型,能处理变长序列输入,提取局部区域之内的特征; + +#### 数据格式说明 + +训练、预测、评估使用的数据可以由用户根据实际的应用场景,自己组织数据。数据由两列组成,以制表符分隔,第一列是以空格分词的中文文本(分词预处理方法将在下文具体说明),文件为utf8编码;第二列是情感倾向分类的类别(0表示消极;1表示积极),注意数据文件第一行固定表示为"text_a\tlabel" + +```text +特 喜欢 这种 好看的 狗狗 1 +这 真是 惊艳 世界 的 中国 黑科技 1 +环境 特别 差 ,脏兮兮 的,再也 不去 了 0 +``` + +#### 代码结构说明 + +```text +. +├── reader.py # 定义了读入数据,加载词典的功能 +├── main.py # 该项目的主函数,封装包括训练、预测、评估的部分 +├── nets.py # 网络结构 +├── utils.py # 定义了其他常用的功能函数 +``` + diff --git a/dygraph/sentiment/main.py b/dygraph/sentiment/main.py new file mode 100644 index 00000000..d87f60c2 --- /dev/null +++ b/dygraph/sentiment/main.py @@ -0,0 +1,276 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import os +import time +import argparse +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable +import nets +import reader +from utils import ArgumentGroup + +parser = argparse.ArgumentParser(__doc__) +model_g = ArgumentGroup(parser, "model", "model configuration and paths.") +model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints") + +train_g = ArgumentGroup(parser, "training", "training options.") +train_g.add_arg("epoch", int, 10, "Number of epoches for training.") +train_g.add_arg("save_steps", int, 1000, + "The steps interval to save checkpoints.") +train_g.add_arg("validation_steps", int, 200, + "The steps interval to evaluate model performance.") +train_g.add_arg("lr", float, 0.002, "The Learning rate value for training.") +train_g.add_arg("padding_size", int, 150, + "The padding size for input sequences.") + +log_g = ArgumentGroup(parser, "logging", "logging related") +log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.") +log_g.add_arg("verbose", bool, False, "Whether to output verbose log") + +data_g = ArgumentGroup(parser, "data", + "Data paths, vocab paths and data processing options") +data_g.add_arg("data_dir", str, "./senta_data/", "Path to training data.") +data_g.add_arg("vocab_path", str, "./senta_data/word_dict.txt", + "Vocabulary path.") +data_g.add_arg("vocab_size", int, 33256, "Vocabulary path.") +data_g.add_arg("batch_size", int, 16, + "Total examples' number in batch for training.") +data_g.add_arg("random_seed", int, 0, "Random seed.") + +run_type_g = ArgumentGroup(parser, "run_type", "running type options.") +run_type_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.") +run_type_g.add_arg("do_train", bool, True, "Whether to perform training.") +run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation.") +run_type_g.add_arg("do_infer", bool, False, "Whether to perform inference.") +run_type_g.add_arg("profile_steps", int, 15000, + "The steps interval to record the performance.") + +args = parser.parse_args() + +if args.use_cuda: + place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) + dev_count = fluid.core.get_cuda_device_count() +else: + place = fluid.CPUPlace() + dev_count = 1 + +import paddle.fluid.profiler as profiler +import contextlib + + +@contextlib.contextmanager +def profile_context(profile=True): + if profile: + with profiler.profiler('All', 'total', '/tmp/profile_file'): + yield + else: + yield + + +def train(): + with fluid.dygraph.guard(place): + processor = reader.SentaProcessor( + data_dir=args.data_dir, + vocab_path=args.vocab_path, + random_seed=args.random_seed) + num_labels = len(processor.get_labels()) + + num_train_examples = processor.get_num_examples(phase="train") + + max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count + + train_data_generator = processor.data_generator( + batch_size=args.batch_size, + phase='train', + epoch=args.epoch, + shuffle=True) + + eval_data_generator = processor.data_generator( + batch_size=args.batch_size, + phase='dev', + epoch=args.epoch, + shuffle=False) + + cnn_net = nets.CNN("cnn_net", args.vocab_size, args.batch_size, + args.padding_size) + + sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr) + steps = 0 + total_cost, total_acc, total_num_seqs = [], [], [] + + for eop in range(args.epoch): + time_begin = time.time() + for batch_id, data in enumerate(train_data_generator()): + enable_profile = steps > args.profile_steps + + with profile_context(enable_profile): + + steps += 1 + doc = to_variable( + np.array([ + np.pad(x[0][0:args.padding_size], ( + 0, args.padding_size - len(x[0][ + 0:args.padding_size])), + 'constant', + constant_values=(args.vocab_size)) + for x in data + ]).astype('int64').reshape(-1, 1)) + + label = to_variable( + np.array([x[1] for x in data]).astype('int64').reshape( + args.batch_size, 1)) + + cnn_net.train() + avg_cost, prediction, acc = cnn_net(doc, label) + avg_cost.backward() + + np_mask = (doc.numpy() != args.vocab_size).astype('int32') + word_num = np.sum(np_mask) + sgd_optimizer.minimize(avg_cost) + cnn_net.clear_gradients() + total_cost.append(avg_cost.numpy() * word_num) + total_acc.append(acc.numpy() * word_num) + total_num_seqs.append(word_num) + + if steps % args.skip_steps == 0: + time_end = time.time() + used_time = time_end - time_begin + print("step: %d, ave loss: %f, " + "ave acc: %f, speed: %f steps/s" % + (steps, + np.sum(total_cost) / np.sum(total_num_seqs), + np.sum(total_acc) / np.sum(total_num_seqs), + args.skip_steps / used_time)) + total_cost, total_acc, total_num_seqs = [], [], [] + time_begin = time.time() + + if steps % args.validation_steps == 0: + total_eval_cost, total_eval_acc, total_eval_num_seqs = [], [], [] + cnn_net.eval() + eval_steps = 0 + for eval_batch_id, eval_data in enumerate( + eval_data_generator()): + eval_np_doc = np.array([ + np.pad(x[0][0:args.padding_size], + (0, args.padding_size - + len(x[0][0:args.padding_size])), + 'constant', + constant_values=(args.vocab_size)) + for x in eval_data + ]).astype('int64').reshape(1, -1) + eval_label = to_variable( + np.array([x[1] for x in eval_data]).astype( + 'int64').reshape(args.batch_size, 1)) + eval_doc = to_variable(eval_np_doc.reshape(-1, 1)) + eval_avg_cost, eval_prediction, eval_acc = cnn_net( + eval_doc, eval_label) + + eval_np_mask = ( + eval_np_doc != args.vocab_size).astype('int32') + eval_word_num = np.sum(eval_np_mask) + total_eval_cost.append(eval_avg_cost.numpy() * + eval_word_num) + total_eval_acc.append(eval_acc.numpy() * + eval_word_num) + total_eval_num_seqs.append(eval_word_num) + + eval_steps += 1 + + time_end = time.time() + used_time = time_end - time_begin + print( + "Final validation result: step: %d, ave loss: %f, " + "ave acc: %f, speed: %f steps/s" % + (steps, np.sum(total_eval_cost) / + np.sum(total_eval_num_seqs), np.sum(total_eval_acc) + / np.sum(total_eval_num_seqs), + eval_steps / used_time)) + time_begin = time.time() + + if steps % args.save_steps == 0: + save_path = "save_dir_" + str(steps) + print('save model to: ' + save_path) + fluid.dygraph.save_persistables(cnn_net.state_dict(), + save_path) + + if enable_profile: + print('save profile result into /tmp/profile_file') + return + + +def infer(): + with fluid.dygraph.guard(place): + processor = reader.SentaProcessor( + data_dir=args.data_dir, + vocab_path=args.vocab_path, + random_seed=args.random_seed) + + infer_data_generator = processor.data_generator( + batch_size=args.batch_size, + phase='infer', + epoch=args.epoch, + shuffle=False) + + cnn_net_infer = nets.CNN("cnn_net", args.vocab_size, args.batch_size, + args.padding_size) + + print('Do inferring ...... ') + total_acc, total_num_seqs = [], [] + + restore = fluid.dygraph.load_persistables(args.checkpoints) + cnn_net_infer.load_dict(restore) + cnn_net_infer.eval() + + steps = 0 + time_begin = time.time() + for batch_id, data in enumerate(infer_data_generator()): + steps += 1 + np_doc = np.array([ + np.pad(x[0][0:args.padding_size], + (0, args.padding_size - len(x[0][0:args.padding_size])), + 'constant', + constant_values=(args.vocab_size)) for x in data + ]).astype('int64').reshape(-1, 1) + doc = to_variable(np_doc) + label = to_variable( + np.array([x[1] for x in data]).astype('int64').reshape( + args.batch_size, 1)) + + _, _, acc = cnn_net_infer(doc, label) + + mask = (np_doc != args.vocab_size).astype('int32') + word_num = np.sum(mask) + total_acc.append(acc.numpy() * word_num) + total_num_seqs.append(word_num) + + time_end = time.time() + used_time = time_end - time_begin + + print("Final infer result: ave acc: %f, speed: %f steps/s" % + (np.sum(total_acc) / np.sum(total_num_seqs), steps / used_time)) + + +def main(): + if args.do_train: + train() + elif args.do_infer: + infer() + + +if __name__ == '__main__': + print(args) + main() diff --git a/dygraph/sentiment/nets.py b/dygraph/sentiment/nets.py new file mode 100644 index 00000000..7a4cb9da --- /dev/null +++ b/dygraph/sentiment/nets.py @@ -0,0 +1,93 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC, Embedding +from paddle.fluid.dygraph.base import to_variable + + +class SimpleConvPool(fluid.dygraph.Layer): + def __init__(self, + name_scope, + num_channels, + num_filters, + filter_size, + use_cudnn=False, + batch_size=None): + super(SimpleConvPool, self).__init__(name_scope) + self.batch_size = batch_size + self._conv2d = Conv2D( + self.full_name(), + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + padding=[1, 1], + use_cudnn=use_cudnn, + act='tanh') + + def forward(self, inputs): + x = self._conv2d(inputs) + x = fluid.layers.reduce_max(x, dim=-1) + x = fluid.layers.reshape(x, shape=[self.batch_size, -1]) + return x + + +class CNN(fluid.dygraph.Layer): + def __init__(self, name_scope, dict_dim, batch_size, seq_len): + super(CNN, self).__init__(name_scope) + self.dict_dim = dict_dim + self.emb_dim = 128 + self.hid_dim = 128 + self.fc_hid_dim = 96 + self.class_dim = 2 + self.win_size = [3, self.hid_dim] + self.batch_size = batch_size + self.seq_len = seq_len + self.embedding = Embedding( + self.full_name(), + size=[self.dict_dim + 1, self.emb_dim], + dtype='float32', + is_sparse=False) + + self._simple_conv_pool_1 = SimpleConvPool( + self.full_name(), + 1, + self.hid_dim, + self.win_size, + batch_size=self.batch_size) + self._fc1 = FC(self.full_name(), size=self.fc_hid_dim, act="softmax") + self._fc_prediction = FC(self.full_name(), + size=self.class_dim, + act="softmax") + + def forward(self, inputs, label=None): + emb = self.embedding(inputs) + o_np_mask = (inputs.numpy() != self.dict_dim).astype('float32') + mask_emb = fluid.layers.expand( + to_variable(o_np_mask), [1, self.hid_dim]) + emb = emb * mask_emb + emb = fluid.layers.reshape( + emb, shape=[-1, 1, self.seq_len, self.hid_dim]) + conv_3 = self._simple_conv_pool_1(emb) + + fc_1 = self._fc1(conv_3) + prediction = self._fc_prediction(fc_1) + + if label: + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, prediction, acc + else: + return prediction \ No newline at end of file diff --git a/dygraph/sentiment/reader.py b/dygraph/sentiment/reader.py new file mode 100644 index 00000000..db1f2718 --- /dev/null +++ b/dygraph/sentiment/reader.py @@ -0,0 +1,70 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from utils import load_vocab +from utils import data_reader + +import paddle + + +class SentaProcessor(object): + def __init__(self, data_dir, vocab_path, random_seed=None): + self.data_dir = data_dir + self.vocab = load_vocab(vocab_path) + self.num_examples = {"train": -1, "dev": -1, "infer": -1} + np.random.seed(random_seed) + + def get_train_examples(self, data_dir, epoch): + return data_reader((self.data_dir + "/train.tsv"), self.vocab, + self.num_examples, "train", epoch) + + def get_dev_examples(self, data_dir, epoch): + return data_reader((self.data_dir + "/dev.tsv"), self.vocab, + self.num_examples, "dev", epoch) + + def get_test_examples(self, data_dir, epoch): + return data_reader((self.data_dir + "/test.tsv"), self.vocab, + self.num_examples, "infer", epoch) + + def get_labels(self): + return ["0", "1"] + + def get_num_examples(self, phase): + if phase not in ['train', 'dev', 'infer']: + raise ValueError( + "Unknown phase, which should be in ['train', 'dev', 'infer'].") + return self.num_examples[phase] + + def get_train_progress(self): + return self.current_train_example, self.current_train_epoch + + def data_generator(self, batch_size, phase='train', epoch=1, shuffle=True): + if phase == "train": + return paddle.batch( + self.get_train_examples(self.data_dir, epoch), + batch_size, + drop_last=True) + elif phase == "dev": + return paddle.batch( + self.get_dev_examples(self.data_dir, epoch), + batch_size, + drop_last=True) + elif phase == "infer": + return paddle.batch( + self.get_test_examples(self.data_dir, epoch), + batch_size, + drop_last=True) + else: + raise ValueError( + "Unknown phase, which should be in ['train', 'dev', 'infer'].") diff --git a/dygraph/sentiment/utils.py b/dygraph/sentiment/utils.py new file mode 100644 index 00000000..20d80dc5 --- /dev/null +++ b/dygraph/sentiment/utils.py @@ -0,0 +1,81 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import sys +import random + + +def str2bool(v): + return v.lower() in ("true", "t", "1") + + +class ArgumentGroup(object): + def __init__(self, parser, title, des): + self._group = parser.add_argument_group(title=title, description=des) + + def add_arg(self, name, type, default, help, **kwargs): + type = str2bool if type == bool else type + self._group.add_argument( + "--" + name, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) + + +def data_reader(file_path, word_dict, num_examples, phrase, epoch): + unk_id = len(word_dict) + all_data = [] + with io.open(file_path, "r", encoding='utf8') as fin: + for line in fin: + if line.startswith('text_a'): + continue + cols = line.strip().split("\t") + if len(cols) != 2: + sys.stderr.write("[NOTICE] Error Format Line!") + continue + label = int(cols[1]) + wids = [ + word_dict[x] if x in word_dict else unk_id + for x in cols[0].split(" ") + ] + all_data.append((wids, label)) + + if phrase == "train": + random.shuffle(all_data) + + num_examples[phrase] = len(all_data) + + def reader(): + for epoch_index in range(epoch): + for doc, label in all_data: + yield doc, label + + return reader + + +def load_vocab(file_path): + vocab = {} + with io.open(file_path, 'r', encoding='utf8') as f: + wid = 0 + for line in f: + if line.strip() not in vocab: + vocab[line.strip()] = wid + wid += 1 + vocab[""] = len(vocab) + return vocab -- GitLab