提交 a000c6ff 编写于 作者: Z zhangxuefei 提交者: wuzewu

add dataset tnews

上级 bb1a88bd
...@@ -23,6 +23,7 @@ from .toxic import Toxic ...@@ -23,6 +23,7 @@ from .toxic import Toxic
from .squad import SQUAD from .squad import SQUAD
from .xnli import XNLI from .xnli import XNLI
from .glue import GLUE from .glue import GLUE
from .tnews import TNews
# CV Dataset # CV Dataset
from .dogcat import DogCatDataset as DogCat from .dogcat import DogCatDataset as DogCat
......
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import io
import os
import csv
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/tnews.tar.gz"
class TNews(HubDataset):
"""
TNews is the chinese news classification dataset on JinRiTouDiao App.
"""
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "tnews")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir,
"toutiao_category_train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir,
"toutiao_category_dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir,
"toutiao_category_test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return [
'news_game', 'news_sports', 'news_finance', 'news_entertainment',
'news_tech', 'news_house', 'news_car', 'news_culture', 'news_world',
'news_travel', 'news_agriculture', 'news_military', 'news_edu',
'news_story', 'stock'
]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file:
examples = []
for line in file:
data = line.strip().split("_!_")
example = InputExample(
guid=data[0], label=data[2], text_a=data[3])
examples.append(example)
return examples
if __name__ == "__main__":
ds = TNews()
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
...@@ -33,8 +33,8 @@ _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/toxic.tar.gz" ...@@ -33,8 +33,8 @@ _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/toxic.tar.gz"
class Toxic(HubDataset): class Toxic(HubDataset):
""" """
ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for The kaggle Toxic dataset:
opinion mining) https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
""" """
def __init__(self): def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册