未验证 提交 35738988 编写于 作者: W WilliamZhang06 提交者: GitHub

Merge pull request #1411 from lym0302/tts-server2

[server] add tts postprocess
......@@ -13,15 +13,19 @@
# limitations under the License.
import argparse
import base64
import os
import random
import librosa
import numpy as np
import soundfile as sf
import yaml
from engine.base_engine import BaseEngine
from ffmpeg import audio
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from utils.audio_types import wav2pcm
from utils.errors import ErrorCode
from utils.exception import ServerBaseException
......@@ -80,8 +84,7 @@ class TTSEngine(BaseEngine):
target_fs: int=16000,
volume: float=1.0,
speed: float=1.0,
audio_path: str=None,
audio_format: str="wav"):
audio_path: str=None):
"""Post-processing operations, including speech, volume, sample rate, save audio file
Args:
......@@ -104,18 +107,26 @@ class TTSEngine(BaseEngine):
wav_vol = wav_tar_fs * volume
# transform speed
# TODO
target_wav = wav_vol.reshape(-1, 1)
# save audio
if audio_path is not None:
sf.write(audio_path, target_wav, target_fs)
logger.info('Wave file has been generated: {}'.format(audio_path))
hash = random.getrandbits(128)
temp_wav = str(hash) + ".wav"
temp_speed_wav = str(hash + 1) + ".wav"
sf.write(temp_wav, wav_vol.reshape(-1, 1), target_fs)
audio.a_speed(temp_wav, speed, temp_speed_wav)
os.system("rm %s" % (temp_wav))
# wav to base64
base64_bytes = base64.b64encode(target_wav)
base64_string = base64_bytes.decode('utf-8')
wav_base64 = base64_string
with open(temp_speed_wav, 'rb') as f:
base64_bytes = base64.b64encode(f.read())
wav_base64 = base64_bytes.decode('utf-8')
# save audio
if audio_path is not None and audio_path.endswith(".wav"):
os.system("mv %s %s" % (temp_speed_wav, audio_path))
elif audio_path is not None and audio_path.endswith(".pcm"):
wav2pcm(temp_speed_wav, audio_path, data_type=np.int16)
os.system("rm %s" % (temp_speed_wav))
else:
os.system("rm %s" % (temp_speed_wav))
return target_fs, wav_base64
......@@ -125,8 +136,7 @@ class TTSEngine(BaseEngine):
speed: float=1.0,
volume: float=1.0,
sample_rate: int=0,
save_path: str=None,
audio_format: str="wav"):
save_path: str=None):
lang = self.conf_dict["lang"]
......@@ -147,8 +157,7 @@ class TTSEngine(BaseEngine):
target_fs=sample_rate,
volume=volume,
speed=speed,
audio_path=save_path,
audio_format=audio_format)
audio_path=save_path)
except:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
......
......@@ -54,8 +54,7 @@ class TTSRequest(BaseModel):
"speed": 1.0,
"volume": 1.0,
"sample_rate": 0,
"tts_audio_path": "./tts.wav",
"audio_format": "wav"
"tts_audio_path": "./tts.wav"
}
"""
......@@ -66,5 +65,3 @@ class TTSRequest(BaseModel):
volume: float = 1.0
sample_rate: int = 0
save_path: str = None
audio_format: str = "wav"
......@@ -19,7 +19,6 @@ from fastapi import APIRouter
from .request import TTSRequest
from .response import TTSResponse
from utils.errors import ErrorCode
from utils.errors import ErrorMsg
from utils.errors import failed_response
from utils.exception import ServerBaseException
......@@ -33,9 +32,9 @@ def help():
Returns:
json: [description]
"""
json_body = {
response = {
"success": "True",
"code": 0,
"code": 200,
"message": {
"global": "success"
},
......@@ -45,7 +44,7 @@ def help():
"audio": "the base64 of audio"
}
}
return json_body
return response
@router.post("/paddlespeech/tts", response_model=TTSResponse)
......@@ -66,12 +65,11 @@ def tts(request_body: TTSRequest):
volume = item_dict['volume']
sample_rate = item_dict['sample_rate']
save_path = item_dict['save_path']
audio_format = item_dict['audio_format']
# Check parameters
if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \
sample_rate not in [0, 16000, 8000] or \
audio_format not in ["pcm", "wav"]:
(save_path is not None and save_path.endswith("pcm") == False and save_path.endswith("wav") == False):
return failed_response(ErrorCode.SERVER_PARAM_ERR)
# single
......@@ -80,15 +78,9 @@ def tts(request_body: TTSRequest):
# run
try:
lang, target_sample_rate, wav_base64 = tts_engine.run(
sentence, spk_id, speed, volume, sample_rate, save_path,
audio_format)
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
sentence, spk_id, speed, volume, sample_rate, save_path)
json_body = {
response = {
"success": True,
"code": 200,
"message": {
......@@ -104,5 +96,10 @@ def tts(request_body: TTSRequest):
"audio": wav_base64
}
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
return json_body
return response
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import wave
import numpy as np
def wav2pcm(wavfile, pcmfile, data_type=np.int16):
f = open(wavfile, "rb")
f.seek(0)
f.read(44)
data = np.fromfile(f, dtype=data_type)
data.tofile(pcmfile)
def pcm2wav(pcm_file, wav_file, channels=1, bits=16, sample_rate=16000):
pcmf = open(pcm_file, 'rb')
pcmdata = pcmf.read()
pcmf.close()
if bits % 8 != 0:
raise ValueError("bits % 8 must == 0. now bits:" + str(bits))
wavfile = wave.open(wav_file, 'wb')
wavfile.setnchannels(channels)
wavfile.setsampwidth(bits // 8)
wavfile.setframerate(sample_rate)
wavfile.writeframes(pcmdata)
wavfile.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册