diff --git a/paddle_hub/downloader.py b/paddle_hub/downloader.py index 90b076ebc8a9bc2764d73d801d58552cd8a3b838..cd8b81964ef9b2179111fecf9bf5f1489671d36f 100644 --- a/paddle_hub/downloader.py +++ b/paddle_hub/downloader.py @@ -64,9 +64,9 @@ def download_and_uncompress(url, save_name=None): os.makedirs(dirname) #TODO(ZeyuChen) add download md5 file to verify file completeness - file_name = os.path.join(dirname, - url.split('/')[-1] - if save_name is None else save_name) + file_name = os.path.join( + dirname, + url.split('/')[-1] if save_name is None else save_name) retry = 0 retry_limit = 3 @@ -76,8 +76,9 @@ def download_and_uncompress(url, save_name=None): if retry < retry_limit: retry += 1 else: - raise RuntimeError("Cannot download {0} within retry limit {1}". - format(url, retry_limit)) + raise RuntimeError( + "Cannot download {0} within retry limit {1}".format( + url, retry_limit)) print("Cache file %s not found, downloading %s" % (file_name, url)) r = requests.get(url, stream=True) total_length = r.headers.get('content-length') @@ -94,8 +95,8 @@ def download_and_uncompress(url, save_name=None): dl += len(data) f.write(data) done = int(50 * dl / total_length) - sys.stdout.write("\r[%s%s]" % ('=' * done, - ' ' * (50 - done))) + sys.stdout.write( + "\r[%s%s]" % ('=' * done, ' ' * (50 - done))) sys.stdout.flush() print("file download completed!", file_name) @@ -111,64 +112,6 @@ def download_and_uncompress(url, save_name=None): return module_name, module_dir -class TqdmProgress(tqdm): - """ - tqdm prograss hook - """ - last_block = 0 - - def update_to(self, block_num=1, block_size=1, total_size=None): - if total_size is not None: - self.total = total_size - self.update((block_num - self.last_block) * block_size) - self.last_block = block_num - - -class DownloadManager(object): - def __init__(self): - self.dst_path = tempfile.mkstemp() - - def download(self, link, dst_path): - file_name = link.split("/")[-1] - if dst_path is not None: - self.dst_path = dst_path - if not os.path.exists(self.dst_path): - os.makedirs(self.dst_path) - file_path = os.path.join(self.dst_path, file_name) - print("download filepath", file_path) - - with TqdmProgress( - unit='B', - unit_scale=True, - unit_divisor=1024, - miniters=1, - desc=file_name) as progress: - path, header = urlretrieve( - link, - filename=file_path, - reporthook=progress.update_to, - data=None) - - return path - - def _extract_file(self, tgz, tarinfo, dst_path, buffer_size=10 << 20): - """Extracts 'tarinfo' from 'tgz' and writes to 'dst_path'.""" - src = tgz.extractfile(tarinfo) - dst = tf.gfile.GFile(dst_path, "wb") - while 1: - buf = src.read(buffer_size) - if not buf: - break - dst.write(buf) - self._log_progress(len(buf)) - dst.close() - src.close() - - def download_and_uncompress(self, link, dst_path): - file_name = self.download(link, dst_path) - print(file_name) - - if __name__ == "__main__": # TODO(ZeyuChen) add unit test link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz" diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 517a2ceb35c4200d796804e96cc8ab9d73c0d31a..9c86d3c6b9159041675bda766f8f542ff02ca231 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -159,19 +159,6 @@ class Module(object): word_dict = self.config.get_dict() return list(map(lambda x: word_dict[x], inputs)) - # # load assets folder - # def _load_assets(self, module_dir): - # assets_dir = os.path.join(module_dir, ASSETS_NAME) - # dict_path = os.path.join(assets_dir, DICT_NAME) - # word_id = 0 - - # with open(dict_path) as fi: - # words = fi.readlines() - # #TODO(ZeyuChen) check whether word id is duplicated and valid - # for line in fi: - # w, w_id = line.split() - # self.dict[w] = int(w_id) - def add_module_feed_list(self, feed_list): self.feed_list = feed_list @@ -240,9 +227,13 @@ class ModuleConfig(object): w_id = self.dict[w] fo.write("{}\t{}\n".format(w, w_id)) - def register_input_var(self, var): + def register_input_var(self, var, signature="default"): + var_name = var.name() + self.desc.sign2input[signature].append(var_name) + + def register_output_var(self, var, signature="default"): var_name = var.name() - self.feed_list.add(var_name) + self.desc.sign2output[signature].append(var_name) def save_dict(self, word_dict, dict_name=DICT_NAME): """ Save dictionary for NLP module