提交 f43c50a6 编写于 作者: W wuzewu

Update pypi utils

上级 30142577
......@@ -156,8 +156,8 @@ class LocalModuleManager(object):
branch: str = None) -> HubModule:
'''
Install a HubModule from name or directory or archive file or url. When installing with the name parameter, if a
module that meets the conditions (both name and version) already installed, the installation step will be
skipped. When installing with other parameter, The locally installed modules will be uninstalled.
module that meets the conditions (both name and version) already installed, the installation step will be skipped.
When installing with other parameter, The locally installed modules will be uninstalled.
Args:
name (str|optional): module name to install
......@@ -278,7 +278,7 @@ class LocalModuleManager(object):
def _install_from_source(self, name: str, version: str, source: str, update: bool = False,
branch: str = None) -> HubModule:
'''Install a HubModule from Git Repo'''
'''Install a HubModule from git repository'''
result = module_server.search_module(name=name, source=source, version=version, update=update, branch=branch)
for item in result:
if item['name'] == name and item['version'].match(version):
......@@ -292,7 +292,7 @@ class LocalModuleManager(object):
os.makedirs(installed_path)
module_file = os.path.join(installed_path, 'module.py')
# Generate a module.py file to reference objects from Git Repo
# Generate a module.py file to reference objects from git repository
with open(module_file, 'w') as file:
file.write('import sys\n\n')
file.write('sys.path.insert(0, \'{}\')\n'.format(item['path']))
......@@ -305,11 +305,20 @@ class LocalModuleManager(object):
file.write('branch: {}'.format(branch))
self._local_modules[name] = HubModule.load(installed_path)
module_file = sys.modules[self._local_modules[name].__module__].__file__
requirements_file = os.path.join(os.path.dirname(module_file), 'requirements.txt')
if os.path.exists(requirements_file):
shutil.copy(requirements_file, installed_path)
# Install python package requirements
self._install_module_requirements(self._local_modules[name])
if version:
log.logger.info('Successfully installed {}-{}'.format(name, version))
else:
log.logger.info('Successfully installed {}'.format(name))
return self._local_modules[name]
raise HubModuleNotFoundError(name=name, version=version, source=source)
def _install_from_directory(self, directory: str) -> HubModule:
......@@ -339,14 +348,8 @@ class LocalModuleManager(object):
hub_module_cls = HubModule.load(self._get_normalized_path(hub_module_cls.name))
self._local_modules[hub_module_cls.name] = hub_module_cls
for py_req in hub_module_cls.get_py_requirements():
log.logger.info('Installing dependent packages: {}'.format(py_req))
result = pypi.install(py_req)
if result:
log.logger.info('Successfully installed {}'.format(py_req))
else:
log.logger.info('Some errors occurred while installing {}'.format(py_req))
# Install python package requirements
self._install_module_requirements(hub_module_cls)
log.logger.info('Successfully installed {}-{}'.format(hub_module_cls.name, hub_module_cls.version))
return hub_module_cls
......@@ -361,3 +364,20 @@ class LocalModuleManager(object):
path = os.path.normpath(path)
directory = os.path.join(_tdir, path.split(os.sep)[0])
return self._install_from_directory(directory)
def _install_module_requirements(self, module: HubModule):
file = utils.get_record_file()
with open(file, 'a') as _stream:
for py_req in module.get_py_requirements():
if py_req.lstrip().rstrip() == '':
continue
with log.logger.processing('Installing dependent packages {}'.format(py_req)):
result = pypi.install(py_req, ostream=_stream, estream=_stream)
if result:
log.logger.info('Successfully installed dependent packages {}'.format(py_req))
else:
log.logger.warning(
'Some errors occurred while installing dependent packages {}. Detailed error information can be found in the {}.'
.format(py_req, file))
......@@ -114,7 +114,7 @@ class RunModule(object):
if not os.path.exists(req_file):
return []
with open(req_file, 'r') as file:
return file.read()
return file.read().split('\n')
@property
def is_runnable(self) -> bool:
......@@ -149,6 +149,7 @@ class Module(object):
'''
def __new__(cls,
*,
name: str = None,
directory: str = None,
version: str = None,
......
......@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pip
import os
import subprocess
from pip._internal.utils.misc import get_installed_distributions
from typing import IO
from paddlehub.utils.utils import Version
from paddlehub.utils.io import discard_oe, typein
......@@ -38,20 +40,33 @@ def check(package: str, version: str = '') -> bool:
return pdict[package].match(version)
def install(package: str, version: str = '', upgrade=False) -> bool:
def install(package: str, version: str = '', upgrade: bool = False, ostream: IO = None, estream: IO = None) -> bool:
'''Install the python package.'''
with discard_oe():
cmds = ['install', '{}{}'.format(package, version)]
package = package.replace(' ', '')
if version:
package = '{}=={}'.format(package, version)
cmd = 'pip install "{}"'.format(package)
if upgrade:
cmds.append('--upgrade')
result = pip.main(cmds)
cmd += ' --upgrade'
result, content = subprocess.getstatusoutput(cmd)
if result:
estream.write(content)
else:
ostream.write(content)
return result == 0
def uninstall(package: str) -> bool:
def uninstall(package: str, ostream: IO = None, estream: IO = None) -> bool:
'''Uninstall the python package.'''
with discard_oe(), typein('y'):
with typein('y'):
# type in 'y' to confirm the uninstall operation
cmds = ['uninstall', '{}'.format(package)]
result = pip.main(cmds)
cmd = 'pip uninstall {}'.format(package)
result, content = subprocess.getstatusoutput(cmd)
if result:
estream.write(content)
else:
ostream.write(content)
return result == 0
......@@ -282,7 +282,7 @@ def md5(text: str):
def record(msg: str) -> str:
'''Record the specified text into the PaddleHub log file witch will be automatically stored according to date.'''
logfile = os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log'))
logfile = get_record_file()
with open(logfile, 'a') as file:
file.write('=' * 50 + '\n')
file.write('Record at ' + time.strftime('%Y-%m-%d %H:%M:%S') + '\n')
......@@ -297,3 +297,7 @@ def record_exception(msg: str) -> str:
tb = traceback.format_exc()
file = record(tb)
utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
def get_record_file():
return os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册