提交 e5aa24fa 编写于 作者: L lym0302

resolve setup.py conflicts, test=doc

......@@ -50,12 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin).*(\.cpp|\.h|\.py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:
......
......@@ -80,6 +80,7 @@ parser.add_argument(
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
......@@ -128,6 +129,7 @@ def create_manifest(data_dir, manifest_path_prefix):
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(base_url, data_list, target_dir, manifest_path,
target_data):
if not os.path.exists(target_dir):
......@@ -164,6 +166,7 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path,
# create the manifest file
create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
......@@ -184,5 +187,6 @@ def main():
print("Manifest prepare done!")
if __name__ == '__main__':
main()
......@@ -5,4 +5,4 @@ cfg_path: # [optional]
ckpt_path: # [optional]
decode_method: 'attention_rescoring'
force_yes: True
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
......@@ -15,7 +15,7 @@ decode_method:
force_yes: True
am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True
switch_ir_optim: True
......
......@@ -29,4 +29,4 @@ voc_stat:
# OTHERS #
##################################################################
lang: 'zh'
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
......@@ -15,7 +15,7 @@ speaker_dict:
spk_id: 0
am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: False
switch_ir_optim: False
......@@ -30,7 +30,7 @@ voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000
voc_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: False
switch_ir_optim: False
......
......@@ -30,12 +30,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Cloning into 'warp-ctc'...\n",
"remote: Enumerating objects: 829, done.\u001b[K\n",
"remote: Total 829 (delta 0), reused 0 (delta 0), pack-reused 829\u001b[K\n",
"Receiving objects: 100% (829/829), 388.85 KiB | 140.00 KiB/s, done.\n",
"Resolving deltas: 100% (419/419), done.\n",
"Checking connectivity... done.\n"
"fatal: destination path 'warp-ctc' already exists and is not an empty directory.\r\n"
]
}
],
......@@ -99,30 +94,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"-- The C compiler identification is GNU 5.4.0\n",
"-- The CXX compiler identification is GNU 5.4.0\n",
"-- Check for working C compiler: /usr/bin/cc\n",
"-- Check for working C compiler: /usr/bin/cc -- works\n",
"-- Detecting C compiler ABI info\n",
"-- Detecting C compiler ABI info - done\n",
"-- Detecting C compile features\n",
"-- Detecting C compile features - done\n",
"-- Check for working CXX compiler: /usr/bin/c++\n",
"-- Check for working CXX compiler: /usr/bin/c++ -- works\n",
"-- Detecting CXX compiler ABI info\n",
"-- Detecting CXX compiler ABI info - done\n",
"-- Detecting CXX compile features\n",
"-- Detecting CXX compile features - done\n",
"-- Looking for pthread.h\n",
"-- Looking for pthread.h - found\n",
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD\n",
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed\n",
"-- Looking for pthread_create in pthreads\n",
"-- Looking for pthread_create in pthreads - not found\n",
"-- Looking for pthread_create in pthread\n",
"-- Looking for pthread_create in pthread - found\n",
"-- Found Threads: TRUE \n",
"-- Found CUDA: /usr/local/cuda (found suitable version \"10.2\", minimum required is \"6.5\") \n",
"-- cuda found TRUE\n",
"-- Building shared library with GPU support\n",
"-- Configuring done\n",
......@@ -145,20 +116,11 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[ 11%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o\u001b[0m\n",
"[ 22%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_ctc_entrypoint.cu.o\u001b[0m\n",
"\u001b[35m\u001b[1mScanning dependencies of target warpctc\u001b[0m\n",
"[ 33%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
"[ 11%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
"[ 33%] Built target warpctc\n",
"[ 44%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/test_gpu.dir/tests/test_gpu_generated_test_gpu.cu.o\u001b[0m\n",
"\u001b[35m\u001b[1mScanning dependencies of target test_cpu\u001b[0m\n",
"[ 55%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/test_cpu.cpp.o\u001b[0m\n",
"[ 66%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/random.cpp.o\u001b[0m\n",
"[ 77%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
"[ 44%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
"[ 55%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
"[ 77%] Built target test_cpu\n",
"\u001b[35m\u001b[1mScanning dependencies of target test_gpu\u001b[0m\n",
"[ 88%] \u001b[32mBuilding CXX object CMakeFiles/test_gpu.dir/tests/random.cpp.o\u001b[0m\n",
"[100%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
"[100%] Built target test_gpu\n"
]
}
......@@ -169,7 +131,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "31761a31",
"metadata": {},
"outputs": [
......@@ -187,7 +149,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"id": "f53316f6",
"metadata": {},
"outputs": [
......@@ -205,7 +167,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"id": "084f1e49",
"metadata": {},
"outputs": [
......@@ -216,29 +178,20 @@
"running install\n",
"running bdist_egg\n",
"running egg_info\n",
"creating warpctc_pytorch.egg-info\n",
"writing warpctc_pytorch.egg-info/PKG-INFO\n",
"writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n",
"writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"installing library code to build/bdist.linux-x86_64/egg\n",
"running install_lib\n",
"running build_py\n",
"creating build\n",
"creating build/lib.linux-x86_64-3.9\n",
"creating build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
"copying warpctc_pytorch/__init__.py -> build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
"running build_ext\n",
"building 'warpctc_pytorch._warp_ctc' extension\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src\n",
"Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n",
"Compiling objects...\n",
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
"[1/1] c++ -MMD -MF /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o.d -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -I/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/TH -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/include/python3.9 -c -c /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/src/binding.cpp -o /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -std=c++14 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1011\"' -DTORCH_EXTENSION_NAME=_warp_ctc -D_GLIBCXX_USE_CXX11_ABI=0\n",
"ninja: no work to do.\n",
"g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n",
"creating build/bdist.linux-x86_64\n",
"creating build/bdist.linux-x86_64/egg\n",
"creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
......@@ -254,7 +207,6 @@
"writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n",
"zip_safe flag not set; analyzing archive contents...\n",
"warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n",
"creating dist\n",
"creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n",
"removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n",
"Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
......@@ -275,7 +227,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"id": "ee4ca9e3",
"metadata": {},
"outputs": [
......@@ -293,7 +245,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"id": "59255ed8",
"metadata": {},
"outputs": [
......@@ -311,21 +263,14 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 22,
"id": "1dae09b9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"grep: warning: GREP_OPTIONS is deprecated; please use an alias or script\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import warpctc_pytorch as wp\n",
"import paddle.nn as pn\n",
"import paddle"
......@@ -333,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"id": "83d0762e",
"metadata": {},
"outputs": [
......@@ -343,7 +288,7 @@
"'1.10.0+cu102'"
]
},
"execution_count": 16,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
......@@ -354,17 +299,17 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"id": "62501e2c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.2.0'"
"'2.2.1'"
]
},
"execution_count": 17,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
......@@ -375,7 +320,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"id": "9e8e0f40",
"metadata": {},
"outputs": [
......@@ -392,6 +337,7 @@
}
],
"source": [
"# warpctc_pytorch CTCLoss\n",
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
......@@ -412,7 +358,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"id": "2cd46569",
"metadata": {},
"outputs": [
......@@ -428,6 +374,7 @@
}
],
"source": [
"# pytorch CTCLoss\n",
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
......@@ -449,7 +396,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 27,
"id": "85c3461a",
"metadata": {},
"outputs": [
......@@ -467,6 +414,7 @@
}
],
"source": [
"# Paddle CTCLoss\n",
"paddle.set_device('cpu')\n",
"probs = paddle.to_tensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n",
......@@ -490,7 +438,55 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d390cd91",
"id": "8cdf76c2",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2c305eaf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 1, 5])\n",
"2.4628584384918213\n",
"[[[ 0.17703117 -0.7081247 0.17703117 0.17703117 0.17703117]]\n",
"\n",
" [[ 0.17703117 0.17703117 -0.7081247 0.17703117 0.17703117]]]\n"
]
}
],
"source": [
"# warpctc_pytorch CTCLoss, log_softmax idempotent\n",
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
"print(probs.size())\n",
"labels = torch.IntTensor([1, 2])\n",
"label_sizes = torch.IntTensor([2])\n",
"probs_sizes = torch.IntTensor([2])\n",
"probs.requires_grad_(True)\n",
"bs = probs.size(1)\n",
"\n",
"ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
"\n",
"log_probs = torch.log_softmax(probs, axis=-1)\n",
"cost = ctc_loss(log_probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "443336f0",
"metadata": {},
"outputs": [],
"source": []
......
......@@ -22,19 +22,17 @@ Authors
* qingenz123@126.com (Qingen ZHAO) 2022
"""
import os
import logging
import argparse
import xml.etree.ElementTree as et
import glob
import json
from ami_splits import get_AMI_split
import logging
import os
import xml.etree.ElementTree as et
from distutils.util import strtobool
from dataio import (
load_pkl,
save_pkl, )
from ami_splits import get_AMI_split
from dataio import load_pkl
from dataio import save_pkl
logger = logging.getLogger(__name__)
SAMPLERATE = 16000
......
......@@ -12,28 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Make VoxCeleb1 trial of kaldi format
this script creat the test trial from kaldi trial voxceleb1_test_v2.txt or official trial veri_test2.txt
to kaldi trial format
"""
import argparse
import codecs
import os
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--voxceleb_trial",
default="voxceleb1_test_v2",
type=str,
help="VoxCeleb trial file. Default we use the kaldi trial voxceleb1_test_v2.txt")
parser.add_argument("--trial",
default="data/test/trial",
type=str,
help="Kaldi format trial file")
parser.add_argument(
"--voxceleb_trial",
default="voxceleb1_test_v2",
type=str,
help="VoxCeleb trial file. Default we use the kaldi trial voxceleb1_test_v2.txt"
)
parser.add_argument(
"--trial",
default="data/test/trial",
type=str,
help="Kaldi format trial file")
args = parser.parse_args()
def main(voxceleb_trial, trial):
"""
VoxCeleb provide several trial file, which format is different with kaldi format.
......@@ -58,7 +60,9 @@ def main(voxceleb_trial, trial):
"""
print("Start convert the voxceleb trial to kaldi format")
if not os.path.exists(voxceleb_trial):
raise RuntimeError("{} does not exist. Pleas input the correct file path".format(voxceleb_trial))
raise RuntimeError(
"{} does not exist. Pleas input the correct file path".format(
voxceleb_trial))
trial_dirname = os.path.dirname(trial)
if not os.path.exists(trial_dirname):
......@@ -66,9 +70,9 @@ def main(voxceleb_trial, trial):
with codecs.open(voxceleb_trial, 'r', encoding='utf-8') as f, \
codecs.open(trial, 'w', encoding='utf-8') as w:
for line in f:
for line in f:
target_or_nontarget, path1, path2 = line.strip().split()
utt_id1 = "-".join(path1.split("/"))
utt_id2 = "-".join(path2.split("/"))
target = "nontarget"
......@@ -77,5 +81,6 @@ def main(voxceleb_trial, trial):
w.write("{} {} {}\n".format(utt_id1, utt_id2, target))
print("Convert the voxceleb trial to kaldi format successfully")
if __name__ == "__main__":
main(args.voxceleb_trial, args.trial)
......@@ -11,14 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -413,7 +413,8 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("invalid sample rate, please input --sr 8000 or --sr 16000")
logger.error(
"invalid sample rate, please input --sr 8000 or --sr 16000")
return False
if isinstance(audio_file, (str, os.PathLike)):
......
......@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from io import BytesIO
from typing import List
import numpy as np
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import uvicorn
import yaml
from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool
......
......@@ -48,8 +48,9 @@ class TTSClientExecutor(BaseExecutor):
self.parser.add_argument(
'--input',
type=str,
default="你好,欢迎使用语音合成服务",
help='A sentence to be synthesized.')
default=None,
help='Text to be synthesized.',
required=True)
self.parser.add_argument(
'--spk_id', type=int, default=0, help='Speaker id')
self.parser.add_argument(
......@@ -123,7 +124,7 @@ class TTSClientExecutor(BaseExecutor):
logger.info("RTF: %f " % (time_consume / duration))
return True
except:
except BaseException:
logger.error("Failed to synthesized audio.")
return False
......@@ -163,7 +164,7 @@ class TTSClientExecutor(BaseExecutor):
print("Audio duration: %f s." % (duration))
print("Response time: %f s." % (time_consume))
print("RTF: %f " % (time_consume / duration))
except:
except BaseException:
print("Failed to synthesized audio.")
......@@ -181,8 +182,9 @@ class ASRClientExecutor(BaseExecutor):
self.parser.add_argument(
'--input',
type=str,
default="./paddlespeech/server/tests/16_audio.wav",
help='Audio file to be recognized')
default=None,
help='Audio file to be recognized',
required=True)
self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate')
self.parser.add_argument(
......@@ -209,7 +211,7 @@ class ASRClientExecutor(BaseExecutor):
logger.info(r.json())
logger.info("time cost %f s." % (time_end - time_start))
return True
except:
except BaseException:
logger.error("Failed to speech recognition.")
return False
......@@ -240,5 +242,5 @@ class ASRClientExecutor(BaseExecutor):
time_end = time.time()
print(r.json())
print("time cost %f s." % (time_end - time_start))
except:
print("Failed to speech recognition.")
\ No newline at end of file
except BaseException:
print("Failed to speech recognition.")
......@@ -41,7 +41,8 @@ class ServerExecutor(BaseExecutor):
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
default=None,
required=True)
self.parser.add_argument(
"--log_file",
......
......@@ -5,4 +5,4 @@ cfg_path: # [optional]
ckpt_path: # [optional]
decode_method: 'attention_rescoring'
force_yes: True
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
......@@ -15,7 +15,7 @@ decode_method:
force_yes: True
am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True
switch_ir_optim: True
......
......@@ -29,4 +29,4 @@ voc_stat:
# OTHERS #
##################################################################
lang: 'zh'
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
......@@ -15,7 +15,7 @@ speaker_dict:
spk_id: 0
am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: False
switch_ir_optim: False
......@@ -30,7 +30,7 @@ voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 #must match the model
voc_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu'
device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: False
switch_ir_optim: False
......
......@@ -13,31 +13,24 @@
# limitations under the License.
import io
import os
from typing import List
from typing import Optional
from typing import Union
import librosa
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine']
pretrained_models = {
"deepspeech2offline_aishell-zh-16k": {
'url':
......@@ -143,7 +136,6 @@ class ASRServerExecutor(ASRExecutor):
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
@paddle.no_grad()
def infer(self, model_type: str):
"""
......@@ -161,9 +153,8 @@ class ASRServerExecutor(ASRExecutor):
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
output_data = run_model(
self.am_predictor,
[audio.numpy(), audio_len.numpy()])
output_data = run_model(self.am_predictor,
[audio.numpy(), audio_len.numpy()])
probs = output_data[0]
eouts_len = output_data[1]
......@@ -208,14 +199,14 @@ class ASREngine(BaseEngine):
paddle.set_device(paddle.get_device())
self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
logger.info("Initialize ASR server engine successfully.")
return True
......@@ -230,7 +221,8 @@ class ASREngine(BaseEngine):
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type, io.BytesIO(audio_data))
self.executor.preprocess(self.config.model_type,
io.BytesIO(audio_data))
self.executor.infer(self.config.model_type)
self.output = self.executor.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine")
......
......@@ -53,7 +53,10 @@ class ASREngine(BaseEngine):
self.executor = ASRServerExecutor()
self.config = get_config(config_file)
paddle.set_device(self.config.device)
if self.config.device is None:
paddle.set_device(paddle.get_device())
else:
paddle.set_device(self.config.device)
self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate,
self.config.cfg_path, self.config.decode_method,
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from typing import List
from typing import Union
from pattern_singleton import Singleton
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from typing import Text
__all__ = ['EngineFactory']
......
......@@ -29,8 +29,10 @@ def init_engine_pool(config) -> bool:
"""
global ENGINE_POOL
for engine in config.engine_backend:
ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine])
if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]):
ENGINE_POOL[engine] = EngineFactory.get_engine(
engine_name=engine, engine_type=config.engine_type[engine])
if not ENGINE_POOL[engine].init(
config_file=config.engine_backend[engine]):
return False
return True
......@@ -360,8 +360,8 @@ class TTSEngine(BaseEngine):
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
except:
logger.info("Initialize TTS server engine Failed.")
except BaseException:
logger.error("Initialize TTS server engine Failed.")
return False
logger.info("Initialize TTS server engine successfully.")
......@@ -405,11 +405,13 @@ class TTSEngine(BaseEngine):
# transform speed
try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs)
except:
except ServerBaseException:
raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR,
"Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
logger.error("Transform speed failed.")
# wav to base64
buf = io.BytesIO()
......@@ -462,9 +464,11 @@ class TTSEngine(BaseEngine):
try:
self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
except:
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
logger.error("tts infer failed.")
try:
target_sample_rate, wav_base64 = self.postprocess(
......@@ -474,8 +478,10 @@ class TTSEngine(BaseEngine):
volume=volume,
speed=speed,
audio_path=save_path)
except:
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
logger.error("tts postprocess failed.")
return lang, target_sample_rate, wav_base64
......@@ -54,7 +54,10 @@ class TTSEngine(BaseEngine):
try:
self.config = get_config(config_file)
paddle.set_device(self.config.device)
if self.config.device is None:
paddle.set_device(paddle.get_device())
else:
paddle.set_device(self.config.device)
self.executor._init_from_path(
am=self.config.am,
......@@ -69,8 +72,8 @@ class TTSEngine(BaseEngine):
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except:
logger.info("Initialize TTS server engine Failed.")
except BaseException:
logger.error("Initialize TTS server engine Failed.")
return False
logger.info("Initialize TTS server engine successfully.")
......@@ -114,10 +117,13 @@ class TTSEngine(BaseEngine):
# transform speed
try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs)
except:
except ServerBaseException:
raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR,
"Can not install soxbindings on your system.")
"Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
logger.error("Transform speed failed.")
# wav to base64
buf = io.BytesIO()
......@@ -170,9 +176,11 @@ class TTSEngine(BaseEngine):
try:
self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
except:
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
logger.error("tts infer failed.")
try:
target_sample_rate, wav_base64 = self.postprocess(
......@@ -182,8 +190,10 @@ class TTSEngine(BaseEngine):
volume=volume,
speed=speed,
audio_path=save_path)
except:
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
logger.error("tts postprocess failed.")
return lang, target_sample_rate, wav_base64
......@@ -14,6 +14,7 @@
import base64
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.server.engine.engine_pool import get_engine_pool
......@@ -83,7 +84,7 @@ def asr(request_body: ASRRequest):
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
......
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Optional
from pydantic import BaseModel
......
......@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Optional
from pydantic import BaseModel
__all__ = ['ASRResponse', 'TTSResponse']
......
......@@ -114,7 +114,7 @@ def tts(request_body: TTSRequest):
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
......
......@@ -10,11 +10,11 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the
import requests
import base64
import json
import time
import base64
import io
import requests
def readwav2base64(wav_file):
......@@ -34,23 +34,23 @@ def main():
url = "http://127.0.0.1:8090/paddlespeech/asr"
# start Timestamp
time_start=time.time()
time_start = time.time()
test_audio_dir = "./16_audio.wav"
audio = readwav2base64(test_audio_dir)
data = {
"audio": audio,
"audio_format": "wav",
"sample_rate": 16000,
"lang": "zh_cn",
}
"audio": audio,
"audio_format": "wav",
"sample_rate": 16000,
"lang": "zh_cn",
}
r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp
time_end=time.time()
print('time cost',time_end - time_start, 's')
time_end = time.time()
print('time cost', time_end - time_start, 's')
print(r.json())
......
......@@ -25,6 +25,7 @@ import soundfile
from paddlespeech.server.utils.audio_process import wav2pcm
# Request and response
def tts_client(args):
""" Request and response
......@@ -99,5 +100,5 @@ if __name__ == "__main__":
print("Inference time: %f" % (time_consume))
print("The duration of synthesized audio: %f" % (duration))
print("The RTF is: %f" % (rtf))
except:
except BaseException:
print("Failed to synthesized audio.")
......@@ -219,7 +219,7 @@ class ConfigCache:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg)
except:
except BaseException:
self.flush()
@property
......
......@@ -258,4 +258,4 @@ class ChainDataset(Dataset):
return dataset[i]
i -= len(dataset)
raise IndexError("dataset index out of range")
\ No newline at end of file
raise IndexError("dataset index out of range")
......@@ -27,47 +27,53 @@ from setuptools.command.install import install
HERE = Path(os.path.abspath(os.path.dirname(__file__)))
VERSION = '0.1.1'
VERSION = '0.1.2'
base = [
"editdistance",
"g2p_en",
"g2pM",
"h5py",
"inflect",
"jieba",
"jsonlines",
"kaldiio",
"librosa==0.8.1",
"loguru",
"matplotlib",
"nara_wpe",
"pandas",
"paddleaudio",
"paddlenlp",
"paddlespeech_feat",
"praatio==5.0.0",
"pypinyin",
"python-dateutil",
"pyworld",
"resampy==0.2.2",
"sacrebleu",
"scipy",
"sentencepiece~=0.1.96",
"soundfile~=0.10",
"textgrid",
"timer",
"tqdm",
"typeguard",
"visualdl",
"webrtcvad",
"yacs~=0.1.8",
]
server = [
"fastapi",
"uvicorn",
"pattern_singleton",
"prettytable",
]
requirements = {
"install": [
"editdistance",
"g2p_en",
"g2pM",
"h5py",
"inflect",
"jieba",
"jsonlines",
"kaldiio",
"librosa",
"loguru",
"matplotlib",
"nara_wpe",
"pandas",
"paddleaudio",
"paddlenlp",
"paddlespeech_feat",
"praatio==5.0.0",
"pypinyin",
"python-dateutil",
"pyworld",
"resampy==0.2.2",
"sacrebleu",
"scipy",
"sentencepiece~=0.1.96",
"soundfile~=0.10",
"textgrid",
"timer",
"tqdm",
"typeguard",
"visualdl",
"webrtcvad",
"yacs~=0.1.8",
# fastapi server
"fastapi",
"uvicorn",
"prettytable"
],
"install":
base + server,
"develop": [
"ConfigArgParse",
"coverage",
......
......@@ -23,10 +23,11 @@ Credits
This code is adapted from https://github.com/nryant/dscore
"""
import argparse
from distutils.util import strtobool
import os
import re
import subprocess
from distutils.util import strtobool
import numpy as np
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册