提交 087acab2 编写于 作者: W wuzewu

Fix the tts module compatibility issue

上级 1dea3f05
......@@ -19,8 +19,10 @@ from easydict import EasyDict
__version__ = '2.0.0-alpha0'
from paddlehub import env
from paddlehub.config import config
from paddlehub.utils import log, parser, utils
from paddlehub.utils import download as _download
from paddlehub.utils.paddlex import download, ResourceNotFoundError
from paddlehub.server import server_check
from paddlehub.server.server_source import ServerConnectionError
......@@ -40,6 +42,8 @@ from paddlehub.compat.task.config import RunConfig
from paddlehub.compat.task.text_generation_task import TextGenerationTask
sys.modules['paddlehub.io.parser'] = parser
sys.modules['paddlehub.common.dir'] = env
sys.modules['paddlehub.common.downloader'] = _download
sys.modules['paddlehub.common.logger'] = log
sys.modules['paddlehub.common.paddle_helper'] = paddle_utils
sys.modules['paddlehub.common.utils'] = utils
......
......@@ -67,15 +67,7 @@ class RunModule(object):
'''The base class of PaddleHub Module, users can inherit this class to implement to realize custom class.'''
def __init__(self, *args, **kwargs):
# Avoid module being initialized multiple times
if '_is_initialize' in self.__dict__ and self._is_initialize:
return
super(RunModule, self).__init__()
_run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
self._run_func = getattr(self, _run_func_name) if _run_func_name else None
self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
self._is_initialize = True
def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
mod = current_cls.__module__ + '.' + current_cls.__name__
......
......@@ -15,18 +15,40 @@
import os
from paddlehub.env import DATA_HOME
import paddlehub.env as hubenv
from paddle.utils.download import get_path_from_url
from paddlehub.utils import log, utils, xarfile
def download_data(url):
save_name = os.path.basename(url).split('.')[0]
output_path = os.path.join(DATA_HOME, save_name)
output_path = os.path.join(hubenv.DATA_HOME, save_name)
if not os.path.exists(output_path):
get_path_from_url(url, DATA_HOME)
get_path_from_url(url, hubenv.DATA_HOME)
def _wrapper(Dataset):
return Dataset
return _wrapper
class Downloader:
def download_file_and_uncompress(self, url: str, save_path: str, print_progress: bool):
with utils.generate_tempdir() as _dir:
if print_progress:
with log.ProgressBar('Download {}'.format(url)) as bar:
for path, ds, ts in utils.download_with_progress(url=url, path=_dir):
bar.update(float(ds) / ts)
else:
path = utils.download(url=url, path=_dir)
if print_progress:
with log.ProgressBar('Decompress {}'.format(path)) as bar:
for path, ds, ts in xarfile.unarchive_with_progress(name=path, path=save_path):
bar.update(float(ds) / ts)
else:
path = xarfile.unarchive(name=path, path=save_path)
default_downloader = Downloader()
......@@ -154,10 +154,12 @@ def seconds_to_hms(seconds: int) -> str:
hms_str = '{:0>2}:{:0>2}:{:0>2}'.format(h, m, s)
return hms_str
def cv2_to_base64(image: np.ndarray) -> str:
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str: str) -> np.ndarray:
'''Convert a string in base64 format to cv2 data'''
data = base64.b64decode(b64str.encode('utf8'))
......@@ -304,11 +306,11 @@ def record_exception(msg: str) -> str:
utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
def get_record_file():
def get_record_file() -> str:
return os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log'))
def is_port_occupied(ip, port):
def is_port_occupied(ip: str, port: int) -> bool:
'''
Check if port os occupied.
'''
......@@ -319,3 +321,9 @@ def is_port_occupied(ip, port):
return True
except:
return False
def mkdir(path: str):
"""The same as the shell command `mkdir -p`."""
if not os.path.exists(path):
os.makedirs(path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册