提交 a3319267 编写于 作者: W wuzewu

Add temp directory in hub home

上级 ef273daa
...@@ -15,3 +15,4 @@ ...@@ -15,3 +15,4 @@
from . import utils from . import utils
from .utils import get_running_device_info from .utils import get_running_device_info
from .dir import tmp_dir, tmp_file
#coding:utf-8 # coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
import os import os
import contextlib
import shutil
import tempfile
# TODO: Change dir.py's filename, this naming rule is not qualified # TODO: Change dir.py's filename, this naming rule is not qualified
...@@ -37,3 +40,24 @@ CACHE_HOME = os.path.join(gen_hub_home(), "cache") ...@@ -37,3 +40,24 @@ CACHE_HOME = os.path.join(gen_hub_home(), "cache")
DATA_HOME = os.path.join(gen_hub_home(), "dataset") DATA_HOME = os.path.join(gen_hub_home(), "dataset")
CONF_HOME = os.path.join(gen_hub_home(), "conf") CONF_HOME = os.path.join(gen_hub_home(), "conf")
THIRD_PARTY_HOME = os.path.join(gen_hub_home(), "thirdparty") THIRD_PARTY_HOME = os.path.join(gen_hub_home(), "thirdparty")
TMP_HOME = os.path.join(gen_hub_home(), "tmp")
if not os.path.exists(TMP_HOME):
os.mkdir(TMP_HOME)
@contextlib.contextmanager
def tmp_file():
with tempfile.NamedTemporaryFile(dir=TMP_HOME) as file:
yield file.name
@contextlib.contextmanager
def tmp_dir():
try:
_dir = tempfile.mkdtemp(dir=TMP_HOME)
yield _dir
except:
raise
finally:
shutil.rmtree(_dir)
...@@ -25,7 +25,7 @@ import inspect ...@@ -25,7 +25,7 @@ import inspect
import importlib import importlib
import tarfile import tarfile
import six import six
from shutil import copyfile import shutil
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -36,6 +36,7 @@ from paddlehub.common.dir import CACHE_HOME ...@@ -36,6 +36,7 @@ from paddlehub.common.dir import CACHE_HOME
from paddlehub.common.lock import lock from paddlehub.common.lock import lock
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater from paddlehub.common.hub_server import CacheUpdater
from paddlehub.common import tmp_dir
from paddlehub.module import module_desc_pb2 from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager from paddlehub.module.manager import default_module_manager
from paddlehub.module.checker import ModuleChecker from paddlehub.module.checker import ModuleChecker
...@@ -58,7 +59,13 @@ HUB_PACKAGE_SUFFIX = "phm" ...@@ -58,7 +59,13 @@ HUB_PACKAGE_SUFFIX = "phm"
def create_module(directory, name, author, email, module_type, summary, def create_module(directory, name, author, email, module_type, summary,
version): version):
save_file_name = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX) save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)
with tmp_dir() as base_dir:
# package the module
with tarfile.open(save_file, "w:gz") as tar:
module_dir = os.path.join(base_dir, name)
shutil.copytree(directory, module_dir)
# record module info and serialize # record module info and serialize
desc = module_desc_pb2.ModuleDesc() desc = module_desc_pb2.ModuleDesc()
...@@ -67,34 +74,36 @@ def create_module(directory, name, author, email, module_type, summary, ...@@ -67,34 +74,36 @@ def create_module(directory, name, author, email, module_type, summary,
module_info = attr.map.data['module_info'] module_info = attr.map.data['module_info']
module_info.type = module_desc_pb2.MAP module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(name, module_info.map.data['name']) utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
utils.from_pyobj_to_module_attr(author, module_info.map.data['author']) utils.from_pyobj_to_module_attr(author,
utils.from_pyobj_to_module_attr(email, module_info.map.data['author_email']) module_info.map.data['author'])
utils.from_pyobj_to_module_attr(module_type, module_info.map.data['type']) utils.from_pyobj_to_module_attr(
utils.from_pyobj_to_module_attr(summary, module_info.map.data['summary']) email, module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(version, module_info.map.data['version']) utils.from_pyobj_to_module_attr(module_type,
module_info.map.data['type'])
module_desc_path = os.path.join(directory, "module_desc.pb") utils.from_pyobj_to_module_attr(summary,
module_info.map.data['summary'])
utils.from_pyobj_to_module_attr(version,
module_info.map.data['version'])
module_desc_path = os.path.join(module_dir, "module_desc.pb")
with open(module_desc_path, "wb") as f: with open(module_desc_path, "wb") as f:
f.write(desc.SerializeToString()) f.write(desc.SerializeToString())
# generate check info # generate check info
checker = ModuleChecker(directory) checker = ModuleChecker(module_dir)
checker.generate_check_info() checker.generate_check_info()
# add __init__ # add __init__
module_init = os.path.join(directory, "__init__.py") module_init = os.path.join(module_dir, "__init__.py")
with open(module_init, "a") as file: with open(module_init, "a") as file:
file.write("") file.write("")
# package the module _cwd = os.getcwd()
with tarfile.open(save_file_name, "w:gz") as tar: os.chdir(base_dir)
for dirname, _, files in os.walk(directory): for dirname, _, files in os.walk(module_dir):
for file in files: for file in files:
tar.add(os.path.join(dirname, file)) tar.add(os.path.join(dirname, file).replace(base_dir, "."))
os.remove(module_desc_path) os.chdir(_cwd)
os.remove(checker.pb_path)
os.remove(module_init)
_module_runable_func = {} _module_runable_func = {}
...@@ -340,7 +349,7 @@ class ModuleV1(Module): ...@@ -340,7 +349,7 @@ class ModuleV1(Module):
for asset in self.assets: for asset in self.assets:
filename = os.path.basename(asset) filename = os.path.basename(asset)
newfile = os.path.join(self.helper.assets_path(), filename) newfile = os.path.join(self.helper.assets_path(), filename)
copyfile(asset, newfile) shutil.copyfile(asset, newfile)
def _load_assets(self): def _load_assets(self):
assets_path = self.helper.assets_path() assets_path = self.helper.assets_path()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册