未验证 提交 3151637a 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1460 from KPatr1ck/cli_batch

[CLI][Batch]Support batch input in cli.
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import sys import sys
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
...@@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor): ...@@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.asr', add_help=True) prog='paddlespeech.asr', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to recognize.') '--input', type=str, default=None, help='Audio file to recognize.')
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
...@@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor): ...@@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor): ...@@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate sample_rate = parser_args.sample_rate
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
decode_method = parser_args.decode_method decode_method = parser_args.decode_method
force_yes = parser_args.yes force_yes = parser_args.yes
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
res = self(audio_file, model, lang, sample_rate, config, ckpt_path, task_results = OrderedDict()
decode_method, force_yes, device) has_exceptions = False
logger.info('ASR Result: {}'.format(res))
return True for id_, input_ in task_source.items():
except Exception as e: try:
logger.exception(e) res = self(input_, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
...@@ -77,7 +79,7 @@ class CLSExecutor(BaseExecutor): ...@@ -77,7 +79,7 @@ class CLSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True) prog='paddlespeech.cls', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to classify.') '--input', type=str, default=None, help='Audio file to classify.')
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
...@@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor): ...@@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor): ...@@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret = '' ret = ''
for idx in topk_idx: for idx in topk_idx:
label, score = self._label_list[idx], result[idx] label, score = self._label_list[idx], result[idx]
ret += f'{label}: {score}\n' ret += f'{label} {score} '
return ret return ret
def postprocess(self, topk: int) -> Union[str, os.PathLike]: def postprocess(self, topk: int) -> Union[str, os.PathLike]:
...@@ -234,18 +241,30 @@ class CLSExecutor(BaseExecutor): ...@@ -234,18 +241,30 @@ class CLSExecutor(BaseExecutor):
label_file = parser_args.label_file label_file = parser_args.label_file
cfg_path = parser_args.config cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
topk = parser_args.topk topk = parser_args.topk
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
res = self(audio_file, model_type, cfg_path, ckpt_path, label_file, task_results = OrderedDict()
topk, device) has_exceptions = False
logger.info('CLS Result:\n{}'.format(res))
return True for id_, input_ in task_source.items():
except Exception as e: try:
logger.exception(e) res = self(input_, model_type, cfg_path, ckpt_path, label_file,
topk, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
...@@ -259,7 +278,7 @@ class CLSExecutor(BaseExecutor): ...@@ -259,7 +278,7 @@ class CLSExecutor(BaseExecutor):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(os.path.expanduser(audio_file))
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, config, ckpt_path, label_file) self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file) self.preprocess(audio_file)
......
...@@ -12,14 +12,19 @@ ...@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import sys
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from collections import OrderedDict
from typing import Any from typing import Any
from typing import Dict
from typing import List from typing import List
from typing import Union from typing import Union
import paddle import paddle
from .log import logger
class BaseExecutor(ABC): class BaseExecutor(ABC):
""" """
...@@ -27,8 +32,8 @@ class BaseExecutor(ABC): ...@@ -27,8 +32,8 @@ class BaseExecutor(ABC):
""" """
def __init__(self): def __init__(self):
self._inputs = dict() self._inputs = OrderedDict()
self._outputs = dict() self._outputs = OrderedDict()
@abstractmethod @abstractmethod
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
...@@ -100,3 +105,107 @@ class BaseExecutor(ABC): ...@@ -100,3 +105,107 @@ class BaseExecutor(ABC):
Python API to call an executor. Python API to call an executor.
""" """
pass pass
def get_task_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]:
"""
Get task input source from command line input.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
Returns:
Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs.
"""
if self._is_job_input(input_):
ret = self._get_job_contents(input_)
else:
ret = OrderedDict()
if input_ is None: # Take input from stdin
for i, line in enumerate(sys.stdin):
line = line.strip()
if len(line.split(' ')) == 1:
ret[str(i + 1)] = line
elif len(line.split(' ')) == 2:
id_, info = line.split(' ')
ret[id_] = info
else: # No valid input info from one line.
continue
else:
ret[1] = input_
return ret
def process_task_results(self,
input_: Union[str, os.PathLike, None],
results: Dict[str, os.PathLike],
job_dump_result: bool=False):
"""
Handling task results and redirect stdout if needed.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
results (Dict[str, os.PathLike]): Task outputs.
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
raw_text = self._format_task_results(results)
print(raw_text, end='')
if self._is_job_input(input_) and job_dump_result:
try:
job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w')
print(raw_text, end='')
logger.info(f'Results had been saved to: {job_output_file}')
finally:
sys.stdout.close()
def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool:
"""
Check if current input file is a job input or not.
Args:
input_ (Union[str, os.PathLike]): Input file of current task.
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return input_ and os.path.isfile(input_) and input_.endswith('.job')
def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
"""
Read a job input file and return its contents in a dictionary.
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = OrderedDict()
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k, v = line.split(' ')
job_contents[k] = v
return job_contents
def _format_task_results(
self, results: Dict[str, Union[str, os.PathLike]]) -> str:
"""
Convert task results to raw text.
Args:
results (Dict[str, str]): A dictionary of task results.
Returns:
str: A string object contains task results.
"""
ret = ''
for k, v in results.items():
ret += f'{k} {v}\n'
return ret
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import subprocess import subprocess
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
...@@ -69,7 +71,7 @@ class STExecutor(BaseExecutor): ...@@ -69,7 +71,7 @@ class STExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog="paddlespeech.st", add_help=True) prog="paddlespeech.st", add_help=True)
self.parser.add_argument( self.parser.add_argument(
"--input", type=str, required=True, help="Audio file to translate.") "--input", type=str, default=None, help="Audio file to translate.")
self.parser.add_argument( self.parser.add_argument(
"--model", "--model",
type=str, type=str,
...@@ -107,6 +109,11 @@ class STExecutor(BaseExecutor): ...@@ -107,6 +109,11 @@ class STExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help="Choose device to execute model inference.") help="Choose device to execute model inference.")
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -319,17 +326,29 @@ class STExecutor(BaseExecutor): ...@@ -319,17 +326,29 @@ class STExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate sample_rate = parser_args.sample_rate
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
res = self(audio_file, model, src_lang, tgt_lang, sample_rate, task_results = OrderedDict()
config, ckpt_path, device) has_exceptions = False
logger.info("ST Result: {}".format(res))
return True for id_, input_ in task_source.items():
except Exception as e: try:
logger.exception(e) res = self(input_, model, src_lang, tgt_lang, sample_rate,
config, ckpt_path, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
......
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import re import re
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
...@@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor): ...@@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True) prog='paddlespeech.text', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Input text.') '--input', type=str, default=None, help='Input text.')
self.parser.add_argument( self.parser.add_argument(
'--task', '--task',
type=str, type=str,
...@@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor): ...@@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor): ...@@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor):
""" """
parser_args = self.parser.parse_args(argv) parser_args = self.parser.parse_args(argv)
text = parser_args.input
task = parser_args.task task = parser_args.task
model_type = parser_args.model model_type = parser_args.model
lang = parser_args.lang lang = parser_args.lang
...@@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor): ...@@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab punc_vocab = parser_args.punc_vocab
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
res = self(text, task, model_type, lang, cfg_path, ckpt_path, task_results = OrderedDict()
punc_vocab, device) has_exceptions = False
logger.info('Text Result:\n{}'.format(res))
return True for id_, input_ in task_source.items():
except Exception as e: try:
logger.exception(e) res = self(input_, task, model_type, lang, cfg_path, ckpt_path,
punc_vocab, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__( def __call__(
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Optional from typing import Optional
...@@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor): ...@@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True) prog='paddlespeech.tts', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Input text to generate.') '--input', type=str, default=None, help='Input text to generate.')
# acoustic model # acoustic model
self.parser.add_argument( self.parser.add_argument(
'--am', '--am',
...@@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor): ...@@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name') '--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor): ...@@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor):
args = self.parser.parse_args(argv) args = self.parser.parse_args(argv)
text = args.input
am = args.am am = args.am
am_config = args.am_config am_config = args.am_config
am_ckpt = args.am_ckpt am_ckpt = args.am_ckpt
...@@ -686,35 +692,53 @@ class TTSExecutor(BaseExecutor): ...@@ -686,35 +692,53 @@ class TTSExecutor(BaseExecutor):
voc_stat = args.voc_stat voc_stat = args.voc_stat
lang = args.lang lang = args.lang
device = args.device device = args.device
output = args.output
spk_id = args.spk_id spk_id = args.spk_id
job_dump_result = args.job_dump_result
try: task_source = self.get_task_source(args.input)
res = self( task_results = OrderedDict()
text=text, has_exceptions = False
# acoustic model related
am=am, for id_, input_ in task_source.items():
am_config=am_config, if len(task_source) > 1:
am_ckpt=am_ckpt, assert isinstance(args.output,
am_stat=am_stat, str) and args.output.endswith('.wav')
phones_dict=phones_dict, output = args.output.replace('.wav', f'_{id_}.wav')
tones_dict=tones_dict, else:
speaker_dict=speaker_dict, output = args.output
spk_id=spk_id,
# vocoder related try:
voc=voc, res = self(
voc_config=voc_config, text=input_,
voc_ckpt=voc_ckpt, # acoustic model related
voc_stat=voc_stat, am=am,
# other am_config=am_config,
lang=lang, am_ckpt=am_ckpt,
device=device, am_stat=am_stat,
output=output) phones_dict=phones_dict,
logger.info('Wave file has been generated: {}'.format(res)) tones_dict=tones_dict,
return True speaker_dict=speaker_dict,
except Exception as e: spk_id=spk_id,
logger.exception(e) # vocoder related
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
# other
lang=lang,
device=device,
output=output)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(args.input, task_results, job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册