提交 5dc25a43 编写于 作者: Z Zeyu Chen

add module_desc

上级 403bbf0b
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle_hub.module import Module
from paddle_hub.module import ModuleSpec
# coding=utf-8
from __future__ import print_function
from __future__ import division
from __future__ import print_function
from urllib.request import urlretrieve from urllib.request import urlretrieve
from tqdm import tqdm from tqdm import tqdm
import os import os
import sys
import hashlib
import requests
import tempfile import tempfile
import tarfile
""" """
tqdm prograss hook tqdm prograss hook
""" """
__all__ = [
'MODULE_HOME',
'download',
'md5file',
'split',
'cluster_files_reader',
'convert',
]
MODULE_HOME = os.path.expanduser('~/.cache/paddle/module')
# When running unit tests, there could be multiple processes that
# trying to create MODULE_HOME directory simultaneously, so we cannot
# use a if condition to check for the existence of the directory;
# instead, we use the filesystem as the synchronization mechanism by
# catching returned errors.
def must_mkdirs(path):
try:
os.makedirs(MODULE_HOME)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
pass
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def download_and_uncompress(url, save_name=None):
module_name = url.split("/")[-2]
dirname = os.path.join(MODULE_HOME, module_name)
print("download to dir", dirname)
if not os.path.exists(dirname):
os.makedirs(dirname)
#TODO add download md5 file to verify file completeness
file_name = os.path.join(
dirname,
url.split('/')[-1] if save_name is None else save_name)
retry = 0
retry_limit = 3
while not (os.path.exists(file_name)):
if os.path.exists(file_name):
print("file md5", md5file(file_name))
if retry < retry_limit:
retry += 1
else:
raise RuntimeError(
"Cannot download {0} within retry limit {1}".format(
url, retry_limit))
print("Cache file %s not found, downloading %s" % (file_name, url))
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(file_name, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(file_name, 'wb') as f:
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
sys.stdout.write(
"\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush()
print("file download completed!", file_name)
with tarfile.open(file_name, "r:gz") as tar:
file_names = tar.getnames()
print(file_names)
module_dir = os.path.join(dirname, file_names[0])
for file_name in file_names:
tar.extract(file_name, dirname)
return module_dir
class TqdmProgress(tqdm): class TqdmProgress(tqdm):
last_block = 0 last_block = 0
...@@ -72,6 +169,10 @@ class DownloadManager(object): ...@@ -72,6 +169,10 @@ class DownloadManager(object):
if __name__ == "__main__": if __name__ == "__main__":
link = "ftp://nj03-rp-m22nlp062.nj03.baidu.com//home/disk0/chenzeyu01/movie/movie_summary.txt" link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz"
dl = DownloadManager()
dl.download_and_uncompress(link, "./tmp") module_path = download_and_uncompress(link)
print("module path", module_path)
# dl = DownloadManager()
# dl.download_and_uncompress(link, "./tmp")
...@@ -3,8 +3,60 @@ from __future__ import division ...@@ -3,8 +3,60 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_hub as hub
import tempfile
import os
class Module(object): class Module(object):
def __init__(self, spec): def __init__(self, module_url):
module_dir = download_and_extract(module_url)
def __init__(self, module_name, module_dir=None):
if module_dir is None:
self.module_dir = tempfile.gettempdir()
else:
self.module_dir = module_dir
self.module_name = module_name
self.module_dir = os.path.join(self.module_dir, self.module_name)
print("create module dir folder at {}".format(self.module_dir))
self._mkdir(self.module_dir)
self.feed_list = []
self.output_list = []
pass
def save_dict(self, word_dict):
with open(os.path.join(self.module_dir, "tokens.txt"), "w") as fo:
#map(str, word_dict)
dict_str = "\n".join(word_dict)
fo.write(dict_str)
def add_module_feed_list(self, feed_list):
self.feed_list = feed_list
def add_module_output_list(self, output_list):
self.output_list = output_list
def _mkdir(self, path):
if not os.path.exists(path):
os.makedirs(path)
def save_inference_module(feed_var_names, target_vars, executor):
pass
class ModuleImpl(object):
def get_signature_name():
pass
class ModuleDesc(object):
def __init__(self, input_list, output_list):
self.input_list = input_sig
self.output_list = output_list
pass
def add_signature(input, output):
pass pass
// Copyright 2018 The Paddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax = "proto3";
package paddle_hub
// A Hub Module is stored in a directory with a file 'paddlehub_module.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
message ModuleDesc {
required stirng name = 0; // PaddleHub module name
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册