提交 b7f6f1f1 编写于 作者: W wuzewu

add text classification reader

上级 3298af5c
...@@ -14,4 +14,5 @@ ...@@ -14,4 +14,5 @@
from .nlp_reader import ClassifyReader from .nlp_reader import ClassifyReader
from .nlp_reader import SequenceLabelReader from .nlp_reader import SequenceLabelReader
from .nlp_reader import TextClassificationReader
from .cv_reader import ImageClassificationReader from .cv_reader import ImageClassificationReader
...@@ -20,11 +20,13 @@ import csv ...@@ -20,11 +20,13 @@ import csv
import json import json
from collections import namedtuple from collections import namedtuple
import paddle
import numpy as np import numpy as np
from paddlehub.reader import tokenization from paddlehub.reader import tokenization
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from .batching import pad_batch_data from .batching import pad_batch_data
import paddlehub as hub
class BaseReader(object): class BaseReader(object):
...@@ -381,5 +383,55 @@ class ExtractEmbeddingReader(BaseReader): ...@@ -381,5 +383,55 @@ class ExtractEmbeddingReader(BaseReader):
return return_list return return_list
class TextClassificationReader(object):
def __init__(self, dataset, vocab_path, do_lower_case=False):
self.dataset = dataset
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
self.vocab = self.tokenizer.vocab
self.lac = hub.Module(name="lac")
self.feed_key = list(
self.lac.processor.data_format(
sign_name="lexical_analysis").keys())[0]
def data_generator(self,
batch_size=1,
phase="train",
shuffle=False,
data=None):
if phase == "train":
data = self.dataset.get_train_examples()
elif phase == "test":
shuffle = False
data = self.dataset.get_test_examples()
elif phase == "val" or phase == "dev":
shuffle = False
data = self.dataset.get_dev_examples()
elif phase == "predict":
data = data
def preprocess(text):
data_dict = {self.feed_key: [text]}
processed = self.lac.lexical_analysis(data=data_dict)
processed = [
self.vocab[word] for word in processed[0]['word']
if word in self.vocab
]
return processed
def _data_reader():
if phase == "predict":
for text in data:
text = preprocess(text)
yield (text, )
else:
for item in data:
text = preprocess(item.text_a)
yield (text, item.label)
return paddle.batch(_data_reader, batch_size=batch_size)
if __name__ == '__main__': if __name__ == '__main__':
pass pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册