未验证 提交 cf2e2c78 编写于 作者: K KP 提交者: GitHub

Update download_data wrapper

上级 a5c91d1d
......@@ -32,7 +32,7 @@ python train.py
在命名实体识别的任务中,因不同的数据集标识实体的标签不同,评测的方式也有所差异。因此,在初始化模型的之前,需要先确定实际标签的形式,下方的`label_list`则是MSRA-NER数据集中使用的标签类别。
如果用户使用的实体识别的数据集的标签方式与MSRA-NER不同,则需要自行根据数据集确定。
```python
label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
label_list = hub.datasets.MSRA_NER.label_list
label_map = {
idx: label for idx, label in enumerate(label_list)
}
......
......@@ -32,7 +32,7 @@ args = parser.parse_args()
if __name__ == '__main__':
label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
label_list = MSRA_NER.label_list
label_map = {
idx: label for idx, label in enumerate(label_list)
}
......
......@@ -31,6 +31,7 @@ class MSRA_NER(SeqLabelingDataset):
for research purposes. For more information please refer to
https://www.microsoft.com/en-us/download/details.aspx?id=52531
"""
label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
def __init__(
self,
......@@ -39,7 +40,6 @@ class MSRA_NER(SeqLabelingDataset):
mode: str = 'train',
):
base_path = os.path.join(DATA_HOME, "msra_ner")
label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
if mode == 'train':
data_file = 'train.tsv'
......@@ -54,6 +54,6 @@ class MSRA_NER(SeqLabelingDataset):
mode=mode,
data_file=data_file,
label_file=None,
label_list=label_list,
label_list=self.label_list,
is_file_with_header=True,
)
......@@ -25,8 +25,7 @@ from paddlehub.utils import log, utils, xarfile
def download_data(url):
def _wrapper(Dataset):
def _download_dataset_from_url(*args, **kwargs):
def _check_download():
save_name = os.path.basename(url).split('.')[0]
output_path = os.path.join(hubenv.DATA_HOME, save_name)
lock = filelock.FileLock(os.path.join(hubenv.TMP_HOME, save_name))
......@@ -34,8 +33,12 @@ def download_data(url):
if not os.path.exists(output_path):
default_downloader.download_file_and_uncompress(url, hubenv.DATA_HOME, True)
return Dataset(*args, **kwargs)
return _download_dataset_from_url
class WrapperDataset(Dataset):
def __new__(cls, *args, **kwargs):
_check_download()
return super(WrapperDataset, cls).__new__(cls)
return WrapperDataset
return _wrapper
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册