提交 ff3bc5b8 编写于 作者: Z Zeyu Chen

fix typo in ernie sequence labeling task

上级 6985001a
...@@ -46,7 +46,7 @@ if __name__ == '__main__': ...@@ -46,7 +46,7 @@ if __name__ == '__main__':
trainable=True, max_seq_len=args.max_seq_len) trainable=True, max_seq_len=args.max_seq_len)
# Step2: Download dataset and use SequenceLabelReader to read dataset # Step2: Download dataset and use SequenceLabelReader to read dataset
dataset = hub.dataset.MSRA_NER(), dataset = hub.dataset.MSRA_NER()
reader = hub.reader.SequenceLabelReader( reader = hub.reader.SequenceLabelReader(
dataset=dataset, dataset=dataset,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
...@@ -91,6 +91,7 @@ if __name__ == '__main__': ...@@ -91,6 +91,7 @@ if __name__ == '__main__':
use_cuda=True, use_cuda=True,
num_epoch=args.num_epoch, num_epoch=args.num_epoch,
batch_size=args.batch_size, batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=strategy) strategy=strategy)
# Finetune and evaluate model by PaddleHub's API # Finetune and evaluate model by PaddleHub's API
# will finish training, evaluation, testing, save model automatically # will finish training, evaluation, testing, save model automatically
......
...@@ -68,7 +68,6 @@ class ChnSentiCorp(HubDataset): ...@@ -68,7 +68,6 @@ class ChnSentiCorp(HubDataset):
return self.test_examples return self.test_examples
def get_labels(self): def get_labels(self):
"""See base class."""
return ["0", "1"] return ["0", "1"]
def _read_tsv(self, input_file, quotechar=None): def _read_tsv(self, input_file, quotechar=None):
......
...@@ -21,6 +21,7 @@ import csv ...@@ -21,6 +21,7 @@ import csv
import json import json
from collections import namedtuple from collections import namedtuple
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
...@@ -28,7 +29,14 @@ from paddlehub.common.logger import logger ...@@ -28,7 +29,14 @@ from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz" DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz"
class MSRA_NER(object): class MSRA_NER(HubDataset):
"""
A set of manually annotated Chinese word-segmentation data and
specifications for training and testing a Chinese word-segmentation system
for research purposes. For more information please refer to
https://www.microsoft.com/en-us/download/details.aspx?id=52531
"""
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "msra_ner") self.dataset_dir = os.path.join(DATA_HOME, "msra_ner")
if not os.path.exists(self.dataset_dir): if not os.path.exists(self.dataset_dir):
...@@ -78,12 +86,13 @@ class MSRA_NER(object): ...@@ -78,12 +86,13 @@ class MSRA_NER(object):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r") as f: with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
headers = next(reader)
Example = namedtuple('Example', headers)
examples = [] examples = []
seq_id = 0
header = next(reader) # skip header
for line in reader: for line in reader:
example = Example(*line) example = InputExample(
guid=seq_id, label=line[1], text_a=line[0])
seq_id += 1
examples.append(example) examples.append(example)
return examples return examples
...@@ -92,4 +101,4 @@ class MSRA_NER(object): ...@@ -92,4 +101,4 @@ class MSRA_NER(object):
if __name__ == "__main__": if __name__ == "__main__":
ds = MSRA_NER() ds = MSRA_NER()
for e in ds.get_train_examples(): for e in ds.get_train_examples():
print(e) print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
...@@ -29,6 +29,12 @@ DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz" ...@@ -29,6 +29,12 @@ DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz"
class NLPCC_DBQA(HubDataset): class NLPCC_DBQA(HubDataset):
"""
Please refer to
http://tcci.ccf.org.cn/conference/2017/dldoc/taskgline05.pdf
for more information
"""
def __init__(self): def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa") self.dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa")
if not os.path.exists(self.dataset_dir): if not os.path.exists(self.dataset_dir):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册