diff --git a/demo/sequence_labeling/README.md b/demo/sequence_labeling/README.md index c18c0fca280c2c1530057db3c55d46877a6ae251..5c8be25321dab809b0ebedd3f8dce6e0c5da6554 100644 --- a/demo/sequence_labeling/README.md +++ b/demo/sequence_labeling/README.md @@ -32,7 +32,7 @@ python train.py 在命名实体识别的任务中,因不同的数据集标识实体的标签不同,评测的方式也有所差异。因此,在初始化模型的之前,需要先确定实际标签的形式,下方的`label_list`则是MSRA-NER数据集中使用的标签类别。 如果用户使用的实体识别的数据集的标签方式与MSRA-NER不同,则需要自行根据数据集确定。 ```python -label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] +label_list = hub.datasets.MSRA_NER.label_list label_map = { idx: label for idx, label in enumerate(label_list) } diff --git a/demo/sequence_labeling/train.py b/demo/sequence_labeling/train.py index 3e26d20b835b2413fc4758bfbe250e1f916f2dcd..b9acbf0a9c3a38eeb618d11a4c7189e8a6ae71e5 100644 --- a/demo/sequence_labeling/train.py +++ b/demo/sequence_labeling/train.py @@ -32,7 +32,7 @@ args = parser.parse_args() if __name__ == '__main__': - label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] + label_list = MSRA_NER.label_list label_map = { idx: label for idx, label in enumerate(label_list) } diff --git a/paddlehub/datasets/msra_ner.py b/paddlehub/datasets/msra_ner.py index 8440e7c0b2f70e362db37519d57591f68fd45fd9..e258b7414f37f5175c41bf3657003340c27ebaaf 100644 --- a/paddlehub/datasets/msra_ner.py +++ b/paddlehub/datasets/msra_ner.py @@ -31,6 +31,7 @@ class MSRA_NER(SeqLabelingDataset): for research purposes. For more information please refer to https://www.microsoft.com/en-us/download/details.aspx?id=52531 """ + label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] def __init__( self, @@ -39,7 +40,6 @@ class MSRA_NER(SeqLabelingDataset): mode: str = 'train', ): base_path = os.path.join(DATA_HOME, "msra_ner") - label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] if mode == 'train': data_file = 'train.tsv' @@ -54,6 +54,6 @@ class MSRA_NER(SeqLabelingDataset): mode=mode, data_file=data_file, label_file=None, - label_list=label_list, + label_list=self.label_list, is_file_with_header=True, ) diff --git a/paddlehub/utils/download.py b/paddlehub/utils/download.py index 5f80e41f46ff13de765bd64189e67e9b0e8da2d4..d0d83a2e55bee85c0ebdca5e716327baf4d28861 100644 --- a/paddlehub/utils/download.py +++ b/paddlehub/utils/download.py @@ -25,17 +25,20 @@ from paddlehub.utils import log, utils, xarfile def download_data(url): def _wrapper(Dataset): - - def _download_dataset_from_url(*args, **kwargs): + def _check_download(): save_name = os.path.basename(url).split('.')[0] output_path = os.path.join(hubenv.DATA_HOME, save_name) lock = filelock.FileLock(os.path.join(hubenv.TMP_HOME, save_name)) with lock: if not os.path.exists(output_path): default_downloader.download_file_and_uncompress(url, hubenv.DATA_HOME, True) - - return Dataset(*args, **kwargs) - return _download_dataset_from_url + + class WrapperDataset(Dataset): + def __new__(cls, *args, **kwargs): + _check_download() + return super(WrapperDataset, cls).__new__(cls) + + return WrapperDataset return _wrapper