未验证 提交 35738988 编写于 作者: W WilliamZhang06 提交者: GitHub

Merge pull request #1411 from lym0302/tts-server2

[server] add tts postprocess
...@@ -13,15 +13,19 @@ ...@@ -13,15 +13,19 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import base64 import base64
import os
import random
import librosa import librosa
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
import yaml import yaml
from engine.base_engine import BaseEngine from engine.base_engine import BaseEngine
from ffmpeg import audio
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.tts.infer import TTSExecutor
from utils.audio_types import wav2pcm
from utils.errors import ErrorCode from utils.errors import ErrorCode
from utils.exception import ServerBaseException from utils.exception import ServerBaseException
...@@ -80,8 +84,7 @@ class TTSEngine(BaseEngine): ...@@ -80,8 +84,7 @@ class TTSEngine(BaseEngine):
target_fs: int=16000, target_fs: int=16000,
volume: float=1.0, volume: float=1.0,
speed: float=1.0, speed: float=1.0,
audio_path: str=None, audio_path: str=None):
audio_format: str="wav"):
"""Post-processing operations, including speech, volume, sample rate, save audio file """Post-processing operations, including speech, volume, sample rate, save audio file
Args: Args:
...@@ -104,18 +107,26 @@ class TTSEngine(BaseEngine): ...@@ -104,18 +107,26 @@ class TTSEngine(BaseEngine):
wav_vol = wav_tar_fs * volume wav_vol = wav_tar_fs * volume
# transform speed # transform speed
# TODO hash = random.getrandbits(128)
target_wav = wav_vol.reshape(-1, 1) temp_wav = str(hash) + ".wav"
temp_speed_wav = str(hash + 1) + ".wav"
# save audio sf.write(temp_wav, wav_vol.reshape(-1, 1), target_fs)
if audio_path is not None: audio.a_speed(temp_wav, speed, temp_speed_wav)
sf.write(audio_path, target_wav, target_fs) os.system("rm %s" % (temp_wav))
logger.info('Wave file has been generated: {}'.format(audio_path))
# wav to base64 # wav to base64
base64_bytes = base64.b64encode(target_wav) with open(temp_speed_wav, 'rb') as f:
base64_string = base64_bytes.decode('utf-8') base64_bytes = base64.b64encode(f.read())
wav_base64 = base64_string wav_base64 = base64_bytes.decode('utf-8')
# save audio
if audio_path is not None and audio_path.endswith(".wav"):
os.system("mv %s %s" % (temp_speed_wav, audio_path))
elif audio_path is not None and audio_path.endswith(".pcm"):
wav2pcm(temp_speed_wav, audio_path, data_type=np.int16)
os.system("rm %s" % (temp_speed_wav))
else:
os.system("rm %s" % (temp_speed_wav))
return target_fs, wav_base64 return target_fs, wav_base64
...@@ -125,8 +136,7 @@ class TTSEngine(BaseEngine): ...@@ -125,8 +136,7 @@ class TTSEngine(BaseEngine):
speed: float=1.0, speed: float=1.0,
volume: float=1.0, volume: float=1.0,
sample_rate: int=0, sample_rate: int=0,
save_path: str=None, save_path: str=None):
audio_format: str="wav"):
lang = self.conf_dict["lang"] lang = self.conf_dict["lang"]
...@@ -147,8 +157,7 @@ class TTSEngine(BaseEngine): ...@@ -147,8 +157,7 @@ class TTSEngine(BaseEngine):
target_fs=sample_rate, target_fs=sample_rate,
volume=volume, volume=volume,
speed=speed, speed=speed,
audio_path=save_path, audio_path=save_path)
audio_format=audio_format)
except: except:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
......
...@@ -54,8 +54,7 @@ class TTSRequest(BaseModel): ...@@ -54,8 +54,7 @@ class TTSRequest(BaseModel):
"speed": 1.0, "speed": 1.0,
"volume": 1.0, "volume": 1.0,
"sample_rate": 0, "sample_rate": 0,
"tts_audio_path": "./tts.wav", "tts_audio_path": "./tts.wav"
"audio_format": "wav"
} }
""" """
...@@ -66,5 +65,3 @@ class TTSRequest(BaseModel): ...@@ -66,5 +65,3 @@ class TTSRequest(BaseModel):
volume: float = 1.0 volume: float = 1.0
sample_rate: int = 0 sample_rate: int = 0
save_path: str = None save_path: str = None
audio_format: str = "wav"
...@@ -19,7 +19,6 @@ from fastapi import APIRouter ...@@ -19,7 +19,6 @@ from fastapi import APIRouter
from .request import TTSRequest from .request import TTSRequest
from .response import TTSResponse from .response import TTSResponse
from utils.errors import ErrorCode from utils.errors import ErrorCode
from utils.errors import ErrorMsg
from utils.errors import failed_response from utils.errors import failed_response
from utils.exception import ServerBaseException from utils.exception import ServerBaseException
...@@ -33,9 +32,9 @@ def help(): ...@@ -33,9 +32,9 @@ def help():
Returns: Returns:
json: [description] json: [description]
""" """
json_body = { response = {
"success": "True", "success": "True",
"code": 0, "code": 200,
"message": { "message": {
"global": "success" "global": "success"
}, },
...@@ -45,7 +44,7 @@ def help(): ...@@ -45,7 +44,7 @@ def help():
"audio": "the base64 of audio" "audio": "the base64 of audio"
} }
} }
return json_body return response
@router.post("/paddlespeech/tts", response_model=TTSResponse) @router.post("/paddlespeech/tts", response_model=TTSResponse)
...@@ -66,12 +65,11 @@ def tts(request_body: TTSRequest): ...@@ -66,12 +65,11 @@ def tts(request_body: TTSRequest):
volume = item_dict['volume'] volume = item_dict['volume']
sample_rate = item_dict['sample_rate'] sample_rate = item_dict['sample_rate']
save_path = item_dict['save_path'] save_path = item_dict['save_path']
audio_format = item_dict['audio_format']
# Check parameters # Check parameters
if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \
sample_rate not in [0, 16000, 8000] or \ sample_rate not in [0, 16000, 8000] or \
audio_format not in ["pcm", "wav"]: (save_path is not None and save_path.endswith("pcm") == False and save_path.endswith("wav") == False):
return failed_response(ErrorCode.SERVER_PARAM_ERR) return failed_response(ErrorCode.SERVER_PARAM_ERR)
# single # single
...@@ -80,29 +78,28 @@ def tts(request_body: TTSRequest): ...@@ -80,29 +78,28 @@ def tts(request_body: TTSRequest):
# run # run
try: try:
lang, target_sample_rate, wav_base64 = tts_engine.run( lang, target_sample_rate, wav_base64 = tts_engine.run(
sentence, spk_id, speed, volume, sample_rate, save_path, sentence, spk_id, speed, volume, sample_rate, save_path)
audio_format)
response = {
"success": True,
"code": 200,
"message": {
"description": "success."
},
"result": {
"lang": lang,
"spk_id": spk_id,
"speed": speed,
"volume": volume,
"sample_rate": target_sample_rate,
"save_path": save_path,
"audio": wav_base64
}
}
except ServerBaseException as e: except ServerBaseException as e:
response = failed_response(e.error_code, e.msg) response = failed_response(e.error_code, e.msg)
except: except:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc() traceback.print_exc()
json_body = { return response
"success": True,
"code": 200,
"message": {
"description": "success."
},
"result": {
"lang": lang,
"spk_id": spk_id,
"speed": speed,
"volume": volume,
"sample_rate": target_sample_rate,
"save_path": save_path,
"audio": wav_base64
}
}
return json_body
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import wave
import numpy as np
def wav2pcm(wavfile, pcmfile, data_type=np.int16):
f = open(wavfile, "rb")
f.seek(0)
f.read(44)
data = np.fromfile(f, dtype=data_type)
data.tofile(pcmfile)
def pcm2wav(pcm_file, wav_file, channels=1, bits=16, sample_rate=16000):
pcmf = open(pcm_file, 'rb')
pcmdata = pcmf.read()
pcmf.close()
if bits % 8 != 0:
raise ValueError("bits % 8 must == 0. now bits:" + str(bits))
wavfile = wave.open(wav_file, 'wb')
wavfile.setnchannels(channels)
wavfile.setsampwidth(bits // 8)
wavfile.setframerate(sample_rate)
wavfile.writeframes(pcmdata)
wavfile.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册