From a000c6ff836ae76725825f0dd1aa717dc6a792c5 Mon Sep 17 00:00:00 2001 From: zhangxuefei Date: Mon, 14 Oct 2019 19:41:07 +0800 Subject: [PATCH] add dataset tnews --- paddlehub/dataset/__init__.py | 1 + paddlehub/dataset/tnews.py | 105 ++++++++++++++++++++++++++++++++++ paddlehub/dataset/toxic.py | 4 +- 3 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 paddlehub/dataset/tnews.py diff --git a/paddlehub/dataset/__init__.py b/paddlehub/dataset/__init__.py index 80156414..e507dde5 100644 --- a/paddlehub/dataset/__init__.py +++ b/paddlehub/dataset/__init__.py @@ -23,6 +23,7 @@ from .toxic import Toxic from .squad import SQUAD from .xnli import XNLI from .glue import GLUE +from .tnews import TNews # CV Dataset from .dogcat import DogCatDataset as DogCat diff --git a/paddlehub/dataset/tnews.py b/paddlehub/dataset/tnews.py new file mode 100644 index 00000000..c36d6a12 --- /dev/null +++ b/paddlehub/dataset/tnews.py @@ -0,0 +1,105 @@ +# coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import io +import os +import csv + +from paddlehub.dataset import InputExample, HubDataset +from paddlehub.common.downloader import default_downloader +from paddlehub.common.dir import DATA_HOME +from paddlehub.common.logger import logger + +_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/tnews.tar.gz" + + +class TNews(HubDataset): + """ + TNews is the chinese news classification dataset on JinRiTouDiao App. + """ + + def __init__(self): + self.dataset_dir = os.path.join(DATA_HOME, "tnews") + if not os.path.exists(self.dataset_dir): + ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( + url=_DATA_URL, save_path=DATA_HOME, print_progress=True) + else: + logger.info("Dataset {} already cached.".format(self.dataset_dir)) + + self._load_train_examples() + self._load_test_examples() + self._load_dev_examples() + + def _load_train_examples(self): + self.train_file = os.path.join(self.dataset_dir, + "toutiao_category_train.txt") + self.train_examples = self._read_file(self.train_file) + + def _load_dev_examples(self): + self.dev_file = os.path.join(self.dataset_dir, + "toutiao_category_dev.txt") + self.dev_examples = self._read_file(self.dev_file) + + def _load_test_examples(self): + self.test_file = os.path.join(self.dataset_dir, + "toutiao_category_test.txt") + self.test_examples = self._read_file(self.test_file) + + def get_train_examples(self): + return self.train_examples + + def get_dev_examples(self): + return self.dev_examples + + def get_test_examples(self): + return self.test_examples + + def get_labels(self): + return [ + 'news_game', 'news_sports', 'news_finance', 'news_entertainment', + 'news_tech', 'news_house', 'news_car', 'news_culture', 'news_world', + 'news_travel', 'news_agriculture', 'news_military', 'news_edu', + 'news_story', 'stock' + ] + + @property + def num_labels(self): + """ + Return the number of labels in the dataset. + """ + return len(self.get_labels()) + + def _read_file(self, input_file): + """Reads a tab separated value file.""" + with io.open(input_file, "r", encoding="UTF-8") as file: + examples = [] + for line in file: + data = line.strip().split("_!_") + example = InputExample( + guid=data[0], label=data[2], text_a=data[3]) + examples.append(example) + + return examples + + +if __name__ == "__main__": + ds = TNews() + for e in ds.get_train_examples()[:10]: + print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) diff --git a/paddlehub/dataset/toxic.py b/paddlehub/dataset/toxic.py index 1b85ec00..472e3211 100644 --- a/paddlehub/dataset/toxic.py +++ b/paddlehub/dataset/toxic.py @@ -33,8 +33,8 @@ _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/toxic.tar.gz" class Toxic(HubDataset): """ - ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for - opinion mining) + The kaggle Toxic dataset: + https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge """ def __init__(self): -- GitLab