提交 ff3bc5b8 编写于 作者: Z Zeyu Chen

fix typo in ernie sequence labeling task

上级 6985001a
......@@ -46,7 +46,7 @@ if __name__ == '__main__':
trainable=True, max_seq_len=args.max_seq_len)
# Step2: Download dataset and use SequenceLabelReader to read dataset
dataset = hub.dataset.MSRA_NER(),
dataset = hub.dataset.MSRA_NER()
reader = hub.reader.SequenceLabelReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
......@@ -91,6 +91,7 @@ if __name__ == '__main__':
use_cuda=True,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=strategy)
# Finetune and evaluate model by PaddleHub's API
# will finish training, evaluation, testing, save model automatically
......
......@@ -68,7 +68,6 @@ class ChnSentiCorp(HubDataset):
return self.test_examples
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _read_tsv(self, input_file, quotechar=None):
......
......@@ -21,6 +21,7 @@ import csv
import json
from collections import namedtuple
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
......@@ -28,7 +29,14 @@ from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz"
class MSRA_NER(object):
class MSRA_NER(HubDataset):
"""
A set of manually annotated Chinese word-segmentation data and
specifications for training and testing a Chinese word-segmentation system
for research purposes. For more information please refer to
https://www.microsoft.com/en-us/download/details.aspx?id=52531
"""
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "msra_ner")
if not os.path.exists(self.dataset_dir):
......@@ -78,12 +86,13 @@ class MSRA_NER(object):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
headers = next(reader)
Example = namedtuple('Example', headers)
examples = []
seq_id = 0
header = next(reader) # skip header
for line in reader:
example = Example(*line)
example = InputExample(
guid=seq_id, label=line[1], text_a=line[0])
seq_id += 1
examples.append(example)
return examples
......@@ -92,4 +101,4 @@ class MSRA_NER(object):
if __name__ == "__main__":
ds = MSRA_NER()
for e in ds.get_train_examples():
print(e)
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
......@@ -29,6 +29,12 @@ DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz"
class NLPCC_DBQA(HubDataset):
"""
Please refer to
http://tcci.ccf.org.cn/conference/2017/dldoc/taskgline05.pdf
for more information
"""
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa")
if not os.path.exists(self.dataset_dir):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册