diff --git a/paddlehub/reader/__init__.py b/paddlehub/reader/__init__.py index 9595bce44b7d10ae15c9a2835bac8c75fceb2b51..c2455b9e420a7bd99790e77757e61641fbd2ae10 100644 --- a/paddlehub/reader/__init__.py +++ b/paddlehub/reader/__init__.py @@ -14,4 +14,5 @@ from .nlp_reader import ClassifyReader from .nlp_reader import SequenceLabelReader +from .nlp_reader import TextClassificationReader from .cv_reader import ImageClassificationReader diff --git a/paddlehub/reader/nlp_reader.py b/paddlehub/reader/nlp_reader.py index e9c5ac2939e93c1342f25fb11e21e605af8fbc3b..d5d9023f3daf653d8064091d631465873a99579c 100644 --- a/paddlehub/reader/nlp_reader.py +++ b/paddlehub/reader/nlp_reader.py @@ -20,11 +20,13 @@ import csv import json from collections import namedtuple +import paddle import numpy as np from paddlehub.reader import tokenization from paddlehub.common.logger import logger from .batching import pad_batch_data +import paddlehub as hub class BaseReader(object): @@ -381,5 +383,55 @@ class ExtractEmbeddingReader(BaseReader): return return_list +class TextClassificationReader(object): + def __init__(self, dataset, vocab_path, do_lower_case=False): + + self.dataset = dataset + self.tokenizer = tokenization.FullTokenizer( + vocab_file=vocab_path, do_lower_case=do_lower_case) + self.vocab = self.tokenizer.vocab + self.lac = hub.Module(name="lac") + self.feed_key = list( + self.lac.processor.data_format( + sign_name="lexical_analysis").keys())[0] + + def data_generator(self, + batch_size=1, + phase="train", + shuffle=False, + data=None): + if phase == "train": + data = self.dataset.get_train_examples() + elif phase == "test": + shuffle = False + data = self.dataset.get_test_examples() + elif phase == "val" or phase == "dev": + shuffle = False + data = self.dataset.get_dev_examples() + elif phase == "predict": + data = data + + def preprocess(text): + data_dict = {self.feed_key: [text]} + processed = self.lac.lexical_analysis(data=data_dict) + processed = [ + self.vocab[word] for word in processed[0]['word'] + if word in self.vocab + ] + return processed + + def _data_reader(): + if phase == "predict": + for text in data: + text = preprocess(text) + yield (text, ) + else: + for item in data: + text = preprocess(item.text_a) + yield (text, item.label) + + return paddle.batch(_data_reader, batch_size=batch_size) + + if __name__ == '__main__': pass