提交 a3319267 编写于 作者: W wuzewu

Add temp directory in hub home

上级 ef273daa
......@@ -15,3 +15,4 @@
from . import utils
from .utils import get_running_device_info
from .dir import tmp_dir, tmp_file
#coding:utf-8
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
......@@ -14,6 +14,9 @@
# limitations under the License.
import os
import contextlib
import shutil
import tempfile
# TODO: Change dir.py's filename, this naming rule is not qualified
......@@ -37,3 +40,24 @@ CACHE_HOME = os.path.join(gen_hub_home(), "cache")
DATA_HOME = os.path.join(gen_hub_home(), "dataset")
CONF_HOME = os.path.join(gen_hub_home(), "conf")
THIRD_PARTY_HOME = os.path.join(gen_hub_home(), "thirdparty")
TMP_HOME = os.path.join(gen_hub_home(), "tmp")
if not os.path.exists(TMP_HOME):
os.mkdir(TMP_HOME)
@contextlib.contextmanager
def tmp_file():
with tempfile.NamedTemporaryFile(dir=TMP_HOME) as file:
yield file.name
@contextlib.contextmanager
def tmp_dir():
try:
_dir = tempfile.mkdtemp(dir=TMP_HOME)
yield _dir
except:
raise
finally:
shutil.rmtree(_dir)
......@@ -25,7 +25,7 @@ import inspect
import importlib
import tarfile
import six
from shutil import copyfile
import shutil
import paddle
import paddle.fluid as fluid
......@@ -36,6 +36,7 @@ from paddlehub.common.dir import CACHE_HOME
from paddlehub.common.lock import lock
from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
from paddlehub.common import tmp_dir
from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager
from paddlehub.module.checker import ModuleChecker
......@@ -58,7 +59,13 @@ HUB_PACKAGE_SUFFIX = "phm"
def create_module(directory, name, author, email, module_type, summary,
version):
save_file_name = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)
save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)
with tmp_dir() as base_dir:
# package the module
with tarfile.open(save_file, "w:gz") as tar:
module_dir = os.path.join(base_dir, name)
shutil.copytree(directory, module_dir)
# record module info and serialize
desc = module_desc_pb2.ModuleDesc()
......@@ -67,34 +74,36 @@ def create_module(directory, name, author, email, module_type, summary,
module_info = attr.map.data['module_info']
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
utils.from_pyobj_to_module_attr(author, module_info.map.data['author'])
utils.from_pyobj_to_module_attr(email, module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(module_type, module_info.map.data['type'])
utils.from_pyobj_to_module_attr(summary, module_info.map.data['summary'])
utils.from_pyobj_to_module_attr(version, module_info.map.data['version'])
module_desc_path = os.path.join(directory, "module_desc.pb")
utils.from_pyobj_to_module_attr(author,
module_info.map.data['author'])
utils.from_pyobj_to_module_attr(
email, module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(module_type,
module_info.map.data['type'])
utils.from_pyobj_to_module_attr(summary,
module_info.map.data['summary'])
utils.from_pyobj_to_module_attr(version,
module_info.map.data['version'])
module_desc_path = os.path.join(module_dir, "module_desc.pb")
with open(module_desc_path, "wb") as f:
f.write(desc.SerializeToString())
# generate check info
checker = ModuleChecker(directory)
checker = ModuleChecker(module_dir)
checker.generate_check_info()
# add __init__
module_init = os.path.join(directory, "__init__.py")
module_init = os.path.join(module_dir, "__init__.py")
with open(module_init, "a") as file:
file.write("")
# package the module
with tarfile.open(save_file_name, "w:gz") as tar:
for dirname, _, files in os.walk(directory):
_cwd = os.getcwd()
os.chdir(base_dir)
for dirname, _, files in os.walk(module_dir):
for file in files:
tar.add(os.path.join(dirname, file))
tar.add(os.path.join(dirname, file).replace(base_dir, "."))
os.remove(module_desc_path)
os.remove(checker.pb_path)
os.remove(module_init)
os.chdir(_cwd)
_module_runable_func = {}
......@@ -340,7 +349,7 @@ class ModuleV1(Module):
for asset in self.assets:
filename = os.path.basename(asset)
newfile = os.path.join(self.helper.assets_path(), filename)
copyfile(asset, newfile)
shutil.copyfile(asset, newfile)
def _load_assets(self):
assets_path = self.helper.assets_path()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册