提交 b7f6f1f1 编写于 作者: W wuzewu

add text classification reader

上级 3298af5c
......@@ -14,4 +14,5 @@
from .nlp_reader import ClassifyReader
from .nlp_reader import SequenceLabelReader
from .nlp_reader import TextClassificationReader
from .cv_reader import ImageClassificationReader
......@@ -20,11 +20,13 @@ import csv
import json
from collections import namedtuple
import paddle
import numpy as np
from paddlehub.reader import tokenization
from paddlehub.common.logger import logger
from .batching import pad_batch_data
import paddlehub as hub
class BaseReader(object):
......@@ -381,5 +383,55 @@ class ExtractEmbeddingReader(BaseReader):
return return_list
class TextClassificationReader(object):
def __init__(self, dataset, vocab_path, do_lower_case=False):
self.dataset = dataset
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
self.vocab = self.tokenizer.vocab
self.lac = hub.Module(name="lac")
self.feed_key = list(
self.lac.processor.data_format(
sign_name="lexical_analysis").keys())[0]
def data_generator(self,
batch_size=1,
phase="train",
shuffle=False,
data=None):
if phase == "train":
data = self.dataset.get_train_examples()
elif phase == "test":
shuffle = False
data = self.dataset.get_test_examples()
elif phase == "val" or phase == "dev":
shuffle = False
data = self.dataset.get_dev_examples()
elif phase == "predict":
data = data
def preprocess(text):
data_dict = {self.feed_key: [text]}
processed = self.lac.lexical_analysis(data=data_dict)
processed = [
self.vocab[word] for word in processed[0]['word']
if word in self.vocab
]
return processed
def _data_reader():
if phase == "predict":
for text in data:
text = preprocess(text)
yield (text, )
else:
for item in data:
text = preprocess(item.text_a)
yield (text, item.label)
return paddle.batch(_data_reader, batch_size=batch_size)
if __name__ == '__main__':
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册