diff --git a/speechserving/speechserving/engine/tts/python/tts_engine.py b/speechserving/speechserving/engine/tts/python/tts_engine.py index 4f0e99066197ccdcd319dc7d912bcf91d7ad684a..65e35fb8fe77bd86b33d9fee91de3a70499c1fc1 100644 --- a/speechserving/speechserving/engine/tts/python/tts_engine.py +++ b/speechserving/speechserving/engine/tts/python/tts_engine.py @@ -13,15 +13,19 @@ # limitations under the License. import argparse import base64 +import os +import random import librosa import numpy as np import soundfile as sf import yaml from engine.base_engine import BaseEngine +from ffmpeg import audio from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from utils.audio_types import wav2pcm from utils.errors import ErrorCode from utils.exception import ServerBaseException @@ -80,8 +84,7 @@ class TTSEngine(BaseEngine): target_fs: int=16000, volume: float=1.0, speed: float=1.0, - audio_path: str=None, - audio_format: str="wav"): + audio_path: str=None): """Post-processing operations, including speech, volume, sample rate, save audio file Args: @@ -104,18 +107,26 @@ class TTSEngine(BaseEngine): wav_vol = wav_tar_fs * volume # transform speed - # TODO - target_wav = wav_vol.reshape(-1, 1) - - # save audio - if audio_path is not None: - sf.write(audio_path, target_wav, target_fs) - logger.info('Wave file has been generated: {}'.format(audio_path)) + hash = random.getrandbits(128) + temp_wav = str(hash) + ".wav" + temp_speed_wav = str(hash + 1) + ".wav" + sf.write(temp_wav, wav_vol.reshape(-1, 1), target_fs) + audio.a_speed(temp_wav, speed, temp_speed_wav) + os.system("rm %s" % (temp_wav)) # wav to base64 - base64_bytes = base64.b64encode(target_wav) - base64_string = base64_bytes.decode('utf-8') - wav_base64 = base64_string + with open(temp_speed_wav, 'rb') as f: + base64_bytes = base64.b64encode(f.read()) + wav_base64 = base64_bytes.decode('utf-8') + + # save audio + if audio_path is not None and audio_path.endswith(".wav"): + os.system("mv %s %s" % (temp_speed_wav, audio_path)) + elif audio_path is not None and audio_path.endswith(".pcm"): + wav2pcm(temp_speed_wav, audio_path, data_type=np.int16) + os.system("rm %s" % (temp_speed_wav)) + else: + os.system("rm %s" % (temp_speed_wav)) return target_fs, wav_base64 @@ -125,8 +136,7 @@ class TTSEngine(BaseEngine): speed: float=1.0, volume: float=1.0, sample_rate: int=0, - save_path: str=None, - audio_format: str="wav"): + save_path: str=None): lang = self.conf_dict["lang"] @@ -147,8 +157,7 @@ class TTSEngine(BaseEngine): target_fs=sample_rate, volume=volume, speed=speed, - audio_path=save_path, - audio_format=audio_format) + audio_path=save_path) except: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts postprocess failed.") diff --git a/speechserving/speechserving/restful/request.py b/speechserving/speechserving/restful/request.py index 58a0f08de4aa772dae6e18f76b6f52e4fb6a558b..9ebc6d752aaee8dc5897091593a3a8c5a6e88511 100644 --- a/speechserving/speechserving/restful/request.py +++ b/speechserving/speechserving/restful/request.py @@ -54,8 +54,7 @@ class TTSRequest(BaseModel): "speed": 1.0, "volume": 1.0, "sample_rate": 0, - "tts_audio_path": "./tts.wav", - "audio_format": "wav" + "tts_audio_path": "./tts.wav" } """ @@ -66,5 +65,3 @@ class TTSRequest(BaseModel): volume: float = 1.0 sample_rate: int = 0 save_path: str = None - audio_format: str = "wav" - diff --git a/speechserving/speechserving/restful/tts_api.py b/speechserving/speechserving/restful/tts_api.py index 69930f242965f9549ee0b3500aaf89c3c6c7942e..e9dcfa16fa17e736b8ee79fbd8ee170a6fcacc0a 100644 --- a/speechserving/speechserving/restful/tts_api.py +++ b/speechserving/speechserving/restful/tts_api.py @@ -19,7 +19,6 @@ from fastapi import APIRouter from .request import TTSRequest from .response import TTSResponse from utils.errors import ErrorCode -from utils.errors import ErrorMsg from utils.errors import failed_response from utils.exception import ServerBaseException @@ -33,9 +32,9 @@ def help(): Returns: json: [description] """ - json_body = { + response = { "success": "True", - "code": 0, + "code": 200, "message": { "global": "success" }, @@ -45,7 +44,7 @@ def help(): "audio": "the base64 of audio" } } - return json_body + return response @router.post("/paddlespeech/tts", response_model=TTSResponse) @@ -66,12 +65,11 @@ def tts(request_body: TTSRequest): volume = item_dict['volume'] sample_rate = item_dict['sample_rate'] save_path = item_dict['save_path'] - audio_format = item_dict['audio_format'] # Check parameters if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ sample_rate not in [0, 16000, 8000] or \ - audio_format not in ["pcm", "wav"]: + (save_path is not None and save_path.endswith("pcm") == False and save_path.endswith("wav") == False): return failed_response(ErrorCode.SERVER_PARAM_ERR) # single @@ -80,29 +78,28 @@ def tts(request_body: TTSRequest): # run try: lang, target_sample_rate, wav_base64 = tts_engine.run( - sentence, spk_id, speed, volume, sample_rate, save_path, - audio_format) + sentence, spk_id, speed, volume, sample_rate, save_path) + + response = { + "success": True, + "code": 200, + "message": { + "description": "success." + }, + "result": { + "lang": lang, + "spk_id": spk_id, + "speed": speed, + "volume": volume, + "sample_rate": target_sample_rate, + "save_path": save_path, + "audio": wav_base64 + } + } except ServerBaseException as e: response = failed_response(e.error_code, e.msg) except: response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) traceback.print_exc() - json_body = { - "success": True, - "code": 200, - "message": { - "description": "success." - }, - "result": { - "lang": lang, - "spk_id": spk_id, - "speed": speed, - "volume": volume, - "sample_rate": target_sample_rate, - "save_path": save_path, - "audio": wav_base64 - } - } - - return json_body + return response diff --git a/speechserving/speechserving/utils/audio_types.py b/speechserving/speechserving/utils/audio_types.py new file mode 100644 index 0000000000000000000000000000000000000000..eb655ddd5902d27fb91fbc0718f7362400af91b4 --- /dev/null +++ b/speechserving/speechserving/utils/audio_types.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import wave + +import numpy as np + + +def wav2pcm(wavfile, pcmfile, data_type=np.int16): + f = open(wavfile, "rb") + f.seek(0) + f.read(44) + data = np.fromfile(f, dtype=data_type) + data.tofile(pcmfile) + + +def pcm2wav(pcm_file, wav_file, channels=1, bits=16, sample_rate=16000): + pcmf = open(pcm_file, 'rb') + pcmdata = pcmf.read() + pcmf.close() + + if bits % 8 != 0: + raise ValueError("bits % 8 must == 0. now bits:" + str(bits)) + + wavfile = wave.open(wav_file, 'wb') + wavfile.setnchannels(channels) + wavfile.setsampwidth(bits // 8) + wavfile.setframerate(sample_rate) + wavfile.writeframes(pcmdata) + wavfile.close()