提交 4d23b9da 编写于 作者: W wangxiao

update downloader.py

上级 19370b22
......@@ -15,6 +15,8 @@
from __future__ import print_function
import os
import tarfile
import shutil
try:
from urllib.request import urlopen # Python 3
except ImportError:
......@@ -40,16 +42,19 @@ def _download(item, scope, path, silent=False):
data_dir = path + '/' + item + '/' + scope
if not os.path.exists(data_dir):
os.makedirs(os.path.join(data_dir))
filename = data_dir + '/' + data_url.split('/')[-1]
data_name = data_url.split('/')[-1]
filename = data_dir + '/' + data_name
def chunk_report(bytes_so_far, total_size):
# print process
def _chunk_report(bytes_so_far, total_size):
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
if not silent:
print('\r>> Downloading... {:.1%}'.format(percent), end = "")
def chunk_read(response, url, chunk_size = 16 * 1024, report_hook = None):
# copy to local
def _chunk_read(response, url, chunk_size = 16 * 1024, report_hook = None):
total_size = response.info().getheader('Content-Length').strip()
total_size = int(total_size)
bytes_so_far = 0
......@@ -66,14 +71,51 @@ def _download(item, scope, path, silent=False):
return bytes_so_far
response = urlopen(data_url)
chunk_read(response, data_url, report_hook=chunk_report)
_chunk_read(response, data_url, report_hook=_chunk_report)
if not silent:
print(' done!')
if item == 'pretrain':
if not silent:
print ('Extracting {}...'.format(data_name), end=" ")
if os.path.exists(filename):
tar = tarfile.open(filename, 'r')
tar.extractall(path = data_dir)
tar.close()
os.remove(filename)
if not silent:
print ('done!')
if not silent:
print ('Converting params...', end=" ")
_convert(data_dir + '/' + data_name.split('.')[0], silent)
if not silent:
print ('done!')
def _convert():
raise NotImplementedError()
def _convert(path, silent=False):
if os.path.isfile(path + '/params/__palminfo__'):
if not silent:
print ('already converted.')
else:
if os.path.exists(path + '/params/'):
os.rename(path + '/params/', path + '/params1/')
os.mkdir(path + '/params/')
tar_model = tarfile.open(path + '/params/' + '__palmmodel__', 'w:gz')
tar_info = open(path + '/params/'+ '__palminfo__', 'w')
for root, dirs, files in os.walk(path + '/params1/'):
for file in files:
src_file = os.path.join(root, file)
newname = path + '/params1/' + '__paddlepalm_' + file
os.rename(src_file, newname)
tar_model.add(newname)
tar_info.write(newname)
os.remove(newname)
tar_model.close()
tar_info.close()
os.removedirs(path + '/params1/')
# raise NotImplementedError()
def download(item, scope='all', path='.'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册