提交 db5fd353 编写于 作者: Z Zeyu Chen

remove useless files

上级 b43e0aa3
...@@ -64,9 +64,9 @@ def download_and_uncompress(url, save_name=None): ...@@ -64,9 +64,9 @@ def download_and_uncompress(url, save_name=None):
os.makedirs(dirname) os.makedirs(dirname)
#TODO(ZeyuChen) add download md5 file to verify file completeness #TODO(ZeyuChen) add download md5 file to verify file completeness
file_name = os.path.join(dirname, file_name = os.path.join(
url.split('/')[-1] dirname,
if save_name is None else save_name) url.split('/')[-1] if save_name is None else save_name)
retry = 0 retry = 0
retry_limit = 3 retry_limit = 3
...@@ -76,8 +76,9 @@ def download_and_uncompress(url, save_name=None): ...@@ -76,8 +76,9 @@ def download_and_uncompress(url, save_name=None):
if retry < retry_limit: if retry < retry_limit:
retry += 1 retry += 1
else: else:
raise RuntimeError("Cannot download {0} within retry limit {1}". raise RuntimeError(
format(url, retry_limit)) "Cannot download {0} within retry limit {1}".format(
url, retry_limit))
print("Cache file %s not found, downloading %s" % (file_name, url)) print("Cache file %s not found, downloading %s" % (file_name, url))
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
total_length = r.headers.get('content-length') total_length = r.headers.get('content-length')
...@@ -94,8 +95,8 @@ def download_and_uncompress(url, save_name=None): ...@@ -94,8 +95,8 @@ def download_and_uncompress(url, save_name=None):
dl += len(data) dl += len(data)
f.write(data) f.write(data)
done = int(50 * dl / total_length) done = int(50 * dl / total_length)
sys.stdout.write("\r[%s%s]" % ('=' * done, sys.stdout.write(
' ' * (50 - done))) "\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush() sys.stdout.flush()
print("file download completed!", file_name) print("file download completed!", file_name)
...@@ -111,64 +112,6 @@ def download_and_uncompress(url, save_name=None): ...@@ -111,64 +112,6 @@ def download_and_uncompress(url, save_name=None):
return module_name, module_dir return module_name, module_dir
class TqdmProgress(tqdm):
"""
tqdm prograss hook
"""
last_block = 0
def update_to(self, block_num=1, block_size=1, total_size=None):
if total_size is not None:
self.total = total_size
self.update((block_num - self.last_block) * block_size)
self.last_block = block_num
class DownloadManager(object):
def __init__(self):
self.dst_path = tempfile.mkstemp()
def download(self, link, dst_path):
file_name = link.split("/")[-1]
if dst_path is not None:
self.dst_path = dst_path
if not os.path.exists(self.dst_path):
os.makedirs(self.dst_path)
file_path = os.path.join(self.dst_path, file_name)
print("download filepath", file_path)
with TqdmProgress(
unit='B',
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=file_name) as progress:
path, header = urlretrieve(
link,
filename=file_path,
reporthook=progress.update_to,
data=None)
return path
def _extract_file(self, tgz, tarinfo, dst_path, buffer_size=10 << 20):
"""Extracts 'tarinfo' from 'tgz' and writes to 'dst_path'."""
src = tgz.extractfile(tarinfo)
dst = tf.gfile.GFile(dst_path, "wb")
while 1:
buf = src.read(buffer_size)
if not buf:
break
dst.write(buf)
self._log_progress(len(buf))
dst.close()
src.close()
def download_and_uncompress(self, link, dst_path):
file_name = self.download(link, dst_path)
print(file_name)
if __name__ == "__main__": if __name__ == "__main__":
# TODO(ZeyuChen) add unit test # TODO(ZeyuChen) add unit test
link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz" link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz"
......
...@@ -159,19 +159,6 @@ class Module(object): ...@@ -159,19 +159,6 @@ class Module(object):
word_dict = self.config.get_dict() word_dict = self.config.get_dict()
return list(map(lambda x: word_dict[x], inputs)) return list(map(lambda x: word_dict[x], inputs))
# # load assets folder
# def _load_assets(self, module_dir):
# assets_dir = os.path.join(module_dir, ASSETS_NAME)
# dict_path = os.path.join(assets_dir, DICT_NAME)
# word_id = 0
# with open(dict_path) as fi:
# words = fi.readlines()
# #TODO(ZeyuChen) check whether word id is duplicated and valid
# for line in fi:
# w, w_id = line.split()
# self.dict[w] = int(w_id)
def add_module_feed_list(self, feed_list): def add_module_feed_list(self, feed_list):
self.feed_list = feed_list self.feed_list = feed_list
...@@ -240,9 +227,13 @@ class ModuleConfig(object): ...@@ -240,9 +227,13 @@ class ModuleConfig(object):
w_id = self.dict[w] w_id = self.dict[w]
fo.write("{}\t{}\n".format(w, w_id)) fo.write("{}\t{}\n".format(w, w_id))
def register_input_var(self, var): def register_input_var(self, var, signature="default"):
var_name = var.name()
self.desc.sign2input[signature].append(var_name)
def register_output_var(self, var, signature="default"):
var_name = var.name() var_name = var.name()
self.feed_list.add(var_name) self.desc.sign2output[signature].append(var_name)
def save_dict(self, word_dict, dict_name=DICT_NAME): def save_dict(self, word_dict, dict_name=DICT_NAME):
""" Save dictionary for NLP module """ Save dictionary for NLP module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册