提交 a3319267 编写于 作者: W wuzewu

Add temp directory in hub home

上级 ef273daa
...@@ -15,3 +15,4 @@ ...@@ -15,3 +15,4 @@
from . import utils from . import utils
from .utils import get_running_device_info from .utils import get_running_device_info
from .dir import tmp_dir, tmp_file
#coding:utf-8 # coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
import os import os
import contextlib
import shutil
import tempfile
# TODO: Change dir.py's filename, this naming rule is not qualified # TODO: Change dir.py's filename, this naming rule is not qualified
...@@ -37,3 +40,24 @@ CACHE_HOME = os.path.join(gen_hub_home(), "cache") ...@@ -37,3 +40,24 @@ CACHE_HOME = os.path.join(gen_hub_home(), "cache")
DATA_HOME = os.path.join(gen_hub_home(), "dataset") DATA_HOME = os.path.join(gen_hub_home(), "dataset")
CONF_HOME = os.path.join(gen_hub_home(), "conf") CONF_HOME = os.path.join(gen_hub_home(), "conf")
THIRD_PARTY_HOME = os.path.join(gen_hub_home(), "thirdparty") THIRD_PARTY_HOME = os.path.join(gen_hub_home(), "thirdparty")
TMP_HOME = os.path.join(gen_hub_home(), "tmp")
if not os.path.exists(TMP_HOME):
os.mkdir(TMP_HOME)
@contextlib.contextmanager
def tmp_file():
with tempfile.NamedTemporaryFile(dir=TMP_HOME) as file:
yield file.name
@contextlib.contextmanager
def tmp_dir():
try:
_dir = tempfile.mkdtemp(dir=TMP_HOME)
yield _dir
except:
raise
finally:
shutil.rmtree(_dir)
...@@ -25,7 +25,7 @@ import inspect ...@@ -25,7 +25,7 @@ import inspect
import importlib import importlib
import tarfile import tarfile
import six import six
from shutil import copyfile import shutil
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -36,6 +36,7 @@ from paddlehub.common.dir import CACHE_HOME ...@@ -36,6 +36,7 @@ from paddlehub.common.dir import CACHE_HOME
from paddlehub.common.lock import lock from paddlehub.common.lock import lock
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater from paddlehub.common.hub_server import CacheUpdater
from paddlehub.common import tmp_dir
from paddlehub.module import module_desc_pb2 from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager from paddlehub.module.manager import default_module_manager
from paddlehub.module.checker import ModuleChecker from paddlehub.module.checker import ModuleChecker
...@@ -58,43 +59,51 @@ HUB_PACKAGE_SUFFIX = "phm" ...@@ -58,43 +59,51 @@ HUB_PACKAGE_SUFFIX = "phm"
def create_module(directory, name, author, email, module_type, summary, def create_module(directory, name, author, email, module_type, summary,
version): version):
save_file_name = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX) save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)
# record module info and serialize with tmp_dir() as base_dir:
desc = module_desc_pb2.ModuleDesc() # package the module
attr = desc.attr with tarfile.open(save_file, "w:gz") as tar:
attr.type = module_desc_pb2.MAP module_dir = os.path.join(base_dir, name)
module_info = attr.map.data['module_info'] shutil.copytree(directory, module_dir)
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(name, module_info.map.data['name']) # record module info and serialize
utils.from_pyobj_to_module_attr(author, module_info.map.data['author']) desc = module_desc_pb2.ModuleDesc()
utils.from_pyobj_to_module_attr(email, module_info.map.data['author_email']) attr = desc.attr
utils.from_pyobj_to_module_attr(module_type, module_info.map.data['type']) attr.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(summary, module_info.map.data['summary']) module_info = attr.map.data['module_info']
utils.from_pyobj_to_module_attr(version, module_info.map.data['version']) module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
module_desc_path = os.path.join(directory, "module_desc.pb") utils.from_pyobj_to_module_attr(author,
with open(module_desc_path, "wb") as f: module_info.map.data['author'])
f.write(desc.SerializeToString()) utils.from_pyobj_to_module_attr(
email, module_info.map.data['author_email'])
# generate check info utils.from_pyobj_to_module_attr(module_type,
checker = ModuleChecker(directory) module_info.map.data['type'])
checker.generate_check_info() utils.from_pyobj_to_module_attr(summary,
module_info.map.data['summary'])
# add __init__ utils.from_pyobj_to_module_attr(version,
module_init = os.path.join(directory, "__init__.py") module_info.map.data['version'])
with open(module_init, "a") as file: module_desc_path = os.path.join(module_dir, "module_desc.pb")
file.write("") with open(module_desc_path, "wb") as f:
f.write(desc.SerializeToString())
# package the module
with tarfile.open(save_file_name, "w:gz") as tar: # generate check info
for dirname, _, files in os.walk(directory): checker = ModuleChecker(module_dir)
for file in files: checker.generate_check_info()
tar.add(os.path.join(dirname, file))
# add __init__
os.remove(module_desc_path) module_init = os.path.join(module_dir, "__init__.py")
os.remove(checker.pb_path) with open(module_init, "a") as file:
os.remove(module_init) file.write("")
_cwd = os.getcwd()
os.chdir(base_dir)
for dirname, _, files in os.walk(module_dir):
for file in files:
tar.add(os.path.join(dirname, file).replace(base_dir, "."))
os.chdir(_cwd)
_module_runable_func = {} _module_runable_func = {}
...@@ -340,7 +349,7 @@ class ModuleV1(Module): ...@@ -340,7 +349,7 @@ class ModuleV1(Module):
for asset in self.assets: for asset in self.assets:
filename = os.path.basename(asset) filename = os.path.basename(asset)
newfile = os.path.join(self.helper.assets_path(), filename) newfile = os.path.join(self.helper.assets_path(), filename)
copyfile(asset, newfile) shutil.copyfile(asset, newfile)
def _load_assets(self): def _load_assets(self):
assets_path = self.helper.assets_path() assets_path = self.helper.assets_path()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册