未验证 提交 881bd978 编写于 作者: W Weiyue Su 提交者: GitHub

local Path in from_pretrained (#600)

Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
上级 8855c7dd
...@@ -202,7 +202,7 @@ class PretrainedModel(object): ...@@ -202,7 +202,7 @@ class PretrainedModel(object):
pretrain_dir = _fetch_from_remote(url, force_download) pretrain_dir = _fetch_from_remote(url, force_download)
else: else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists(): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s' % pretrain_dir)
......
...@@ -24,6 +24,10 @@ import re ...@@ -24,6 +24,10 @@ import re
import logging import logging
import tempfile import tempfile
from functools import partial from functools import partial
if six.PY2:
from pathlib2 import Path
else:
from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
...@@ -90,7 +94,7 @@ class ErnieTokenizer(object): ...@@ -90,7 +94,7 @@ class ErnieTokenizer(object):
pretrain_dir = _fetch_from_remote(url, force_download=force_download) pretrain_dir = _fetch_from_remote(url, force_download=force_download)
else: else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists(): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s' % pretrain_dir)
vocab_path = pretrain_dir / 'vocab.txt' vocab_path = pretrain_dir / 'vocab.txt'
...@@ -206,7 +210,7 @@ class ErnieTinyTokenizer(ErnieTokenizer): ...@@ -206,7 +210,7 @@ class ErnieTinyTokenizer(ErnieTokenizer):
pretrain_dir = _fetch_from_remote(url, force_download) pretrain_dir = _fetch_from_remote(url, force_download)
else: else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists(): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s' % pretrain_dir)
vocab_path = pretrain_dir / 'vocab.txt' vocab_path = pretrain_dir / 'vocab.txt'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册