未验证 提交 9d20a10b 编写于 作者: Honei_X's avatar Honei_X 提交者: GitHub

Merge branch 'develop' into server

...@@ -50,13 +50,13 @@ repos: ...@@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook entry: python .pre-commit-hooks/copyright-check.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:
......
...@@ -90,7 +90,7 @@ Then to start the system server, and it provides HTTP backend services. ...@@ -90,7 +90,7 @@ Then to start the system server, and it provides HTTP backend services.
```bash ```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
python src/main.py python src/audio_search.py
``` ```
Then you will see the Application is started: Then you will see the Application is started:
...@@ -111,7 +111,7 @@ Then to start the system server, and it provides HTTP backend services. ...@@ -111,7 +111,7 @@ Then to start the system server, and it provides HTTP backend services.
```bash ```bash
wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
``` ```
**Note**: If you want to build a quick demo, you can use ./src/test_main.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example **Note**: If you want to build a quick demo, you can use ./src/test_audio_search.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example
- Prepare model(Skip this step if you use the default model.) - Prepare model(Skip this step if you use the default model.)
```bash ```bash
...@@ -123,7 +123,7 @@ Then to start the system server, and it provides HTTP backend services. ...@@ -123,7 +123,7 @@ Then to start the system server, and it provides HTTP backend services.
The internal process is downloading data, loading the paddlespeech model, extracting embedding, storing library, retrieving and deleting library The internal process is downloading data, loading the paddlespeech model, extracting embedding, storing library, retrieving and deleting library
```bash ```bash
python ./src/test_main.py python ./src/test_audio_search.py
``` ```
Output: Output:
......
...@@ -92,7 +92,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ...@@ -92,7 +92,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
```bash ```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
python src/main.py python src/audio_search.py
``` ```
然后你会看到应用程序启动: 然后你会看到应用程序启动:
...@@ -113,7 +113,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ...@@ -113,7 +113,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
```bash ```bash
wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
``` ```
**注**:如果希望快速搭建 demo,可以采用 ./src/test_main.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例 **注**:如果希望快速搭建 demo,可以采用 ./src/test_audio_search.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例
- 准备模型(如果使用默认模型,可以跳过此步骤) - 准备模型(如果使用默认模型,可以跳过此步骤)
```bash ```bash
...@@ -124,7 +124,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ...@@ -124,7 +124,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
- 脚本测试(推荐) - 脚本测试(推荐)
```bash ```bash
python ./src/test_main.py python ./src/test_audio_search.py
``` ```
注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding,存储建库,检索,删库 注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding,存储建库,检索,删库
......
...@@ -40,7 +40,6 @@ app.add_middleware( ...@@ -40,7 +40,6 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"]) allow_headers=["*"])
MODEL = None
MILVUS_CLI = MilvusHelper() MILVUS_CLI = MilvusHelper()
MYSQL_CLI = MySQLHelper() MYSQL_CLI = MySQLHelper()
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from logs import LOGGER from logs import LOGGER
from paddlespeech.cli import VectorExecutor from paddlespeech.cli import VectorExecutor
vector_executor = VectorExecutor() vector_executor = VectorExecutor()
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
import numpy
import pymysql import pymysql
from config import MYSQL_DB from config import MYSQL_DB
from config import MYSQL_HOST from config import MYSQL_HOST
...@@ -69,7 +70,7 @@ class MySQLHelper(): ...@@ -69,7 +70,7 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def load_data_to_mysql(self, table_name, data): def load_data_to_mysql(self, table_name, data):
# Batch insert (Milvus_ids, img_path) to mysql # Batch insert (Milvus_ids, audio_path) to mysql
self.test_connection() self.test_connection()
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);" sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
try: try:
...@@ -82,7 +83,7 @@ class MySQLHelper(): ...@@ -82,7 +83,7 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def search_by_milvus_ids(self, ids, table_name): def search_by_milvus_ids(self, ids, table_name):
# Get the img_path according to the milvus ids # Get the audio_path according to the milvus ids
self.test_connection() self.test_connection()
str_ids = str(ids).replace('[', '').replace(']', '') str_ids = str(ids).replace('[', '').replace(']', '')
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");" sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
...@@ -120,14 +121,83 @@ class MySQLHelper(): ...@@ -120,14 +121,83 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def count_table(self, table_name): def count_table(self, table_name):
# Get the number of mysql table # Get the number of spk in mysql table
self.test_connection() self.test_connection()
sql = "select count(milvus_id) from " + table_name + ";" sql = "select count(spk_id) from " + table_name + ";"
try: try:
self.cursor.execute(sql) self.cursor.execute(sql)
results = self.cursor.fetchall() results = self.cursor.fetchall()
LOGGER.debug(f"MYSQL count table:{table_name}") LOGGER.debug(f"MYSQL count table:{results[0][0]}")
return results[0][0] return results[0][0]
except Exception as e: except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1) sys.exit(1)
def create_mysql_table_vpr(self, table_name):
# Create mysql table if not exists
self.test_connection()
sql = "create table if not exists " + table_name + "(spk_id TEXT, audio_path TEXT, embedding TEXT);"
try:
self.cursor.execute(sql)
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def load_data_to_mysql_vpr(self, table_name, data):
# Insert (spk, audio, embedding) to mysql
self.test_connection()
sql = "insert into " + table_name + " (spk_id,audio_path,embedding) values (%s,%s,%s);"
try:
self.cursor.execute(sql, data)
LOGGER.debug(
f"MYSQL loads data to table: {table_name} successfully")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def list_vpr(self, table_name):
# Get all records in mysql
self.test_connection()
sql = "select * from " + table_name + " ;"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
self.conn.commit()
spk_ids = [res[0] for res in results]
audio_paths = [res[1] for res in results]
embeddings = [
numpy.array(
str(res[2]).replace('[', '').replace(']', '').split(","))
for res in results
]
return spk_ids, audio_paths, embeddings
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def search_audio_vpr(self, table_name, spk_id):
# Get the audio_path according to the spk_id
self.test_connection()
sql = "select audio_path from " + table_name + " where spk_id='" + spk_id + "' ;"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
LOGGER.debug(
f"MYSQL search by spk id {spk_id} to get audio {results[0][0]}.")
return results[0][0]
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def delete_data_vpr(self, table_name, spk_id):
# Delete a record by spk_id in mysql table
self.test_connection()
sql = "delete from " + table_name + " where spk_id='" + spk_id + "';"
try:
self.cursor.execute(sql)
LOGGER.debug(
f"MYSQL delete a record {spk_id} in table {table_name}")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
...@@ -31,3 +31,45 @@ def do_count(table_name, milvus_cli): ...@@ -31,3 +31,45 @@ def do_count(table_name, milvus_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error attempting to count table {e}") LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1) sys.exit(1)
def do_count_vpr(table_name, mysql_cli):
"""
Returns the total number of spk in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
num = mysql_cli.count_table(table_name)
return num
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_list(table_name, mysql_cli):
"""
Returns the total records of vpr in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
spk_ids, audio_paths, _ = mysql_cli.list_vpr(table_name)
return spk_ids, audio_paths
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_get(table_name, spk_id, mysql_cli):
"""
Returns the audio path by spk_id in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
audio_apth = mysql_cli.search_audio_vpr(table_name, spk_id)
return audio_apth
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
...@@ -32,3 +32,31 @@ def do_drop(table_name, milvus_cli, mysql_cli): ...@@ -32,3 +32,31 @@ def do_drop(table_name, milvus_cli, mysql_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}") LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1) sys.exit(1)
def do_drop_vpr(table_name, mysql_cli):
"""
Delete the table of MySQL
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
mysql_cli.delete_table(table_name)
return "OK"
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)
def do_delete(table_name, spk_id, mysql_cli):
"""
Delete a record by spk_id in MySQL
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
mysql_cli.delete_data_vpr(table_name, spk_id)
return "OK"
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)
...@@ -82,3 +82,16 @@ def do_load(table_name, audio_dir, milvus_cli, mysql_cli): ...@@ -82,3 +82,16 @@ def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
mysql_cli.create_mysql_table(table_name) mysql_cli.create_mysql_table(table_name)
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names)) mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
return len(ids) return len(ids)
def do_enroll(table_name, spk_id, audio_path, mysql_cli):
"""
Import spk_id,audio_path,embedding to Mysql
"""
if not table_name:
table_name = DEFAULT_TABLE
embedding = get_audio_embedding(audio_path)
mysql_cli.create_mysql_table_vpr(table_name)
data = (spk_id, audio_path, str(embedding))
mysql_cli.load_data_to_mysql_vpr(table_name, data)
return "OK"
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
import numpy
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from config import TOP_K from config import TOP_K
from encode import get_audio_embedding from encode import get_audio_embedding
...@@ -39,3 +40,26 @@ def do_search(host, table_name, audio_path, milvus_cli, mysql_cli): ...@@ -39,3 +40,26 @@ def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error with search: {e}") LOGGER.error(f"Error with search: {e}")
sys.exit(1) sys.exit(1)
def do_search_vpr(host, table_name, audio_path, mysql_cli):
"""
Search the uploaded audio in MySQL
"""
try:
if not table_name:
table_name = DEFAULT_TABLE
emb = get_audio_embedding(audio_path)
emb = numpy.array(emb)
spk_ids, paths, vectors = mysql_cli.list_vpr(table_name)
scores = [numpy.dot(emb, x.astype(numpy.float64)) for x in vectors]
spk_ids = [str(x) for x in spk_ids]
paths = [str(x) for x in paths]
for i in range(len(paths)):
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
paths[i] = tmp
scores[i] = scores[i] * 100
return spk_ids, paths, scores
except Exception as e:
LOGGER.error(f"Error with search: {e}")
sys.exit(1)
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from audio_search import app
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from main import app
from utils.utility import download from utils.utility import download
from utils.utility import unpack from utils.utility import unpack
...@@ -22,7 +22,7 @@ client = TestClient(app) ...@@ -22,7 +22,7 @@ client = TestClient(app)
def download_audio_data(): def download_audio_data():
""" """
download audio data Download audio data
""" """
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz" url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39" md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
...@@ -64,7 +64,7 @@ def test_count(): ...@@ -64,7 +64,7 @@ def test_count():
""" """
Returns the total number of vectors in the system Returns the total number of vectors in the system
""" """
response = client.get("audio/count") response = client.get("/audio/count")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == 20 assert response.json() == 20
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi.testclient import TestClient
from vpr_search import app
from utils.utility import download
from utils.utility import unpack
client = TestClient(app)
def download_audio_data():
"""
Download audio data
"""
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
target_dir = "./"
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir, True)
def test_drop():
"""
Delete the table of MySQL
"""
response = client.post("/vpr/drop")
assert response.status_code == 200
def test_enroll_local(spk: str, audio: str):
"""
Enroll the audio to MySQL
"""
response = client.post("/vpr/enroll/local?spk_id=" + spk +
"&audio_path=.%2Fexample_audio%2F" + audio + ".wav")
assert response.status_code == 200
assert response.json() == {
'status': True,
'msg': "Successfully enroll data!"
}
def test_search_local():
"""
Search the spk in MySQL by audio
"""
response = client.post(
"/vpr/recog/local?audio_path=.%2Fexample_audio%2Ftest.wav")
assert response.status_code == 200
def test_list():
"""
Get all records in MySQL
"""
response = client.get("/vpr/list")
assert response.status_code == 200
def test_data(spk: str):
"""
Get the audio file by spk_id in MySQL
"""
response = client.get("/vpr/data?spk_id=" + spk)
assert response.status_code == 200
def test_del(spk: str):
"""
Delete the record in MySQL by spk_id
"""
response = client.post("/vpr/del?spk_id=" + spk)
assert response.status_code == 200
def test_count():
"""
Get the number of spk in MySQL
"""
response = client.get("/vpr/count")
assert response.status_code == 200
if __name__ == "__main__":
download_audio_data()
test_enroll_local("spk1", "arms_strikes")
test_enroll_local("spk2", "sword_wielding")
test_enroll_local("spk3", "test")
test_list()
test_data("spk1")
test_count()
test_search_local()
test_del("spk1")
test_count()
test_search_local()
test_enroll_local("spk1", "arms_strikes")
test_count()
test_search_local()
test_drop()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import uvicorn
from config import UPLOAD_PATH
from fastapi import FastAPI
from fastapi import File
from fastapi import UploadFile
from logs import LOGGER
from mysql_helpers import MySQLHelper
from operations.count import do_count_vpr
from operations.count import do_get
from operations.count import do_list
from operations.drop import do_delete
from operations.drop import do_drop_vpr
from operations.load import do_enroll
from operations.search import do_search_vpr
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import FileResponse
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
MYSQL_CLI = MySQLHelper()
# Mkdir 'tmp/audio-data'
if not os.path.exists(UPLOAD_PATH):
os.makedirs(UPLOAD_PATH)
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
@app.post('/vpr/enroll')
async def vpr_enroll(table_name: str=None,
spk_id: str=None,
audio: UploadFile=File(...)):
# Enroll the uploaded audio with spk-id into MySQL
try:
# Save the upload data to server.
content = await audio.read()
audio_path = os.path.join(UPLOAD_PATH, audio.filename)
with open(audio_path, "wb+") as f:
f.write(content)
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
LOGGER.info(f"Successfully enrolled {spk_id} online!")
return {'status': True, 'msg': "Successfully enroll data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/enroll/local')
async def vpr_enroll_local(table_name: str=None,
spk_id: str=None,
audio_path: str=None):
# Enroll the local audio with spk-id into MySQL
try:
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
LOGGER.info(f"Successfully enrolled {spk_id} locally!")
return {'status': True, 'msg': "Successfully enroll data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/recog')
async def vpr_recog(request: Request,
table_name: str=None,
audio: UploadFile=File(...)):
# Voice print recognition online
try:
# Save the upload data to server.
content = await audio.read()
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
with open(query_audio_path, "wb+") as f:
f.write(content)
host = request.headers['host']
spk_ids, paths, scores = do_search_vpr(host, table_name,
query_audio_path, MYSQL_CLI)
for spk_id, path, score in zip(spk_ids, paths, scores):
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
res = dict(zip(spk_ids, zip(paths, scores)))
# Sort results by distance metric, closest distances first
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
LOGGER.info("Successfully speaker recognition online!")
return res
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/recog/local')
async def vpr_recog_local(request: Request,
table_name: str=None,
audio_path: str=None):
# Voice print recognition locally
try:
host = request.headers['host']
spk_ids, paths, scores = do_search_vpr(host, table_name, audio_path,
MYSQL_CLI)
for spk_id, path, score in zip(spk_ids, paths, scores):
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
res = dict(zip(spk_ids, zip(paths, scores)))
# Sort results by distance metric, closest distances first
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
LOGGER.info("Successfully speaker recognition locally!")
return res
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/del')
async def vpr_del(table_name: str=None, spk_id: str=None):
# Delete a record by spk_id in MySQL
try:
do_delete(table_name, spk_id, MYSQL_CLI)
LOGGER.info("Successfully delete a record by spk_id in MySQL")
return {'status': True, 'msg': "Successfully delete data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/list')
async def vpr_list(table_name: str=None):
# Get all records in MySQL
try:
spk_ids, audio_paths = do_list(table_name, MYSQL_CLI)
for i in range(len(spk_ids)):
LOGGER.debug(f"spk {spk_ids[i]}, audio path {audio_paths[i]}")
LOGGER.info("Successfully list all records from mysql!")
return spk_ids, audio_paths
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/data')
async def vpr_data(
table_name: str=None,
spk_id: str=None, ):
# Get the audio file from path by spk_id in MySQL
try:
audio_path = do_get(table_name, spk_id, MYSQL_CLI)
LOGGER.info(f"Successfully get audio path {audio_path}!")
return FileResponse(audio_path)
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/count')
async def vpr_count(table_name: str=None):
# Get the total number of spk in MySQL
try:
num = do_count_vpr(table_name, MYSQL_CLI)
LOGGER.info("Successfully count the number of spk!")
return num
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/drop')
async def drop_tables(table_name: str=None):
# Delete the table of MySQL
try:
do_drop_vpr(table_name, MYSQL_CLI)
LOGGER.info("Successfully drop tables in MySQL!")
return {'status': True, 'msg': "Successfully drop tables!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/data')
def audio_path(audio_path):
# Get the audio file from path
try:
LOGGER.info(f"Successfully get audio: {audio_path}")
return FileResponse(audio_path)
except Exception as e:
LOGGER.error(f"get audio error: {e}")
return {'status': False, 'msg': e}, 400
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=8002)
...@@ -85,6 +85,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -85,6 +85,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- 命令行 (推荐使用) - 命令行 (推荐使用)
``` ```
paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
# 流式ASR
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8091 --input ./zh.wav
``` ```
使用帮助: 使用帮助:
...@@ -191,7 +195,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -191,7 +195,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
``` ```
### 5. CLS 客户端使用方法 ### 6. CLS 客户端使用方法
**注意:** 初次使用客户端时响应时间会略长 **注意:** 初次使用客户端时响应时间会略长
- 命令行 (推荐使用) - 命令行 (推荐使用)
``` ```
......
...@@ -37,7 +37,7 @@ Model Type | Dataset| Example Link | Pretrained Models|Static Models|Size (stati ...@@ -37,7 +37,7 @@ Model Type | Dataset| Example Link | Pretrained Models|Static Models|Size (stati
Tacotron2|LJSpeech|[tacotron2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip)||| Tacotron2|LJSpeech|[tacotron2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip)|||
Tacotron2|CSMSC|[tacotron2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts0)|[tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip)|[tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip)|103MB| Tacotron2|CSMSC|[tacotron2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts0)|[tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip)|[tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip)|103MB|
TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)||| TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)|||
SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)|12MB| SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2)|[speedyspeech_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)|12MB|
FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip)|157MB| FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip)|157MB|
FastSpeech2-Conformer| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)||| FastSpeech2-Conformer| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)|||
FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip)||| FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip)|||
......
...@@ -223,22 +223,28 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} ...@@ -223,22 +223,28 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path}
## Pretrained Model ## Pretrained Model
Pretrained SpeedySpeech model with no silence in the edge of audios: Pretrained SpeedySpeech model with no silence in the edge of audios:
- [speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip) - [speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)
- [speedyspeech_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip)
The static model can be downloaded here: The static model can be downloaded here:
- [speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip) - [speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip)
- [speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip) - [speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)
The ONNX model can be downloaded here:
- [speedyspeech_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip)
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/ssim_loss Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/ssim_loss
:-------------:| :------------:| :-----: | :-----: | :--------:|:--------: :-------------:| :------------:| :-----: | :-----: | :--------:|:--------:
default| 1(gpu) x 11400|0.83655|0.42324|0.03211| 0.38119 default| 1(gpu) x 11400|0.79532|0.400246|0.030259| 0.36482
SpeedySpeech checkpoint contains files listed below. SpeedySpeech checkpoint contains files listed below.
```text ```text
speedyspeech_nosil_baker_ckpt_0.5 speedyspeech_csmsc_ckpt_0.2.0
├── default.yaml # default config used to train speedyspeech ├── default.yaml # default config used to train speedyspeech
├── feats_stats.npy # statistics used to normalize spectrogram when training speedyspeech ├── feats_stats.npy # statistics used to normalize spectrogram when training speedyspeech
├── phone_id_map.txt # phone vocabulary file when training speedyspeech ├── phone_id_map.txt # phone vocabulary file when training speedyspeech
├── snapshot_iter_11400.pdz # model parameters and optimizer states ├── snapshot_iter_30600.pdz # model parameters and optimizer states
└── tone_id_map.txt # tone vocabulary file when training speedyspeech └── tone_id_map.txt # tone vocabulary file when training speedyspeech
``` ```
You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained speedyspeech and parallel wavegan models. You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained speedyspeech and parallel wavegan models.
...@@ -249,9 +255,9 @@ FLAGS_allocator_strategy=naive_best_fit \ ...@@ -249,9 +255,9 @@ FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_e2e.py \ python3 ${BIN_DIR}/../synthesize_e2e.py \
--am=speedyspeech_csmsc \ --am=speedyspeech_csmsc \
--am_config=speedyspeech_nosil_baker_ckpt_0.5/default.yaml \ --am_config=speedyspeech_csmsc_ckpt_0.2.0/default.yaml \
--am_ckpt=speedyspeech_nosil_baker_ckpt_0.5/snapshot_iter_11400.pdz \ --am_ckpt=speedyspeech_csmsc_ckpt_0.2.0/snapshot_iter_30600.pdz \
--am_stat=speedyspeech_nosil_baker_ckpt_0.5/feats_stats.npy \ --am_stat=speedyspeech_csmsc_ckpt_0.2.0/feats_stats.npy \
--voc=pwgan_csmsc \ --voc=pwgan_csmsc \
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \ --voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \ --voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
...@@ -260,6 +266,6 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ ...@@ -260,6 +266,6 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--text=${BIN_DIR}/../sentences.txt \ --text=${BIN_DIR}/../sentences.txt \
--output_dir=exp/default/test_e2e \ --output_dir=exp/default/test_e2e \
--inference_dir=exp/default/inference \ --inference_dir=exp/default/inference \
--phones_dict=speedyspeech_nosil_baker_ckpt_0.5/phone_id_map.txt \ --phones_dict=speedyspeech_csmsc_ckpt_0.2.0/phone_id_map.txt \
--tones_dict=speedyspeech_nosil_baker_ckpt_0.5/tone_id_map.txt --tones_dict=speedyspeech_csmsc_ckpt_0.2.0/tone_id_map.txt
``` ```
train_output_path=$1
stage=0
stop_stage=0
# only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now!
# synthesize from metadata
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/../ort_predict.py \
--inference_dir=${train_output_path}/inference_onnx \
--am=speedyspeech_csmsc \
--voc=hifigan_csmsc \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/onnx_infer_out \
--device=cpu \
--cpu_threads=2
fi
# e2e, synthesize from text
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
python3 ${BIN_DIR}/../ort_predict_e2e.py \
--inference_dir=${train_output_path}/inference_onnx \
--am=speedyspeech_csmsc \
--voc=hifigan_csmsc \
--output_dir=${train_output_path}/onnx_infer_out_e2e \
--text=${BIN_DIR}/../csmsc_test.txt \
--phones_dict=dump/phone_id_map.txt \
--tones_dict=dump/tone_id_map.txt \
--device=cpu \
--cpu_threads=2
fi
../../tts3/local/paddle2onnx.sh
\ No newline at end of file
...@@ -40,3 +40,25 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then ...@@ -40,3 +40,25 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# inference with static model # inference with static model
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
fi fi
# paddle2onnx, please make sure the static models are in ${train_output_path}/inference first
# we have only tested the following models so far
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# install paddle2onnx
version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}')
if [[ -z "$version" || ${version} != '0.9.4' ]]; then
pip install paddle2onnx==0.9.4
fi
./local/paddle2onnx.sh ${train_output_path} inference inference_onnx speedyspeech_csmsc
./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_csmsc
fi
# inference with onnxruntime, use fastspeech2 + hifigan by default
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# install onnxruntime
version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}')
if [[ -z "$version" || ${version} != '1.10.0' ]]; then
pip install onnxruntime==1.10.0
fi
./local/ort_predict.sh ${train_output_path}
fi
...@@ -3,7 +3,7 @@ train_output_path=$1 ...@@ -3,7 +3,7 @@ train_output_path=$1
stage=0 stage=0
stop_stage=0 stop_stage=0
# only support default_fastspeech2 + hifigan/mb_melgan now! # only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now!
# synthesize from metadata # synthesize from metadata
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
......
...@@ -19,4 +19,5 @@ paddle2onnx \ ...@@ -19,4 +19,5 @@ paddle2onnx \
--model_filename ${model}.pdmodel \ --model_filename ${model}.pdmodel \
--params_filename ${model}.pdiparams \ --params_filename ${model}.pdiparams \
--save_file ${train_output_path}/${output_dir}/${model}.onnx \ --save_file ${train_output_path}/${output_dir}/${model}.onnx \
--opset_version 11 \
--enable_dev_version ${enable_dev_version} --enable_dev_version ${enable_dev_version}
\ No newline at end of file
...@@ -133,6 +133,9 @@ The pretrained model can be downloaded here: ...@@ -133,6 +133,9 @@ The pretrained model can be downloaded here:
The static model can be downloaded here: The static model can be downloaded here:
- [pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip) - [pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip)
The ONNX model can be downloaded here:
- [pwgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_csmsc_onnx_0.2.0.zip)
Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss| eval/spectral_convergence_loss Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss| eval/spectral_convergence_loss
:-------------:| :------------:| :-----: | :-----: | :--------: :-------------:| :------------:| :-----: | :-----: | :--------:
default| 1(gpu) x 400000|1.948763|0.670098|0.248882 default| 1(gpu) x 400000|1.948763|0.670098|0.248882
......
...@@ -26,10 +26,10 @@ from paddlespeech.s2t.utils.log import Log ...@@ -26,10 +26,10 @@ from paddlespeech.s2t.utils.log import Log
#TODO(Hui Zhang): remove fluid import #TODO(Hui Zhang): remove fluid import
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
########### hcak logging ############# ########### hack logging #############
logger.warn = logger.warning logger.warn = logger.warning
########### hcak paddle ############# ########### hack paddle #############
paddle.half = 'float16' paddle.half = 'float16'
paddle.float = 'float32' paddle.float = 'float32'
paddle.double = 'float64' paddle.double = 'float64'
...@@ -110,7 +110,7 @@ if not hasattr(paddle, 'cat'): ...@@ -110,7 +110,7 @@ if not hasattr(paddle, 'cat'):
paddle.cat = cat paddle.cat = cat
########### hcak paddle.Tensor ############# ########### hack paddle.Tensor #############
def item(x: paddle.Tensor): def item(x: paddle.Tensor):
return x.numpy().item() return x.numpy().item()
...@@ -353,7 +353,7 @@ if not hasattr(paddle.Tensor, 'tolist'): ...@@ -353,7 +353,7 @@ if not hasattr(paddle.Tensor, 'tolist'):
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn ############# ########### hack paddle.nn #############
class GLU(nn.Layer): class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer""" """Gated Linear Units (GLU) Layer"""
......
...@@ -341,7 +341,7 @@ def stft(x: np.ndarray, ...@@ -341,7 +341,7 @@ def stft(x: np.ndarray,
hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None. hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None.
win_length (Optional[int], optional): The size of window. Defaults to None. win_length (Optional[int], optional): The size of window. Defaults to None.
window (str, optional): A string of window specification. Defaults to "hann". window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
dtype (type, optional): Data type of STFT results. Defaults to np.complex64. dtype (type, optional): Data type of STFT results. Defaults to np.complex64.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
...@@ -509,7 +509,7 @@ def melspectrogram(x: np.ndarray, ...@@ -509,7 +509,7 @@ def melspectrogram(x: np.ndarray,
fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0. fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0.
fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None. fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
window (str, optional): A string of window specification. Defaults to "hann". window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
to_db (bool, optional): Enable db scale. Defaults to True. to_db (bool, optional): Enable db scale. Defaults to True.
...@@ -564,7 +564,7 @@ def spectrogram(x: np.ndarray, ...@@ -564,7 +564,7 @@ def spectrogram(x: np.ndarray,
window_size (int, optional): Size of FFT and window length. Defaults to 512. window_size (int, optional): Size of FFT and window length. Defaults to 512.
hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320. hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320.
window (str, optional): A string of window specification. Defaults to "hann". window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
......
...@@ -42,7 +42,7 @@ class Spectrogram(nn.Layer): ...@@ -42,7 +42,7 @@ class Spectrogram(nn.Layer):
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
dtype (str, optional): Data type of input and window. Defaults to 'float32'. dtype (str, optional): Data type of input and window. Defaults to 'float32'.
""" """
...@@ -99,7 +99,7 @@ class MelSpectrogram(nn.Layer): ...@@ -99,7 +99,7 @@ class MelSpectrogram(nn.Layer):
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64. n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.
...@@ -176,7 +176,7 @@ class LogMelSpectrogram(nn.Layer): ...@@ -176,7 +176,7 @@ class LogMelSpectrogram(nn.Layer):
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64. n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.
...@@ -257,7 +257,7 @@ class MFCC(nn.Layer): ...@@ -257,7 +257,7 @@ class MFCC(nn.Layer):
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64. n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.
......
...@@ -43,13 +43,13 @@ pretrained_models = { ...@@ -43,13 +43,13 @@ pretrained_models = {
# speedyspeech # speedyspeech
"speedyspeech_csmsc-zh": { "speedyspeech_csmsc-zh": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip', 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip',
'md5': 'md5':
'9edce23b1a87f31b814d9477bf52afbc', '6f6fa967b408454b6662c8c00c0027cb',
'config': 'config':
'default.yaml', 'default.yaml',
'ckpt': 'ckpt':
'snapshot_iter_11400.pdz', 'snapshot_iter_30600.pdz',
'speech_stats': 'speech_stats':
'feats_stats.npy', 'feats_stats.npy',
'phones_dict': 'phones_dict':
......
...@@ -26,10 +26,10 @@ from paddlespeech.s2t.utils.log import Log ...@@ -26,10 +26,10 @@ from paddlespeech.s2t.utils.log import Log
#TODO(Hui Zhang): remove fluid import #TODO(Hui Zhang): remove fluid import
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
########### hcak logging ############# ########### hack logging #############
logger.warn = logger.warning logger.warn = logger.warning
########### hcak paddle ############# ########### hack paddle #############
paddle.half = 'float16' paddle.half = 'float16'
paddle.float = 'float32' paddle.float = 'float32'
paddle.double = 'float64' paddle.double = 'float64'
...@@ -110,7 +110,7 @@ if not hasattr(paddle, 'cat'): ...@@ -110,7 +110,7 @@ if not hasattr(paddle, 'cat'):
paddle.cat = cat paddle.cat = cat
########### hcak paddle.Tensor ############# ########### hack paddle.Tensor #############
def item(x: paddle.Tensor): def item(x: paddle.Tensor):
return x.numpy().item() return x.numpy().item()
......
...@@ -79,7 +79,6 @@ class U2Infer(): ...@@ -79,7 +79,6 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode decode_config = self.config.decode
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
xs, xs,
...@@ -129,6 +128,7 @@ if __name__ == "__main__": ...@@ -129,6 +128,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.decode_cfg: if args.decode_cfg:
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import asyncio
import base64 import base64
import io import io
import json import json
import logging
import os import os
import random import random
import time import time
...@@ -28,6 +30,7 @@ from ..executor import BaseExecutor ...@@ -28,6 +30,7 @@ from ..executor import BaseExecutor
from ..util import cli_client_register from ..util import cli_client_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.tests.asr.online.websocket_client import ASRAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64 from paddlespeech.server.utils.util import wav2base64
...@@ -230,6 +233,76 @@ class ASRClientExecutor(BaseExecutor): ...@@ -230,6 +233,76 @@ class ASRClientExecutor(BaseExecutor):
return res return res
@cli_client_register(
name='paddlespeech_client.asr_online',
description='visit asr online service')
class ASRClientExecutor(BaseExecutor):
def __init__(self):
super(ASRClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.asr', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8091, help='server port')
self.parser.add_argument(
'--input',
type=str,
default=None,
help='Audio file to be recognized',
required=True)
self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate')
self.parser.add_argument(
'--lang', type=str, default="zh_cn", help='language')
self.parser.add_argument(
'--audio_format', type=str, default="wav", help='audio format')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
input_ = args.input
server_ip = args.server_ip
port = args.port
sample_rate = args.sample_rate
lang = args.lang
audio_format = args.audio_format
try:
time_start = time.time()
res = self(
input=input_,
server_ip=server_ip,
port=port,
sample_rate=sample_rate,
lang=lang,
audio_format=audio_format)
time_end = time.time()
logger.info(res.json())
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to speech recognition.")
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8091,
sample_rate: int=16000,
lang: str="zh_cn",
audio_format: str="wav"):
"""
Python API to call an executor.
"""
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished")
@cli_client_register( @cli_client_register(
name='paddlespeech_client.cls', description='visit cls service') name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor): class CLSClientExecutor(BaseExecutor):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# SERVER SETTING # # SERVER SETTING #
################################################################################# #################################################################################
host: 0.0.0.0 host: 0.0.0.0
port: 8096 port: 8090
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online'] # task choices = ['asr_online', 'tts_online']
......
([简体中文](./README_cn.md)|English)
# 语音服务
## 介绍
本文档介绍如何使用流式ASR的三种不同客户端:网页、麦克风、Python模拟流式服务。
## 使用方法
### 1. 安装
请看 [安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
推荐使用 **paddlepaddle 2.2.1** 或以上版本。
你可以从 medium,hard 三中方式中选择一种方式安装 PaddleSpeech。
### 2. 准备测试文件
这个 ASR client 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
可以下载此 ASR client的示例音频:
```bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
### 2. 流式 ASR 客户端使用方法
- Python模拟流式服务命令行
```
# 流式ASR
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8091 --input ./zh.wav
```
- 麦克风
```
# 直接调用麦克风设备
python microphone_client.py
```
- 网页
```
# 进入web目录后参考相关readme.md
```
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved. # Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: zhendong.peng@mobvoi.com (Zhendong Peng) # Author: zhendong.peng@mobvoi.com (Zhendong Peng)
import argparse import argparse
from flask import Flask, render_template from flask import Flask
from flask import render_template
parser = argparse.ArgumentParser(description='training your network') parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--port', default=19999, type=int, help='port id') parser.add_argument('--port', default=19999, type=int, help='port id')
...@@ -14,9 +13,11 @@ args = parser.parse_args() ...@@ -14,9 +13,11 @@ args = parser.parse_args()
app = Flask(__name__) app = Flask(__name__)
@app.route('/') @app.route('/')
def index(): def index():
return render_template('index.html') return render_template('index.html')
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port, debug=True) app.run(host='0.0.0.0', port=args.port, debug=True)
# paddlespeech serving 网页Demo
- 感谢[wenet](https://github.com/wenet-e2e/wenet)团队的前端demo代码.
## 使用方法
### 1. 在本地电脑启动网页服务
```
python app.py
```
### 2. 本地电脑浏览器
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
![图片](./paddle_web_demo.png)
...@@ -34,22 +34,23 @@ class ASRAudioHandler: ...@@ -34,22 +34,23 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz # chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0: if x_len % chunk_size != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype) padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0) padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_stride start = i * chunk_size
end = start + chunk_size end = start + chunk_size
x_chunk = padded_x[start:end] x_chunk = padded_x[start:end]
yield x_chunk yield x_chunk
...@@ -80,6 +81,7 @@ class ASRAudioHandler: ...@@ -80,6 +81,7 @@ class ASRAudioHandler:
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
...@@ -93,6 +95,7 @@ class ASRAudioHandler: ...@@ -93,6 +95,7 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
# decode the bytes to str # decode the bytes to str
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
...@@ -103,7 +106,7 @@ class ASRAudioHandler: ...@@ -103,7 +106,7 @@ class ASRAudioHandler:
def main(args): def main(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8096) handler = ASRAudioHandler("127.0.0.1", 8090)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# support to process single audio file # support to process single audio file
......
...@@ -24,15 +24,38 @@ class Frame(object): ...@@ -24,15 +24,38 @@ class Frame(object):
class ChunkBuffer(object): class ChunkBuffer(object):
def __init__(self, def __init__(self,
frame_duration_ms=80, window_n=7,
shift_ms=40, shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=16000, sample_rate=16000,
sample_width=2): sample_width=2):
self.sample_rate = sample_rate """audio sample data point buffer
self.frame_duration_ms = frame_duration_ms
Args:
window_n (int, optional): decode window frame length. Defaults to 7 frame.
shift_n (int, optional): decode shift frame length. Defaults to 4 frame.
window_ms (int, optional): frame length, ms. Defaults to 20 ms.
shift_ms (int, optional): shift length, ms. Defaults to 10 ms.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
sample_width (int, optional): sample point bytes. Defaults to 2 bytes.
"""
self.window_n = window_n
self.shift_n = shift_n
self.window_ms = window_ms
self.shift_ms = shift_ms self.shift_ms = shift_ms
self.remained_audio = b'' self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
self.window_bytes = int(self.window_sec * self.sample_rate *
self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width)
def frame_generator(self, audio): def frame_generator(self, audio):
"""Generates audio frames from PCM audio data. """Generates audio frames from PCM audio data.
...@@ -43,17 +66,13 @@ class ChunkBuffer(object): ...@@ -43,17 +66,13 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0 offset = 0
timestamp = 0.0 timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width while offset + self.window_bytes <= len(audio):
while offset + n <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp,
yield Frame(audio[offset:offset + n], timestamp, duration) self.window_sec)
timestamp += shift_duration timestamp += self.shift_sec
offset += shift_n offset += self.shift_bytes
self.remained_audio += audio[offset:] self.remained_audio += audio[offset:]
...@@ -37,7 +37,10 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -37,7 +37,10 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer # init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
frame_duration_ms=chunk_buffer_conf['frame_duration_ms'], window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf['sample_width'])
# init vad # init vad
...@@ -80,10 +83,6 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -80,10 +83,6 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
# # vad for input bytes audio
# vad.add_audio(message)
# message = b''.join(f for f in vad.vad_collector()
# if f is not None)
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_results = "" asr_results = ""
......
...@@ -38,8 +38,6 @@ def get_predictor(args, filed='am'): ...@@ -38,8 +38,6 @@ def get_predictor(args, filed='am'):
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, 0)
elif args.device == "cpu": elif args.device == "cpu":
config.disable_gpu() config.disable_gpu()
# This line must be commented for fastspeech2, if not, it will OOM
if model_name != 'fastspeech2':
config.enable_memory_optim() config.enable_memory_optim()
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
return predictor return predictor
......
...@@ -70,8 +70,15 @@ def ort_predict(args): ...@@ -70,8 +70,15 @@ def ort_predict(args):
# am warmup # am warmup
for T in [27, 38, 54]: for T in [27, 38, 54]:
data = np.random.randint(1, 266, size=(T, )) am_input_feed = {}
am_sess.run(None, {"text": data}) if am_name == 'fastspeech2':
phone_ids = np.random.randint(1, 266, size=(T, ))
am_input_feed.update({'text': phone_ids})
elif am_name == 'speedyspeech':
phone_ids = np.random.randint(1, 92, size=(T, ))
tone_ids = np.random.randint(1, 5, size=(T, ))
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids})
am_sess.run(None, input_feed=am_input_feed)
# voc warmup # voc warmup
for T in [227, 308, 544]: for T in [227, 308, 544]:
...@@ -81,14 +88,20 @@ def ort_predict(args): ...@@ -81,14 +88,20 @@ def ort_predict(args):
N = 0 N = 0
T = 0 T = 0
am_input_feed = {}
for example in test_dataset: for example in test_dataset:
utt_id = example['utt_id'] utt_id = example['utt_id']
if am_name == 'fastspeech2':
phone_ids = example["text"] phone_ids = example["text"]
am_input_feed.update({'text': phone_ids})
elif am_name == 'speedyspeech':
phone_ids = example["phones"]
tone_ids = example["tones"]
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids})
with timer() as t: with timer() as t:
mel = am_sess.run(output_names=None, input_feed={'text': phone_ids}) mel = am_sess.run(output_names=None, input_feed=am_input_feed)
mel = mel[0] mel = mel[0]
wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) wav = voc_sess.run(output_names=None, input_feed={'logmel': mel})
N += len(wav[0]) N += len(wav[0])
T += t.elapse T += t.elapse
speed = len(wav[0]) / t.elapse speed = len(wav[0]) / t.elapse
...@@ -110,9 +123,7 @@ def parse_args(): ...@@ -110,9 +123,7 @@ def parse_args():
'--am', '--am',
type=str, type=str,
default='fastspeech2_csmsc', default='fastspeech2_csmsc',
choices=[ choices=['fastspeech2_csmsc', 'speedyspeech_csmsc'],
'fastspeech2_csmsc',
],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
# voc # voc
......
...@@ -68,39 +68,58 @@ def ort_predict(args): ...@@ -68,39 +68,58 @@ def ort_predict(args):
# vocoder # vocoder
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(args, filed='voc')
# frontend warmup
# Loading model cost 0.5+ seconds
if args.lang == 'zh':
frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True)
else:
print("lang should in be 'zh' here!")
# am warmup # am warmup
for T in [27, 38, 54]: for T in [27, 38, 54]:
data = np.random.randint(1, 266, size=(T, )) am_input_feed = {}
am_sess.run(None, {"text": data}) if am_name == 'fastspeech2':
phone_ids = np.random.randint(1, 266, size=(T, ))
am_input_feed.update({'text': phone_ids})
elif am_name == 'speedyspeech':
phone_ids = np.random.randint(1, 92, size=(T, ))
tone_ids = np.random.randint(1, 5, size=(T, ))
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids})
am_sess.run(None, input_feed=am_input_feed)
# voc warmup # voc warmup
for T in [227, 308, 544]: for T in [227, 308, 544]:
data = np.random.rand(T, 80).astype("float32") data = np.random.rand(T, 80).astype("float32")
voc_sess.run(None, {"logmel": data}) voc_sess.run(None, input_feed={"logmel": data})
print("warm up done!") print("warm up done!")
# frontend warmup
# Loading model cost 0.5+ seconds
if args.lang == 'zh':
frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True)
else:
print("lang should in be 'zh' here!")
N = 0 N = 0
T = 0 T = 0
merge_sentences = True merge_sentences = True
get_tone_ids = False
am_input_feed = {}
if am_name == 'speedyspeech':
get_tone_ids = True
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': if args.lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences) sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
else: else:
print("lang should in be 'zh' here!") print("lang should in be 'zh' here!")
# merge_sentences=True here, so we only use the first item of phone_ids # merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0].numpy() phone_ids = phone_ids[0].numpy()
mel = am_sess.run(output_names=None, input_feed={'text': phone_ids}) if am_name == 'fastspeech2':
am_input_feed.update({'text': phone_ids})
elif am_name == 'speedyspeech':
tone_ids = tone_ids[0].numpy()
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids})
mel = am_sess.run(output_names=None, input_feed=am_input_feed)
mel = mel[0] mel = mel[0]
wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) wav = voc_sess.run(output_names=None, input_feed={'logmel': mel})
...@@ -125,9 +144,7 @@ def parse_args(): ...@@ -125,9 +144,7 @@ def parse_args():
'--am', '--am',
type=str, type=str,
default='fastspeech2_csmsc', default='fastspeech2_csmsc',
choices=[ choices=['fastspeech2_csmsc', 'speedyspeech_csmsc'],
'fastspeech2_csmsc',
],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.") "--phones_dict", type=str, default=None, help="phone vocabulary file.")
......
...@@ -68,13 +68,15 @@ def evaluate(args): ...@@ -68,13 +68,15 @@ def evaluate(args):
# but still not stopping in the end (NOTE by yuantian01 Feb 9 2022) # but still not stopping in the end (NOTE by yuantian01 Feb 9 2022)
if am_name == 'tacotron2': if am_name == 'tacotron2':
merge_sentences = True merge_sentences = True
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
N = 0 N = 0
T = 0 T = 0
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
if args.lang == 'zh': if args.lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
sentence, sentence,
......
...@@ -667,8 +667,8 @@ class FastSpeech2(nn.Layer): ...@@ -667,8 +667,8 @@ class FastSpeech2(nn.Layer):
use_teacher_forcing(bool, optional): Whether to use teacher forcing. use_teacher_forcing(bool, optional): Whether to use teacher forcing.
If true, groundtruth of duration, pitch and energy will be used. If true, groundtruth of duration, pitch and energy will be used.
spk_emb(Tensor, optional, optional): peaker embedding vector (spk_embed_dim,). (Default value = None) spk_emb(Tensor, optional, optional): peaker embedding vector (spk_embed_dim,). (Default value = None)
spk_id(Tensor, optional(int64), optional): Batch of padded spk ids (1,). (Default value = None) spk_id(Tensor, optional(int64), optional): spk ids (1,). (Default value = None)
tone_id(Tensor, optional(int64), optional): Batch of padded tone ids (T,). (Default value = None) tone_id(Tensor, optional(int64), optional): tone ids (T,). (Default value = None)
Returns: Returns:
...@@ -751,7 +751,6 @@ class FastSpeech2(nn.Layer): ...@@ -751,7 +751,6 @@ class FastSpeech2(nn.Layer):
Returns: Returns:
""" """
if self.tone_embed_integration_type == "add": if self.tone_embed_integration_type == "add":
# apply projection and then add to hidden states # apply projection and then add to hidden states
......
...@@ -11,17 +11,35 @@ ...@@ -11,17 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
import paddle import paddle
from paddle import nn from paddle import nn
from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding
from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
class ResidualBlock(nn.Layer): class ResidualBlock(nn.Layer):
def __init__(self, channels, kernel_size, dilation, n=2): def __init__(self,
channels: int=128,
kernel_size: int=3,
dilation: int=3,
n: int=2):
"""SpeedySpeech encoder module.
Args:
channels (int, optional): Feature size of the residual output(and also the input).
kernel_size (int, optional): Kernel size of the 1D convolution.
dilation (int, optional): Dilation of the 1D convolution.
n (int): Number of blocks.
"""
super().__init__() super().__init__()
total_pad = (dilation * (kernel_size - 1))
begin = total_pad // 2
end = total_pad - begin
# remove padding='same' here, cause onnx don't support dilation + 'same' padding
blocks = [ blocks = [
nn.Sequential( nn.Sequential(
nn.Conv1D( nn.Conv1D(
...@@ -29,14 +47,20 @@ class ResidualBlock(nn.Layer): ...@@ -29,14 +47,20 @@ class ResidualBlock(nn.Layer):
channels, channels,
kernel_size, kernel_size,
dilation=dilation, dilation=dilation,
padding="same", # make sure output T == input T
data_format="NLC"), padding=((0, 0), (0, 0), (begin, end))),
nn.ReLU(), nn.ReLU(),
nn.BatchNorm1D(channels, data_format="NLC"), ) for _ in range(n) nn.BatchNorm1D(channels), ) for _ in range(n)
] ]
self.blocks = nn.Sequential(*blocks) self.blocks = nn.Sequential(*blocks)
def forward(self, x): def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor): Batch of input sequences (B, hidden_size, Tmax).
Returns:
Tensor: The residual output (B, hidden_size, Tmax).
"""
return x + self.blocks(x) return x + self.blocks(x)
...@@ -62,7 +86,15 @@ class TextEmbedding(nn.Layer): ...@@ -62,7 +86,15 @@ class TextEmbedding(nn.Layer):
tone_vocab_size, tone_embedding_size, tone_padding_idx) tone_vocab_size, tone_embedding_size, tone_padding_idx)
self.concat = concat self.concat = concat
def forward(self, text, tone=None): def forward(self, text: paddle.Tensor, tone: paddle.Tensor=None):
"""Calculate forward propagation.
Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax).
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax).
Returns:
Tensor: The residual output (B, Tmax, embedding_size).
"""
text_embed = self.text_embedding(text) text_embed = self.text_embedding(text)
if tone is None: if tone is None:
return text_embed return text_embed
...@@ -75,13 +107,24 @@ class TextEmbedding(nn.Layer): ...@@ -75,13 +107,24 @@ class TextEmbedding(nn.Layer):
class SpeedySpeechEncoder(nn.Layer): class SpeedySpeechEncoder(nn.Layer):
"""SpeedySpeech encoder module.
Args:
vocab_size (int): Dimension of the inputs.
tone_size (Optional[int]): Number of tones.
hidden_size (int): Number of encoder hidden units.
kernel_size (int): Kernel size of encoder.
dilations (List[int]): Dilations of encoder.
spk_num (Optional[int]): Number of speakers.
"""
def __init__(self, def __init__(self,
vocab_size, vocab_size: int,
tone_size, tone_size: int,
hidden_size, hidden_size: int=128,
kernel_size, kernel_size: int=3,
dilations, dilations: List[int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 1],
spk_num=None): spk_num=None):
super().__init__() super().__init__()
self.embedding = TextEmbedding( self.embedding = TextEmbedding(
vocab_size, vocab_size,
...@@ -109,34 +152,71 @@ class SpeedySpeechEncoder(nn.Layer): ...@@ -109,34 +152,71 @@ class SpeedySpeechEncoder(nn.Layer):
self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size)) self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size))
self.postnet2 = nn.Sequential( self.postnet2 = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.BatchNorm1D(hidden_size, data_format="NLC"), nn.BatchNorm1D(hidden_size), )
nn.Linear(hidden_size, hidden_size), ) self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, text, tones, spk_id=None): def forward(self,
text: paddle.Tensor,
tones: paddle.Tensor,
spk_id: paddle.Tensor=None):
"""Encoder input sequence.
Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax).
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax).
spk_id(Tnesor, optional(int64)): Batch of speaker ids (B,)
Returns:
Tensor: Output tensor (B, Tmax, hidden_size).
"""
embedding = self.embedding(text, tones) embedding = self.embedding(text, tones)
if self.spk_emb: if self.spk_emb:
embedding += self.spk_emb(spk_id).unsqueeze(1) embedding += self.spk_emb(spk_id).unsqueeze(1)
embedding = self.prenet(embedding) embedding = self.prenet(embedding)
x = self.res_blocks(embedding) x = self.res_blocks(embedding.transpose([0, 2, 1])).transpose([0, 2, 1])
# (B, T, dim)
x = embedding + self.postnet1(x) x = embedding + self.postnet1(x)
x = self.postnet2(x) x = self.postnet2(x.transpose([0, 2, 1])).transpose([0, 2, 1])
x = self.linear(x)
return x return x
class DurationPredictor(nn.Layer): class DurationPredictor(nn.Layer):
def __init__(self, hidden_size): def __init__(self, hidden_size: int=128):
super().__init__() super().__init__()
self.layers = nn.Sequential( self.layers = nn.Sequential(
ResidualBlock(hidden_size, 4, 1, n=1), ResidualBlock(hidden_size, 4, 1, n=1),
ResidualBlock(hidden_size, 3, 1, n=1), ResidualBlock(hidden_size, 3, 1, n=1),
ResidualBlock(hidden_size, 1, 1, n=1), nn.Linear(hidden_size, 1)) ResidualBlock(hidden_size, 1, 1, n=1), )
self.linear = nn.Linear(hidden_size, 1)
def forward(self, x): def forward(self, x: paddle.Tensor):
return paddle.squeeze(self.layers(x), -1) """Calculate forward propagation.
Args:
x(Tensor): Batch of input sequences (B, Tmax, hidden_size).
Returns:
Tensor: Batch of predicted durations in log domain (B, Tmax).
"""
x = self.layers(x.transpose([0, 2, 1])).transpose([0, 2, 1])
x = self.linear(x)
return paddle.squeeze(x, -1)
class SpeedySpeechDecoder(nn.Layer): class SpeedySpeechDecoder(nn.Layer):
def __init__(self, hidden_size, output_size, kernel_size, dilations): def __init__(self,
hidden_size: int=128,
output_size: int=80,
kernel_size: int=3,
dilations: List[int]=[
1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1
]):
"""SpeedySpeech decoder module.
Args:
hidden_size (int): Number of decoder hidden units.
kernel_size (int): Kernel size of decoder.
output_size (int): Dimension of the outputs.
dilations (List[int]): Dilations of decoder.
"""
super().__init__() super().__init__()
res_blocks = [ res_blocks = [
ResidualBlock(hidden_size, kernel_size, d, n=2) for d in dilations ResidualBlock(hidden_size, kernel_size, d, n=2) for d in dilations
...@@ -144,14 +224,21 @@ class SpeedySpeechDecoder(nn.Layer): ...@@ -144,14 +224,21 @@ class SpeedySpeechDecoder(nn.Layer):
self.res_blocks = nn.Sequential(*res_blocks) self.res_blocks = nn.Sequential(*res_blocks)
self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size)) self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size))
self.postnet2 = nn.Sequential( self.postnet2 = ResidualBlock(hidden_size, kernel_size, 1, n=2)
ResidualBlock(hidden_size, kernel_size, 1, n=2), self.linear = nn.Linear(hidden_size, output_size)
nn.Linear(hidden_size, output_size))
def forward(self, x): def forward(self, x):
xx = self.res_blocks(x) """Decoder input sequence.
Args:
x(Tensor): Input tensor (B, time, hidden_size).
Returns:
Tensor: Output tensor (B, time, output_size).
"""
xx = self.res_blocks(x.transpose([0, 2, 1])).transpose([0, 2, 1])
x = x + self.postnet1(xx) x = x + self.postnet1(xx)
x = self.postnet2(x) x = self.postnet2(x.transpose([0, 2, 1])).transpose([0, 2, 1])
x = self.linear(x)
return x return x
...@@ -159,17 +246,35 @@ class SpeedySpeech(nn.Layer): ...@@ -159,17 +246,35 @@ class SpeedySpeech(nn.Layer):
def __init__( def __init__(
self, self,
vocab_size, vocab_size,
encoder_hidden_size, encoder_hidden_size: int=128,
encoder_kernel_size, encoder_kernel_size: int=3,
encoder_dilations, encoder_dilations: List[int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 1],
duration_predictor_hidden_size, duration_predictor_hidden_size: int=128,
decoder_hidden_size, decoder_hidden_size: int=128,
decoder_output_size, decoder_output_size: int=80,
decoder_kernel_size, decoder_kernel_size: int=3,
decoder_dilations, decoder_dilations: List[
tone_size=None, int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1],
spk_num=None, tone_size: int=None,
init_type: str="xavier_uniform", ): spk_num: int=None,
init_type: str="xavier_uniform",
positional_dropout_rate: int=0.1):
"""Initialize SpeedySpeech module.
Args:
vocab_size (int): Dimension of the inputs.
encoder_hidden_size (int): Number of encoder hidden units.
encoder_kernel_size (int): Kernel size of encoder.
encoder_dilations (List[int]): Dilations of encoder.
duration_predictor_hidden_size (int): Number of duration predictor hidden units.
decoder_hidden_size (int): Number of decoder hidden units.
decoder_kernel_size (int): Kernel size of decoder.
decoder_dilations (List[int]): Dilations of decoder.
decoder_output_size (int): Dimension of the outputs.
tone_size (Optional[int]): Number of tones.
spk_num (Optional[int]): Number of speakers.
init_type (str): How to initialize transformer parameters.
"""
super().__init__() super().__init__()
# initialize parameters # initialize parameters
...@@ -181,6 +286,8 @@ class SpeedySpeech(nn.Layer): ...@@ -181,6 +286,8 @@ class SpeedySpeech(nn.Layer):
duration_predictor = DurationPredictor(duration_predictor_hidden_size) duration_predictor = DurationPredictor(duration_predictor_hidden_size)
decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size, decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size,
decoder_kernel_size, decoder_dilations) decoder_kernel_size, decoder_dilations)
self.position_enc = ScaledPositionalEncoding(encoder_hidden_size,
positional_dropout_rate)
self.encoder = encoder self.encoder = encoder
self.duration_predictor = duration_predictor self.duration_predictor = duration_predictor
...@@ -190,7 +297,22 @@ class SpeedySpeech(nn.Layer): ...@@ -190,7 +297,22 @@ class SpeedySpeech(nn.Layer):
nn.initializer.set_global_initializer(None) nn.initializer.set_global_initializer(None)
def forward(self, text, tones, durations, spk_id: paddle.Tensor=None): def forward(self,
text: paddle.Tensor,
tones: paddle.Tensor,
durations: paddle.Tensor,
spk_id: paddle.Tensor=None):
"""Calculate forward propagation.
Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax).
durations(Tensor(int64)): Batch of padded durations (B, Tmax).
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax).
spk_id(Tnesor, optional(int64)): Batch of speaker ids (B,)
Returns:
Tensor: Output tensor (B, T_frames, decoder_output_size).
Tensor: Predicted durations (B, Tmax).
"""
# input of embedding must be int64 # input of embedding must be int64
text = paddle.cast(text, 'int64') text = paddle.cast(text, 'int64')
tones = paddle.cast(tones, 'int64') tones = paddle.cast(tones, 'int64')
...@@ -198,23 +320,30 @@ class SpeedySpeech(nn.Layer): ...@@ -198,23 +320,30 @@ class SpeedySpeech(nn.Layer):
spk_id = paddle.cast(spk_id, 'int64') spk_id = paddle.cast(spk_id, 'int64')
durations = paddle.cast(durations, 'int64') durations = paddle.cast(durations, 'int64')
encodings = self.encoder(text, tones, spk_id) encodings = self.encoder(text, tones, spk_id)
pred_durations = self.duration_predictor(encodings.detach()) pred_durations = self.duration_predictor(encodings.detach())
# expand encodings # expand encodings
durations_to_expand = durations durations_to_expand = durations
encodings = self.length_regulator(encodings, durations_to_expand) encodings = self.length_regulator(encodings, durations_to_expand)
encodings = self.position_enc(encodings)
# decode # decode
# remove positional encoding here
_, t_dec, feature_size = encodings.shape
encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings) decoded = self.decoder(encodings)
return decoded, pred_durations return decoded, pred_durations
def inference(self, text, tones=None, durations=None, spk_id=None): def inference(self,
# text: [T] text: paddle.Tensor,
# tones: [T] tones: paddle.Tensor=None,
durations: paddle.Tensor=None,
spk_id: paddle.Tensor=None):
"""Generate the sequence of features given the sequences of characters.
Args:
text(Tensor(int64)): Input sequence of characters (T,).
tones(Tensor, optional(int64)): Batch of padded tone ids (T, ).
durations(Tensor, optional (int64)): Groundtruth of duration (T,).
spk_id(Tensor, optional(int64), optional): spk ids (1,). (Default value = None)
Returns:
Tensor: logmel (T, decoder_output_size).
"""
# input of embedding must be int64 # input of embedding must be int64
text = paddle.cast(text, 'int64') text = paddle.cast(text, 'int64')
text = text.unsqueeze(0) text = text.unsqueeze(0)
...@@ -233,10 +362,7 @@ class SpeedySpeech(nn.Layer): ...@@ -233,10 +362,7 @@ class SpeedySpeech(nn.Layer):
durations_to_expand = durations durations_to_expand = durations
encodings = self.length_regulator( encodings = self.length_regulator(
encodings, durations_to_expand, is_inference=True) encodings, durations_to_expand, is_inference=True)
encodings = self.position_enc(encodings)
shape = paddle.shape(encodings)
t_dec, feature_size = shape[1], shape[2]
encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings) decoded = self.decoder(encodings)
return decoded[0] return decoded[0]
......
...@@ -86,7 +86,7 @@ class LengthRegulator(nn.Layer): ...@@ -86,7 +86,7 @@ class LengthRegulator(nn.Layer):
M[:, i] = m - init M[:, i] = m - init
init = m init = m
M = paddle.reshape(M, shape=[t_dec_1, batch_size, t_enc]) M = paddle.reshape(M, shape=[t_dec_1, batch_size, t_enc])
M = M[1:, :, :] M = M[1:t_dec_1, :, :]
M = paddle.transpose(M, (1, 0, 2)) M = paddle.transpose(M, (1, 0, 2))
encodings = paddle.matmul(M, encodings) encodings = paddle.matmul(M, encodings)
return encodings return encodings
......
...@@ -30,7 +30,7 @@ class WaveNetResidualBlock(nn.Layer): ...@@ -30,7 +30,7 @@ class WaveNetResidualBlock(nn.Layer):
Args: Args:
kernel_size (int, optional): Kernel size of the 1D convolution, by default 3 kernel_size (int, optional): Kernel size of the 1D convolution, by default 3
residual_channels (int, optional): Feature size of the resiaudl output(and also the input), by default 64 residual_channels (int, optional): Feature size of the residual output(and also the input), by default 64
gate_channels (int, optional): Output feature size of the 1D convolution, by default 128 gate_channels (int, optional): Output feature size of the 1D convolution, by default 128
skip_channels (int, optional): Feature size of the skip output, by default 64 skip_channels (int, optional): Feature size of the skip output, by default 64
aux_channels (int, optional): Feature size of the auxiliary input (e.g. spectrogram), by default 80 aux_channels (int, optional): Feature size of the auxiliary input (e.g. spectrogram), by default 80
......
...@@ -347,7 +347,7 @@ class TransformerEncoder(BaseEncoder): ...@@ -347,7 +347,7 @@ class TransformerEncoder(BaseEncoder):
encoder_type="transformer") encoder_type="transformer")
def forward(self, xs, masks): def forward(self, xs, masks):
"""Encode input sequence. """Encoder input sequence.
Args: Args:
xs(Tensor): Input tensor (#batch, time, idim). xs(Tensor): Input tensor (#batch, time, idim).
...@@ -355,7 +355,7 @@ class TransformerEncoder(BaseEncoder): ...@@ -355,7 +355,7 @@ class TransformerEncoder(BaseEncoder):
Returns: Returns:
Tensor: Output tensor (#batch, time, attention_dim). Tensor: Output tensor (#batch, time, attention_dim).
Tensor:Mask tensor (#batch, 1, time). Tensor: Mask tensor (#batch, 1, time).
""" """
xs = self.embed(xs) xs = self.embed(xs)
xs, masks = self.encoders(xs, masks) xs, masks = self.encoders(xs, masks)
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from speechbrain(https://github.com/speechbrain/speechbrain)
""" """
This script contains basic functions used for speaker diarization. This script contains basic functions used for speaker diarization.
This script has an optional dependency on open source sklearn library. This script has an optional dependency on open source sklearn library.
...@@ -18,11 +19,11 @@ A few sklearn functions are modified in this script as per requirement. ...@@ -18,11 +19,11 @@ A few sklearn functions are modified in this script as per requirement.
""" """
import argparse import argparse
import warnings import warnings
from distutils.util import strtobool
import numpy as np import numpy as np
import scipy import scipy
import sklearn import sklearn
from distutils.util import strtobool
from scipy import sparse from scipy import sparse
from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import connected_components
from scipy.sparse.csgraph import laplacian as csgraph_laplacian from scipy.sparse.csgraph import laplacian as csgraph_laplacian
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
......
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat) add_subdirectory(ds2_ol)
add_subdirectory(nnet) add_subdirectory(dev)
add_subdirectory(decoder) \ No newline at end of file
add_subdirectory(glog)
\ No newline at end of file
# Examples # Examples for SpeechX
* dev - for speechx developer, using for test.
* ngram - using to build NGram ARPA lm.
* ds2_ol - ds2 streaming test under `aishell-1` test dataset.
The entrypoint is `ds2_ol/aishell/run.sh`
* glog - glog usage
* feat - mfcc, linear
* nnet - ds2 nn
* decoder - online decoder to work as offline
## How to run ## How to run
`run.sh` is the entry point. `run.sh` is the entry point.
Example to play `decoder`: Example to play `ds2_ol`:
``` ```
pushd decoder pushd ds2_ol/aishell
bash run.sh bash run.sh
``` ```
## Display Model with [Netron](https://github.com/lutzroeder/netron)
```
pip install netron
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20
```
../../../utils
\ No newline at end of file
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline_decoder_sliding_chunk_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_sliding_chunk_main.cc)
target_include_directories(offline_decoder_sliding_chunk_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_sliding_chunk_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "feature matrix rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
DEFINE_int32(chunk_size, 35, "feat chunk size");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding speech feature, deprecated.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
int32 chunk_size = FLAGS_chunk_size;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
LOG(INFO) << "chunk size (frame): " << chunk_size;
int32 num_done = 0, num_err = 0;
// frontend + nnet is decodable
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
LOG(INFO) << "Init decodeable.";
// init decoder
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
LOG(INFO) << "Init decoder.";
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
LOG(INFO) << "utt: " << utt;
// feat dim
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "dim: " << raw_data->Dim();
int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size;
LOG(INFO) << "n chunks: " << num_chunks;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
// feat chunk
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_one_row(feature,
row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(feat_one_row);
row_idx++;
}
// feed to raw cache
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
// decode step
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
export GLOG_logtostderr=1
# 3. gen linear feat
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn
# 4. run decoder
offline_decoder_main \
--feature_respecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams \
--dict_file=$model_dir/vocab.txt \
--lm_path=$model_dir/avg_1.jit.klm
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(glog)
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
SPEECHX_BIN=$SPEECHX_EXAMPLES/feat SPEECHX_BIN=$SPEECHX_EXAMPLES/dev/glog
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
export LC_AL=C
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
\ No newline at end of file
# Deepspeech2 Streaming
Please go to `aishell` to test it.
* aishell
Deepspeech2 Streaming Decoding under aishell dataset.
The below is for developing and offline testing:
* nnet
* feat
* decoder
# Aishell - Deepspeech2 Streaming
## CTC Prefix Beam Search w/o LM
```
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
## CTC Prefix Beam Search w LM
```
```
## CTC WFST
```
```
\ No newline at end of file
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
...@@ -4,6 +4,9 @@ set -e ...@@ -4,6 +4,9 @@ set -e
. path.sh . path.sh
nj=40
# 1. compile # 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT} pushd ${SPEECHX_ROOT}
...@@ -11,52 +14,59 @@ if [ ! -d ${SPEECHX_EXAMPLES} ]; then ...@@ -11,52 +14,59 @@ if [ ! -d ${SPEECHX_EXAMPLES} ]; then
popd popd
fi fi
# input
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
mkdir -p data mkdir -p data
data=$PWD/data data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
mkdir -p exp
exp=$PWD/exp
aishell_wav_scp=aishell_test.scp aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip -d $data aishell_test.zip unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi fi
model_dir=$PWD/aishell_ds2_online_model
if [ ! -d $model_dir ]; then if [ ! -d $ckpt_dir ]; then
mkdir -p $model_dir mkdir -p $ckpt_dir
wget -P $model_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $model_dir tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
if [ ! -f $lm ]; then
pushd $data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi fi
# 3. make feature # 3. make feature
aishell_online_model=$model_dir/exp/deepspeech2_online/checkpoints
lm_model_dir=../paddle_asr_model
label_file=./aishell_result label_file=./aishell_result
wer=./aishell_wer wer=./aishell_wer
nj=40
export GLOG_logtostderr=1 export GLOG_logtostderr=1
#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
data=$PWD/data
# 3. gen linear feat # 3. gen linear feat
cmvn=$PWD/cmvn.ark cmvn=$PWD/cmvn.ark
cmvn_json2binary_main --json_file=$model_dir/data/mean_std.json --cmvn_write_path=$cmvn cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat_log \
linear_spectrogram_without_db_norm_main \ ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
--cmvn_file=$cmvn \ --cmvn_file=$cmvn \
...@@ -65,31 +75,33 @@ linear_spectrogram_without_db_norm_main \ ...@@ -65,31 +75,33 @@ linear_spectrogram_without_db_norm_main \
text=$data/test/text text=$data/test/text
# 4. recognizer # 4. recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
offline_decoder_sliding_chunk_main \ ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result --result_wspecifier=ark,t:$data/split${nj}/JOB/result
cat $data/split${nj}/*/result > ${label_file} cat $data/split${nj}/*/result > ${label_file}
local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer} utils/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
# 4. decode with lm # 4. decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
offline_decoder_sliding_chunk_main \ ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \ --lm_path=$lm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
cat $data/split${nj}/*/result_lm > ${label_file}_lm cat $data/split${nj}/*/result_lm > ${label_file}_lm
local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm utils/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
graph_dir=./aishell_graph graph_dir=./aishell_graph
if [ ! -d $ ]; then if [ ! -d $ ]; then
...@@ -97,17 +109,19 @@ if [ ! -d $ ]; then ...@@ -97,17 +109,19 @@ if [ ! -d $ ]; then
unzip -d aishell_graph.zip unzip -d aishell_graph.zip
fi fi
# 5. test TLG decoder # 5. test TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
offline_wfst_decoder_main \ wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \ --word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \ --graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \ --acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
cat $data/split${nj}/*/result_tlg > ${label_file}_tlg cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg utils/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
\ No newline at end of file \ No newline at end of file
../../../../utils/
\ No newline at end of file
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ctc-prefix-beam-search-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
set(bin_name wfst-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
set(bin_name nnet-logprob-decoder-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
# ASR Decoder
ASR Decoder test bins. We using theses bins to test CTC BeamSearch decoder and WFST decoder.
* decoder_test_main.cc
feed nnet output logprob, and only test decoder
* offline_decoder_sliding_chunk_main.cc
feed streaming audio feature, decode as streaming manner.
* offline_wfst_decoder_main.cc
feed streaming audio feature, decode using WFST as streaming manner.
...@@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length, ...@@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/" "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names"); "model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
...@@ -50,9 +52,13 @@ int main(int argc, char* argv[]) { ...@@ -50,9 +52,13 @@ int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
CHECK(FLAGS_result_wspecifier != "");
CHECK(FLAGS_feature_rspecifier != "");
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file; std::string dict_file = FLAGS_dict_file;
...@@ -73,6 +79,7 @@ int main(int argc, char* argv[]) { ...@@ -73,6 +79,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_graph; model_opts.model_path = model_graph;
model_opts.params_path = model_params; model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names; model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names; model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
......
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
# output
exp_dir=./exp
mkdir -p $exp_dir
# 2. download model
if [[ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# download lm
if [ ! -f $lm ]; then
pushd data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi
feat_wspecifier=$exp_dir/feats.ark
cmvn=$exp_dir/cmvn.ark
export GLOG_logtostderr=1
# dump json cmvn to kaldi
cmvn-json2kaldi \
--json_file $ckpt_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
# generate linear feature as streaming
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming
ctc-prefix-beam-search-decoder-ol \
--result_wspecifier=ark,t:$exp_dir/result.txt \
--feature_rspecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm
\ No newline at end of file
...@@ -28,6 +28,7 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); ...@@ -28,6 +28,7 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph"); DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
......
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name linear-spectrogram-wo-db-norm-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
# Deepspeech2 Straming Audio Feature
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
* linear_spectrogram_without_db_norm_main.cc
compute linear spectrogram w/o db norm in streaming manner.
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Note: Do not print/log ondemand object.
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
...@@ -29,22 +31,34 @@ int main(int argc, char* argv[]) { ...@@ -29,22 +31,34 @@ int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
ondemand::parser parser; LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
try {
padded_string json = padded_string::load(FLAGS_json_file); padded_string json = padded_string::load(FLAGS_json_file);
ondemand::document val = parser.iterate(json);
ondemand::object doc = val; ondemand::parser parser;
kaldi::int32 frame_num = uint64_t(doc["frame_num"]); ondemand::document doc = parser.iterate(json);
auto mean_stat = doc["mean_stat"]; ondemand::value val = doc;
ondemand::array mean_stat = val["mean_stat"];
std::vector<kaldi::BaseFloat> mean_stat_vec; std::vector<kaldi::BaseFloat> mean_stat_vec;
for (double x : mean_stat) { for (double x : mean_stat) {
mean_stat_vec.push_back(x); mean_stat_vec.push_back(x);
} }
auto var_stat = doc["var_stat"]; // LOG(INFO) << mean_stat; this line will casue
// simdjson::simdjson_error("Objects and arrays can only be iterated
// when
// they are first encountered")
ondemand::array var_stat = val["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec; std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) { for (double x : var_stat) {
var_stat_vec.push_back(x); var_stat_vec.push_back(x);
} }
kaldi::int32 frame_num = uint64_t(val["frame_num"]);
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size(); size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1); kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) { for (size_t idx = 0; idx < mean_size; ++idx) {
...@@ -52,7 +66,16 @@ int main(int argc, char* argv[]) { ...@@ -52,7 +66,16 @@ int main(int argc, char* argv[]) {
cmvn_stats(1, idx) = var_stat_vec[idx]; cmvn_stats(1, idx) = var_stat_vec[idx];
} }
cmvn_stats(0, mean_size) = frame_num; cmvn_stats(0, mean_size) = frame_num;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); LOG(INFO) << cmvn_stats;
LOG(INFO) << "the json file have write into " << FLAGS_cmvn_write_path;
kaldi::WriteKaldiObject(
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what();
}
return 0; return 0;
} }
\ No newline at end of file
...@@ -32,6 +32,7 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); ...@@ -32,6 +32,7 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
......
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/glog SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -e data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# input
data_dir=./data
exp_dir=./exp
model_dir=$data_dir/model/
mkdir -p $exp_dir
# 3. run feat
export GLOG_logtostderr=1
cmvn-json2kaldi \
--json_file $model_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data_dir/wav.scp \
--feature_wspecifier=ark,t:$exp_dir/feats.ark \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ds2-model-ol-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
\ No newline at end of file
# Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test.
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <gflags/gflags.h> // deepspeech2 online model info
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <functional> #include <functional>
...@@ -20,21 +21,26 @@ ...@@ -20,21 +21,26 @@
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include "base/flags.h"
#include "base/log.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
using std::cout; using std::cout;
using std::endl; using std::endl;
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams"); DEFINE_string(model_path, "", "xxx.pdmodel");
DEFINE_string(param_path, "", "xxx.pdiparams");
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
DEFINE_int32(feat_dim, 161, "feature dim");
void produce_data(std::vector<std::vector<float>>* data); void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test(); void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) { void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame int chunk_size = FLAGS_chunk_size; // chunk_size in frame
int col_size = 161; // feat dim int col_size = FLAGS_feat_dim; // feat dim
cout << "chunk size: " << chunk_size << endl; cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl; cout << "feat dim: " << col_size << endl;
...@@ -57,6 +63,8 @@ void model_forward_test() { ...@@ -57,6 +63,8 @@ void model_forward_test() {
; ;
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
CHECK(model_graph != "");
CHECK(model_params != "");
cout << "model path: " << model_graph << endl; cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl; cout << "model param path : " << model_params << endl;
...@@ -106,7 +114,7 @@ void model_forward_test() { ...@@ -106,7 +114,7 @@ void model_forward_test() {
// state_h // state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]); predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024}; std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape); chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size = int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(), std::accumulate(chunk_state_h_box_shape.begin(),
...@@ -119,7 +127,7 @@ void model_forward_test() { ...@@ -119,7 +127,7 @@ void model_forward_test() {
// state_c // state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]); predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024}; std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape); chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size = int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(), std::accumulate(chunk_state_c_box_shape.begin(),
...@@ -187,7 +195,9 @@ void model_forward_test() { ...@@ -187,7 +195,9 @@ void model_forward_test() {
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
model_forward_test(); model_forward_test();
return 0; return 0;
} }
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/nnet
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
ckpt_dir=./data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
ds2-model-ol-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(linear_spectrogram_without_db_norm_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_without_db_norm_main.cc)
target_include_directories(linear_spectrogram_without_db_norm_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_without_db_norm_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(cmvn_json2binary_main ${CMAKE_CURRENT_SOURCE_DIR}/cmvn_json2binary_main.cc)
target_include_directories(cmvn_json2binary_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(cmvn_json2binary_main utils kaldi-util kaldi-matrix gflags glog)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册