提交 07e21b51 编写于 作者: K kinghuin 提交者: wuzewu

reconsitution reader and dataset(#279)

* reconsitution reader and dataset
上级 537d3c58
...@@ -20,20 +20,7 @@ from __future__ import print_function ...@@ -20,20 +20,7 @@ from __future__ import print_function
import argparse import argparse
import ast import ast
import collections
import json
import io
import math
import numpy as np
import os
import six
import sys
import time
import paddle
import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
from paddlehub.finetune.task.reading_comprehension_task import write_predictions
hub.common.logger.logger.setLevel("INFO") hub.common.logger.logger.setLevel("INFO")
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# NLP Dataset # NLP Dataset
from .dataset import InputExample, HubDataset from .dataset import InputExample, BaseDataset
from .chnsenticorp import ChnSentiCorp from .chnsenticorp import ChnSentiCorp
from .msra_ner import MSRA_NER from .msra_ner import MSRA_NER
from .nlpcc_dbqa import NLPCC_DBQA from .nlpcc_dbqa import NLPCC_DBQA
......
...@@ -18,15 +18,61 @@ from __future__ import division ...@@ -18,15 +18,61 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import numpy as np import numpy as np
from paddlehub.dataset import BaseDataset
import paddlehub as hub import paddlehub as hub
from paddlehub.common.downloader import default_downloader from paddlehub.common.downloader import default_downloader
from paddlehub.common.logger import logger
class BaseCVDatast(BaseDataset):
def __init__(self,
base_path,
train_list_file=None,
validate_list_file=None,
test_list_file=None,
predict_list_file=None,
label_list_file=None,
label_list=None):
super(BaseCVDatast, self).__init__(
base_path=base_path,
train_file=train_list_file,
dev_file=validate_list_file,
test_file=test_list_file,
predict_file=predict_list_file,
label_file=label_list_file,
label_list=label_list)
def _read_file(self, data_path, phase=None):
data = []
with open(data_path, "r") as file:
while True:
line = file.readline()
if not line:
break
line = line.strip()
items = line.split(" ")
if len(items) > 2:
image_path = " ".join(items[0:-1])
else:
image_path = items[0]
if not os.path.isabs(image_path):
if self.base_path is not None:
image_path = os.path.join(self.base_path, image_path)
label = items[-1]
data.append((image_path, label))
return data
# discarded. please use BaseCVDatast
class ImageClassificationDataset(object): class ImageClassificationDataset(object):
def __init__(self): def __init__(self):
logger.warning(
"ImageClassificationDataset is no longer recommended from PaddleHub v1.5.0, "
"please use BaseCVDataset instead of ImageClassificationDataset. "
"It's more easy-to-use with more functions and support evaluating test set "
"in the end of finetune automatically.")
self.base_path = None self.base_path = None
self.train_list_file = None self.train_list_file = None
self.test_list_file = None self.test_list_file = None
...@@ -99,12 +145,12 @@ class ImageClassificationDataset(object): ...@@ -99,12 +145,12 @@ class ImageClassificationDataset(object):
def test_data(self, shuffle=False): def test_data(self, shuffle=False):
test_data_path = os.path.join(self.base_path, self.test_list_file) test_data_path = os.path.join(self.base_path, self.test_list_file)
return self._parse_data(test_data_path, shuffle, phase='dev') return self._parse_data(test_data_path, shuffle, phase='test')
def validate_data(self, shuffle=False): def validate_data(self, shuffle=False):
validate_data_path = os.path.join(self.base_path, validate_data_path = os.path.join(self.base_path,
self.validate_list_file) self.validate_list_file)
return self._parse_data(validate_data_path, shuffle, phase='test') return self._parse_data(validate_data_path, shuffle, phase='dev')
def get_train_examples(self): def get_train_examples(self):
return self.train_examples return self.train_examples
......
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import csv
from paddlehub.dataset import InputExample, BaseDataset
class BaseNLPDatast(BaseDataset):
def __init__(self,
base_path,
train_file=None,
dev_file=None,
test_file=None,
predict_file=None,
label_file=None,
label_list=None,
train_file_with_head=False,
dev_file_with_head=False,
test_file_with_head=False,
predict_file_with_head=False):
super(BaseNLPDatast, self).__init__(
base_path=base_path,
train_file=train_file,
dev_file=dev_file,
test_file=test_file,
predict_file=predict_file,
label_file=label_file,
label_list=label_list,
train_file_with_head=train_file_with_head,
dev_file_with_head=dev_file_with_head,
test_file_with_head=test_file_with_head,
predict_file_with_head=predict_file_with_head)
def _read_file(self, input_file, phase=None):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file:
reader = csv.reader(file, delimiter="\t", quotechar=None)
examples = []
for (i, line) in enumerate(reader):
if i == 0:
ncol = len(line)
if self.if_file_with_head[phase]:
continue
if ncol == 1:
if phase != "predict":
example = InputExample(guid=i, text_a=line[0])
else:
raise Exception(
"the %s file: %s only has one column but it is not a predict file"
% (phase, input_file))
elif ncol == 2:
example = InputExample(
guid=i, text_a=line[0], label=line[1])
elif ncol == 3:
example = InputExample(
guid=i, text_a=line[0], text_b=line[1], label=line[2])
else:
raise Exception(
"the %s file: %s has too many columns (should <=3)" %
(phase, input_file))
examples.append(example)
return examples
...@@ -17,76 +17,37 @@ from __future__ import absolute_import ...@@ -17,76 +17,37 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import io
import os import os
import csv
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/bq.tar.gz"
class BQ(BaseNLPDatast):
class BQ(HubDataset):
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "bq") dataset_dir = os.path.join(DATA_HOME, "bq")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( dataset_dir,
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) url="https://bj.bcebos.com/paddlehub-dataset/bq.tar.gz")
else: super(BQ, self).__init__(
logger.info("Dataset {} already cached.".format(self.dataset_dir)) base_path=base_path,
train_file="train.txt",
self._load_train_examples() dev_file="dev.txt",
self._load_test_examples() test_file="test.txt",
self._load_dev_examples() label_file=None,
label_list=["0", "1"],
def _load_train_examples(self): )
self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return ["0", "1"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file:
examples = []
for (i, line) in enumerate(file):
data = line.strip().split("\t")
example = InputExample(
guid=i, label=data[2], text_a=data[0], text_b=data[1])
examples.append(example)
return examples
if __name__ == "__main__": if __name__ == "__main__":
ds = BQ() ds = BQ()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]: for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -17,72 +17,39 @@ from __future__ import absolute_import ...@@ -17,72 +17,39 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import codecs import codecs
import os import os
import csv import csv
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/chnsenticorp.tar.gz"
class ChnSentiCorp(BaseNLPDatast):
class ChnSentiCorp(HubDataset):
""" """
ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for
opinion mining) opinion mining)
""" """
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "chnsenticorp") dataset_dir = os.path.join(DATA_HOME, "chnsenticorp")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( dataset_dir,
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) url="https://bj.bcebos.com/paddlehub-dataset/chnsenticorp.tar.gz")
else: super(ChnSentiCorp, self).__init__(
logger.info("Dataset {} already cached.".format(self.dataset_dir)) base_path=base_path,
train_file="train.tsv",
self._load_train_examples() dev_file="dev.tsv",
self._load_test_examples() test_file="test.tsv",
self._load_dev_examples() label_file=None,
label_list=["0", "1"],
def _load_train_examples(self): )
self.train_file = os.path.join(self.dataset_dir, "train.tsv")
self.train_examples = self._read_tsv(self.train_file) def _read_file(self, input_file, phase=None):
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.tsv")
self.test_examples = self._read_tsv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return ["0", "1"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with codecs.open(input_file, "r", encoding="UTF-8") as f: with codecs.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = [] examples = []
seq_id = 0 seq_id = 0
header = next(reader) # skip header header = next(reader) # skip header
...@@ -97,5 +64,5 @@ class ChnSentiCorp(HubDataset): ...@@ -97,5 +64,5 @@ class ChnSentiCorp(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = ChnSentiCorp() ds = ChnSentiCorp()
for e in ds.get_train_examples(): for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
...@@ -16,12 +16,11 @@ ...@@ -16,12 +16,11 @@
import json import json
import os import os
import sys
from paddlehub.reader import tokenization from paddlehub.reader import tokenization
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/cmrc2018.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/cmrc2018.tar.gz"
SPIECE_UNDERLINE = '▁' SPIECE_UNDERLINE = '▁'
...@@ -63,42 +62,22 @@ class CMRC2018Example(object): ...@@ -63,42 +62,22 @@ class CMRC2018Example(object):
return s return s
class CMRC2018(object): class CMRC2018(BaseNLPDatast):
"""A single set of features of data.""" """A single set of features of data."""
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "cmrc2018") dataset_dir = os.path.join(DATA_HOME, "cmrc2018")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(CMRC2018, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="cmrc2018_train.json",
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="cmrc2018_dev.json",
test_file=None,
self._load_train_examples() label_file=None,
self._load_dev_examples() label_list=None,
self._load_test_examples() )
def _load_train_examples(self): def _read_file(self, input_file, phase=False):
self.train_file = os.path.join(self.dataset_dir, "cmrc2018_train.json")
self.train_examples = self._read_json(self.train_file, is_training=True)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "cmrc2018_dev.json")
self.dev_examples = self._read_json(self.dev_file, is_training=False)
def _load_test_examples(self):
pass
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return []
def _read_json(self, input_file, is_training=False):
"""Read a cmrc2018 json file into a list of CRCDExample.""" """Read a cmrc2018 json file into a list of CRCDExample."""
def _is_chinese_char(cp): def _is_chinese_char(cp):
...@@ -197,7 +176,7 @@ class CMRC2018(object): ...@@ -197,7 +176,7 @@ class CMRC2018(object):
# #
# Note that this means for training mode, every example is NOT # Note that this means for training mode, every example is NOT
# guaranteed to be preserved. # guaranteed to be preserved.
if is_training: if phase == "train":
actual_text = "".join( actual_text = "".join(
doc_tokens[start_position:(end_position + 1)]) doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = "".join( cleaned_answer_text = "".join(
......
...@@ -17,6 +17,12 @@ from __future__ import absolute_import ...@@ -17,6 +17,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import paddlehub as hub
from paddlehub.common.downloader import default_downloader
from paddlehub.common.logger import logger
class InputExample(object): class InputExample(object):
""" """
...@@ -49,21 +55,124 @@ class InputExample(object): ...@@ -49,21 +55,124 @@ class InputExample(object):
self.text_a, self.text_b, self.label) self.text_a, self.text_b, self.label)
class HubDataset(object): class BaseDataset(object):
def __init__(self,
base_path,
train_file=None,
dev_file=None,
test_file=None,
predict_file=None,
label_file=None,
label_list=None,
train_file_with_head=False,
dev_file_with_head=False,
test_file_with_head=False,
predict_file_with_head=False):
if not (train_file or dev_file or test_file):
raise ValueError("At least one file should be assigned")
self.base_path = base_path
self.train_file = train_file
self.dev_file = dev_file
self.test_file = test_file
self.predict_file = predict_file
self.label_file = label_file
self.label_list = label_list
self.train_examples = []
self.dev_examples = []
self.test_examples = []
self.predict_examples = []
self.if_file_with_head = {
"train": train_file_with_head,
"dev": dev_file_with_head,
"test": test_file_with_head,
"predict": predict_file_with_head
}
if train_file:
self._load_train_examples()
if dev_file:
self._load_dev_examples()
if test_file:
self._load_test_examples()
if predict_file:
self._load_predict_examples()
if self.label_file:
if not self.label_list:
self.label_list = self._load_label_data()
else:
logger.warning(
"As label_list has been assigned, label_file is noneffective"
)
def get_train_examples(self): def get_train_examples(self):
raise NotImplementedError() return self.train_examples
def get_dev_examples(self): def get_dev_examples(self):
raise NotImplementedError() return self.dev_examples
def get_test_examples(self): def get_test_examples(self):
raise NotImplementedError() return self.test_examples
def get_val_examples(self): def get_val_examples(self):
return self.get_dev_examples() return self.get_dev_examples()
def get_predict_examples(self):
return self.predict_examples
def get_labels(self): def get_labels(self):
raise NotImplementedError() return self.label_list
@property
def num_labels(self): def num_labels(self):
raise NotImplementedError() return len(self.label_list)
def label_dict(self):
return {index: key for index, key in enumerate(self.label_list)}
def _download_dataset(self, dataset_path, url):
if not os.path.exists(dataset_path):
result, tips, dataset_path = default_downloader.download_file_and_uncompress(
url=url,
save_path=hub.common.dir.DATA_HOME,
print_progress=True,
replace=True)
if not result:
raise Exception(tips)
else:
logger.info("Dataset {} already cached.".format(dataset_path))
return dataset_path
def _load_train_examples(self):
self.train_path = os.path.join(self.base_path, self.train_file)
self.train_examples = self._read_file(self.train_path, phase="train")
def _load_dev_examples(self):
self.dev_path = os.path.join(self.base_path, self.dev_file)
self.dev_examples = self._read_file(self.dev_path, phase="dev")
def _load_test_examples(self):
self.test_path = os.path.join(self.base_path, self.test_file)
self.test_examples = self._read_file(self.test_path, phase="test")
def _load_predict_examples(self):
self.predict_path = os.path.join(self.base_path, self.predict_file)
self.predict_examples = self._read_file(
self.predict_path, phase="predict")
def _read_file(self, path, phase=None):
raise NotImplementedError
def _load_label_data(self):
with open(os.path.join(self.base_path, self.label_file), "r") as file:
return file.read().split("\n")
def __str__(self):
return "Dataset: %s with %i train examples, %i dev examples and %i test examples" % (
self.__class__.__name__, len(self.train_examples),
len(self.dev_examples), len(self.test_examples))
# add alias, compatible with old version
HubDataset = BaseDataset
...@@ -20,18 +20,33 @@ from __future__ import print_function ...@@ -20,18 +20,33 @@ from __future__ import print_function
import os import os
import paddlehub as hub import paddlehub as hub
from paddlehub.dataset.base_cv_dataset import ImageClassificationDataset from paddlehub.dataset.base_cv_dataset import BaseCVDatast
class DogCatDataset(ImageClassificationDataset): class DogCatDataset(BaseCVDatast):
def __init__(self): def __init__(self):
super(DogCatDataset, self).__init__()
dataset_path = os.path.join(hub.common.dir.DATA_HOME, "dog-cat") dataset_path = os.path.join(hub.common.dir.DATA_HOME, "dog-cat")
self.base_path = self._download_dataset( base_path = self._download_dataset(
dataset_path=dataset_path, dataset_path=dataset_path,
url="https://bj.bcebos.com/paddlehub-dataset/dog-cat.tar.gz") url="https://bj.bcebos.com/paddlehub-dataset/dog-cat.tar.gz")
self.train_list_file = "train_list.txt" super(DogCatDataset, self).__init__(
self.test_list_file = "test_list.txt" base_path=base_path,
self.validate_list_file = "validate_list.txt" train_list_file="train_list.txt",
self.label_list_file = "label_list.txt" validate_list_file="validate_list.txt",
self.num_labels = 2 test_list_file="test_list.txt",
label_list_file="label_list.txt",
label_list=None)
if __name__ == "__main__":
ds = DogCatDataset()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print(e)
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print(e)
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print(e)
print(ds)
...@@ -16,12 +16,11 @@ ...@@ -16,12 +16,11 @@
import json import json
import os import os
import sys
from paddlehub.reader import tokenization from paddlehub.reader import tokenization
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/drcd.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/drcd.tar.gz"
SPIECE_UNDERLINE = '▁' SPIECE_UNDERLINE = '▁'
...@@ -39,8 +38,7 @@ class DRCDExample(object): ...@@ -39,8 +38,7 @@ class DRCDExample(object):
doc_tokens, doc_tokens,
orig_answer_text=None, orig_answer_text=None,
start_position=None, start_position=None,
end_position=None, end_position=None):
is_impossible=False):
self.qas_id = qas_id self.qas_id = qas_id
self.question_text = question_text self.question_text = question_text
self.doc_tokens = doc_tokens self.doc_tokens = doc_tokens
...@@ -64,43 +62,22 @@ class DRCDExample(object): ...@@ -64,43 +62,22 @@ class DRCDExample(object):
return s return s
class DRCD(object): class DRCD(BaseNLPDatast):
"""A single set of features of data.""" """A single set of features of data."""
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "drcd") dataset_dir = os.path.join(DATA_HOME, "drcd")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(DRCD, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="DRCD_training.json",
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="DRCD_dev.json",
test_file="DRCD_test.json",
self._load_train_examples() label_file=None,
self._load_dev_examples() label_list=None,
self._load_test_examples() )
def _load_train_examples(self): def _read_file(self, input_file, phase=None):
self.train_file = os.path.join(self.dataset_dir, "DRCD_training.json")
self.train_examples = self._read_json(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "DRCD_dev.json")
self.dev_examples = self._read_json(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "DRCD_test.json")
self.test_examples = self._read_json(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def _read_json(self, input_file):
"""Read a DRCD json file into a list of CRCDExample.""" """Read a DRCD json file into a list of CRCDExample."""
def _is_chinese_char(cp): def _is_chinese_char(cp):
......
...@@ -20,18 +20,33 @@ from __future__ import print_function ...@@ -20,18 +20,33 @@ from __future__ import print_function
import os import os
import paddlehub as hub import paddlehub as hub
from paddlehub.dataset.base_cv_dataset import ImageClassificationDataset from paddlehub.dataset.base_cv_dataset import BaseCVDatast
class FlowersDataset(ImageClassificationDataset): class FlowersDataset(BaseCVDatast):
def __init__(self): def __init__(self):
super(FlowersDataset, self).__init__()
dataset_path = os.path.join(hub.common.dir.DATA_HOME, "flower_photos") dataset_path = os.path.join(hub.common.dir.DATA_HOME, "flower_photos")
self.base_path = self._download_dataset( base_path = self._download_dataset(
dataset_path=dataset_path, dataset_path=dataset_path,
url="https://bj.bcebos.com/paddlehub-dataset/flower_photos.tar.gz") url="https://bj.bcebos.com/paddlehub-dataset/flower_photos.tar.gz")
self.train_list_file = "train_list.txt" super(FlowersDataset, self).__init__(
self.test_list_file = "test_list.txt" base_path=base_path,
self.validate_list_file = "validate_list.txt" train_list_file="train_list.txt",
self.label_list_file = "label_list.txt" validate_list_file="validate_list.txt",
self.num_labels = 5 test_list_file="test_list.txt",
label_list_file="label_list.txt",
label_list=None)
if __name__ == "__main__":
ds = FlowersDataset()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print(e)
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print(e)
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print(e)
print(ds)
...@@ -20,19 +20,33 @@ from __future__ import print_function ...@@ -20,19 +20,33 @@ from __future__ import print_function
import os import os
import paddlehub as hub import paddlehub as hub
from paddlehub.dataset.base_cv_dataset import ImageClassificationDataset from paddlehub.dataset.base_cv_dataset import BaseCVDatast
class Food101Dataset(ImageClassificationDataset): class Food101Dataset(BaseCVDatast):
def __init__(self): def __init__(self):
super(Food101Dataset, self).__init__()
dataset_path = os.path.join(hub.common.dir.DATA_HOME, "food-101", dataset_path = os.path.join(hub.common.dir.DATA_HOME, "food-101",
"images") "images")
self.base_path = self._download_dataset( base_path = self._download_dataset(
dataset_path=dataset_path, dataset_path=dataset_path,
url="https://bj.bcebos.com/paddlehub-dataset/Food101.tar.gz") url="https://bj.bcebos.com/paddlehub-dataset/Food101.tar.gz")
self.train_list_file = "train_list.txt" super(Food101Dataset, self).__init__(
self.test_list_file = "test_list.txt" base_path=base_path,
self.validate_list_file = "validate_list.txt" train_list_file="train_list.txt",
self.label_list_file = "label_list.txt" test_list_file="test_list.txt",
self.num_labels = 101 validate_list_file="validate_list.txt",
label_list_file="label_list.txt")
if __name__ == "__main__":
ds = Food101Dataset()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print(e)
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print(e)
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print(e)
print(ds)
...@@ -21,15 +21,15 @@ import os ...@@ -21,15 +21,15 @@ import os
import csv import csv
import io import io
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common.dir import DATA_HOME
from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/glue_data.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/glue_data.tar.gz"
class GLUE(HubDataset): class GLUE(BaseNLPDatast):
""" """
Please refer to Please refer to
https://gluebenchmark.com https://gluebenchmark.com
...@@ -43,147 +43,107 @@ class GLUE(HubDataset): ...@@ -43,147 +43,107 @@ class GLUE(HubDataset):
'RTE', 'SST-2', 'STS-B' 'RTE', 'SST-2', 'STS-B'
]: ]:
raise Exception( raise Exception(
sub_dataset + "%s is not in GLUE benchmark. Please confirm the data set" %
" is not in GLUE benchmark. Please confirm the data set") sub_dataset)
self.mismatch = False
mismatch = False
if sub_dataset == 'MNLI_mm': if sub_dataset == 'MNLI_mm':
sub_dataset = 'MNLI' sub_dataset = 'MNLI'
self.mismatch = True mismatch = True
elif sub_dataset == 'MNLI_m': elif sub_dataset == 'MNLI_m':
sub_dataset = 'MNLI' sub_dataset = 'MNLI'
self.sub_dataset = sub_dataset self.sub_dataset = sub_dataset
self.dataset_dir = os.path.join(DATA_HOME, "glue_data")
# test.tsv has not label,so it is a predict file
if not os.path.exists(self.dataset_dir): dev_file = "dev.tsv"
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( predict_file = "test.tsv"
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) if sub_dataset == 'MNLI' and not mismatch:
else: dev_file = 'dev_matched.tsv'
logger.info("Dataset {} already cached.".format(self.dataset_dir)) predict_file = "test_matched.tsv"
elif sub_dataset == 'MNLI' and mismatch:
self._load_train_examples() dev_file = 'dev_mismatched.tsv'
self._load_dev_examples() predict_file = "test_mismatched.tsv"
self._load_test_examples()
self._load_predict_examples() dataset_dir = os.path.join(DATA_HOME, "glue_data")
dataset_dir = self._download_dataset(dataset_dir, url=_DATA_URL)
def _load_train_examples(self): base_path = os.path.join(dataset_dir, self.sub_dataset)
self.train_file = os.path.join(self.dataset_dir, self.sub_dataset,
"train.tsv") label_list = None
self.train_examples = self._read_tsv(self.train_file) if sub_dataset in ['MRPC', 'QQP', 'SST-2', 'CoLA']:
label_list = ["0", "1"]
def _load_dev_examples(self): elif sub_dataset in ['QNLI', 'RTE']:
if self.sub_dataset == 'MNLI' and not self.mismatch: label_list = ['not_entailment', 'entailment']
self.dev_file = os.path.join(self.dataset_dir, self.sub_dataset, elif sub_dataset in ['MNLI']:
"dev_matched.tsv") label_list = ["neutral", "contradiction", "entailment"]
elif self.sub_dataset == 'MNLI' and self.mismatch: elif sub_dataset in ['STS-B']:
self.dev_file = os.path.join(self.dataset_dir, self.sub_dataset, label_list = None
"dev_mismatched.tsv")
else: super(GLUE, self).__init__(
self.dev_file = os.path.join(self.dataset_dir, self.sub_dataset, base_path=base_path,
"dev.tsv") train_file="train.tsv",
self.dev_examples = self._read_tsv(self.dev_file) dev_file=dev_file,
predict_file=predict_file,
def _load_test_examples(self): label_file=None,
self.test_examples = [] label_list=label_list,
)
def _load_predict_examples(self):
if self.sub_dataset == 'MNLI' and not self.mismatch: def _read_file(self, input_file, phase=None):
self.predict_file = os.path.join(self.dataset_dir, self.sub_dataset,
"test_matched.tsv")
elif self.sub_dataset == 'MNLI' and self.mismatch:
self.predict_file = os.path.join(self.dataset_dir, self.sub_dataset,
"test_mismatched.tsv")
else:
self.predict_file = os.path.join(self.dataset_dir, self.sub_dataset,
"test.tsv")
self.predict_examples = self._read_tsv(self.predict_file, wo_label=True)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_predict_examples(self):
return self.predict_examples
def get_labels(self):
"""See base class."""
if self.sub_dataset in ['MRPC', 'QQP', 'SST-2', 'CoLA']:
return ["0", "1"]
elif self.sub_dataset in ['QNLI', 'RTE']:
return ['not_entailment', 'entailment']
elif self.sub_dataset in ['MNLI']:
return ["neutral", "contradiction", "entailment"]
elif self.sub_dataset in ['STS-B']:
return Exception("No category labels for regreesion tasks")
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None, wo_label=False):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as f: with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = [] examples = []
seq_id = 0 seq_id = 0
if self.sub_dataset != 'CoLA' or wo_label: if self.sub_dataset != 'CoLA' or phase == "predict":
header = next(reader) # skip header header = next(reader) # skip header
if self.sub_dataset in [ if self.sub_dataset in [
'MRPC', 'MRPC',
]: ]:
if wo_label: if phase == "predict":
label_index, text_a_index, text_b_index = [None, -2, -1] label_index, text_a_index, text_b_index = [None, -2, -1]
else: else:
label_index, text_a_index, text_b_index = [0, -2, -1] label_index, text_a_index, text_b_index = [0, -2, -1]
elif self.sub_dataset in [ elif self.sub_dataset in [
'QNLI', 'QNLI',
]: ]:
if wo_label: if phase == "predict":
label_index, text_a_index, text_b_index = [None, 1, 2] label_index, text_a_index, text_b_index = [None, 1, 2]
else: else:
label_index, text_a_index, text_b_index = [3, 1, 2] label_index, text_a_index, text_b_index = [3, 1, 2]
elif self.sub_dataset in [ elif self.sub_dataset in [
'QQP', 'QQP',
]: ]:
if wo_label: if phase == "predict":
label_index, text_a_index, text_b_index = [None, 1, 2] label_index, text_a_index, text_b_index = [None, 1, 2]
else: else:
label_index, text_a_index, text_b_index = [5, 3, 4] label_index, text_a_index, text_b_index = [5, 3, 4]
elif self.sub_dataset in [ elif self.sub_dataset in [
'RTE', 'RTE',
]: ]:
if wo_label: if phase == "predict":
label_index, text_a_index, text_b_index = [None, 1, 2] label_index, text_a_index, text_b_index = [None, 1, 2]
else: else:
label_index, text_a_index, text_b_index = [3, 1, 2] label_index, text_a_index, text_b_index = [3, 1, 2]
elif self.sub_dataset in [ elif self.sub_dataset in [
'SST-2', 'SST-2',
]: ]:
if wo_label: if phase == "predict":
label_index, text_a_index, text_b_index = [None, 1, None] label_index, text_a_index, text_b_index = [None, 1, None]
else: else:
label_index, text_a_index, text_b_index = [1, 0, None] label_index, text_a_index, text_b_index = [1, 0, None]
elif self.sub_dataset in [ elif self.sub_dataset in [
'MNLI', 'MNLI',
]: ]:
if wo_label: if phase == "predict":
label_index, text_a_index, text_b_index = [None, 8, 9] label_index, text_a_index, text_b_index = [None, 8, 9]
else: else:
label_index, text_a_index, text_b_index = [-1, 8, 9] label_index, text_a_index, text_b_index = [-1, 8, 9]
elif self.sub_dataset in ['CoLA']: elif self.sub_dataset in ['CoLA']:
if wo_label: if phase == "predict":
label_index, text_a_index, text_b_index = [None, 1, None] label_index, text_a_index, text_b_index = [None, 1, None]
else: else:
label_index, text_a_index, text_b_index = [1, 3, None] label_index, text_a_index, text_b_index = [1, 3, None]
elif self.sub_dataset in ['STS-B']: elif self.sub_dataset in ['STS-B']:
if wo_label: if phase == "predict":
label_index, text_a_index, text_b_index = [None, -2, -1] label_index, text_a_index, text_b_index = [None, -2, -1]
else: else:
label_index, text_a_index, text_b_index = [-1, -3, -2] label_index, text_a_index, text_b_index = [-1, -3, -2]
......
...@@ -17,64 +17,30 @@ from __future__ import absolute_import ...@@ -17,64 +17,30 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import io import io
import os import os
import csv
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/iflytek.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/iflytek.tar.gz"
class IFLYTEK(HubDataset): class IFLYTEK(BaseNLPDatast):
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "iflytek") dataset_dir = os.path.join(DATA_HOME, "iflytek")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(IFLYTEK, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="train.txt",
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="dev.txt",
test_file="test.txt",
self._load_train_examples() label_file=None,
self._load_test_examples() label_list=[str(i) for i in range(119)],
self._load_dev_examples() )
def _load_train_examples(self): def _read_file(self, input_file, phase=None):
self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return [str(i) for i in range(119)]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file: with io.open(input_file, "r", encoding="UTF-8") as file:
examples = [] examples = []
...@@ -91,5 +57,13 @@ class IFLYTEK(HubDataset): ...@@ -91,5 +57,13 @@ class IFLYTEK(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = IFLYTEK() ds = IFLYTEK()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]: for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -20,18 +20,33 @@ from __future__ import print_function ...@@ -20,18 +20,33 @@ from __future__ import print_function
import os import os
import paddlehub as hub import paddlehub as hub
from paddlehub.dataset.base_cv_dataset import ImageClassificationDataset from paddlehub.dataset.base_cv_dataset import BaseCVDatast
class Indoor67Dataset(ImageClassificationDataset): class Indoor67Dataset(BaseCVDatast):
def __init__(self): def __init__(self):
super(Indoor67Dataset, self).__init__()
dataset_path = os.path.join(hub.common.dir.DATA_HOME, "Indoor67") dataset_path = os.path.join(hub.common.dir.DATA_HOME, "Indoor67")
self.base_path = self._download_dataset( base_path = self._download_dataset(
dataset_path=dataset_path, dataset_path=dataset_path,
url="https://bj.bcebos.com/paddlehub-dataset/Indoor67.tar.gz") url="https://bj.bcebos.com/paddlehub-dataset/Indoor67.tar.gz")
self.train_list_file = "train_list.txt" super(Indoor67Dataset, self).__init__(
self.test_list_file = "test_list.txt" base_path=base_path,
self.validate_list_file = "validate_list.txt" train_list_file="train_list.txt",
self.label_list_file = "label_list.txt" validate_list_file="validate_list.txt",
self.num_labels = 67 test_list_file="test_list.txt",
label_list_file="label_list.txt",
label_list=None)
if __name__ == "__main__":
ds = Indoor67Dataset()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print(e)
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print(e)
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print(e)
print(ds)
...@@ -17,73 +17,40 @@ from __future__ import absolute_import ...@@ -17,73 +17,40 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import io import io
import os import os
import csv import csv
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/inews.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/inews.tar.gz"
class INews(HubDataset): class INews(BaseNLPDatast):
""" """
INews is a sentiment analysis dataset for Internet News INews is a sentiment analysis dataset for Internet News
""" """
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "inews") dataset_dir = os.path.join(DATA_HOME, "inews")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(INews, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="train.txt",
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="dev.txt",
test_file="test.txt",
self._load_train_examples() label_file=None,
self._load_test_examples() label_list=["0", "1", "2"],
self._load_dev_examples() )
def _load_train_examples(self): def _read_file(self, input_file, phase=None):
self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file, is_training=True)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file, is_training=False)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file, is_training=False)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return ["0", "1", "2"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file, is_training):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file: with io.open(input_file, "r", encoding="UTF-8") as file:
examples = [] examples = []
for (i, line) in enumerate(file): for (i, line) in enumerate(file):
if i == 0 and is_training: if i == 0 and phase == 'train':
continue continue
data = line.strip().split("_!_") data = line.strip().split("_!_")
example = InputExample( example = InputExample(
...@@ -94,5 +61,13 @@ class INews(HubDataset): ...@@ -94,5 +61,13 @@ class INews(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = INews() ds = INews()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]: for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -17,68 +17,34 @@ from __future__ import absolute_import ...@@ -17,68 +17,34 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import codecs import codecs
import os import os
import csv import csv
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/lcqmc.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/lcqmc.tar.gz"
class LCQMC(HubDataset): class LCQMC(BaseNLPDatast):
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "lcqmc") dataset_dir = os.path.join(DATA_HOME, "lcqmc")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(LCQMC, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="train.tsv",
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="dev.tsv",
test_file="test.tsv",
self._load_train_examples() label_file=None,
self._load_test_examples() label_list=["0", "1"],
self._load_dev_examples() )
def _load_train_examples(self): def _read_file(self, input_file, phase=None):
self.train_file = os.path.join(self.dataset_dir, "train.tsv")
self.train_examples = self._read_tsv(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.tsv")
self.test_examples = self._read_tsv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""See base class."""
return ["0", "1"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with codecs.open(input_file, "r", encoding="UTF-8") as f: with codecs.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = [] examples = []
seq_id = 0 seq_id = 0
header = next(reader) # skip header header = next(reader) # skip header
...@@ -93,5 +59,13 @@ class LCQMC(HubDataset): ...@@ -93,5 +59,13 @@ class LCQMC(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = LCQMC() ds = LCQMC()
for e in ds.get_train_examples(): print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -20,18 +20,15 @@ from __future__ import print_function ...@@ -20,18 +20,15 @@ from __future__ import print_function
import os import os
import codecs import codecs
import csv import csv
import json
from collections import namedtuple
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/msra_ner.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/msra_ner.tar.gz"
class MSRA_NER(HubDataset): class MSRA_NER(BaseNLPDatast):
""" """
A set of manually annotated Chinese word-segmentation data and A set of manually annotated Chinese word-segmentation data and
specifications for training and testing a Chinese word-segmentation system specifications for training and testing a Chinese word-segmentation system
...@@ -40,55 +37,23 @@ class MSRA_NER(HubDataset): ...@@ -40,55 +37,23 @@ class MSRA_NER(HubDataset):
""" """
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "msra_ner") dataset_dir = os.path.join(DATA_HOME, "msra_ner")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(MSRA_NER, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="train.tsv",
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="dev.tsv",
test_file="test.tsv",
self._load_train_examples() label_file=None,
self._load_test_examples() label_list=[
self._load_dev_examples() "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"
],
def _load_train_examples(self): )
train_file = os.path.join(self.dataset_dir, "train.tsv")
self.train_examples = self._read_tsv(train_file) def _read_file(self, input_file, phase=None):
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.tsv")
self.test_examples = self._read_tsv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def get_label_map(self):
return self.label_map
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with codecs.open(input_file, "r", encoding="UTF-8") as f: with codecs.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = [] examples = []
seq_id = 0 seq_id = 0
header = next(reader) # skip header header = next(reader) # skip header
...@@ -103,5 +68,13 @@ class MSRA_NER(HubDataset): ...@@ -103,5 +68,13 @@ class MSRA_NER(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = MSRA_NER() ds = MSRA_NER()
for e in ds.get_train_examples(): print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -17,20 +17,18 @@ from __future__ import absolute_import ...@@ -17,20 +17,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import codecs import codecs
import os import os
import csv import csv
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/nlpcc-dbqa.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/nlpcc-dbqa.tar.gz"
class NLPCC_DBQA(HubDataset): class NLPCC_DBQA(BaseNLPDatast):
""" """
Please refer to Please refer to
http://tcci.ccf.org.cn/conference/2017/dldoc/taskgline05.pdf http://tcci.ccf.org.cn/conference/2017/dldoc/taskgline05.pdf
...@@ -38,53 +36,21 @@ class NLPCC_DBQA(HubDataset): ...@@ -38,53 +36,21 @@ class NLPCC_DBQA(HubDataset):
""" """
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa") dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(NLPCC_DBQA, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="train.tsv",
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="dev.tsv",
test_file="test.tsv",
self._load_train_examples() label_file=None,
self._load_test_examples() label_list=["0", "1"],
self._load_dev_examples() )
def _load_train_examples(self): def _read_file(self, input_file, phase=None):
self.train_file = os.path.join(self.dataset_dir, "train.tsv")
self.train_examples = self._read_tsv(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.tsv")
self.test_examples = self._read_tsv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""See base class."""
return ["0", "1"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with codecs.open(input_file, "r", encoding="UTF-8") as f: with codecs.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = [] examples = []
seq_id = 0 seq_id = 0
header = next(reader) # skip header header = next(reader) # skip header
...@@ -99,5 +65,13 @@ class NLPCC_DBQA(HubDataset): ...@@ -99,5 +65,13 @@ class NLPCC_DBQA(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = NLPCC_DBQA() ds = NLPCC_DBQA()
for e in ds.get_train_examples(): print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -16,12 +16,11 @@ ...@@ -16,12 +16,11 @@
import json import json
import os import os
import sys
from paddlehub.reader import tokenization from paddlehub.reader import tokenization
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/squad.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/squad.tar.gz"
...@@ -66,61 +65,31 @@ class SquadExample(object): ...@@ -66,61 +65,31 @@ class SquadExample(object):
return s return s
class SQUAD(object): class SQUAD(BaseNLPDatast):
"""A single set of features of data.""" """A single set of features of data."""
def __init__(self, version_2_with_negative=False): def __init__(self, version_2_with_negative=False):
self.dataset_dir = os.path.join(DATA_HOME, "squad_data")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self.version_2_with_negative = version_2_with_negative self.version_2_with_negative = version_2_with_negative
self._load_train_examples(version_2_with_negative, if_has_answer=True)
self._load_dev_examples(version_2_with_negative, if_has_answer=True)
def _load_train_examples(self,
version_2_with_negative=False,
if_has_answer=True):
if not version_2_with_negative:
self.train_file = os.path.join(self.dataset_dir, "train-v1.1.json")
else:
self.train_file = os.path.join(self.dataset_dir, "train-v2.0.json")
self.train_examples = self._read_json(self.train_file, if_has_answer,
version_2_with_negative)
def _load_dev_examples(self,
version_2_with_negative=False,
if_has_answer=True):
if not version_2_with_negative: if not version_2_with_negative:
self.dev_file = os.path.join(self.dataset_dir, "dev-v1.1.json") train_file = "train-v1.1.json"
dev_file = "dev-v1.1.json"
else: else:
self.dev_file = os.path.join(self.dataset_dir, "dev-v2.0.json") train_file = "train-v2.0.json"
dev_file = "dev-v2.0.json"
self.dev_examples = self._read_json(self.dev_file, if_has_answer,
version_2_with_negative) dataset_dir = os.path.join(DATA_HOME, "squad_data")
base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
def _load_test_examples(self,
version_2_with_negative=False, super(SQUAD, self).__init__(
is_training=False): base_path=base_path,
self.test_file = None train_file=train_file,
logger.error("not test_file") dev_file=dev_file,
test_file=None,
def get_train_examples(self): label_file=None,
return self.train_examples label_list=None,
)
def get_dev_examples(self):
return self.dev_examples def _read_file(self, input_file, phase=None):
def get_test_examples(self):
return []
def _read_json(self,
input_file,
if_has_answer,
version_2_with_negative=False):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
with open(input_file, "r") as reader: with open(input_file, "r") as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
...@@ -156,13 +125,15 @@ class SQUAD(object): ...@@ -156,13 +125,15 @@ class SQUAD(object):
end_position = None end_position = None
orig_answer_text = None orig_answer_text = None
is_impossible = False is_impossible = False
if if_has_answer: if phase in ["train", "dev"]:
if version_2_with_negative: if self.version_2_with_negative:
is_impossible = qa["is_impossible"] is_impossible = qa["is_impossible"]
# if (len(qa["answers"]) != 1) and (not is_impossible): if phase == "train" and (len(qa["answers"]) !=
# raise ValueError( 1) and (not is_impossible):
# "For training, each question should have exactly 1 answer." print(qa)
# ) raise ValueError(
"For training, each question should have exactly 1 answer."
)
if not is_impossible: if not is_impossible:
answer = qa["answers"][0] answer = qa["answers"][0]
orig_answer_text = answer["text"] orig_answer_text = answer["text"]
...@@ -206,8 +177,14 @@ class SQUAD(object): ...@@ -206,8 +177,14 @@ class SQUAD(object):
if __name__ == "__main__": if __name__ == "__main__":
ds = SQUAD(version_2_with_negative=False) ds = SQUAD(version_2_with_negative=True)
examples = ds.get_train_examples() print("first 10 dev")
for index, e in enumerate(examples): for e in ds.get_dev_examples()[:2]:
if index < 10: print(e)
print(e) print("first 10 train")
for e in ds.get_train_examples()[:2]:
print(e)
print("first 10 test")
for e in ds.get_test_examples()[:2]:
print(e)
print(ds)
...@@ -20,20 +20,35 @@ from __future__ import print_function ...@@ -20,20 +20,35 @@ from __future__ import print_function
import os import os
import paddlehub as hub import paddlehub as hub
from paddlehub.dataset.base_cv_dataset import ImageClassificationDataset from paddlehub.dataset.base_cv_dataset import BaseCVDatast
class StanfordDogsDataset(ImageClassificationDataset): class StanfordDogsDataset(BaseCVDatast):
def __init__(self): def __init__(self):
super(StanfordDogsDataset, self).__init__()
dataset_path = os.path.join(hub.common.dir.DATA_HOME, dataset_path = os.path.join(hub.common.dir.DATA_HOME,
"StanfordDogs-120") "StanfordDogs-120")
self.base_path = self._download_dataset( base_path = self._download_dataset(
dataset_path=dataset_path, dataset_path=dataset_path,
url="https://bj.bcebos.com/paddlehub-dataset/StanfordDogs-120.tar.gz" url="https://bj.bcebos.com/paddlehub-dataset/StanfordDogs-120.tar.gz"
) )
self.train_list_file = "train_list.txt" super(StanfordDogsDataset, self).__init__(
self.test_list_file = "test_list.txt" base_path=base_path,
self.validate_list_file = "validate_list.txt" train_list_file="train_list.txt",
self.label_list_file = "label_list.txt" validate_list_file="validate_list.txt",
self.num_labels = 120 test_list_file="test_list.txt",
label_list_file="label_list.txt",
label_list=None)
if __name__ == "__main__":
ds = StanfordDogsDataset()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print(e)
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print(e)
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print(e)
print(ds)
...@@ -17,64 +17,30 @@ from __future__ import absolute_import ...@@ -17,64 +17,30 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import io import io
import os import os
import csv
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/thucnews.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/thucnews.tar.gz"
class THUCNEWS(HubDataset): class THUCNEWS(BaseNLPDatast):
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "thucnews") dataset_dir = os.path.join(DATA_HOME, "thucnews")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(THUCNEWS, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="train.txt",
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="dev.txt",
test_file="test.txt",
self._load_train_examples() label_file=None,
self._load_test_examples() label_list=[str(i) for i in range(14)],
self._load_dev_examples() )
def _load_train_examples(self): def _read_file(self, input_file, phase=None):
self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return [str(i) for i in range(14)]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file: with io.open(input_file, "r", encoding="UTF-8") as file:
examples = [] examples = []
...@@ -91,5 +57,13 @@ class THUCNEWS(HubDataset): ...@@ -91,5 +57,13 @@ class THUCNEWS(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = THUCNEWS() ds = THUCNEWS()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]: for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -17,15 +17,11 @@ from __future__ import absolute_import ...@@ -17,15 +17,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import io import io
import os import os
import csv
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample, BaseDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/tnews.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/tnews.tar.gz"
...@@ -48,64 +44,31 @@ LABEL_NAME = { ...@@ -48,64 +44,31 @@ LABEL_NAME = {
} }
class TNews(HubDataset): class TNews(BaseDataset):
""" """
TNews is the chinese news classification dataset on Jinri Toutiao App. TNews is the chinese news classification dataset on Jinri Toutiao App.
""" """
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "tnews") dataset_dir = os.path.join(DATA_HOME, "tnews")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( label_list = [
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir,
"toutiao_category_train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir,
"toutiao_category_dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir,
"toutiao_category_test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return [
'100', '101', '102', '103', '104', '106', '107', '108', '109', '100', '101', '102', '103', '104', '106', '107', '108', '109',
'110', '112', '113', '114', '115', '116' '110', '112', '113', '114', '115', '116'
] ]
super(TNews, self).__init__(
base_path=base_path,
train_file="toutiao_category_train.txt",
dev_file="toutiao_category_dev.txt",
test_file="toutiao_category_test.txt",
label_file=None,
label_list=label_list,
)
def get_label_name(self, id): def get_label_name(self, id):
return LABEL_NAME[id] return LABEL_NAME[id]
@property def _read_file(self, input_file, phase=None):
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file: with io.open(input_file, "r", encoding="UTF-8") as file:
examples = [] examples = []
...@@ -120,5 +83,13 @@ class TNews(HubDataset): ...@@ -120,5 +83,13 @@ class TNews(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = TNews() ds = TNews()
print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]: for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -17,73 +17,39 @@ from __future__ import absolute_import ...@@ -17,73 +17,39 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple
import codecs
import os import os
import pandas as pd import pandas as pd
from numpy import nan
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/toxic.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/toxic.tar.gz"
class Toxic(HubDataset): class Toxic(BaseNLPDatast):
""" """
The kaggle Toxic dataset: The kaggle Toxic dataset:
https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
""" """
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "toxic") dataset_dir = os.path.join(DATA_HOME, "toxic")
if not os.path.exists(self.dataset_dir): base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( label_list = [
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "train.csv")
self.train_examples = self._read_csv(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.csv")
self.dev_examples = self._read_csv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.csv")
self.test_examples = self._read_csv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return [
'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult',
'identity_hate' 'identity_hate'
] ]
super(Toxic, self).__init__(
@property base_path=base_path,
def num_labels(self): train_file="train.csv",
""" dev_file="dev.csv",
Return the number of labels in the dataset. test_file="test.csv",
""" label_file=None,
return len(self.get_labels()) label_list=label_list,
)
def _read_csv(self, input_file, quotechar=None):
def _read_file(self, input_file, phase=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
data = pd.read_csv(input_file, encoding="UTF-8") data = pd.read_csv(input_file, encoding="UTF-8")
examples = [] examples = []
...@@ -99,5 +65,13 @@ class Toxic(HubDataset): ...@@ -99,5 +65,13 @@ class Toxic(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = Toxic() ds = Toxic()
for e in ds.get_train_examples(): print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -23,15 +23,14 @@ import io ...@@ -23,15 +23,14 @@ import io
import os import os
import csv import csv
from paddlehub.dataset import InputExample, HubDataset from paddlehub.dataset import InputExample
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/XNLI-lan.tar.gz" _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/XNLI-lan.tar.gz"
class XNLI(HubDataset): class XNLI(BaseNLPDatast):
""" """
Please refer to Please refer to
https://arxiv.org/pdf/1809.05053.pdf https://arxiv.org/pdf/1809.05053.pdf
...@@ -43,61 +42,25 @@ class XNLI(HubDataset): ...@@ -43,61 +42,25 @@ class XNLI(HubDataset):
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw",
"th", "tr", "ur", "vi", "zh" "th", "tr", "ur", "vi", "zh"
]: ]:
raise Exception(
raise Exception(language + "%s is not in XNLI. Please confirm the language" % language)
"is not in XNLI. Please confirm the language")
self.language = language self.language = language
self.dataset_dir = os.path.join(DATA_HOME, "XNLI-lan") dataset_dir = os.path.join(DATA_HOME, "XNLI-lan")
dataset_dir = self._download_dataset(dataset_dir, url=_DATA_URL)
if not os.path.exists(self.dataset_dir): base_path = os.path.join(dataset_dir, language)
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( super(XNLI, self).__init__(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) base_path=base_path,
else: train_file="%s_train.tsv" % language,
logger.info("Dataset {} already cached.".format(self.dataset_dir)) dev_file="%s_dev.tsv" % language,
test_file="%s_test.tsv" % language,
self._load_train_examples() label_file=None,
self._load_test_examples() label_list=["neutral", "contradiction", "entailment"],
self._load_dev_examples() )
def _load_train_examples(self): def _read_file(self, input_file, phase=None):
self.train_file = os.path.join(self.dataset_dir, self.language,
self.language + "_train.tsv")
self.train_examples = self._read_tsv(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, self.language,
self.language + "_dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, self.language,
self.language + "_test.tsv")
self.test_examples = self._read_tsv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""See base class."""
return ["neutral", "contradiction", "entailment"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as f: with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = [] examples = []
seq_id = 0 seq_id = 0
header = next(reader) # skip header header = next(reader) # skip header
...@@ -112,5 +75,13 @@ class XNLI(HubDataset): ...@@ -112,5 +75,13 @@ class XNLI(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = XNLI() ds = XNLI()
for e in ds.get_train_examples()[:3]: print("first 10 dev")
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 train")
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print("first 10 test")
for e in ds.get_test_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
print(ds)
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .basic_task import BasicTask, RunEnv, RunState from .base_task import BaseTask, RunEnv, RunState
from .classifier_task import ClassifierTask, ImageClassifierTask, TextClassifierTask, MultiLabelClassifierTask from .classifier_task import ClassifierTask, ImageClassifierTask, TextClassifierTask, MultiLabelClassifierTask
from .reading_comprehension_task import ReadingComprehensionTask from .reading_comprehension_task import ReadingComprehensionTask
from .regression_task import RegressionTask from .regression_task import RegressionTask
......
...@@ -192,7 +192,7 @@ class TaskHooks(): ...@@ -192,7 +192,7 @@ class TaskHooks():
return self.info(only_customized=False) return self.info(only_customized=False)
class BasicTask(object): class BaseTask(object):
def __init__(self, def __init__(self,
feed_list, feed_list,
data_reader, data_reader,
...@@ -265,7 +265,7 @@ class BasicTask(object): ...@@ -265,7 +265,7 @@ class BasicTask(object):
for hook_type, event_hooks in self._hooks._registered_hooks.items(): for hook_type, event_hooks in self._hooks._registered_hooks.items():
self._hooks.add(hook_type, "default", self._hooks.add(hook_type, "default",
eval("self._default_%s_event" % hook_type)) eval("self._default_%s_event" % hook_type))
setattr(BasicTask, "_%s_event" % hook_type, setattr(BaseTask, "_%s_event" % hook_type,
self.create_event_function(hook_type)) self.create_event_function(hook_type))
# accelerate predict # accelerate predict
......
...@@ -23,10 +23,10 @@ import numpy as np ...@@ -23,10 +23,10 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlehub.finetune.evaluate import calculate_f1_np, matthews_corrcoef from paddlehub.finetune.evaluate import calculate_f1_np, matthews_corrcoef
from .basic_task import BasicTask from .base_task import BaseTask
class ClassifierTask(BasicTask): class ClassifierTask(BaseTask):
def __init__(self, def __init__(self,
feature, feature,
num_classes, num_classes,
......
...@@ -28,7 +28,7 @@ from collections import OrderedDict ...@@ -28,7 +28,7 @@ from collections import OrderedDict
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from .basic_task import BasicTask from .base_task import BaseTask
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.reader import tokenization from paddlehub.reader import tokenization
from paddlehub.finetune.evaluator import squad1_evaluate from paddlehub.finetune.evaluator import squad1_evaluate
...@@ -176,6 +176,13 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -176,6 +176,13 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
output_nbest_file, output_null_log_odds_file, output_nbest_file, output_null_log_odds_file,
version_2_with_negative, null_score_diff_threshold, version_2_with_negative, null_score_diff_threshold,
is_english): is_english):
_PrelimPrediction = collections.namedtuple("PrelimPrediction", [
"feature_index", "start_index", "end_index", "start_logit", "end_logit"
])
_NbestPrediction = collections.namedtuple(
"NbestPrediction", ["text", "start_logit", "end_logit"])
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
example_index_to_features[feature.example_index].append(feature) example_index_to_features[feature.example_index].append(feature)
...@@ -184,10 +191,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -184,10 +191,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
for result in all_results: for result in all_results:
unique_id_to_result[result.unique_id] = result unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple("PrelimPrediction", [
"feature_index", "start_index", "end_index", "start_logit", "end_logit"
])
all_predictions = collections.OrderedDict() all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict() all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict() scores_diff_json = collections.OrderedDict()
...@@ -262,9 +265,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -262,9 +265,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
key=lambda x: (x.start_logit + x.end_logit), key=lambda x: (x.start_logit + x.end_logit),
reverse=True) reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {} seen_predictions = {}
nbest = [] nbest = []
if not prelim_predictions: if not prelim_predictions:
...@@ -384,7 +384,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -384,7 +384,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
+ "\n") + "\n")
class ReadingComprehensionTask(BasicTask): class ReadingComprehensionTask(BaseTask):
def __init__(self, def __init__(self,
feature, feature,
feed_list, feed_list,
...@@ -420,6 +420,9 @@ class ReadingComprehensionTask(BasicTask): ...@@ -420,6 +420,9 @@ class ReadingComprehensionTask(BasicTask):
self.n_best_size = n_best_size self.n_best_size = n_best_size
self.max_answer_length = max_answer_length self.max_answer_length = max_answer_length
self.RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"])
def _build_net(self): def _build_net(self):
self.unique_ids = fluid.layers.data( self.unique_ids = fluid.layers.data(
name="unique_ids", shape=[-1, 1], lod_level=0, dtype="int64") name="unique_ids", shape=[-1, 1], lod_level=0, dtype="int64")
...@@ -493,8 +496,6 @@ class ReadingComprehensionTask(BasicTask): ...@@ -493,8 +496,6 @@ class ReadingComprehensionTask(BasicTask):
def _calculate_metrics(self, run_states): def _calculate_metrics(self, run_states):
total_cost, total_num_seqs, all_results = [], [], [] total_cost, total_num_seqs, all_results = [], [], []
run_step = 0 run_step = 0
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"])
for run_state in run_states: for run_state in run_states:
np_loss = run_state.run_results[0] np_loss = run_state.run_results[0]
np_num_seqs = run_state.run_results[1] np_num_seqs = run_state.run_results[1]
...@@ -510,7 +511,7 @@ class ReadingComprehensionTask(BasicTask): ...@@ -510,7 +511,7 @@ class ReadingComprehensionTask(BasicTask):
start_logits = [float(x) for x in np_start_logits[idx].flat] start_logits = [float(x) for x in np_start_logits[idx].flat]
end_logits = [float(x) for x in np_end_logits[idx].flat] end_logits = [float(x) for x in np_end_logits[idx].flat]
all_results.append( all_results.append(
RawResult( self.RawResult(
unique_id=unique_id, unique_id=unique_id,
start_logits=start_logits, start_logits=start_logits,
end_logits=end_logits)) end_logits=end_logits))
...@@ -544,13 +545,13 @@ class ReadingComprehensionTask(BasicTask): ...@@ -544,13 +545,13 @@ class ReadingComprehensionTask(BasicTask):
is_english=self.is_english) is_english=self.is_english)
if self.phase == 'val' or self.phase == 'dev': if self.phase == 'val' or self.phase == 'dev':
with open( with open(
self.data_reader.dataset.dev_file, 'r', self.data_reader.dataset.dev_path, 'r',
encoding="utf8") as dataset_file: encoding="utf8") as dataset_file:
dataset_json = json.load(dataset_file) dataset_json = json.load(dataset_file)
dataset = dataset_json['data'] dataset = dataset_json['data']
elif self.phase == 'test': elif self.phase == 'test':
with open( with open(
self.data_reader.dataset.test_file, 'r', self.data_reader.dataset.test_path, 'r',
encoding="utf8") as dataset_file: encoding="utf8") as dataset_file:
dataset_json = json.load(dataset_file) dataset_json = json.load(dataset_file)
dataset = dataset_json['data'] dataset = dataset_json['data']
...@@ -577,8 +578,6 @@ class ReadingComprehensionTask(BasicTask): ...@@ -577,8 +578,6 @@ class ReadingComprehensionTask(BasicTask):
def _default_predict_end_event(self, run_states): def _default_predict_end_event(self, run_states):
all_results = [] all_results = []
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"])
for run_state in run_states: for run_state in run_states:
np_unique_ids = run_state.run_results[0] np_unique_ids = run_state.run_results[0]
np_start_logits = run_state.run_results[1] np_start_logits = run_state.run_results[1]
...@@ -588,7 +587,7 @@ class ReadingComprehensionTask(BasicTask): ...@@ -588,7 +587,7 @@ class ReadingComprehensionTask(BasicTask):
start_logits = [float(x) for x in np_start_logits[idx].flat] start_logits = [float(x) for x in np_start_logits[idx].flat]
end_logits = [float(x) for x in np_end_logits[idx].flat] end_logits = [float(x) for x in np_end_logits[idx].flat]
all_results.append( all_results.append(
RawResult( self.RawResult(
unique_id=unique_id, unique_id=unique_id,
start_logits=start_logits, start_logits=start_logits,
end_logits=end_logits)) end_logits=end_logits))
......
...@@ -23,10 +23,10 @@ from collections import OrderedDict ...@@ -23,10 +23,10 @@ from collections import OrderedDict
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from scipy.stats import spearmanr from scipy.stats import spearmanr
from .basic_task import BasicTask from .base_task import BaseTask
class RegressionTask(BasicTask): class RegressionTask(BaseTask):
def __init__(self, def __init__(self,
feature, feature,
feed_list, feed_list,
......
...@@ -25,10 +25,10 @@ import paddle ...@@ -25,10 +25,10 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlehub.finetune.evaluate import chunk_eval, calculate_f1 from paddlehub.finetune.evaluate import chunk_eval, calculate_f1
from paddlehub.common.utils import version_compare from paddlehub.common.utils import version_compare
from .basic_task import BasicTask from .base_task import BaseTask
class SequenceLabelTask(BasicTask): class SequenceLabelTask(BaseTask):
def __init__(self, def __init__(self,
feature, feature,
max_seq_len, max_seq_len,
......
import numpy as np
class BaseReader(object):
def __init__(self, dataset, random_seed=None):
self.dataset = dataset
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
np.random.seed(random_seed)
def get_train_examples(self):
return self.dataset.get_train_examples()
def get_dev_examples(self):
return self.dataset.get_dev_examples()
def get_test_examples(self):
return self.dataset.get_test_examples()
def data_generator(self):
raise NotImplementedError
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
from PIL import Image from PIL import Image
import paddlehub.io.augmentation as image_augmentation import paddlehub.io.augmentation as image_augmentation
from .base_reader import BaseReader
channel_order_dict = { channel_order_dict = {
"RGB": [0, 1, 2], "RGB": [0, 1, 2],
...@@ -33,7 +34,7 @@ channel_order_dict = { ...@@ -33,7 +34,7 @@ channel_order_dict = {
} }
class ImageClassificationReader(object): class ImageClassificationReader(BaseReader):
def __init__(self, def __init__(self,
image_width, image_width,
image_height, image_height,
...@@ -41,15 +42,15 @@ class ImageClassificationReader(object): ...@@ -41,15 +42,15 @@ class ImageClassificationReader(object):
channel_order="RGB", channel_order="RGB",
images_mean=None, images_mean=None,
images_std=None, images_std=None,
data_augmentation=False): data_augmentation=False,
random_seed=None):
super(ImageClassificationReader, self).__init__(dataset, random_seed)
self.image_width = image_width self.image_width = image_width
self.image_height = image_height self.image_height = image_height
self.channel_order = channel_order self.channel_order = channel_order
self.dataset = dataset
self.data_augmentation = data_augmentation self.data_augmentation = data_augmentation
self.images_std = images_std self.images_std = images_std
self.images_mean = images_mean self.images_mean = images_mean
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
if self.images_mean is None: if self.images_mean is None:
try: try:
...@@ -73,24 +74,38 @@ class ImageClassificationReader(object): ...@@ -73,24 +74,38 @@ class ImageClassificationReader(object):
raise ValueError("Image width and height should not be negative.") raise ValueError("Image width and height should not be negative.")
def data_generator(self, def data_generator(self,
batch_size, batch_size=1,
phase="train", phase="train",
shuffle=False, shuffle=False,
data=None): data=None):
if phase != 'predict' and not self.dataset: if phase != 'predict' and not self.dataset:
raise ValueError("The dataset is none and it's not allowed!") raise ValueError("The dataset is none and it's not allowed!")
if phase == "train": if phase == "train":
data = self.dataset.train_data(shuffle) shuffle = True
self.num_examples['train'] = len(self.get_train_examples()) if hasattr(self.dataset, "train_data"):
elif phase == "test": # Compatible with ImageClassificationDataset which has done shuffle
shuffle = False self.dataset.train_data()
data = self.dataset.test_data(shuffle) shuffle = False
self.num_examples['test'] = len(self.get_test_examples()) data = self.get_train_examples()
self.num_examples['train'] = len(data)
elif phase == "val" or phase == "dev": elif phase == "val" or phase == "dev":
shuffle = False shuffle = False
data = self.dataset.validate_data(shuffle) if hasattr(self.dataset, "validate_data"):
self.num_examples['dev'] = len(self.get_dev_examples()) # Compatible with ImageClassificationDataset
self.dataset.validate_data()
shuffle = False
data = self.get_dev_examples()
self.num_examples['dev'] = len(data)
elif phase == "test":
shuffle = False
if hasattr(self.dataset, "test_data"):
# Compatible with ImageClassificationDataset
data = self.dataset.test_data()
shuffle = False
data = self.get_test_examples()
self.num_examples['test'] = len(data)
elif phase == "predict": elif phase == "predict":
shuffle = False
data = data data = data
def preprocess(image_path): def preprocess(image_path):
...@@ -118,6 +133,9 @@ class ImageClassificationReader(object): ...@@ -118,6 +133,9 @@ class ImageClassificationReader(object):
return image return image
def _data_reader(): def _data_reader():
if shuffle:
np.random.shuffle(data)
if phase == "predict": if phase == "predict":
for image_path in data: for image_path in data:
image = preprocess(image_path) image = preprocess(image_path)
...@@ -128,12 +146,3 @@ class ImageClassificationReader(object): ...@@ -128,12 +146,3 @@ class ImageClassificationReader(object):
yield (image, label) yield (image, label)
return paddle.batch(_data_reader, batch_size=batch_size) return paddle.batch(_data_reader, batch_size=batch_size)
def get_train_examples(self):
return self.dataset.train_examples
def get_dev_examples(self):
return self.dataset.dev_examples
def get_test_examples(self):
return self.dataset.test_examples
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册