提交 d4a2cb17 编写于 作者: W wangxiao

update downloader.py

上级 4d23b9da
...@@ -84,11 +84,22 @@ def _download(item, scope, path, silent=False): ...@@ -84,11 +84,22 @@ def _download(item, scope, path, silent=False):
tar.extractall(path = data_dir) tar.extractall(path = data_dir)
tar.close() tar.close()
os.remove(filename) os.remove(filename)
if scope == 'bert-en-uncased-large':
source_path = data_dir + '/' + data_name.split('.')[0]
fileList = os.listdir(source_path)
print('source: {}'.format(source_path))
print('filelist: {}'.format(fileList))
for file in fileList:
filePath = os.path.join(source_path, file)
print('filepath: {}'.format(filePath))
print('datadir: {}'.format(data_dir))
shutil.move(filePath, data_dir)
os.removedirs(source_path)
if not silent: if not silent:
print ('done!') print ('done!')
if not silent: if not silent:
print ('Converting params...', end=" ") print ('Converting params...', end=" ")
_convert(data_dir + '/' + data_name.split('.')[0], silent) _convert(data_dir, silent)
if not silent: if not silent:
print ('done!') print ('done!')
...@@ -115,9 +126,6 @@ def _convert(path, silent=False): ...@@ -115,9 +126,6 @@ def _convert(path, silent=False):
tar_info.close() tar_info.close()
os.removedirs(path + '/params1/') os.removedirs(path + '/params1/')
# raise NotImplementedError()
def download(item, scope='all', path='.'): def download(item, scope='all', path='.'):
item = item.lower() item = item.lower()
scope = scope.lower() scope = scope.lower()
...@@ -136,5 +144,3 @@ def download(item, scope='all', path='.'): ...@@ -136,5 +144,3 @@ def download(item, scope='all', path='.'):
def ls(item=None, scope='all'): def ls(item=None, scope='all'):
pass pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册