提交 9dfd6e56 编写于 作者: Z Zeyu Chen

add test downloader

上级 5dc25a43
......@@ -5,4 +5,5 @@ from __future__ import print_function
import paddle.fluid as fluid
from paddle_hub.module import Module
from paddle_hub.module import ModuleSpec
from paddle_hub.module import ModuleDesc
from paddle_hub.downloader import download_and_uncompress
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
from __future__ import print_function
from __future__ import division
......@@ -17,14 +31,11 @@ tqdm prograss hook
"""
__all__ = [
'MODULE_HOME',
'download',
'md5file',
'split',
'cluster_files_reader',
'convert',
'MODULE_HOME', 'download', 'md5file', 'split', 'cluster_files_reader',
'convert', 'download_and_uncompress'
]
# TODO(ZeyuChen) add environment varialble to set MODULE_HOME
MODULE_HOME = os.path.expanduser('~/.cache/paddle/module')
......@@ -58,8 +69,7 @@ def download_and_uncompress(url, save_name=None):
if not os.path.exists(dirname):
os.makedirs(dirname)
#TODO add download md5 file to verify file completeness
#TODO(ZeyuChen) add download md5 file to verify file completeness
file_name = os.path.join(
dirname,
url.split('/')[-1] if save_name is None else save_name)
......@@ -83,6 +93,7 @@ def download_and_uncompress(url, save_name=None):
with open(file_name, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
#TODO(ZeyuChen) upgrade to tqdm process
with open(file_name, 'wb') as f:
dl = 0
total_length = int(total_length)
......@@ -95,6 +106,8 @@ def download_and_uncompress(url, save_name=None):
sys.stdout.flush()
print("file download completed!", file_name)
#TODO(ZeyuChen) add md5 check error and file incompleted error, then raise
# them and catch them
with tarfile.open(file_name, "r:gz") as tar:
file_names = tar.getnames()
print(file_names)
......@@ -169,6 +182,7 @@ class DownloadManager(object):
if __name__ == "__main__":
# TODO(ZeyuChen) add unit test
link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz"
module_path = download_and_uncompress(link)
......
......@@ -3,35 +3,120 @@ from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import paddle_hub as hub
import numpy as np
# import paddle_hub as hub
import tempfile
import os
from collections import defaultdict
class Module(object):
def __init__(self, module_url):
module_dir = download_and_extract(module_url)
def __init__(self, module_name, module_dir=None):
if module_dir is None:
self.module_dir = tempfile.gettempdir()
else:
self.module_dir = module_dir
from paddle_hub.downloader import download_and_uncompress
self.module_name = module_name
self.module_dir = os.path.join(self.module_dir, self.module_name)
print("create module dir folder at {}".format(self.module_dir))
self._mkdir(self.module_dir)
self.feed_list = []
self.output_list = []
pass
def save_dict(self, word_dict):
with open(os.path.join(self.module_dir, "tokens.txt"), "w") as fo:
#map(str, word_dict)
dict_str = "\n".join(word_dict)
fo.write(dict_str)
class Module(object):
def __init__(self, module_url):
# donwload module
#module_dir = downloader.download_and_uncompress(module_url)
module_dir = download_and_uncompress(module_url)
# load paddle inference model
place = fluid.CPUPlace()
self.exe = fluid.Executor(fluid.CPUPlace())
[self.inference_program, self.feed_target_names,
self.fetch_targets] = fluid.io.load_inference_model(
dirname=module_dir, executor=self.exe)
print("inference_program")
print(self.inference_program)
print("feed_target_names")
print(self.feed_target_names)
print("fetch_targets")
print(self.fetch_targets)
# load assets
# self._load_assets(module_dir)
def __call__(self, inputs=None, signature=None):
word_ids_lod_tensor = self._process_input(inputs)
np_words_id = np.array(word_ids_lod_tensor)
print("word_ids_lod_tensor\n", np_words_id)
results = self.exe.run(
self.inference_program,
feed={self.feed_target_names[0]: word_ids_lod_tensor},
fetch_list=self.fetch_targets,
return_numpy=False) # return_numpy=Flase is important
print(self.feed_target_names)
print(self.fetch_targets)
# np_result = np.array(results[0])
return np_result
def get_vars(self):
return self.inference_program.list_vars()
def get_input_vars(self):
for var in self.inference_program.list_vars():
print(var)
if var.name == "words":
return var
# return self.fetch_targets
def get_module_output(self):
for var in self.inference_program.list_vars():
print(var)
if var.name == "embedding_0.tmp_0":
return var
def get_inference_program(self):
return self.inference_program
# for text sequence input, transform to lod tensor as paddle graph's input
def _process_input(self, inputs):
# words id mapping and dealing with oov
# transform to lod tensor
seq = []
for s in inputs:
seq.append(self._word_id_mapping(s))
lod_tensor = self.seq2lod_tensor(seq)
return lod_tensor
def seq2lod_tensor(self, seq_inputs, place=fluid.CPUPlace()):
""" sequence to lod tensor, need to determine which space"""
lod = []
lod.append([])
for s in seq_inputs:
# generate lod
lod[0].append(len(s))
# print("seq", seq_inputs)
# print("lod", lod)
lod_tensor = fluid.create_lod_tensor(seq_inputs, lod, place)
return lod_tensor
def _word_id_mapping(self, inputs):
return list(map(lambda x: self.dict[x], inputs))
# load assets folder
def _load_assets(self, module_dir):
self.dict = defaultdict(int)
self.dict.setdefault(0)
assets_dir = os.path.join(module_dir, "assets")
tokens_path = os.path.join(assets_dir, "tokens.txt")
word_id = 0
with open(tokens_path) as fi:
words = fi.readlines()
words = map(str.strip, words)
for w in words:
self.dict[w] = word_id
word_id += 1
print(w, word_id)
def add_module_feed_list(self, feed_list):
self.feed_list = feed_list
......@@ -43,9 +128,6 @@ class Module(object):
if not os.path.exists(path):
os.makedirs(path)
def save_inference_module(feed_var_names, target_vars, executor):
pass
class ModuleImpl(object):
def get_signature_name():
......@@ -53,10 +135,36 @@ class ModuleImpl(object):
class ModuleDesc(object):
def __init__(self, input_list, output_list):
self.input_list = input_sig
self.output_list = output_list
def __init__(self):
pass
def add_signature(input, output):
@staticmethod
def _mkdir(path):
if not os.path.exists(path):
print("mkdir", path)
os.makedirs(path)
@staticmethod
def save_dict(path, word_dict):
ModuleDesc._mkdir(path)
with open(os.path.join(path, "tokens.txt"), "w") as fo:
print("tokens.txt path", os.path.join(path, "tokens.txt"))
dict_str = "\n".join(word_dict)
fo.write(dict_str)
@staticmethod
def save_module_dict(module_path, word_dict):
assets_path = os.path.join(module_path, "assets")
print("save_module_dict", assets_path)
ModuleDesc.save_dict(assets_path, word_dict)
pass
if __name__ == "__main__":
module_link = "http://paddlehub.cdn.bcebos.com/word2vec/w2v_saved_inference_module.tar.gz"
m = Module(module_link)
inputs = [["it", "is", "new"], ["hello", "world"]]
#tensor = m._process_input(inputs)
#print(tensor)
result = m(inputs)
print(result)
......@@ -14,11 +14,21 @@
// =============================================================================
syntax = "proto3";
package paddle_hub
package paddle_hub;
message InputDesc {
}
message OutputDesc {
bool return_numpy = 1;
}
// A Hub Module is stored in a directory with a file 'paddlehub_module.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
message ModuleDesc {
required stirng name = 0; // PaddleHub module name
string name = 1; // PaddleHub module name
repeated string input_signature
}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import downloader as dl
import unittest
class TestDownloader(unittest.TestCase):
def test_download(self):
link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz"
module_path = dl.
ownload_and_uncompress(link)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册