提交 da08f1c1 编写于 作者: K KP

Add RFT for asr task.

上级 99a76e43
......@@ -14,6 +14,7 @@
import argparse
import os
import sys
import time
from collections import OrderedDict
from typing import List
from typing import Optional
......@@ -29,8 +30,10 @@ from ..download import get_path_from_url
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import CLI_TIMER
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from ..utils import timer_register
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
......@@ -41,6 +44,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']
@timer_register
@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor):
......@@ -99,6 +103,11 @@ class ASRExecutor(BaseExecutor):
default=False,
help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate'
)
self.parser.add_argument(
'--rtf',
action="store_true",
default=False,
help='Show Real-time Factor(RTF).')
self.parser.add_argument(
'--device',
type=str,
......@@ -407,6 +416,7 @@ class ASRExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path
decode_method = parser_args.decode_method
force_yes = parser_args.yes
rtf = parser_args.rtf
device = parser_args.device
if not parser_args.verbose:
......@@ -419,12 +429,15 @@ class ASRExecutor(BaseExecutor):
for id_, input_ in task_source.items():
try:
res = self(input_, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device)
decode_method, force_yes, rtf, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
if rtf:
self.show_rtf(CLI_TIMER[self.__class__.__name__])
self.process_task_results(parser_args.input, task_results,
parser_args.job_dump_result)
......@@ -443,6 +456,7 @@ class ASRExecutor(BaseExecutor):
ckpt_path: os.PathLike=None,
decode_method: str='attention_rescoring',
force_yes: bool=False,
rtf: bool=False,
device=paddle.get_device()):
"""
Python API to call an executor.
......@@ -454,7 +468,18 @@ class ASRExecutor(BaseExecutor):
self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path)
self.preprocess(model, audio_file)
self.infer(model)
if rtf:
k = self.__class__.__name__
CLI_TIMER[k]['start'].append(time.time())
self.infer(model)
CLI_TIMER[k]['end'].append(time.time())
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate)
else:
self.infer(model)
res = self.postprocess() # Retrieve result of asr.
return res
......@@ -235,3 +235,19 @@ class BaseExecutor(ABC):
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def show_rtf(self, info: Dict[str, List[float]]):
"""
Calculate rft of current task and show results.
"""
num_samples = 0
task_duration = 0.0
wav_duration = 0.0
for start, end, dur in zip(info['start'], info['end'], info['extra']):
num_samples += 1
task_duration += end - start
wav_duration += dur
logger.info('Sample Count: {}'.format(num_samples))
logger.info('Avg RTF: {}'.format(task_duration / wav_duration))
......@@ -24,11 +24,11 @@ from typing import Any
from typing import Dict
import paddle
import paddleaudio
import requests
import yaml
from paddle.framework import load
import paddleaudio
from . import download
from .entry import commands
try:
......@@ -39,6 +39,7 @@ except ImportError:
requests.adapters.DEFAULT_RETRIES = 3
__all__ = [
'timer_register',
'cli_register',
'get_command',
'download_and_decompress',
......@@ -46,6 +47,13 @@ __all__ = [
'stats_wrapper',
]
CLI_TIMER = {}
def timer_register(command):
CLI_TIMER[command.__name__] = {'start': [], 'end': [], 'extra': []}
return command
def cli_register(name: str, description: str='') -> Any:
def _warpper(command):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册