提交 5dc25a43 编写于 作者: Z Zeyu Chen

add module_desc

上级 403bbf0b
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle_hub.module import Module
from paddle_hub.module import ModuleSpec
# coding=utf-8
from __future__ import print_function
from __future__ import division
from __future__ import print_function
from urllib.request import urlretrieve
from tqdm import tqdm
import os
import sys
import hashlib
import requests
import tempfile
import tarfile
"""
tqdm prograss hook
"""
__all__ = [
'MODULE_HOME',
'download',
'md5file',
'split',
'cluster_files_reader',
'convert',
]
MODULE_HOME = os.path.expanduser('~/.cache/paddle/module')
# When running unit tests, there could be multiple processes that
# trying to create MODULE_HOME directory simultaneously, so we cannot
# use a if condition to check for the existence of the directory;
# instead, we use the filesystem as the synchronization mechanism by
# catching returned errors.
def must_mkdirs(path):
try:
os.makedirs(MODULE_HOME)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
pass
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def download_and_uncompress(url, save_name=None):
module_name = url.split("/")[-2]
dirname = os.path.join(MODULE_HOME, module_name)
print("download to dir", dirname)
if not os.path.exists(dirname):
os.makedirs(dirname)
#TODO add download md5 file to verify file completeness
file_name = os.path.join(
dirname,
url.split('/')[-1] if save_name is None else save_name)
retry = 0
retry_limit = 3
while not (os.path.exists(file_name)):
if os.path.exists(file_name):
print("file md5", md5file(file_name))
if retry < retry_limit:
retry += 1
else:
raise RuntimeError(
"Cannot download {0} within retry limit {1}".format(
url, retry_limit))
print("Cache file %s not found, downloading %s" % (file_name, url))
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(file_name, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(file_name, 'wb') as f:
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
sys.stdout.write(
"\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush()
print("file download completed!", file_name)
with tarfile.open(file_name, "r:gz") as tar:
file_names = tar.getnames()
print(file_names)
module_dir = os.path.join(dirname, file_names[0])
for file_name in file_names:
tar.extract(file_name, dirname)
return module_dir
class TqdmProgress(tqdm):
last_block = 0
......@@ -72,6 +169,10 @@ class DownloadManager(object):
if __name__ == "__main__":
link = "ftp://nj03-rp-m22nlp062.nj03.baidu.com//home/disk0/chenzeyu01/movie/movie_summary.txt"
dl = DownloadManager()
dl.download_and_uncompress(link, "./tmp")
link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz"
module_path = download_and_uncompress(link)
print("module path", module_path)
# dl = DownloadManager()
# dl.download_and_uncompress(link, "./tmp")
......@@ -3,8 +3,60 @@ from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import paddle_hub as hub
import tempfile
import os
class Module(object):
def __init__(self, spec):
def __init__(self, module_url):
module_dir = download_and_extract(module_url)
def __init__(self, module_name, module_dir=None):
if module_dir is None:
self.module_dir = tempfile.gettempdir()
else:
self.module_dir = module_dir
self.module_name = module_name
self.module_dir = os.path.join(self.module_dir, self.module_name)
print("create module dir folder at {}".format(self.module_dir))
self._mkdir(self.module_dir)
self.feed_list = []
self.output_list = []
pass
def save_dict(self, word_dict):
with open(os.path.join(self.module_dir, "tokens.txt"), "w") as fo:
#map(str, word_dict)
dict_str = "\n".join(word_dict)
fo.write(dict_str)
def add_module_feed_list(self, feed_list):
self.feed_list = feed_list
def add_module_output_list(self, output_list):
self.output_list = output_list
def _mkdir(self, path):
if not os.path.exists(path):
os.makedirs(path)
def save_inference_module(feed_var_names, target_vars, executor):
pass
class ModuleImpl(object):
def get_signature_name():
pass
class ModuleDesc(object):
def __init__(self, input_list, output_list):
self.input_list = input_sig
self.output_list = output_list
pass
def add_signature(input, output):
pass
// Copyright 2018 The Paddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax = "proto3";
package paddle_hub
// A Hub Module is stored in a directory with a file 'paddlehub_module.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
message ModuleDesc {
required stirng name = 0; // PaddleHub module name
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册