From 54b3b726f9f64106cd9cdcb3024052bff5875d09 Mon Sep 17 00:00:00 2001 From: wangxiao1021 Date: Tue, 28 Apr 2020 02:50:21 +0800 Subject: [PATCH] add emotion_detection and update senta --- examples/emotion_detection/config.yaml | 23 +++ examples/emotion_detection/download.py | 123 ++++++++++++ examples/emotion_detection/download_data.sh | 8 + examples/emotion_detection/models.py | 179 ++++++++++++++++++ examples/emotion_detection/run_classifier.py | 158 ++++++++++++++++ examples/sentiment_classification/models.py | 40 +++- .../sentiment_classifier.py | 55 +++--- hapi/text/emo_tect/__init__.py | 15 ++ hapi/text/emo_tect/data_processor.py | 79 ++++++++ hapi/text/emo_tect/data_reader.py | 126 ++++++++++++ hapi/text/senta/__init__.py | 2 +- .../{data_processer.py => data_processor.py} | 0 12 files changed, 772 insertions(+), 36 deletions(-) create mode 100644 examples/emotion_detection/config.yaml create mode 100644 examples/emotion_detection/download.py create mode 100644 examples/emotion_detection/download_data.sh create mode 100644 examples/emotion_detection/models.py create mode 100644 examples/emotion_detection/run_classifier.py create mode 100644 hapi/text/emo_tect/__init__.py create mode 100644 hapi/text/emo_tect/data_processor.py create mode 100644 hapi/text/emo_tect/data_reader.py rename hapi/text/senta/{data_processer.py => data_processor.py} (100%) diff --git a/examples/emotion_detection/config.yaml b/examples/emotion_detection/config.yaml new file mode 100644 index 0000000..121fdf6 --- /dev/null +++ b/examples/emotion_detection/config.yaml @@ -0,0 +1,23 @@ +model_type: "bow_net" +num_labels: 3 +vocab_size: 240465 +vocab_path: "./data/vocab.txt" +data_dir: "./data" +inference_model_dir: "./inference_model" +save_checkpoint_dir: "" +init_checkpoint: "" +checkpoints: "./checkpoints/" +lr: 0.02 +epoch: 10 +batch_size: 24 +do_train: True +do_val: True +do_infer: False +do_save_inference_model: False +max_seq_len: 20 +skip_steps: 10 +save_freq: 1 +eval_freq: 1 +random_seed: 0 +output_dir: "./output" +use_cuda: True diff --git a/examples/emotion_detection/download.py b/examples/emotion_detection/download.py new file mode 100644 index 0000000..9d19201 --- /dev/null +++ b/examples/emotion_detection/download.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Download script, download dataset and pretrain models. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import os +import sys +import time +import hashlib +import tarfile +import requests + + +def usage(): + desc = ("\nDownload datasets and pretrained models for EmotionDetection task.\n" + "Usage:\n" + " python download.py dataset\n") + print(desc) + + +def md5file(fname): + hash_md5 = hashlib.md5() + with io.open(fname, "rb") as fin: + for chunk in iter(lambda: fin.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def extract(fname, dir_path): + """ + Extract tar.gz file + """ + try: + tar = tarfile.open(fname, "r:gz") + file_names = tar.getnames() + for file_name in file_names: + tar.extract(file_name, dir_path) + print(file_name) + tar.close() + except Exception as e: + raise e + + +def download(url, filename, md5sum): + """ + Download file and check md5 + """ + retry = 0 + retry_limit = 3 + chunk_size = 4096 + while not (os.path.exists(filename) and md5file(filename) == md5sum): + if retry < retry_limit: + retry += 1 + else: + raise RuntimeError("Cannot download dataset ({0}) with retry {1} times.". + format(url, retry_limit)) + try: + start = time.time() + size = 0 + res = requests.get(url, stream=True) + filesize = int(res.headers['content-length']) + if res.status_code == 200: + print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024)) + # save by chunk + with io.open(filename, "wb") as fout: + for chunk in res.iter_content(chunk_size=chunk_size): + if chunk: + fout.write(chunk) + size += len(chunk) + pr = '>' * int(size * 50 / filesize) + print('\r[Process ]: %s%.2f%%' % (pr, float(size / filesize*100)), end='') + end = time.time() + print("\n[CostTime]: %.2f s" % (end - start)) + except Exception as e: + print(e) + + +def download_dataset(dir_path): + BASE_URL = "https://baidu-nlp.bj.bcebos.com/" + DATASET_NAME = "emotion_detection-dataset-1.0.0.tar.gz" + DATASET_MD5 = "512d256add5f9ebae2c101b74ab053e9" + file_path = os.path.join(dir_path, DATASET_NAME) + url = BASE_URL + DATASET_NAME + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + # download dataset + print("Downloading dataset: %s" % url) + download(url, file_path, DATASET_MD5) + # extract dataset + print("Extracting dataset: %s" % file_path) + extract(file_path, dir_path) + os.remove(file_path) + +if __name__ == '__main__': + if len(sys.argv) != 2: + usage() + sys.exit(1) + + if sys.argv[1] == "dataset": + pwd = os.path.join(os.path.dirname(__file__), './') + download_dataset(pwd) + else: + usage() + diff --git a/examples/emotion_detection/download_data.sh b/examples/emotion_detection/download_data.sh new file mode 100644 index 0000000..e699d42 --- /dev/null +++ b/examples/emotion_detection/download_data.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# download dataset file to ./data/ +DATA_URL=https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz +wget --no-check-certificate ${DATA_URL} + +tar xvf emotion_detection-dataset-1.0.0.tar.gz +/bin/rm emotion_detection-dataset-1.0.0.tar.gz diff --git a/examples/emotion_detection/models.py b/examples/emotion_detection/models.py new file mode 100644 index 0000000..4b6e7f0 --- /dev/null +++ b/examples/emotion_detection/models.py @@ -0,0 +1,179 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear, Embedding +from paddle.fluid.dygraph.base import to_variable +import numpy as np +from hapi.model import Model +from hapi.text.text import GRUEncoderLayer as BiGRUEncoder +from hapi.text.text import BOWEncoder, CNNEncoder, GRUEncoder, LSTMEncoder + +class CNN(Model): + def __init__(self, dict_dim, seq_len): + super(CNN, self).__init__() + self.dict_dim = dict_dim + self.emb_dim = 128 + self.hid_dim = 128 + self.fc_hid_dim = 96 + self.class_dim = 3 + self.channels = 1 + self.win_size = [3, self.hid_dim] + self.seq_len = seq_len + self._encoder = CNNEncoder( + dict_size=self.dict_dim + 1, + emb_dim=self.emb_dim, + seq_len=self.seq_len, + filter_size= self.win_size, + num_filters= self.hid_dim, + hidden_dim= self.hid_dim, + padding_idx=None, + act='tanh') + self._fc1 = Linear(input_dim = self.hid_dim*self.seq_len, output_dim=self.fc_hid_dim, act="softmax") + self._fc_prediction = Linear(input_dim = self.fc_hid_dim, + output_dim = self.class_dim, + act="softmax") + + def forward(self, inputs): + conv_3 = self._encoder(inputs) + fc_1 = self._fc1(conv_3) + prediction = self._fc_prediction(fc_1) + return prediction + + +class BOW(Model): + def __init__(self, dict_dim, seq_len): + super(BOW, self).__init__() + self.dict_dim = dict_dim + self.emb_dim = 128 + self.hid_dim = 128 + self.fc_hid_dim = 96 + self.class_dim = 3 + self.seq_len = seq_len + self._encoder = BOWEncoder( + dict_size=self.dict_dim + 1, + emb_dim=self.emb_dim, + padding_idx=None, + bow_dim=self.hid_dim, + seq_len=self.seq_len) + self._fc1 = Linear(input_dim = self.hid_dim, output_dim=self.hid_dim, act="tanh") + self._fc2 = Linear(input_dim = self.hid_dim, output_dim=self.fc_hid_dim, act="tanh") + self._fc_prediction = Linear(input_dim = self.fc_hid_dim, + output_dim = self.class_dim, + act="softmax") + + def forward(self, inputs): + bow_1 = self._encoder(inputs) + bow_1 = fluid.layers.tanh(bow_1) + fc_1 = self._fc1(bow_1) + fc_2 = self._fc2(fc_1) + prediction = self._fc_prediction(fc_2) + return prediction + + +class GRU(Model): + def __init__(self, dict_dim, seq_len): + super(GRU, self).__init__() + self.dict_dim = dict_dim + self.emb_dim = 128 + self.hid_dim = 128 + self.fc_hid_dim = 96 + self.class_dim = 3 + self.seq_len = seq_len + self._fc1 = Linear(input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="tanh") + self._fc_prediction = Linear(input_dim=self.fc_hid_dim, + output_dim=self.class_dim, + act="softmax") + self._encoder = GRUEncoder( + dict_size=self.dict_dim + 1, + emb_dim=self.emb_dim, + gru_dim=self.hid_dim, + hidden_dim=self.hid_dim, + padding_idx=None, + seq_len=self.seq_len) + + def forward(self, inputs): + emb = self._encoder(inputs) + fc_1 = self._fc1(emb) + prediction = self._fc_prediction(fc_1) + return prediction + + +class BiGRU(Model): + def __init__(self, dict_dim, batch_size, seq_len): + super(BiGRU, self).__init__() + self.dict_dim = dict_dim + self.emb_dim = 128 + self.hid_dim = 128 + self.fc_hid_dim = 96 + self.class_dim = 3 + self.batch_size = batch_size + self.seq_len = seq_len + self.embedding = Embedding( + size=[self.dict_dim + 1, self.emb_dim], + dtype='float32', + param_attr=fluid.ParamAttr(learning_rate=30), + is_sparse=False) + h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32") + h_0 = to_variable(h_0) + self._fc1 = Linear(input_dim = self.hid_dim, output_dim=self.hid_dim*3) + self._fc2 = Linear(input_dim = self.hid_dim*2, output_dim=self.fc_hid_dim, act="tanh") + self._fc_prediction = Linear(input_dim=self.fc_hid_dim, + output_dim=self.class_dim, + act="softmax") + self._encoder = BiGRUEncoder( + grnn_hidden_dim=self.hid_dim, + input_dim=self.hid_dim * 3, + h_0=h_0, + init_bound=0.1, + is_bidirection=True) + + def forward(self, inputs): + emb = self.embedding(inputs) + emb = fluid.layers.reshape(emb, shape=[self.batch_size, -1, self.hid_dim]) + fc_1 = self._fc1(emb) + encoded_vector = self._encoder(fc_1) + encoded_vector = fluid.layers.tanh(encoded_vector) + encoded_vector = fluid.layers.reduce_max(encoded_vector, dim=1) + fc_2 = self._fc2(encoded_vector) + prediction = self._fc_prediction(fc_2) + return prediction + +class LSTM(Model): + def __init__(self, dict_dim, seq_len): + super(LSTM, self).__init__() + self.seq_len = seq_len, + self.dict_dim = dict_dim, + self.emb_dim = 128, + self.hid_dim = 128, + self.fc_hid_dim = 96, + self.class_dim = 3, + self.emb_lr = 30.0, + self._encoder = LSTMEncoder( + dict_size=dict_dim + 1, + emb_dim=self.emb_dim, + lstm_dim=self.hid_dim, + hidden_dim=self.hid_dim, + seq_len=self.seq_len, + padding_idx=None, + is_reverse=False) + + self._fc1 = Linear(input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="tanh") + self._fc_prediction = Linear(input_dim=self.fc_hid_dim, + output_dim=self.class_dim, + act="softmax") + def forward(self, inputs): + emb = self._encoder(inputs) + fc_1 = self._fc1(emb) + prediction = self._fc_prediction(fc_1) + return prediction diff --git a/examples/emotion_detection/run_classifier.py b/examples/emotion_detection/run_classifier.py new file mode 100644 index 0000000..ffaee44 --- /dev/null +++ b/examples/emotion_detection/run_classifier.py @@ -0,0 +1,158 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Emotion Detection Task in Paddle Dygraph Mode. +""" + +from __future__ import print_function +import os +import paddle +import paddle.fluid as fluid +import numpy as np +from hapi.model import set_device, CrossEntropy, Input +from hapi.metrics import Accuracy +from hapi.text.emo_tect import EmoTectProcessor +from models import CNN, BOW, GRU, BiGRU, LSTM +from hapi.configure import Config +import json + +def main(): + """ + Main Function + """ + args = Config(yaml_file='./config.yaml') + args.build() + args.Print() + if not (args.do_train or args.do_val or args.do_infer): + raise ValueError("For args `do_train`, `do_val` and `do_infer`, at " + "least one of them must be True.") + + place = set_device("gpu" if args.use_cuda else "cpu") + fluid.enable_dygraph(place) + + processor = EmoTectProcessor( + data_dir=args.data_dir, + vocab_path=args.vocab_path, + random_seed=args.random_seed) + num_labels = args.num_labels + + if args.model_type == 'cnn_net': + model = CNN( args.vocab_size, + args.max_seq_len) + elif args.model_type == 'bow_net': + model = BOW( args.vocab_size, + args.max_seq_len) + elif args.model_type == 'lstm_net': + model = LSTM( args.vocab_size, + args.max_seq_len) + elif args.model_type == 'gru_net': + model = GRU( args.vocab_size, + args.max_seq_len) + elif args.model_type == 'bigru_net': + model = BiGRU( args.vocab_size, + args.batch_size, + args.max_seq_len) + else: + raise ValueError("Unknown model type!") + + inputs = [Input([None, args.max_seq_len], 'int64', name='doc')] + optimizer = None + labels = None + + if args.do_train: + train_data_generator = processor.data_generator( + batch_size=args.batch_size, + places=place, + phase='train', + epoch=args.epoch, + padding_size=args.max_seq_len) + + num_train_examples = processor.get_num_examples(phase="train") + max_train_steps = args.epoch * num_train_examples // args.batch_size + 1 + + print("Num train examples: %d" % num_train_examples) + print("Max train steps: %d" % max_train_steps) + + labels = [Input([None, 1], 'int64', name='label')] + optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr, parameter_list=model.parameters()) + test_data_generator = None + if args.do_val: + test_data_generator = processor.data_generator( + batch_size=args.batch_size, + phase='dev', + epoch=1, + places=place, + padding_size=args.max_seq_len) + + elif args.do_val: + test_data_generator = processor.data_generator( + batch_size=args.batch_size, + phase='test', + epoch=1, + places=place, + padding_size=args.max_seq_len) + + elif args.do_infer: + infer_data_generator = processor.data_generator( + batch_size=args.batch_size, + phase='infer', + epoch=1, + places=place, + padding_size=args.max_seq_len) + + model.prepare( + optimizer, + CrossEntropy(), + Accuracy(topk=(1,)), + inputs, + labels, + device=place) + + if args.do_train: + if args.init_checkpoint: + model.load(args.init_checkpoint) + elif args.do_val or args.do_infer: + if not args.init_checkpoint: + raise ValueError("args 'init_checkpoint' should be set if" + "only doing validation or infer!") + model.load(args.init_checkpoint, reset_optimizer=True) + + if args.do_train: + model.fit(train_data=train_data_generator, + eval_data=test_data_generator, + batch_size=args.batch_size, + epochs=args.epoch, + save_dir=args.checkpoints, + eval_freq=args.eval_freq, + save_freq=args.save_freq) + elif args.do_val: + eval_result = model.evaluate(eval_data=test_data_generator, + batch_size=args.batch_size) + print("Final eval result: acc: {:.4f}, loss: {:.4f}".format(eval_result['acc'], eval_result['loss'][0])) + + elif args.do_infer: + preds = model.predict(test_data=infer_data_generator) + preds = np.array(preds[0]).reshape((-1, args.num_labels)) + + if args.output_dir: + with open(os.path.join(args.output_dir, 'predictions.json'), 'w') as w: + + for p in range(len(preds)): + label = np.argmax(preds[p]) + result = json.dumps({'index': p, 'label': label, 'probs': preds[p].tolist()}) + w.write(result+'\n') + print('Predictions saved at '+os.path.join(args.output_dir, 'predictions.json')) + +if __name__ == "__main__": + main() diff --git a/examples/sentiment_classification/models.py b/examples/sentiment_classification/models.py index 313b928..49b08b7 100644 --- a/examples/sentiment_classification/models.py +++ b/examples/sentiment_classification/models.py @@ -17,11 +17,11 @@ from paddle.fluid.dygraph.base import to_variable import numpy as np from hapi.model import Model from hapi.text.text import GRUEncoderLayer as BiGRUEncoder -from hapi.text.test import BOWEncoder, CNNEncoder, GRUEncoder +from hapi.text.text import BOWEncoder, CNNEncoder, GRUEncoder, LSTMEncoder class CNN(Model): - def __init__(self, dict_dim, batch_size, seq_len): + def __init__(self, dict_dim, seq_len): super(CNN, self).__init__() self.dict_dim = dict_dim self.emb_dim = 128 @@ -30,7 +30,6 @@ class CNN(Model): self.class_dim = 2 self.channels = 1 self.win_size = [3, self.hid_dim] - self.batch_size = batch_size self.seq_len = seq_len self._encoder = CNNEncoder( dict_size=self.dict_dim + 1, @@ -54,14 +53,13 @@ class CNN(Model): class BOW(Model): - def __init__(self, dict_dim, batch_size, seq_len): + def __init__(self, dict_dim, seq_len): super(BOW, self).__init__() self.dict_dim = dict_dim self.emb_dim = 128 self.hid_dim = 128 self.fc_hid_dim = 96 self.class_dim = 2 - self.batch_size = batch_size self.seq_len = seq_len self._encoder = BOWEncoder( dict_size=self.dict_dim + 1, @@ -85,14 +83,13 @@ class BOW(Model): class GRU(Model): - def __init__(self, dict_dim, batch_size, seq_len): + def __init__(self, dict_dim, seq_len): super(GRU, self).__init__() self.dict_dim = dict_dim self.emb_dim = 128 self.hid_dim = 128 self.fc_hid_dim = 96 self.class_dim = 2 - self.batch_size = batch_size self.seq_len = seq_len self._fc1 = Linear(input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="tanh") self._fc_prediction = Linear(input_dim=self.fc_hid_dim, @@ -152,3 +149,32 @@ class BiGRU(Model): fc_2 = self._fc2(encoded_vector) prediction = self._fc_prediction(fc_2) return prediction + +class LSTM(Model): + def __init__(self, dict_dim, seq_len): + super(LSTM, self).__init__() + self.seq_len = seq_len, + self.dict_dim = dict_dim, + self.emb_dim = 128, + self.hid_dim = 128, + self.fc_hid_dim = 96, + self.class_dim = 2, + self.emb_lr = 30.0, + self._encoder = LSTMEncoder( + dict_size=dict_dim + 1, + emb_dim=self.emb_dim, + lstm_dim=self.hid_dim, + hidden_dim=self.hid_dim, + seq_len=self.seq_len, + padding_idx=None, + is_reverse=False) + + self._fc1 = Linear(input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="tanh") + self._fc_prediction = Linear(input_dim=self.fc_hid_dim, + output_dim=self.class_dim, + act="softmax") + def forward(self, inputs): + emb = self._encoder(inputs) + fc_1 = self._fc1(emb) + prediction = self._fc_prediction(fc_1) + return prediction diff --git a/examples/sentiment_classification/sentiment_classifier.py b/examples/sentiment_classification/sentiment_classifier.py index b5f6a0d..116bdba 100644 --- a/examples/sentiment_classification/sentiment_classifier.py +++ b/examples/sentiment_classification/sentiment_classifier.py @@ -17,11 +17,11 @@ from __future__ import print_function import numpy as np import paddle.fluid as fluid -from hapi.model import set_device, Model, CrossEntropy, Input +from hapi.model import set_device, CrossEntropy, Input from hapi.configure import Config from hapi.text.senta import SentaProcessor from hapi.metrics import Accuracy -from models import CNN, BOW, GRU, BiGRU +from models import CNN, BOW, GRU, BiGRU, LSTM import json import os @@ -38,6 +38,26 @@ def main(): elif args.do_infer: infer() +def create_model(): + if args.model_type == 'cnn_net': + model = CNN( args.vocab_size, + args.padding_size) + elif args.model_type == 'bow_net': + model = BOW( args.vocab_size, + args.padding_size) + elif args.model_type == 'lstm_net': + model = LSTM( args.vocab_size, + args.padding_size) + elif args.model_type == 'gru_net': + model = GRU( args.vocab_size, + args.padding_size) + elif args.model_type == 'bigru_net': + model = BiGRU( args.vocab_size, args.batch_size, + args.padding_size) + else: + raise ValueError("Unknown model type!") + return model + def train(): fluid.enable_dygraph(device) processor = SentaProcessor( @@ -65,23 +85,13 @@ def train(): phase='dev', epoch=args.epoch, shuffle=False) - if args.model_type == 'cnn_net': - model = CNN( args.vocab_size, args.batch_size, - args.padding_size) - elif args.model_type == 'bow_net': - model = BOW( args.vocab_size, args.batch_size, - args.padding_size) - elif args.model_type == 'gru_net': - model = GRU( args.vocab_size, args.batch_size, - args.padding_size) - elif args.model_type == 'bigru_net': - model = BiGRU( args.vocab_size, args.batch_size, - args.padding_size) - optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr, parameter_list=model.parameters()) inputs = [Input([None, None], 'int64', name='doc')] labels = [Input([None, 1], 'int64', name='label')] + + model = create_model() + optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr, parameter_list=model.parameters()) model.prepare( optimizer, @@ -113,19 +123,8 @@ def infer(): phase='infer', epoch=1, shuffle=False) - if args.model_type == 'cnn_net': - model_infer = CNN( args.vocab_size, args.batch_size, - args.padding_size) - elif args.model_type == 'bow_net': - model_infer = BOW( args.vocab_size, args.batch_size, - args.padding_size) - elif args.model_type == 'gru_net': - model_infer = GRU( args.vocab_size, args.batch_size, - args.padding_size) - elif args.model_type == 'bigru_net': - model_infer = BiGRU( args.vocab_size, args.batch_size, - args.padding_size) - + + model_infer = create_model() print('Do inferring ...... ') inputs = [Input([None, None], 'int64', name='doc')] model_infer.prepare( diff --git a/hapi/text/emo_tect/__init__.py b/hapi/text/emo_tect/__init__.py new file mode 100644 index 0000000..a96df76 --- /dev/null +++ b/hapi/text/emo_tect/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hapi.text.emo_tect.data_processor import EmoTectProcessor diff --git a/hapi/text/emo_tect/data_processor.py b/hapi/text/emo_tect/data_processor.py new file mode 100644 index 0000000..4dafa04 --- /dev/null +++ b/hapi/text/emo_tect/data_processor.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from hapi.text.emo_tect.data_reader import load_vocab +from hapi.text.emo_tect.data_reader import data_reader +from paddle.io import DataLoader + + +class EmoTectProcessor(object): + def __init__(self, data_dir, vocab_path, random_seed=None): + self.data_dir = data_dir + self.vocab = load_vocab(vocab_path) + self.num_examples = {"train": -1, "dev": -1, "test": -1, "infer": -1} + np.random.seed(random_seed) + + def get_train_examples(self, data_dir, epoch, shuffle, batch_size, places, padding_size): + train_reader = data_reader((self.data_dir + "/train.tsv"), self.vocab, + self.num_examples, "train", epoch, padding_size, shuffle) + loader = DataLoader.from_generator(capacity=50, return_list=True) + loader.set_sample_generator(train_reader, batch_size=batch_size, drop_last=False, places=places) + return loader + + + def get_dev_examples(self, data_dir, epoch, shuffle, batch_size, places, padding_size): + dev_reader = data_reader((self.data_dir + "/dev.tsv"), self.vocab, + self.num_examples, "dev", epoch, padding_size, shuffle) + loader = DataLoader.from_generator(capacity=50, return_list=True) + loader.set_sample_generator(dev_reader, batch_size=batch_size, drop_last=False, places=places) + return loader + + def get_test_examples(self, data_dir, epoch, batch_size, places, padding_size): + test_reader = data_reader((self.data_dir + "/test.tsv"), self.vocab, + self.num_examples, "test", epoch, padding_size) + loader = DataLoader.from_generator(capacity=50, return_list=True) + loader.set_sample_generator(test_reader, batch_size=batch_size, drop_last=False, places=places) + return loader + + def get_infer_examples(self, data_dir, epoch, batch_size, places, padding_size): + infer_reader = data_reader((self.data_dir + "/infer.tsv"), self.vocab, + self.num_examples, "infer", epoch, padding_size) + loader = DataLoader.from_generator(capacity=50, return_list=True) + loader.set_sample_generator(infer_reader, batch_size=batch_size, drop_last=False, places=places) + return loader + + def get_labels(self): + return ["0", "1", "2"] + + def get_num_examples(self, phase): + if phase not in ['train', 'dev', 'test', 'infer']: + raise ValueError( + "Unknown phase, which should be in ['train', 'dev', 'infer'].") + return self.num_examples[phase] + + def get_train_progress(self): + return self.current_train_example, self.current_train_epoch + + def data_generator(self, padding_size, batch_size, places, phase='train', epoch=1, shuffle=True): + if phase == "train": + return self.get_train_examples(self.data_dir, epoch, shuffle, batch_size, places, padding_size) + elif phase == "dev": + return self.get_dev_examples(self.data_dir, epoch, shuffle, batch_size, places, padding_size) + elif phase == "test": + return self.get_test_examples(self.data_dir, epoch, batch_size, places, padding_size) + elif phase == "infer": + return self.get_infer_examples(self.data_dir, epoch, batch_size, places, padding_size) + else: + raise ValueError( + "Unknown phase, which should be in ['train', 'dev', 'infer'].") diff --git a/hapi/text/emo_tect/data_reader.py b/hapi/text/emo_tect/data_reader.py new file mode 100644 index 0000000..26bc002 --- /dev/null +++ b/hapi/text/emo_tect/data_reader.py @@ -0,0 +1,126 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import os +import sys +import six +import random + +import paddle +import paddle.fluid as fluid +import numpy as np + + +def word2id(word_dict, query): + """ + Convert word sequence into id list + """ + unk_id = len(word_dict) + wids = [ + word_dict[w] if w in word_dict else unk_id + for w in query.strip().split(" ") + ] + return wids + + +def pad_wid(wids, max_seq_len=128, pad_id=0): + """ + Padding data to max_seq_len + """ + seq_len = len(wids) + if seq_len < max_seq_len: + for i in range(max_seq_len - seq_len): + wids.append(pad_id) + else: + wids = wids[:max_seq_len] + return wids + + +def data_reader(file_path, word_dict, num_examples, phase, epoch, max_seq_len, shuffle=False): + """ + Data reader, which convert word sequence into id list + """ + unk_id = len(word_dict) + all_data = [] + with io.open(file_path, "r", encoding='utf8') as fin: + for line in fin: + if line.startswith("label"): + continue + if phase == "infer": + cols = line.strip().split("\t") + query = cols[-1] if len(cols) != -1 else cols[0] + wids = word2id(word_dict, query) + wids = pad_wid(wids, max_seq_len, unk_id) + all_data.append((wids)) + else: + cols = line.strip().split("\t") + if len(cols) != 2: + sys.stderr.write("[NOTICE] Error Format Line!") + continue + label = [int(cols[0])] + query = cols[1].strip() + wids = word2id(word_dict, query) + wids = pad_wid(wids, max_seq_len, unk_id) + all_data.append((wids, label)) + num_examples[phase] = len(all_data) + + if phase == "infer": + + def reader(): + """ + Infer reader function + """ + for wids in all_data: + yield wids + + return reader + + def reader(): + """ + Reader function + """ + for idx in range(epoch): + if phase == "train" and shuffle: + random.shuffle(all_data) + for wids, label in all_data: + yield wids, label + + return reader + + +def load_vocab(file_path): + """ + load the given vocabulary + """ + vocab = {} + with io.open(file_path, 'r', encoding='utf8') as fin: + wid = 0 + for line in fin: + if line.strip() not in vocab: + vocab[line.strip()] = wid + wid += 1 + vocab[""] = len(vocab) + return vocab + +def query2ids(vocab_path, query): + """ + Convert query to id list according to the given vocab + """ + vocab = load_vocab(vocab_path) + wids = word2id(vocab, query) + return wids diff --git a/hapi/text/senta/__init__.py b/hapi/text/senta/__init__.py index fb38949..c8ebe19 100644 --- a/hapi/text/senta/__init__.py +++ b/hapi/text/senta/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hapi.text.senta.data_processer import SentaProcessor +from hapi.text.senta.data_processor import SentaProcessor diff --git a/hapi/text/senta/data_processer.py b/hapi/text/senta/data_processor.py similarity index 100% rename from hapi/text/senta/data_processer.py rename to hapi/text/senta/data_processor.py -- GitLab