From b7f6f1f1ab4e5d646c6605e54e55d66f98784a78 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Fri, 19 Apr 2019 13:39:59 +0800 Subject: [PATCH] add text classification reader --- paddlehub/reader/__init__.py | 1 + paddlehub/reader/nlp_reader.py | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/paddlehub/reader/__init__.py b/paddlehub/reader/__init__.py index 9595bce4..c2455b9e 100644 --- a/paddlehub/reader/__init__.py +++ b/paddlehub/reader/__init__.py @@ -14,4 +14,5 @@ from .nlp_reader import ClassifyReader from .nlp_reader import SequenceLabelReader +from .nlp_reader import TextClassificationReader from .cv_reader import ImageClassificationReader diff --git a/paddlehub/reader/nlp_reader.py b/paddlehub/reader/nlp_reader.py index e9c5ac29..d5d9023f 100644 --- a/paddlehub/reader/nlp_reader.py +++ b/paddlehub/reader/nlp_reader.py @@ -20,11 +20,13 @@ import csv import json from collections import namedtuple +import paddle import numpy as np from paddlehub.reader import tokenization from paddlehub.common.logger import logger from .batching import pad_batch_data +import paddlehub as hub class BaseReader(object): @@ -381,5 +383,55 @@ class ExtractEmbeddingReader(BaseReader): return return_list +class TextClassificationReader(object): + def __init__(self, dataset, vocab_path, do_lower_case=False): + + self.dataset = dataset + self.tokenizer = tokenization.FullTokenizer( + vocab_file=vocab_path, do_lower_case=do_lower_case) + self.vocab = self.tokenizer.vocab + self.lac = hub.Module(name="lac") + self.feed_key = list( + self.lac.processor.data_format( + sign_name="lexical_analysis").keys())[0] + + def data_generator(self, + batch_size=1, + phase="train", + shuffle=False, + data=None): + if phase == "train": + data = self.dataset.get_train_examples() + elif phase == "test": + shuffle = False + data = self.dataset.get_test_examples() + elif phase == "val" or phase == "dev": + shuffle = False + data = self.dataset.get_dev_examples() + elif phase == "predict": + data = data + + def preprocess(text): + data_dict = {self.feed_key: [text]} + processed = self.lac.lexical_analysis(data=data_dict) + processed = [ + self.vocab[word] for word in processed[0]['word'] + if word in self.vocab + ] + return processed + + def _data_reader(): + if phase == "predict": + for text in data: + text = preprocess(text) + yield (text, ) + else: + for item in data: + text = preprocess(item.text_a) + yield (text, item.label) + + return paddle.batch(_data_reader, batch_size=batch_size) + + if __name__ == '__main__': pass -- GitLab