downloader.py 5.0 KB
Newer Older
W
wangxiao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
X
xixiaoyao 已提交
15 16 17

from __future__ import print_function
import os
W
wangxiao 已提交
18 19
import tarfile
import shutil
W
wangxiao 已提交
20 21 22 23 24 25 26 27
try:
    from urllib.request import urlopen # Python 3
except ImportError:
    from urllib2 import urlopen # Python 2

# for https
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
X
xixiaoyao 已提交
28 29

_items = {
W
wangxiao 已提交
30 31 32 33 34 35
    'pretrain': {'ernie-en-uncased-large': 'https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz',
                 'bert-en-uncased-large': 'https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz',
                 'utils': None},
    'reader': {'utils': None},
    'backbone': {'utils': None},
    'tasktype': {'utils': None},
X
xixiaoyao 已提交
36 37
}

W
wangxiao 已提交
38 39 40 41 42 43 44
def _download(item, scope, path, silent=False):
    if not silent:
        print('Downloading {}: {} from {}...'.format(item, scope, _items[item][scope]))
    data_url = _items[item][scope]
    data_dir = path + '/' + item + '/' + scope
    if not os.path.exists(data_dir):
        os.makedirs(os.path.join(data_dir))
W
wangxiao 已提交
45 46
    data_name = data_url.split('/')[-1]
    filename = data_dir + '/' + data_name
W
wangxiao 已提交
47

W
wangxiao 已提交
48 49
    # print process
    def _chunk_report(bytes_so_far, total_size):
W
wangxiao 已提交
50 51 52 53 54
        percent = float(bytes_so_far) / float(total_size)
        if percent > 1:
            percent = 1
        if not silent:
            print('\r>> Downloading... {:.1%}'.format(percent), end = "")
W
wangxiao 已提交
55 56 57
    
    # copy to local
    def _chunk_read(response, url, chunk_size = 16 * 1024, report_hook = None):
W
wangxiao 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        total_size = response.info().getheader('Content-Length').strip()
        total_size = int(total_size)
        bytes_so_far = 0
        with open("%s" % filename, "wb") as f:
            while 1:
                chunk = response.read(chunk_size)
                f.write(chunk)
                f.flush() 
                bytes_so_far += len(chunk)
                if not chunk:
                    break
                if report_hook:
                    report_hook(bytes_so_far, total_size)
        return bytes_so_far

    response = urlopen(data_url)
W
wangxiao 已提交
74
    _chunk_read(response, data_url, report_hook=_chunk_report)
W
wangxiao 已提交
75 76 77
    
    if not silent:
        print(' done!')
W
wangxiao 已提交
78 79 80 81 82 83 84 85 86
    
    if item == 'pretrain':
        if not silent:
            print ('Extracting {}...'.format(data_name), end=" ")
        if os.path.exists(filename):
            tar = tarfile.open(filename, 'r')
            tar.extractall(path = data_dir)
            tar.close()
            os.remove(filename)
W
wangxiao 已提交
87 88 89 90 91 92 93
        if scope == 'bert-en-uncased-large':
            source_path = data_dir + '/' + data_name.split('.')[0]
            fileList = os.listdir(source_path)
            for file in fileList:
                filePath = os.path.join(source_path, file)
                shutil.move(filePath, data_dir)
            os.removedirs(source_path)
W
wangxiao 已提交
94 95 96 97
        if not silent:
            print ('done!')
        if not silent:
            print ('Converting params...', end=" ")
W
wangxiao 已提交
98
        _convert(data_dir, silent)
W
wangxiao 已提交
99 100
        if not silent:
            print ('done!')
W
wangxiao 已提交
101

X
xixiaoyao 已提交
102

W
wangxiao 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115
def _convert(path, silent=False):
    if os.path.isfile(path + '/params/__palminfo__'):
        if not silent:
            print ('already converted.')
    else:
        if os.path.exists(path + '/params/'):
            os.rename(path + '/params/', path + '/params1/')
            os.mkdir(path + '/params/')
            tar_model = tarfile.open(path + '/params/' + '__palmmodel__', 'w:gz')
            tar_info = open(path + '/params/'+ '__palminfo__', 'w')
            for root, dirs, files in os.walk(path + '/params1/'):
                for file in files:
                    src_file = os.path.join(root, file)
W
wangxiao 已提交
116 117 118
                    tar_model.add(src_file, '__paddlepalm_' + file)
                    tar_info.write('__paddlepalm_' + file)
                    os.remove(src_file)
W
wangxiao 已提交
119 120 121 122
            tar_model.close()
            tar_info.close()
            os.removedirs(path + '/params1/') 

X
xixiaoyao 已提交
123 124 125
def download(item, scope='all', path='.'):
    item = item.lower()
    scope = scope.lower()
W
wangxiao 已提交
126
    assert item in _items, '{} is not found. Support list: {}'.format(item, list(_items.keys()))
W
wangxiao 已提交
127
   
W
wangxiao 已提交
128
    if _items[item]['utils'] is not None:
W
wangxiao 已提交
129
        _download(item, 'utils', path, silent=True)
X
xixiaoyao 已提交
130 131

    if scope != 'all':
W
wangxiao 已提交
132
        assert scope in _items[item], '{} is not found. Support scopes: {}'.format(item, list(_items[item].keys()))
W
wangxiao 已提交
133
        _download(item, scope, path)
X
xixiaoyao 已提交
134
    else:
W
wangxiao 已提交
135
        for s in _items[item].keys():
W
wangxiao 已提交
136
            _download(item, s, path)
X
xixiaoyao 已提交
137 138 139 140


def ls(item=None, scope='all'):
    pass