From 1a3c811f04ac4e92563c41ee647e4595ff077af9 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Fri, 8 Apr 2022 15:59:52 +0800 Subject: [PATCH] code format, test=doc --- .../server/engine/asr/online/asr_engine.py | 18 +- .../server/engine/tts/online/tts_engine.py | 161 +++++------------- .../server/tests/tts/online/ws_client.py | 4 +- .../tests/tts/online/ws_client_playaudio.py | 4 +- paddlespeech/server/utils/audio_process.py | 14 ++ paddlespeech/server/utils/util.py | 13 ++ paddlespeech/server/ws/tts_socket.py | 4 +- 7 files changed, 73 insertions(+), 145 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 389175a0..ca82b615 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -27,6 +27,7 @@ from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor __all__ = ['ASREngine'] @@ -222,21 +223,6 @@ class ASRServerExecutor(ASRExecutor): else: raise Exception("invalid model name") - def _pcm16to32(self, audio): - """pcm int16 to float32 - - Args: - audio(numpy.array): numpy.int16 - - Returns: - audio(numpy.array): numpy.float32 - """ - if audio.dtype == np.int16: - audio = audio.astype("float32") - bits = np.iinfo(np.int16).bits - audio = audio / (2**(bits - 1)) - return audio - def extract_feat(self, samples, sample_rate): """extract feat @@ -249,7 +235,7 @@ class ASRServerExecutor(ASRExecutor): x_chunk_lens (numpy.array): shape[B] """ # pcm16 -> pcm 32 - samples = self._pcm16to32(samples) + samples = pcm2float(samples) # read audio speech_segment = SpeechSegment.from_pcm( diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py index 2f068b3b..25a8bc76 100644 --- a/paddlespeech/server/engine/tts/online/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -12,29 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 -import io import time -import librosa import numpy as np import paddle -import soundfile as sf -from scipy.io import wavfile from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.server.engine.base_engine import BaseEngine -from paddlespeech.server.utils.audio_process import change_speed -from paddlespeech.server.utils.errors import ErrorCode -from paddlespeech.server.utils.exception import ServerBaseException from paddlespeech.server.utils.audio_process import float2pcm -from paddlespeech.server.utils.config import get_config -from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import get_chunks - -import math - __all__ = ['TTSEngine'] @@ -44,15 +32,16 @@ class TTSServerExecutor(TTSExecutor): pass @paddle.no_grad() - def infer(self, - text: str, - lang: str='zh', - am: str='fastspeech2_csmsc', - spk_id: int=0, - am_block: int=42, - am_pad: int=12, - voc_block: int=14, - voc_pad: int=14,): + def infer( + self, + text: str, + lang: str='zh', + am: str='fastspeech2_csmsc', + spk_id: int=0, + am_block: int=42, + am_pad: int=12, + voc_block: int=14, + voc_pad: int=14, ): """ Model inference and result stored in self.output. """ @@ -61,8 +50,6 @@ class TTSServerExecutor(TTSExecutor): get_tone_ids = False merge_sentences = False frontend_st = time.time() - if am_name == 'speedyspeech': - get_tone_ids = True if lang == 'zh': input_ids = self.frontend.get_input_ids( text, @@ -95,7 +82,7 @@ class TTSServerExecutor(TTSExecutor): else: mel = self.am_inference(part_phone_ids) am_et = time.time() - + # voc streaming voc_upsample = self.voc_config.n_shift mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") @@ -103,17 +90,19 @@ class TTSServerExecutor(TTSExecutor): voc_st = time.time() for i, mel_chunk in enumerate(mel_chunks): sub_wav = self.voc_inference(mel_chunk) - front_pad = min(i*voc_block, voc_pad) + front_pad = min(i * voc_block, voc_pad) if i == 0: - sub_wav = sub_wav[: voc_block * voc_upsample] + sub_wav = sub_wav[:voc_block * voc_upsample] elif i == chunk_num - 1: - sub_wav = sub_wav[front_pad * voc_upsample : ] + sub_wav = sub_wav[front_pad * voc_upsample:] else: - sub_wav = sub_wav[front_pad * voc_upsample: (front_pad + voc_block) * voc_upsample] - + sub_wav = sub_wav[front_pad * voc_upsample:( + front_pad + voc_block) * voc_upsample] + yield sub_wav + class TTSEngine(BaseEngine): """TTS server engine @@ -128,9 +117,11 @@ class TTSEngine(BaseEngine): def init(self, config: dict) -> bool: self.executor = TTSServerExecutor() - + self.config = config + assert "fastspeech2_csmsc" in config.am and ( + config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc" + ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' try: - self.config = config if self.config.device: self.device = self.config.device else: @@ -176,86 +167,11 @@ class TTSEngine(BaseEngine): def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: - text_bytes = base64.b64decode(text_bese64) # base64 to bytes - text = text_bytes.decode('utf-8') # bytes to text + text_bytes = base64.b64decode(text_bese64) # base64 to bytes + text = text_bytes.decode('utf-8') # bytes to text return text - def postprocess(self, - wav, - original_fs: int, - target_fs: int=0, - volume: float=1.0, - speed: float=1.0, - audio_path: str=None): - """Post-processing operations, including speech, volume, sample rate, save audio file - - Args: - wav (numpy(float)): Synthesized audio sample points - original_fs (int): original audio sample rate - target_fs (int): target audio sample rate - volume (float): target volume - speed (float): target speed - - Raises: - ServerBaseException: Throws an exception if the change speed unsuccessfully. - - Returns: - target_fs: target sample rate for synthesized audio. - wav_base64: The base64 format of the synthesized audio. - """ - - # transform sample_rate - if target_fs == 0 or target_fs > original_fs: - target_fs = original_fs - wav_tar_fs = wav - logger.info( - "The sample rate of synthesized audio is the same as model, which is {}Hz". - format(original_fs)) - else: - wav_tar_fs = librosa.resample( - np.squeeze(wav), original_fs, target_fs) - logger.info( - "The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.". - format(original_fs, target_fs)) - # transform volume - wav_vol = wav_tar_fs * volume - logger.info("Transform the volume of the audio successfully.") - - # transform speed - try: # windows not support soxbindings - wav_speed = change_speed(wav_vol, speed, target_fs) - logger.info("Transform the speed of the audio successfully.") - except ServerBaseException: - raise ServerBaseException( - ErrorCode.SERVER_INTERNAL_ERR, - "Failed to transform speed. Can not install soxbindings on your system. \ - You need to set speed value 1.0.") - except BaseException: - logger.error("Failed to transform speed.") - - # wav to base64 - buf = io.BytesIO() - wavfile.write(buf, target_fs, wav_speed) - base64_bytes = base64.b64encode(buf.read()) - wav_base64 = base64_bytes.decode('utf-8') - logger.info("Audio to string successfully.") - - # save audio - if audio_path is not None: - if audio_path.endswith(".wav"): - sf.write(audio_path, wav_speed, target_fs) - elif audio_path.endswith(".pcm"): - wav_norm = wav_speed * (32767 / max(0.001, - np.max(np.abs(wav_speed)))) - with open(audio_path, "wb") as f: - f.write(wav_norm.astype(np.int16)) - logger.info("Save audio to {} successfully.".format(audio_path)) - else: - logger.info("There is no need to save audio.") - - return target_fs, wav_base64 - def run(self, sentence: str, spk_id: int=0, @@ -275,31 +191,30 @@ class TTSEngine(BaseEngine): save_path (str, optional): The save path of the synthesized audio. None means do not save audio. Defaults to None. - Raises: - ServerBaseException: Throws an exception if tts inference unsuccessfully. - ServerBaseException: Throws an exception if postprocess unsuccessfully. - Returns: - lang: model language - target_sample_rate: target sample rate for synthesized audio. wav_base64: The base64 format of the synthesized audio. """ lang = self.config.lang wav_list = [] - for wav in self.executor.infer(text=sentence, lang=lang, am=self.config.am, spk_id=spk_id, am_block=self.am_block, am_pad=self.am_pad, voc_block=self.voc_block, voc_pad=self.voc_pad): + for wav in self.executor.infer( + text=sentence, + lang=lang, + am=self.config.am, + spk_id=spk_id, + am_block=self.am_block, + am_pad=self.am_pad, + voc_block=self.voc_block, + voc_pad=self.voc_pad): # wav type: float32, convert to pcm (base64) - wav = float2pcm(wav) # float32 to int16 + wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64 wav_list.append(wav) - - yield wav_base64 - - wav_all = np.concatenate(wav_list, axis=0) - logger.info("The durations of audio is: {} s".format(len(wav_all)/self.executor.am_config.fs)) - + yield wav_base64 - + wav_all = np.concatenate(wav_list, axis=0) + logger.info("The durations of audio is: {} s".format( + len(wav_all) / self.executor.am_config.fs)) diff --git a/paddlespeech/server/tests/tts/online/ws_client.py b/paddlespeech/server/tests/tts/online/ws_client.py index e0f47b55..eef010cf 100644 --- a/paddlespeech/server/tests/tts/online/ws_client.py +++ b/paddlespeech/server/tests/tts/online/ws_client.py @@ -25,7 +25,7 @@ st = 0.0 all_bytes = b'' -class Ws_Param(object): +class WsParam(object): # 初始化 def __init__(self, text, server="127.0.0.1", port=8090): self.server = server @@ -116,7 +116,7 @@ if __name__ == "__main__": print("Sentence to be synthesized: ", args.text) print("***************************************") - wsParam = Ws_Param(text=args.text, server=args.server, port=args.port) + wsParam = WsParam(text=args.text, server=args.server, port=args.port) websocket.enableTrace(False) wsUrl = wsParam.create_url() diff --git a/paddlespeech/server/tests/tts/online/ws_client_playaudio.py b/paddlespeech/server/tests/tts/online/ws_client_playaudio.py index 4e1c538d..cdeb362d 100644 --- a/paddlespeech/server/tests/tts/online/ws_client_playaudio.py +++ b/paddlespeech/server/tests/tts/online/ws_client_playaudio.py @@ -32,7 +32,7 @@ st = 0.0 all_bytes = 0.0 -class Ws_Param(object): +class WsParam(object): # 初始化 def __init__(self, text, server="127.0.0.1", port=8090): self.server = server @@ -144,7 +144,7 @@ if __name__ == "__main__": print("Sentence to be synthesized: ", args.text) print("***************************************") - wsParam = Ws_Param(text=args.text, server=args.server, port=args.port) + wsParam = WsParam(text=args.text, server=args.server, port=args.port) websocket.enableTrace(False) wsUrl = wsParam.create_url() diff --git a/paddlespeech/server/utils/audio_process.py b/paddlespeech/server/utils/audio_process.py index 1d4b158c..e85b9a27 100644 --- a/paddlespeech/server/utils/audio_process.py +++ b/paddlespeech/server/utils/audio_process.py @@ -126,3 +126,17 @@ def float2pcm(sig, dtype='int16'): abs_max = 2**(i.bits - 1) offset = i.min + abs_max return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype) + + +def pcm2float(data): + """pcm int16 to float32 + Args: + audio(numpy.array): numpy.int16 + Returns: + audio(numpy.array): numpy.float32 + """ + if data.dtype == np.int16: + data = data.astype("float32") + bits = np.iinfo(np.int16).bits + data = data / (2**(bits - 1)) + return data diff --git a/paddlespeech/server/utils/util.py b/paddlespeech/server/utils/util.py index c35939b7..0fe70849 100644 --- a/paddlespeech/server/utils/util.py +++ b/paddlespeech/server/utils/util.py @@ -35,10 +35,23 @@ def self_check(): def denorm(data, mean, std): + """stream am model need to denorm + """ return data * std + mean def get_chunks(data, block_size, pad_size, step): + """Divide data into multiple chunks + + Args: + data (tensor): data + block_size (int): [description] + pad_size (int): [description] + step (str): set "am" or "voc", generate chunk for step am or vocoder(voc) + + Returns: + list: chunks list + """ if step == "am": data_len = data.shape[1] elif step == "voc": diff --git a/paddlespeech/server/ws/tts_socket.py b/paddlespeech/server/ws/tts_socket.py index 4df2850a..11458b3c 100644 --- a/paddlespeech/server/ws/tts_socket.py +++ b/paddlespeech/server/ws/tts_socket.py @@ -44,11 +44,11 @@ async def websocket_endpoint(websocket: WebSocket): sentence = tts_engine.preprocess(text_bese64=text_bese64) # run - wav = tts_engine.run(sentence) + wav_generator = tts_engine.run(sentence) while True: try: - tts_results = next(wav) + tts_results = next(wav_generator) resp = {"status": 1, "audio": tts_results} await websocket.send_json(resp) logger.info("streaming audio...") -- GitLab