提交 936f00ce 编写于 作者: W wuzewu

save processor with an unique name to avoid naming conflict

上级 978757a0
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Store PaddleHub version string """
import os
USER_HOME = os.path.expanduser('~')
......
......@@ -81,7 +81,3 @@ class HubServer:
default_hub_server = HubServer()
if __name__ == "__main__":
print(default_hub_server.search_module("ssd"))
print(default_hub_server.get_module_url("ssd_mobilenet_pascal"))
......@@ -164,15 +164,3 @@ class ModuleChecker:
logger.error("file type error %s" % file_path)
return False
return True
if __name__ == "__main__":
# check_info = check_info_pb2.CheckInfo()
# check_info.paddle_version = "1"
# check_info.hub_version = "1"
# check_info.module_proto_version = "1"
# with open(os.path.join(".", CHECK_INFO_PB_FILENAME), "wb") as fi:
# fi.write(check_info.SerializeToString())
check_info = ModuleChecker(
"/home/wuzewu/code/PaddleHub/demo/object-detection/hub_module_ssd")
print(check_info.check())
......@@ -28,6 +28,7 @@ from paddle_hub import version
from paddle_hub.module.base_processor import BaseProcessor
from shutil import copyfile
import os
import time
import sys
import functools
import paddle
......@@ -58,7 +59,7 @@ MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB@"
HUB_VAR_PREFIX = "@HUB_%s@"
class ModuleHelper:
......@@ -98,6 +99,7 @@ class Module:
self.module_info = None
self.processor = None
self.assets = []
self.name = "temp"
if url:
self._init_with_url(url=url)
elif module_dir:
......@@ -128,17 +130,22 @@ class Module:
pymodule = inspect.getmodule(self.processor)
pycode = inspect.getsource(pymodule)
processor_path = self.helper.processor_path()
processor_name = self.helper.processor_name()
processor_md5 = utils.md5(pycode)
processor_md5 += str(time.time())
processor_name = utils.md5(processor_md5)
output_file = os.path.join(processor_path, processor_name + ".py")
utils.mkdir(processor_path)
with open(output_file, "w") as file:
file.write(pycode)
utils.from_pyobj_to_flexible_data(
processor_name, self.desc.extra_info.map.data['processor_info'])
def _load_processor(self):
processor_path = self.helper.processor_path()
if os.path.exists(processor_path):
sys.path.append(processor_path)
processor_name = self.helper.processor_name()
processor_name = utils.from_flexible_data_to_pyobj(
self.desc.extra_info.map.data['processor_info'])
self.processor = __import__(processor_name).Processor(module=self)
else:
self.processor = None
......@@ -199,7 +206,7 @@ class Module:
param_attrs = self.desc.extra_info.map.data['param_attrs']
for key, param_attr in param_attrs.map.data.items():
param = paddle_helper.from_flexible_data_to_param(param_attr)
param['name'] = HUB_VAR_PREFIX + key
param['name'] = self.get_var_name_with_prefix(key)
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
......@@ -221,7 +228,7 @@ class Module:
stop_gradient = utils.from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['stop_gradient'])
block = program.blocks[idx]
var_name = HUB_VAR_PREFIX + var_info
var_name = self.get_var_name_with_prefix(var_info)
if var_name in block.vars:
var = block.vars[var_name]
var.stop_gradient = stop_gradient
......@@ -317,12 +324,12 @@ class Module:
fetch_names = sign.fetch_names
for index, input in enumerate(sign.inputs):
feed_var = feed_desc.add()
feed_var.var_name = HUB_VAR_PREFIX + input.name
feed_var.var_name = self.get_var_name_with_prefix(input.name)
feed_var.alias = feed_names[index]
for index, output in enumerate(sign.outputs):
fetch_var = fetch_desc.add()
fetch_var.var_name = HUB_VAR_PREFIX + output.name
fetch_var.var_name = self.get_var_name_with_prefix(output.name)
fetch_var.alias = fetch_names[index]
# save module info
......@@ -435,6 +442,12 @@ class Module:
return feed_dict, fetch_dict, program
def get_name_prefix(self):
return HUB_VAR_PREFIX % self.name
def get_var_name_with_prefix(self, var_name):
return self.get_name_prefix() + var_name
def parameters(self):
pass
......@@ -506,11 +519,11 @@ class Module:
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if HUB_VAR_PREFIX not in var
if self.get_name_prefix() not in var
}
for var, block in varlist.items():
old_name = var
new_name = HUB_VAR_PREFIX + old_name
new_name = self.get_var_name_with_prefix(old_name)
block._rename_var(old_name, new_name)
utils.mkdir(self.helper.model_path())
with open(
......@@ -519,17 +532,12 @@ class Module:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(self.helper.model_path()):
if (file == "__model__" or HUB_VAR_PREFIX in file):
if (file == "__model__" or self.get_name_prefix() in file):
continue
os.rename(
os.path.join(self.helper.model_path(), file),
os.path.join(self.helper.model_path(),
HUB_VAR_PREFIX + file))
# Serialize module_desc pb
module_pb = self.desc.SerializeToString()
with open(self.helper.module_desc_path(), "wb") as f:
f.write(module_pb)
self.get_var_name_with_prefix(file)))
# create processor file
if self.processor:
......@@ -541,3 +549,8 @@ class Module:
# create check info
checker = ModuleChecker(self.helper.module_dir)
checker.generate_check_info()
# Serialize module_desc pb
module_pb = self.desc.SerializeToString()
with open(self.helper.module_desc_path(), "wb") as f:
f.write(module_pb)
......@@ -22,6 +22,7 @@ from paddle_hub.tools.logger import logger
import paddle
import paddle.fluid as fluid
import os
import hashlib
def to_list(input):
......@@ -39,6 +40,23 @@ def mkdir(path):
os.makedirs(path)
def md5_of_file(file):
md5 = hashlib.md5()
with open(file, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
return md5.hexdigest()
def md5(text):
if isinstance(text, str):
text = text.encode("utf8")
md5 = hashlib.md5()
md5.update(text)
return md5.hexdigest()
def get_keyed_type_of_pyobj(pyobj):
if isinstance(pyobj, bool):
return module_desc_pb2.BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册