提交 162361d8 编写于 作者: L lym0302

format code, test=doc

上级 434708cf
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import uvicorn
import yaml
from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool
......
......@@ -124,7 +124,7 @@ class TTSClientExecutor(BaseExecutor):
logger.info("RTF: %f " % (time_consume / duration))
return True
except:
except BaseException:
logger.error("Failed to synthesized audio.")
return False
......@@ -164,7 +164,7 @@ class TTSClientExecutor(BaseExecutor):
print("Audio duration: %f s." % (duration))
print("Response time: %f s." % (time_consume))
print("RTF: %f " % (time_consume / duration))
except:
except BaseException:
print("Failed to synthesized audio.")
......@@ -211,7 +211,7 @@ class ASRClientExecutor(BaseExecutor):
logger.info(r.json())
logger.info("time cost %f s." % (time_end - time_start))
return True
except:
except BaseException:
logger.error("Failed to speech recognition.")
return False
......@@ -242,5 +242,5 @@ class ASRClientExecutor(BaseExecutor):
time_end = time.time()
print(r.json())
print("time cost %f s." % (time_end - time_start))
except:
except BaseException:
print("Failed to speech recognition.")
......@@ -13,31 +13,24 @@
# limitations under the License.
import io
import os
from typing import List
from typing import Optional
from typing import Union
import librosa
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine']
pretrained_models = {
"deepspeech2offline_aishell-zh-16k": {
'url':
......@@ -143,7 +136,6 @@ class ASRServerExecutor(ASRExecutor):
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
@paddle.no_grad()
def infer(self, model_type: str):
"""
......@@ -161,9 +153,8 @@ class ASRServerExecutor(ASRExecutor):
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
output_data = run_model(
self.am_predictor,
[audio.numpy(), audio_len.numpy()])
output_data = run_model(self.am_predictor,
[audio.numpy(), audio_len.numpy()])
probs = output_data[0]
eouts_len = output_data[1]
......@@ -208,14 +199,14 @@ class ASREngine(BaseEngine):
paddle.set_device(paddle.get_device())
self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
logger.info("Initialize ASR server engine successfully.")
return True
......@@ -230,7 +221,8 @@ class ASREngine(BaseEngine):
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type, io.BytesIO(audio_data))
self.executor.preprocess(self.config.model_type,
io.BytesIO(audio_data))
self.executor.infer(self.config.model_type)
self.output = self.executor.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine")
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from typing import List
from typing import Union
from pattern_singleton import Singleton
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from typing import Text
__all__ = ['EngineFactory']
......
......@@ -29,8 +29,10 @@ def init_engine_pool(config) -> bool:
"""
global ENGINE_POOL
for engine in config.engine_backend:
ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine])
if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]):
ENGINE_POOL[engine] = EngineFactory.get_engine(
engine_name=engine, engine_type=config.engine_type[engine])
if not ENGINE_POOL[engine].init(
config_file=config.engine_backend[engine]):
return False
return True
......@@ -360,8 +360,8 @@ class TTSEngine(BaseEngine):
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
except:
logger.info("Initialize TTS server engine Failed.")
except BaseException:
logger.error("Initialize TTS server engine Failed.")
return False
logger.info("Initialize TTS server engine successfully.")
......@@ -405,11 +405,13 @@ class TTSEngine(BaseEngine):
# transform speed
try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs)
except:
except ServerBaseException:
raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR,
"Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
logger.error("Transform speed failed.")
# wav to base64
buf = io.BytesIO()
......@@ -462,9 +464,11 @@ class TTSEngine(BaseEngine):
try:
self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
except:
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
logger.error("tts infer failed.")
try:
target_sample_rate, wav_base64 = self.postprocess(
......@@ -474,8 +478,10 @@ class TTSEngine(BaseEngine):
volume=volume,
speed=speed,
audio_path=save_path)
except:
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
logger.error("tts postprocess failed.")
return lang, target_sample_rate, wav_base64
......@@ -72,8 +72,8 @@ class TTSEngine(BaseEngine):
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except:
logger.info("Initialize TTS server engine Failed.")
except BaseException:
logger.error("Initialize TTS server engine Failed.")
return False
logger.info("Initialize TTS server engine successfully.")
......@@ -117,10 +117,13 @@ class TTSEngine(BaseEngine):
# transform speed
try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs)
except:
except ServerBaseException:
raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR,
"Can not install soxbindings on your system.")
"Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
logger.error("Transform speed failed.")
# wav to base64
buf = io.BytesIO()
......@@ -173,9 +176,11 @@ class TTSEngine(BaseEngine):
try:
self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
except:
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
logger.error("tts infer failed.")
try:
target_sample_rate, wav_base64 = self.postprocess(
......@@ -185,8 +190,10 @@ class TTSEngine(BaseEngine):
volume=volume,
speed=speed,
audio_path=save_path)
except:
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
logger.error("tts postprocess failed.")
return lang, target_sample_rate, wav_base64
......@@ -14,6 +14,7 @@
import base64
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.server.engine.engine_pool import get_engine_pool
......@@ -83,7 +84,7 @@ def asr(request_body: ASRRequest):
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
......
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Optional
from pydantic import BaseModel
......
......@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Optional
from pydantic import BaseModel
__all__ = ['ASRResponse', 'TTSResponse']
......
......@@ -114,7 +114,7 @@ def tts(request_body: TTSRequest):
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
......
......@@ -10,11 +10,11 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the
import requests
import base64
import json
import time
import base64
import io
import requests
def readwav2base64(wav_file):
......@@ -34,23 +34,23 @@ def main():
url = "http://127.0.0.1:8090/paddlespeech/asr"
# start Timestamp
time_start=time.time()
time_start = time.time()
test_audio_dir = "./16_audio.wav"
audio = readwav2base64(test_audio_dir)
data = {
"audio": audio,
"audio_format": "wav",
"sample_rate": 16000,
"lang": "zh_cn",
}
"audio": audio,
"audio_format": "wav",
"sample_rate": 16000,
"lang": "zh_cn",
}
r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp
time_end=time.time()
print('time cost',time_end - time_start, 's')
time_end = time.time()
print('time cost', time_end - time_start, 's')
print(r.json())
......
......@@ -25,6 +25,7 @@ import soundfile
from paddlespeech.server.utils.audio_process import wav2pcm
# Request and response
def tts_client(args):
""" Request and response
......@@ -99,5 +100,5 @@ if __name__ == "__main__":
print("Inference time: %f" % (time_consume))
print("The duration of synthesized audio: %f" % (duration))
print("The RTF is: %f" % (rtf))
except:
except BaseException:
print("Failed to synthesized audio.")
......@@ -219,7 +219,7 @@ class ConfigCache:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg)
except:
except BaseException:
self.flush()
@property
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册