提交 087acab2 编写于 作者: W wuzewu

Fix the tts module compatibility issue

上级 1dea3f05
...@@ -19,8 +19,10 @@ from easydict import EasyDict ...@@ -19,8 +19,10 @@ from easydict import EasyDict
__version__ = '2.0.0-alpha0' __version__ = '2.0.0-alpha0'
from paddlehub import env
from paddlehub.config import config from paddlehub.config import config
from paddlehub.utils import log, parser, utils from paddlehub.utils import log, parser, utils
from paddlehub.utils import download as _download
from paddlehub.utils.paddlex import download, ResourceNotFoundError from paddlehub.utils.paddlex import download, ResourceNotFoundError
from paddlehub.server import server_check from paddlehub.server import server_check
from paddlehub.server.server_source import ServerConnectionError from paddlehub.server.server_source import ServerConnectionError
...@@ -40,6 +42,8 @@ from paddlehub.compat.task.config import RunConfig ...@@ -40,6 +42,8 @@ from paddlehub.compat.task.config import RunConfig
from paddlehub.compat.task.text_generation_task import TextGenerationTask from paddlehub.compat.task.text_generation_task import TextGenerationTask
sys.modules['paddlehub.io.parser'] = parser sys.modules['paddlehub.io.parser'] = parser
sys.modules['paddlehub.common.dir'] = env
sys.modules['paddlehub.common.downloader'] = _download
sys.modules['paddlehub.common.logger'] = log sys.modules['paddlehub.common.logger'] = log
sys.modules['paddlehub.common.paddle_helper'] = paddle_utils sys.modules['paddlehub.common.paddle_helper'] = paddle_utils
sys.modules['paddlehub.common.utils'] = utils sys.modules['paddlehub.common.utils'] = utils
......
...@@ -67,15 +67,7 @@ class RunModule(object): ...@@ -67,15 +67,7 @@ class RunModule(object):
'''The base class of PaddleHub Module, users can inherit this class to implement to realize custom class.''' '''The base class of PaddleHub Module, users can inherit this class to implement to realize custom class.'''
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# Avoid module being initialized multiple times
if '_is_initialize' in self.__dict__ and self._is_initialize:
return
super(RunModule, self).__init__() super(RunModule, self).__init__()
_run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
self._run_func = getattr(self, _run_func_name) if _run_func_name else None
self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
self._is_initialize = True
def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]: def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
mod = current_cls.__module__ + '.' + current_cls.__name__ mod = current_cls.__module__ + '.' + current_cls.__name__
...@@ -133,7 +125,7 @@ class RunModule(object): ...@@ -133,7 +125,7 @@ class RunModule(object):
`hub run` command. `hub run` command.
''' '''
return True if self._run_func else False return True if self._run_func else False
@property @property
def serving_func_name(self): def serving_func_name(self):
return self._get_func_name(self.__class__, _module_serving_func) return self._get_func_name(self.__class__, _module_serving_func)
...@@ -343,4 +335,4 @@ def moduleinfo(name: str, ...@@ -343,4 +335,4 @@ def moduleinfo(name: str,
wrap_cls._hook_by_hub = True wrap_cls._hook_by_hub = True
return wrap_cls return wrap_cls
return _wrapper return _wrapper
\ No newline at end of file
...@@ -15,18 +15,40 @@ ...@@ -15,18 +15,40 @@
import os import os
from paddlehub.env import DATA_HOME import paddlehub.env as hubenv
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from paddlehub.utils import log, utils, xarfile
def download_data(url): def download_data(url):
save_name = os.path.basename(url).split('.')[0] save_name = os.path.basename(url).split('.')[0]
output_path = os.path.join(DATA_HOME, save_name) output_path = os.path.join(hubenv.DATA_HOME, save_name)
if not os.path.exists(output_path): if not os.path.exists(output_path):
get_path_from_url(url, DATA_HOME) get_path_from_url(url, hubenv.DATA_HOME)
def _wrapper(Dataset): def _wrapper(Dataset):
return Dataset return Dataset
return _wrapper return _wrapper
class Downloader:
def download_file_and_uncompress(self, url: str, save_path: str, print_progress: bool):
with utils.generate_tempdir() as _dir:
if print_progress:
with log.ProgressBar('Download {}'.format(url)) as bar:
for path, ds, ts in utils.download_with_progress(url=url, path=_dir):
bar.update(float(ds) / ts)
else:
path = utils.download(url=url, path=_dir)
if print_progress:
with log.ProgressBar('Decompress {}'.format(path)) as bar:
for path, ds, ts in xarfile.unarchive_with_progress(name=path, path=save_path):
bar.update(float(ds) / ts)
else:
path = xarfile.unarchive(name=path, path=save_path)
default_downloader = Downloader()
...@@ -154,10 +154,12 @@ def seconds_to_hms(seconds: int) -> str: ...@@ -154,10 +154,12 @@ def seconds_to_hms(seconds: int) -> str:
hms_str = '{:0>2}:{:0>2}:{:0>2}'.format(h, m, s) hms_str = '{:0>2}:{:0>2}:{:0>2}'.format(h, m, s)
return hms_str return hms_str
def cv2_to_base64(image: np.ndarray) -> str: def cv2_to_base64(image: np.ndarray) -> str:
data = cv2.imencode('.jpg', image)[1] data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8') return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str: str) -> np.ndarray: def base64_to_cv2(b64str: str) -> np.ndarray:
'''Convert a string in base64 format to cv2 data''' '''Convert a string in base64 format to cv2 data'''
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
...@@ -304,11 +306,11 @@ def record_exception(msg: str) -> str: ...@@ -304,11 +306,11 @@ def record_exception(msg: str) -> str:
utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file)) utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
def get_record_file(): def get_record_file() -> str:
return os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log')) return os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log'))
def is_port_occupied(ip, port): def is_port_occupied(ip: str, port: int) -> bool:
''' '''
Check if port os occupied. Check if port os occupied.
''' '''
...@@ -319,3 +321,9 @@ def is_port_occupied(ip, port): ...@@ -319,3 +321,9 @@ def is_port_occupied(ip, port):
return True return True
except: except:
return False return False
def mkdir(path: str):
"""The same as the shell command `mkdir -p`."""
if not os.path.exists(path):
os.makedirs(path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册