未验证 提交 cf2e2c78 编写于 作者: K KP 提交者: GitHub

Update download_data wrapper

上级 a5c91d1d
...@@ -32,7 +32,7 @@ python train.py ...@@ -32,7 +32,7 @@ python train.py
在命名实体识别的任务中,因不同的数据集标识实体的标签不同,评测的方式也有所差异。因此,在初始化模型的之前,需要先确定实际标签的形式,下方的`label_list`则是MSRA-NER数据集中使用的标签类别。 在命名实体识别的任务中,因不同的数据集标识实体的标签不同,评测的方式也有所差异。因此,在初始化模型的之前,需要先确定实际标签的形式,下方的`label_list`则是MSRA-NER数据集中使用的标签类别。
如果用户使用的实体识别的数据集的标签方式与MSRA-NER不同,则需要自行根据数据集确定。 如果用户使用的实体识别的数据集的标签方式与MSRA-NER不同,则需要自行根据数据集确定。
```python ```python
label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] label_list = hub.datasets.MSRA_NER.label_list
label_map = { label_map = {
idx: label for idx, label in enumerate(label_list) idx: label for idx, label in enumerate(label_list)
} }
......
...@@ -32,7 +32,7 @@ args = parser.parse_args() ...@@ -32,7 +32,7 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] label_list = MSRA_NER.label_list
label_map = { label_map = {
idx: label for idx, label in enumerate(label_list) idx: label for idx, label in enumerate(label_list)
} }
......
...@@ -31,6 +31,7 @@ class MSRA_NER(SeqLabelingDataset): ...@@ -31,6 +31,7 @@ class MSRA_NER(SeqLabelingDataset):
for research purposes. For more information please refer to for research purposes. For more information please refer to
https://www.microsoft.com/en-us/download/details.aspx?id=52531 https://www.microsoft.com/en-us/download/details.aspx?id=52531
""" """
label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
def __init__( def __init__(
self, self,
...@@ -39,7 +40,6 @@ class MSRA_NER(SeqLabelingDataset): ...@@ -39,7 +40,6 @@ class MSRA_NER(SeqLabelingDataset):
mode: str = 'train', mode: str = 'train',
): ):
base_path = os.path.join(DATA_HOME, "msra_ner") base_path = os.path.join(DATA_HOME, "msra_ner")
label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
if mode == 'train': if mode == 'train':
data_file = 'train.tsv' data_file = 'train.tsv'
...@@ -54,6 +54,6 @@ class MSRA_NER(SeqLabelingDataset): ...@@ -54,6 +54,6 @@ class MSRA_NER(SeqLabelingDataset):
mode=mode, mode=mode,
data_file=data_file, data_file=data_file,
label_file=None, label_file=None,
label_list=label_list, label_list=self.label_list,
is_file_with_header=True, is_file_with_header=True,
) )
...@@ -25,8 +25,7 @@ from paddlehub.utils import log, utils, xarfile ...@@ -25,8 +25,7 @@ from paddlehub.utils import log, utils, xarfile
def download_data(url): def download_data(url):
def _wrapper(Dataset): def _wrapper(Dataset):
def _check_download():
def _download_dataset_from_url(*args, **kwargs):
save_name = os.path.basename(url).split('.')[0] save_name = os.path.basename(url).split('.')[0]
output_path = os.path.join(hubenv.DATA_HOME, save_name) output_path = os.path.join(hubenv.DATA_HOME, save_name)
lock = filelock.FileLock(os.path.join(hubenv.TMP_HOME, save_name)) lock = filelock.FileLock(os.path.join(hubenv.TMP_HOME, save_name))
...@@ -34,8 +33,12 @@ def download_data(url): ...@@ -34,8 +33,12 @@ def download_data(url):
if not os.path.exists(output_path): if not os.path.exists(output_path):
default_downloader.download_file_and_uncompress(url, hubenv.DATA_HOME, True) default_downloader.download_file_and_uncompress(url, hubenv.DATA_HOME, True)
return Dataset(*args, **kwargs) class WrapperDataset(Dataset):
return _download_dataset_from_url def __new__(cls, *args, **kwargs):
_check_download()
return super(WrapperDataset, cls).__new__(cls)
return WrapperDataset
return _wrapper return _wrapper
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册