未验证 提交 159d8fd6 编写于 作者: Q qingen 提交者: GitHub

Merge branch 'develop' into cluster

...@@ -50,13 +50,13 @@ repos: ...@@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook entry: python .pre-commit-hooks/copyright-check.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:
......
...@@ -90,7 +90,7 @@ Then to start the system server, and it provides HTTP backend services. ...@@ -90,7 +90,7 @@ Then to start the system server, and it provides HTTP backend services.
```bash ```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
python src/main.py python src/audio_search.py
``` ```
Then you will see the Application is started: Then you will see the Application is started:
...@@ -111,7 +111,7 @@ Then to start the system server, and it provides HTTP backend services. ...@@ -111,7 +111,7 @@ Then to start the system server, and it provides HTTP backend services.
```bash ```bash
wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
``` ```
**Note**: If you want to build a quick demo, you can use ./src/test_main.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example **Note**: If you want to build a quick demo, you can use ./src/test_audio_search.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example
- Prepare model(Skip this step if you use the default model.) - Prepare model(Skip this step if you use the default model.)
```bash ```bash
...@@ -123,7 +123,7 @@ Then to start the system server, and it provides HTTP backend services. ...@@ -123,7 +123,7 @@ Then to start the system server, and it provides HTTP backend services.
The internal process is downloading data, loading the paddlespeech model, extracting embedding, storing library, retrieving and deleting library The internal process is downloading data, loading the paddlespeech model, extracting embedding, storing library, retrieving and deleting library
```bash ```bash
python ./src/test_main.py python ./src/test_audio_search.py
``` ```
Output: Output:
......
...@@ -92,7 +92,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ...@@ -92,7 +92,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
```bash ```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
python src/main.py python src/audio_search.py
``` ```
然后你会看到应用程序启动: 然后你会看到应用程序启动:
...@@ -113,7 +113,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ...@@ -113,7 +113,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
```bash ```bash
wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
``` ```
**注**:如果希望快速搭建 demo,可以采用 ./src/test_main.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例 **注**:如果希望快速搭建 demo,可以采用 ./src/test_audio_search.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例
- 准备模型(如果使用默认模型,可以跳过此步骤) - 准备模型(如果使用默认模型,可以跳过此步骤)
```bash ```bash
...@@ -124,7 +124,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ...@@ -124,7 +124,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
- 脚本测试(推荐) - 脚本测试(推荐)
```bash ```bash
python ./src/test_main.py python ./src/test_audio_search.py
``` ```
注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding,存储建库,检索,删库 注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding,存储建库,检索,删库
......
...@@ -40,7 +40,6 @@ app.add_middleware( ...@@ -40,7 +40,6 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"]) allow_headers=["*"])
MODEL = None
MILVUS_CLI = MilvusHelper() MILVUS_CLI = MilvusHelper()
MYSQL_CLI = MySQLHelper() MYSQL_CLI = MySQLHelper()
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from logs import LOGGER from logs import LOGGER
from paddlespeech.cli import VectorExecutor from paddlespeech.cli import VectorExecutor
vector_executor = VectorExecutor() vector_executor = VectorExecutor()
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
import numpy
import pymysql import pymysql
from config import MYSQL_DB from config import MYSQL_DB
from config import MYSQL_HOST from config import MYSQL_HOST
...@@ -69,7 +70,7 @@ class MySQLHelper(): ...@@ -69,7 +70,7 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def load_data_to_mysql(self, table_name, data): def load_data_to_mysql(self, table_name, data):
# Batch insert (Milvus_ids, img_path) to mysql # Batch insert (Milvus_ids, audio_path) to mysql
self.test_connection() self.test_connection()
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);" sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
try: try:
...@@ -82,7 +83,7 @@ class MySQLHelper(): ...@@ -82,7 +83,7 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def search_by_milvus_ids(self, ids, table_name): def search_by_milvus_ids(self, ids, table_name):
# Get the img_path according to the milvus ids # Get the audio_path according to the milvus ids
self.test_connection() self.test_connection()
str_ids = str(ids).replace('[', '').replace(']', '') str_ids = str(ids).replace('[', '').replace(']', '')
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");" sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
...@@ -120,14 +121,83 @@ class MySQLHelper(): ...@@ -120,14 +121,83 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def count_table(self, table_name): def count_table(self, table_name):
# Get the number of mysql table # Get the number of spk in mysql table
self.test_connection() self.test_connection()
sql = "select count(milvus_id) from " + table_name + ";" sql = "select count(spk_id) from " + table_name + ";"
try: try:
self.cursor.execute(sql) self.cursor.execute(sql)
results = self.cursor.fetchall() results = self.cursor.fetchall()
LOGGER.debug(f"MYSQL count table:{table_name}") LOGGER.debug(f"MYSQL count table:{results[0][0]}")
return results[0][0] return results[0][0]
except Exception as e: except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1) sys.exit(1)
def create_mysql_table_vpr(self, table_name):
# Create mysql table if not exists
self.test_connection()
sql = "create table if not exists " + table_name + "(spk_id TEXT, audio_path TEXT, embedding TEXT);"
try:
self.cursor.execute(sql)
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def load_data_to_mysql_vpr(self, table_name, data):
# Insert (spk, audio, embedding) to mysql
self.test_connection()
sql = "insert into " + table_name + " (spk_id,audio_path,embedding) values (%s,%s,%s);"
try:
self.cursor.execute(sql, data)
LOGGER.debug(
f"MYSQL loads data to table: {table_name} successfully")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def list_vpr(self, table_name):
# Get all records in mysql
self.test_connection()
sql = "select * from " + table_name + " ;"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
self.conn.commit()
spk_ids = [res[0] for res in results]
audio_paths = [res[1] for res in results]
embeddings = [
numpy.array(
str(res[2]).replace('[', '').replace(']', '').split(","))
for res in results
]
return spk_ids, audio_paths, embeddings
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def search_audio_vpr(self, table_name, spk_id):
# Get the audio_path according to the spk_id
self.test_connection()
sql = "select audio_path from " + table_name + " where spk_id='" + spk_id + "' ;"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
LOGGER.debug(
f"MYSQL search by spk id {spk_id} to get audio {results[0][0]}.")
return results[0][0]
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def delete_data_vpr(self, table_name, spk_id):
# Delete a record by spk_id in mysql table
self.test_connection()
sql = "delete from " + table_name + " where spk_id='" + spk_id + "';"
try:
self.cursor.execute(sql)
LOGGER.debug(
f"MYSQL delete a record {spk_id} in table {table_name}")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
...@@ -31,3 +31,45 @@ def do_count(table_name, milvus_cli): ...@@ -31,3 +31,45 @@ def do_count(table_name, milvus_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error attempting to count table {e}") LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1) sys.exit(1)
def do_count_vpr(table_name, mysql_cli):
"""
Returns the total number of spk in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
num = mysql_cli.count_table(table_name)
return num
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_list(table_name, mysql_cli):
"""
Returns the total records of vpr in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
spk_ids, audio_paths, _ = mysql_cli.list_vpr(table_name)
return spk_ids, audio_paths
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_get(table_name, spk_id, mysql_cli):
"""
Returns the audio path by spk_id in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
audio_apth = mysql_cli.search_audio_vpr(table_name, spk_id)
return audio_apth
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
...@@ -32,3 +32,31 @@ def do_drop(table_name, milvus_cli, mysql_cli): ...@@ -32,3 +32,31 @@ def do_drop(table_name, milvus_cli, mysql_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}") LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1) sys.exit(1)
def do_drop_vpr(table_name, mysql_cli):
"""
Delete the table of MySQL
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
mysql_cli.delete_table(table_name)
return "OK"
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)
def do_delete(table_name, spk_id, mysql_cli):
"""
Delete a record by spk_id in MySQL
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
mysql_cli.delete_data_vpr(table_name, spk_id)
return "OK"
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)
...@@ -82,3 +82,16 @@ def do_load(table_name, audio_dir, milvus_cli, mysql_cli): ...@@ -82,3 +82,16 @@ def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
mysql_cli.create_mysql_table(table_name) mysql_cli.create_mysql_table(table_name)
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names)) mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
return len(ids) return len(ids)
def do_enroll(table_name, spk_id, audio_path, mysql_cli):
"""
Import spk_id,audio_path,embedding to Mysql
"""
if not table_name:
table_name = DEFAULT_TABLE
embedding = get_audio_embedding(audio_path)
mysql_cli.create_mysql_table_vpr(table_name)
data = (spk_id, audio_path, str(embedding))
mysql_cli.load_data_to_mysql_vpr(table_name, data)
return "OK"
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
import numpy
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from config import TOP_K from config import TOP_K
from encode import get_audio_embedding from encode import get_audio_embedding
...@@ -39,3 +40,26 @@ def do_search(host, table_name, audio_path, milvus_cli, mysql_cli): ...@@ -39,3 +40,26 @@ def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error with search: {e}") LOGGER.error(f"Error with search: {e}")
sys.exit(1) sys.exit(1)
def do_search_vpr(host, table_name, audio_path, mysql_cli):
"""
Search the uploaded audio in MySQL
"""
try:
if not table_name:
table_name = DEFAULT_TABLE
emb = get_audio_embedding(audio_path)
emb = numpy.array(emb)
spk_ids, paths, vectors = mysql_cli.list_vpr(table_name)
scores = [numpy.dot(emb, x.astype(numpy.float64)) for x in vectors]
spk_ids = [str(x) for x in spk_ids]
paths = [str(x) for x in paths]
for i in range(len(paths)):
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
paths[i] = tmp
scores[i] = scores[i] * 100
return spk_ids, paths, scores
except Exception as e:
LOGGER.error(f"Error with search: {e}")
sys.exit(1)
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from audio_search import app
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from main import app
from utils.utility import download from utils.utility import download
from utils.utility import unpack from utils.utility import unpack
...@@ -22,7 +22,7 @@ client = TestClient(app) ...@@ -22,7 +22,7 @@ client = TestClient(app)
def download_audio_data(): def download_audio_data():
""" """
download audio data Download audio data
""" """
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz" url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39" md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
...@@ -64,7 +64,7 @@ def test_count(): ...@@ -64,7 +64,7 @@ def test_count():
""" """
Returns the total number of vectors in the system Returns the total number of vectors in the system
""" """
response = client.get("audio/count") response = client.get("/audio/count")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == 20 assert response.json() == 20
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi.testclient import TestClient
from vpr_search import app
from utils.utility import download
from utils.utility import unpack
client = TestClient(app)
def download_audio_data():
"""
Download audio data
"""
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
target_dir = "./"
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir, True)
def test_drop():
"""
Delete the table of MySQL
"""
response = client.post("/vpr/drop")
assert response.status_code == 200
def test_enroll_local(spk: str, audio: str):
"""
Enroll the audio to MySQL
"""
response = client.post("/vpr/enroll/local?spk_id=" + spk +
"&audio_path=.%2Fexample_audio%2F" + audio + ".wav")
assert response.status_code == 200
assert response.json() == {
'status': True,
'msg': "Successfully enroll data!"
}
def test_search_local():
"""
Search the spk in MySQL by audio
"""
response = client.post(
"/vpr/recog/local?audio_path=.%2Fexample_audio%2Ftest.wav")
assert response.status_code == 200
def test_list():
"""
Get all records in MySQL
"""
response = client.get("/vpr/list")
assert response.status_code == 200
def test_data(spk: str):
"""
Get the audio file by spk_id in MySQL
"""
response = client.get("/vpr/data?spk_id=" + spk)
assert response.status_code == 200
def test_del(spk: str):
"""
Delete the record in MySQL by spk_id
"""
response = client.post("/vpr/del?spk_id=" + spk)
assert response.status_code == 200
def test_count():
"""
Get the number of spk in MySQL
"""
response = client.get("/vpr/count")
assert response.status_code == 200
if __name__ == "__main__":
download_audio_data()
test_enroll_local("spk1", "arms_strikes")
test_enroll_local("spk2", "sword_wielding")
test_enroll_local("spk3", "test")
test_list()
test_data("spk1")
test_count()
test_search_local()
test_del("spk1")
test_count()
test_search_local()
test_enroll_local("spk1", "arms_strikes")
test_count()
test_search_local()
test_drop()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import uvicorn
from config import UPLOAD_PATH
from fastapi import FastAPI
from fastapi import File
from fastapi import UploadFile
from logs import LOGGER
from mysql_helpers import MySQLHelper
from operations.count import do_count_vpr
from operations.count import do_get
from operations.count import do_list
from operations.drop import do_delete
from operations.drop import do_drop_vpr
from operations.load import do_enroll
from operations.search import do_search_vpr
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import FileResponse
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
MYSQL_CLI = MySQLHelper()
# Mkdir 'tmp/audio-data'
if not os.path.exists(UPLOAD_PATH):
os.makedirs(UPLOAD_PATH)
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
@app.post('/vpr/enroll')
async def vpr_enroll(table_name: str=None,
spk_id: str=None,
audio: UploadFile=File(...)):
# Enroll the uploaded audio with spk-id into MySQL
try:
# Save the upload data to server.
content = await audio.read()
audio_path = os.path.join(UPLOAD_PATH, audio.filename)
with open(audio_path, "wb+") as f:
f.write(content)
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
LOGGER.info(f"Successfully enrolled {spk_id} online!")
return {'status': True, 'msg': "Successfully enroll data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/enroll/local')
async def vpr_enroll_local(table_name: str=None,
spk_id: str=None,
audio_path: str=None):
# Enroll the local audio with spk-id into MySQL
try:
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
LOGGER.info(f"Successfully enrolled {spk_id} locally!")
return {'status': True, 'msg': "Successfully enroll data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/recog')
async def vpr_recog(request: Request,
table_name: str=None,
audio: UploadFile=File(...)):
# Voice print recognition online
try:
# Save the upload data to server.
content = await audio.read()
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
with open(query_audio_path, "wb+") as f:
f.write(content)
host = request.headers['host']
spk_ids, paths, scores = do_search_vpr(host, table_name,
query_audio_path, MYSQL_CLI)
for spk_id, path, score in zip(spk_ids, paths, scores):
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
res = dict(zip(spk_ids, zip(paths, scores)))
# Sort results by distance metric, closest distances first
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
LOGGER.info("Successfully speaker recognition online!")
return res
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/recog/local')
async def vpr_recog_local(request: Request,
table_name: str=None,
audio_path: str=None):
# Voice print recognition locally
try:
host = request.headers['host']
spk_ids, paths, scores = do_search_vpr(host, table_name, audio_path,
MYSQL_CLI)
for spk_id, path, score in zip(spk_ids, paths, scores):
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
res = dict(zip(spk_ids, zip(paths, scores)))
# Sort results by distance metric, closest distances first
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
LOGGER.info("Successfully speaker recognition locally!")
return res
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/del')
async def vpr_del(table_name: str=None, spk_id: str=None):
# Delete a record by spk_id in MySQL
try:
do_delete(table_name, spk_id, MYSQL_CLI)
LOGGER.info("Successfully delete a record by spk_id in MySQL")
return {'status': True, 'msg': "Successfully delete data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/list')
async def vpr_list(table_name: str=None):
# Get all records in MySQL
try:
spk_ids, audio_paths = do_list(table_name, MYSQL_CLI)
for i in range(len(spk_ids)):
LOGGER.debug(f"spk {spk_ids[i]}, audio path {audio_paths[i]}")
LOGGER.info("Successfully list all records from mysql!")
return spk_ids, audio_paths
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/data')
async def vpr_data(
table_name: str=None,
spk_id: str=None, ):
# Get the audio file from path by spk_id in MySQL
try:
audio_path = do_get(table_name, spk_id, MYSQL_CLI)
LOGGER.info(f"Successfully get audio path {audio_path}!")
return FileResponse(audio_path)
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/count')
async def vpr_count(table_name: str=None):
# Get the total number of spk in MySQL
try:
num = do_count_vpr(table_name, MYSQL_CLI)
LOGGER.info("Successfully count the number of spk!")
return num
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/drop')
async def drop_tables(table_name: str=None):
# Delete the table of MySQL
try:
do_drop_vpr(table_name, MYSQL_CLI)
LOGGER.info("Successfully drop tables in MySQL!")
return {'status': True, 'msg': "Successfully drop tables!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/data')
def audio_path(audio_path):
# Get the audio file from path
try:
LOGGER.info(f"Successfully get audio: {audio_path}")
return FileResponse(audio_path)
except Exception as e:
LOGGER.error(f"get audio error: {e}")
return {'status': False, 'msg': e}, 400
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=8002)
...@@ -7,4 +7,4 @@ paddlespeech asr --input ./zh.wav ...@@ -7,4 +7,4 @@ paddlespeech asr --input ./zh.wav
# asr + punc # asr + punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
\ No newline at end of file
...@@ -85,6 +85,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -85,6 +85,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- 命令行 (推荐使用) - 命令行 (推荐使用)
``` ```
paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
# 流式ASR
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8091 --input ./zh.wav
``` ```
使用帮助: 使用帮助:
...@@ -191,7 +195,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -191,7 +195,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
``` ```
### 5. CLS 客户端使用方法 ### 6. CLS 客户端使用方法
**注意:** 初次使用客户端时响应时间会略长 **注意:** 初次使用客户端时响应时间会略长
- 命令行 (推荐使用) - 命令行 (推荐使用)
``` ```
......
...@@ -84,7 +84,7 @@ setuptools.setup( ...@@ -84,7 +84,7 @@ setuptools.setup(
install_requires=[ install_requires=[
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2', 'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos' 'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos'
], ],
extras_require={ extras_require={
'test': [ 'test': [
'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1', 'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1',
......
...@@ -79,7 +79,6 @@ class U2Infer(): ...@@ -79,7 +79,6 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode decode_config = self.config.decode
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
xs, xs,
...@@ -129,6 +128,7 @@ if __name__ == "__main__": ...@@ -129,6 +128,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.decode_cfg: if args.decode_cfg:
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import asyncio
import base64 import base64
import io import io
import json import json
import logging
import os import os
import random import random
import time import time
...@@ -28,6 +30,7 @@ from ..executor import BaseExecutor ...@@ -28,6 +30,7 @@ from ..executor import BaseExecutor
from ..util import cli_client_register from ..util import cli_client_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.tests.asr.online.websocket_client import ASRAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64 from paddlespeech.server.utils.util import wav2base64
...@@ -230,6 +233,76 @@ class ASRClientExecutor(BaseExecutor): ...@@ -230,6 +233,76 @@ class ASRClientExecutor(BaseExecutor):
return res return res
@cli_client_register(
name='paddlespeech_client.asr_online',
description='visit asr online service')
class ASRClientExecutor(BaseExecutor):
def __init__(self):
super(ASRClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.asr', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8091, help='server port')
self.parser.add_argument(
'--input',
type=str,
default=None,
help='Audio file to be recognized',
required=True)
self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate')
self.parser.add_argument(
'--lang', type=str, default="zh_cn", help='language')
self.parser.add_argument(
'--audio_format', type=str, default="wav", help='audio format')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
input_ = args.input
server_ip = args.server_ip
port = args.port
sample_rate = args.sample_rate
lang = args.lang
audio_format = args.audio_format
try:
time_start = time.time()
res = self(
input=input_,
server_ip=server_ip,
port=port,
sample_rate=sample_rate,
lang=lang,
audio_format=audio_format)
time_end = time.time()
logger.info(res.json())
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to speech recognition.")
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8091,
sample_rate: int=16000,
lang: str="zh_cn",
audio_format: str="wav"):
"""
Python API to call an executor.
"""
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished")
@cli_client_register( @cli_client_register(
name='paddlespeech_client.cls', description='visit cls service') name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor): class CLSClientExecutor(BaseExecutor):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# SERVER SETTING # # SERVER SETTING #
################################################################################# #################################################################################
host: 0.0.0.0 host: 0.0.0.0
port: 8091 port: 8090
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online'] # task choices = ['asr_online', 'tts_online']
......
([简体中文](./README_cn.md)|English)
# 语音服务
## 介绍
本文档介绍如何使用流式ASR的三种不同客户端:网页、麦克风、Python模拟流式服务。
## 使用方法
### 1. 安装
请看 [安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
推荐使用 **paddlepaddle 2.2.1** 或以上版本。
你可以从 medium,hard 三中方式中选择一种方式安装 PaddleSpeech。
### 2. 准备测试文件
这个 ASR client 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
可以下载此 ASR client的示例音频:
```bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
### 2. 流式 ASR 客户端使用方法
- Python模拟流式服务命令行
```
# 流式ASR
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8091 --input ./zh.wav
```
- 麦克风
```
# 直接调用麦克风设备
python microphone_client.py
```
- 网页
```
# 进入web目录后参考相关readme.md
```
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved. # Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: zhendong.peng@mobvoi.com (Zhendong Peng) # Author: zhendong.peng@mobvoi.com (Zhendong Peng)
import argparse import argparse
from flask import Flask, render_template from flask import Flask
from flask import render_template
parser = argparse.ArgumentParser(description='training your network') parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--port', default=19999, type=int, help='port id') parser.add_argument('--port', default=19999, type=int, help='port id')
...@@ -14,9 +13,11 @@ args = parser.parse_args() ...@@ -14,9 +13,11 @@ args = parser.parse_args()
app = Flask(__name__) app = Flask(__name__)
@app.route('/') @app.route('/')
def index(): def index():
return render_template('index.html') return render_template('index.html')
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port, debug=True) app.run(host='0.0.0.0', port=args.port, debug=True)
# paddlespeech serving 网页Demo
- 感谢[wenet](https://github.com/wenet-e2e/wenet)团队的前端demo代码.
## 使用方法
### 1. 在本地电脑启动网页服务
```
python app.py
```
### 2. 本地电脑浏览器
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
![图片](./paddle_web_demo.png)
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
import argparse import argparse
import asyncio import asyncio
import codecs
import json import json
import logging import logging
import os
import numpy as np import numpy as np
import soundfile import soundfile
...@@ -32,34 +34,30 @@ class ASRAudioHandler: ...@@ -32,34 +34,30 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz # chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0: if x_len % chunk_size != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype) padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0) padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_stride start = i * chunk_size
end = start + chunk_size end = start + chunk_size
x_chunk = padded_x[start:end] x_chunk = padded_x[start:end]
yield x_chunk yield x_chunk
async def run(self, wavfile_path: str): async def run(self, wavfile_path: str):
logging.info("send a message to the server") logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
...@@ -77,8 +75,10 @@ class ASRAudioHandler: ...@@ -77,8 +75,10 @@ class ASRAudioHandler:
for chunk_data in self.read_wave(wavfile_path): for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes()) await ws.send(chunk_data.tobytes())
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
{ {
...@@ -91,16 +91,35 @@ class ASRAudioHandler: ...@@ -91,16 +91,35 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
return result
def main(args): def main(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091) handler = ASRAudioHandler("127.0.0.1", 8090)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished") # support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__": if __name__ == "__main__":
...@@ -110,6 +129,8 @@ if __name__ == "__main__": ...@@ -110,6 +129,8 @@ if __name__ == "__main__":
action="store", action="store",
help="wav file path ", help="wav file path ",
default="./16_audio.wav") default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -24,15 +24,38 @@ class Frame(object): ...@@ -24,15 +24,38 @@ class Frame(object):
class ChunkBuffer(object): class ChunkBuffer(object):
def __init__(self, def __init__(self,
frame_duration_ms=80, window_n=7,
shift_ms=40, shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=16000, sample_rate=16000,
sample_width=2): sample_width=2):
self.sample_rate = sample_rate """audio sample data point buffer
self.frame_duration_ms = frame_duration_ms
Args:
window_n (int, optional): decode window frame length. Defaults to 7 frame.
shift_n (int, optional): decode shift frame length. Defaults to 4 frame.
window_ms (int, optional): frame length, ms. Defaults to 20 ms.
shift_ms (int, optional): shift length, ms. Defaults to 10 ms.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
sample_width (int, optional): sample point bytes. Defaults to 2 bytes.
"""
self.window_n = window_n
self.shift_n = shift_n
self.window_ms = window_ms
self.shift_ms = shift_ms self.shift_ms = shift_ms
self.remained_audio = b'' self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
self.window_bytes = int(self.window_sec * self.sample_rate *
self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width)
def frame_generator(self, audio): def frame_generator(self, audio):
"""Generates audio frames from PCM audio data. """Generates audio frames from PCM audio data.
...@@ -43,17 +66,13 @@ class ChunkBuffer(object): ...@@ -43,17 +66,13 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0 offset = 0
timestamp = 0.0 timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width while offset + self.window_bytes <= len(audio):
while offset + n <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp,
yield Frame(audio[offset:offset + n], timestamp, duration) self.window_sec)
timestamp += shift_duration timestamp += self.shift_sec
offset += shift_n offset += self.shift_bytes
self.remained_audio += audio[offset:] self.remained_audio += audio[offset:]
...@@ -36,6 +36,10 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -36,6 +36,10 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer # init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf['sample_width'])
# init vad # init vad
...@@ -75,11 +79,6 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -75,11 +79,6 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
# vad for input bytes audio
vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector()
if f is not None)
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_results = "" asr_results = ""
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from speechbrain(https://github.com/speechbrain/speechbrain)
""" """
This script contains basic functions used for speaker diarization. This script contains basic functions used for speaker diarization.
This script has an optional dependency on open source sklearn library. This script has an optional dependency on open source sklearn library.
...@@ -19,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement. ...@@ -19,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement.
import argparse import argparse
import copy import copy
import warnings import warnings
from distutils.util import strtobool
import numpy as np import numpy as np
import scipy import scipy
import sklearn import sklearn
from distutils.util import strtobool
from scipy import linalg from scipy import linalg
from scipy import sparse from scipy import sparse
from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import connected_components
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
......
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat) add_subdirectory(ds2_ol)
add_subdirectory(nnet) add_subdirectory(dev)
add_subdirectory(decoder) \ No newline at end of file
add_subdirectory(glog)
\ No newline at end of file
# Examples # Examples for SpeechX
* dev - for speechx developer, using for test.
* ngram - using to build NGram ARPA lm.
* ds2_ol - ds2 streaming test under `aishell-1` test dataset.
The entrypoint is `ds2_ol/aishell/run.sh`
* glog - glog usage
* feat - mfcc, linear
* nnet - ds2 nn
* decoder - online decoder to work as offline
## How to run ## How to run
`run.sh` is the entry point. `run.sh` is the entry point.
Example to play `decoder`: Example to play `ds2_ol`:
``` ```
pushd decoder pushd ds2_ol/aishell
bash run.sh bash run.sh
``` ```
## Display Model with [Netron](https://github.com/lutzroeder/netron)
```
pip install netron
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20
```
../../../utils
\ No newline at end of file
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline_decoder_sliding_chunk_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_sliding_chunk_main.cc)
target_include_directories(offline_decoder_sliding_chunk_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_sliding_chunk_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "feature matrix rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
DEFINE_int32(chunk_size, 35, "feat chunk size");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding speech feature, deprecated.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
int32 chunk_size = FLAGS_chunk_size;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
LOG(INFO) << "chunk size (frame): " << chunk_size;
int32 num_done = 0, num_err = 0;
// frontend + nnet is decodable
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
LOG(INFO) << "Init decodeable.";
// init decoder
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
LOG(INFO) << "Init decoder.";
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
LOG(INFO) << "utt: " << utt;
// feat dim
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "dim: " << raw_data->Dim();
int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size;
LOG(INFO) << "n chunks: " << num_chunks;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
// feat chunk
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_one_row(feature,
row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(feat_one_row);
row_idx++;
}
// feed to raw cache
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
// decode step
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
export GLOG_logtostderr=1
# 3. gen linear feat
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn
# 4. run decoder
offline_decoder_main \
--feature_respecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams \
--dict_file=$model_dir/vocab.txt \
--lm_path=$model_dir/avg_1.jit.klm
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(glog)
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
SPEECHX_BIN=$SPEECHX_EXAMPLES/feat SPEECHX_BIN=$SPEECHX_EXAMPLES/dev/glog
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
export LC_AL=C
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
\ No newline at end of file
# Deepspeech2 Streaming
Please go to `aishell` to test it.
* aishell
Deepspeech2 Streaming Decoding under aishell dataset.
The below is for developing and offline testing:
* nnet
* feat
* decoder
# Aishell - Deepspeech2 Streaming
## CTC Prefix Beam Search w/o LM
```
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
## CTC Prefix Beam Search w LM
```
```
## CTC WFST
```
```
\ No newline at end of file
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
...@@ -4,6 +4,9 @@ set -e ...@@ -4,6 +4,9 @@ set -e
. path.sh . path.sh
nj=40
# 1. compile # 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT} pushd ${SPEECHX_ROOT}
...@@ -11,52 +14,59 @@ if [ ! -d ${SPEECHX_EXAMPLES} ]; then ...@@ -11,52 +14,59 @@ if [ ! -d ${SPEECHX_EXAMPLES} ]; then
popd popd
fi fi
# input
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
mkdir -p data mkdir -p data
data=$PWD/data data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
mkdir -p exp
exp=$PWD/exp
aishell_wav_scp=aishell_test.scp aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip -d $data aishell_test.zip unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi fi
model_dir=$PWD/aishell_ds2_online_model
if [ ! -d $model_dir ]; then if [ ! -d $ckpt_dir ]; then
mkdir -p $model_dir mkdir -p $ckpt_dir
wget -P $model_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $model_dir tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
if [ ! -f $lm ]; then
pushd $data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi fi
# 3. make feature # 3. make feature
aishell_online_model=$model_dir/exp/deepspeech2_online/checkpoints
lm_model_dir=../paddle_asr_model
label_file=./aishell_result label_file=./aishell_result
wer=./aishell_wer wer=./aishell_wer
nj=40
export GLOG_logtostderr=1 export GLOG_logtostderr=1
#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
data=$PWD/data
# 3. gen linear feat # 3. gen linear feat
cmvn=$PWD/cmvn.ark cmvn=$PWD/cmvn.ark
cmvn_json2binary_main --json_file=$model_dir/data/mean_std.json --cmvn_write_path=$cmvn cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat_log \
linear_spectrogram_without_db_norm_main \ ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
--cmvn_file=$cmvn \ --cmvn_file=$cmvn \
...@@ -65,31 +75,33 @@ linear_spectrogram_without_db_norm_main \ ...@@ -65,31 +75,33 @@ linear_spectrogram_without_db_norm_main \
text=$data/test/text text=$data/test/text
# 4. recognizer # 4. recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
offline_decoder_sliding_chunk_main \ ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result --result_wspecifier=ark,t:$data/split${nj}/JOB/result
cat $data/split${nj}/*/result > ${label_file} cat $data/split${nj}/*/result > ${label_file}
local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer} utils/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
# 4. decode with lm # 4. decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
offline_decoder_sliding_chunk_main \ ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \ --lm_path=$lm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
cat $data/split${nj}/*/result_lm > ${label_file}_lm cat $data/split${nj}/*/result_lm > ${label_file}_lm
local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm utils/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
graph_dir=./aishell_graph graph_dir=./aishell_graph
if [ ! -d $ ]; then if [ ! -d $ ]; then
...@@ -97,17 +109,19 @@ if [ ! -d $ ]; then ...@@ -97,17 +109,19 @@ if [ ! -d $ ]; then
unzip -d aishell_graph.zip unzip -d aishell_graph.zip
fi fi
# 5. test TLG decoder # 5. test TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
offline_wfst_decoder_main \ wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \ --word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \ --graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \ --acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
cat $data/split${nj}/*/result_tlg > ${label_file}_tlg cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg utils/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
\ No newline at end of file \ No newline at end of file
../../../../utils/
\ No newline at end of file
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ctc-prefix-beam-search-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
set(bin_name wfst-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
set(bin_name nnet-logprob-decoder-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
# ASR Decoder
ASR Decoder test bins. We using theses bins to test CTC BeamSearch decoder and WFST decoder.
* decoder_test_main.cc
feed nnet output logprob, and only test decoder
* offline_decoder_sliding_chunk_main.cc
feed streaming audio feature, decode as streaming manner.
* offline_wfst_decoder_main.cc
feed streaming audio feature, decode using WFST as streaming manner.
...@@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length, ...@@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/" "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names"); "model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
...@@ -50,9 +52,13 @@ int main(int argc, char* argv[]) { ...@@ -50,9 +52,13 @@ int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
CHECK(FLAGS_result_wspecifier != "");
CHECK(FLAGS_feature_rspecifier != "");
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file; std::string dict_file = FLAGS_dict_file;
...@@ -73,6 +79,7 @@ int main(int argc, char* argv[]) { ...@@ -73,6 +79,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_graph; model_opts.model_path = model_graph;
model_opts.params_path = model_params; model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names; model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names; model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
......
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
# output
exp_dir=./exp
mkdir -p $exp_dir
# 2. download model
if [[ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# download lm
if [ ! -f $lm ]; then
pushd data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi
feat_wspecifier=$exp_dir/feats.ark
cmvn=$exp_dir/cmvn.ark
export GLOG_logtostderr=1
# dump json cmvn to kaldi
cmvn-json2kaldi \
--json_file $ckpt_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
# generate linear feature as streaming
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming
ctc-prefix-beam-search-decoder-ol \
--result_wspecifier=ark,t:$exp_dir/result.txt \
--feature_rspecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm
\ No newline at end of file
...@@ -28,6 +28,7 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); ...@@ -28,6 +28,7 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph"); DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
......
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name linear-spectrogram-wo-db-norm-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
# Deepspeech2 Straming Audio Feature
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
* linear_spectrogram_without_db_norm_main.cc
compute linear spectrogram w/o db norm in streaming manner.
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Note: Do not print/log ondemand object.
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
...@@ -29,30 +31,51 @@ int main(int argc, char* argv[]) { ...@@ -29,30 +31,51 @@ int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
ondemand::parser parser; LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
padded_string json = padded_string::load(FLAGS_json_file);
ondemand::document val = parser.iterate(json); try {
ondemand::object doc = val; padded_string json = padded_string::load(FLAGS_json_file);
kaldi::int32 frame_num = uint64_t(doc["frame_num"]);
auto mean_stat = doc["mean_stat"]; ondemand::parser parser;
std::vector<kaldi::BaseFloat> mean_stat_vec; ondemand::document doc = parser.iterate(json);
for (double x : mean_stat) { ondemand::value val = doc;
mean_stat_vec.push_back(x);
} ondemand::array mean_stat = val["mean_stat"];
auto var_stat = doc["var_stat"]; std::vector<kaldi::BaseFloat> mean_stat_vec;
std::vector<kaldi::BaseFloat> var_stat_vec; for (double x : mean_stat) {
for (double x : var_stat) { mean_stat_vec.push_back(x);
var_stat_vec.push_back(x); }
} // LOG(INFO) << mean_stat; this line will casue
// simdjson::simdjson_error("Objects and arrays can only be iterated
// when
// they are first encountered")
size_t mean_size = mean_stat_vec.size(); ondemand::array var_stat = val["var_stat"];
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1); std::vector<kaldi::BaseFloat> var_stat_vec;
for (size_t idx = 0; idx < mean_size; ++idx) { for (double x : var_stat) {
cmvn_stats(0, idx) = mean_stat_vec[idx]; var_stat_vec.push_back(x);
cmvn_stats(1, idx) = var_stat_vec[idx]; }
kaldi::int32 frame_num = uint64_t(val["frame_num"]);
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats;
kaldi::WriteKaldiObject(
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what();
} }
cmvn_stats(0, mean_size) = frame_num;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "the json file have write into " << FLAGS_cmvn_write_path;
return 0; return 0;
} }
\ No newline at end of file
...@@ -32,6 +32,7 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); ...@@ -32,6 +32,7 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
......
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/glog SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -e data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# input
data_dir=./data
exp_dir=./exp
model_dir=$data_dir/model/
mkdir -p $exp_dir
# 3. run feat
export GLOG_logtostderr=1
cmvn-json2kaldi \
--json_file $model_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data_dir/wav.scp \
--feature_wspecifier=ark,t:$exp_dir/feats.ark \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ds2-model-ol-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
\ No newline at end of file
# Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test.
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <gflags/gflags.h> // deepspeech2 online model info
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <functional> #include <functional>
...@@ -20,21 +21,26 @@ ...@@ -20,21 +21,26 @@
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include "base/flags.h"
#include "base/log.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
using std::cout; using std::cout;
using std::endl; using std::endl;
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams"); DEFINE_string(model_path, "", "xxx.pdmodel");
DEFINE_string(param_path, "", "xxx.pdiparams");
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
DEFINE_int32(feat_dim, 161, "feature dim");
void produce_data(std::vector<std::vector<float>>* data); void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test(); void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) { void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame int chunk_size = FLAGS_chunk_size; // chunk_size in frame
int col_size = 161; // feat dim int col_size = FLAGS_feat_dim; // feat dim
cout << "chunk size: " << chunk_size << endl; cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl; cout << "feat dim: " << col_size << endl;
...@@ -57,6 +63,8 @@ void model_forward_test() { ...@@ -57,6 +63,8 @@ void model_forward_test() {
; ;
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
CHECK(model_graph != "");
CHECK(model_params != "");
cout << "model path: " << model_graph << endl; cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl; cout << "model param path : " << model_params << endl;
...@@ -106,7 +114,7 @@ void model_forward_test() { ...@@ -106,7 +114,7 @@ void model_forward_test() {
// state_h // state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]); predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024}; std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape); chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size = int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(), std::accumulate(chunk_state_h_box_shape.begin(),
...@@ -119,7 +127,7 @@ void model_forward_test() { ...@@ -119,7 +127,7 @@ void model_forward_test() {
// state_c // state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]); predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024}; std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape); chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size = int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(), std::accumulate(chunk_state_c_box_shape.begin(),
...@@ -187,7 +195,9 @@ void model_forward_test() { ...@@ -187,7 +195,9 @@ void model_forward_test() {
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
model_forward_test(); model_forward_test();
return 0; return 0;
} }
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/nnet
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
ckpt_dir=./data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
ds2-model-ol-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(linear_spectrogram_without_db_norm_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_without_db_norm_main.cc)
target_include_directories(linear_spectrogram_without_db_norm_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_without_db_norm_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(cmvn_json2binary_main ${CMAKE_CURRENT_SOURCE_DIR}/cmvn_json2binary_main.cc)
target_include_directories(cmvn_json2binary_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(cmvn_json2binary_main utils kaldi-util kaldi-matrix gflags glog)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "base/kaldi-math.h"
#include "feat/feature-mfcc.h"
#include "feat/wave-reader.h"
#include "matrix/kaldi-matrix-inl.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input("test_data/test_matlab.ascii");
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
// op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp =
1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
freq3 = freq2 +
(high_freq - freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++) UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
std::vector<float> mean_{
-13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, false);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
WriteMatrix();
int32 num_done = 0, num_err = 0;
// feature pipeline: wave cache --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FrontendInterface> data_source(new
// ppspeech::DataCache());
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FrontendInterface> db_norm(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
opt.streaming_chunk = FLAGS_streaming_chunk;
opt.frame_opts.dither = 0.0;
opt.frame_opts.remove_dc_offset = false;
opt.frame_opts.window_type = "hanning";
opt.frame_opts.preemph_coeff = 0.0;
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));
std::unique_ptr<ppspeech::FrontendInterface> cmvn(new ppspeech::CMVN(
FLAGS_cmvn_write_path, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "process utt: " << utt;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Vector<BaseFloat> features;
feature_cache.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished();
}
feature_cache.Read(&features);
if (features.Dim() == 0) break;
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.Dim() / feature_cache.Dim();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feature_cache.Dim());
for (auto feat : feats) {
int num_rows = feat.Dim() / feature_cache.Dim();
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
++col_idx) {
features(cur_idx, col_idx) =
feat(row_idx * feature_cache.Dim() + col_idx);
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
feature_cache.Reset();
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
# 3. run feat
export GLOG_logtostderr=1
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})
\ No newline at end of file
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
# 4. run decoder
pp-model-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams
...@@ -92,8 +92,7 @@ void CTCBeamSearch::AdvanceDecode( ...@@ -92,8 +92,7 @@ void CTCBeamSearch::AdvanceDecode(
while (1) { while (1) {
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob; vector<BaseFloat> frame_prob;
bool flag = bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break; if (flag == false) break;
likelihood.push_back(frame_prob); likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);
......
...@@ -46,10 +46,10 @@ class LinearSpectrogram : public FrontendInterface { ...@@ -46,10 +46,10 @@ class LinearSpectrogram : public FrontendInterface {
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); } virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { virtual void Reset() {
base_extractor_->Reset(); base_extractor_->Reset();
reminded_wav_.Resize(0); reminded_wav_.Resize(0);
} }
private: private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves, bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,
......
...@@ -49,19 +49,19 @@ bool Decodable::IsLastFrame(int32 frame) { ...@@ -49,19 +49,19 @@ bool Decodable::IsLastFrame(int32 frame) {
int32 Decodable::NumIndices() const { return 0; } int32 Decodable::NumIndices() const { return 0; }
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob id. // the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob
int32 Decodable::TokenId2NnetId(int32 token_id) { // id.
return token_id - 1; int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(index, nnet_cache_.NumCols());
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_; int32 frame_idx = frame - frame_offset_;
// the nnet output is prob ranther than log prob // the nnet output is prob ranther than log prob
// the index - 1, because the ilabel // the index - 1, because the ilabel
return acoustic_scale_ * std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) + return acoustic_scale_ *
std::numeric_limits<float>::min()); std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
} }
bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {
......
...@@ -37,8 +37,7 @@ std::string ReadFile2String(const std::string& path) { ...@@ -37,8 +37,7 @@ std::string ReadFile2String(const std::string& path) {
if (!input_file.is_open()) { if (!input_file.is_open()) {
std::cerr << "please input a valid file" << std::endl; std::cerr << "please input a valid file" << std::endl;
} }
return std::string((std::istreambuf_iterator<char>(input_file)), return std::string((std::istreambuf_iterator<char>(input_file)),
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
} }
} }
...@@ -20,5 +20,4 @@ bool ReadFileToVector(const std::string& filename, ...@@ -20,5 +20,4 @@ bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data); std::vector<std::string>* data);
std::string ReadFile2String(const std::string& path); std::string ReadFile2String(const std::string& path);
} }
因为 它太大了无法显示 source diff 。你可以改为 查看blob
因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -35,66 +35,68 @@ ...@@ -35,66 +35,68 @@
*/ */
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
const char *usage = const char *usage =
"Adds self-loops to states of an FST to propagate disambiguation " "Adds self-loops to states of an FST to propagate disambiguation "
"symbols through it\n" "symbols through it\n"
"They are added on each final state and each state with non-epsilon " "They are added on each final state and each state with "
"output symbols\n" "non-epsilon "
"on at least one arc out of the state. Useful in conjunction with " "output symbols\n"
"predeterminize\n" "on at least one arc out of the state. Useful in conjunction with "
"\n" "predeterminize\n"
"Usage: fstaddselfloops in-disambig-list out-disambig-list [in.fst " "\n"
"[out.fst] ]\n" "Usage: fstaddselfloops in-disambig-list out-disambig-list "
"E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n" "[in.fst "
"in.list and out.list are lists of integers, one per line, of the\n" "[out.fst] ]\n"
"same length.\n"; "E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n"
"in.list and out.list are lists of integers, one per line, of the\n"
ParseOptions po(usage); "same length.\n";
po.Read(argc, argv);
ParseOptions po(usage);
if (po.NumArgs() < 2 || po.NumArgs() > 4) { po.Read(argc, argv);
po.PrintUsage();
exit(1); if (po.NumArgs() < 2 || po.NumArgs() > 4) {
po.PrintUsage();
exit(1);
}
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR << "fstaddselfloops: Could not read disambiguation "
"symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR << "fstaddselfloops: Could not read disambiguation "
"symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR << "fstaddselfloops: mismatch in size of disambiguation "
"symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
} }
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR
<< "fstaddselfloops: mismatch in size of disambiguation symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0; return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
} }
...@@ -56,59 +56,61 @@ bool debug_location = false; ...@@ -56,59 +56,61 @@ bool debug_location = false;
void signal_handler(int) { debug_location = true; } void signal_handler(int) { debug_location = true; }
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
const char *usage = const char *usage =
"Removes epsilons and determinizes in one step\n" "Removes epsilons and determinizes in one step\n"
"\n" "\n"
"Usage: fstdeterminizestar [in.fst [out.fst] ]\n" "Usage: fstdeterminizestar [in.fst [out.fst] ]\n"
"\n" "\n"
"See also: fstdeterminizelog, lattice-determinize\n"; "See also: fstdeterminizelog, lattice-determinize\n";
float delta = kDelta; float delta = kDelta;
int max_states = -1; int max_states = -1;
bool use_log = false; bool use_log = false;
ParseOptions po(usage); ParseOptions po(usage);
po.Register("use-log", &use_log, "Determinize in log semiring."); po.Register("use-log", &use_log, "Determinize in log semiring.");
po.Register("delta", &delta, po.Register("delta",
"Delta value used to determine equivalence of weights."); &delta,
po.Register( "Delta value used to determine equivalence of weights.");
"max-states", &max_states, po.Register("max-states",
"Maximum number of states in determinized FST before it will abort."); &max_states,
po.Read(argc, argv); "Maximum number of states in determinized FST before it "
"will abort.");
po.Read(argc, argv);
if (po.NumArgs() > 2) { if (po.NumArgs() > 2) {
po.PrintUsage(); po.PrintUsage();
exit(1); exit(1);
} }
std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2); std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2);
// This enables us to get traceback info from determinization that is // This enables us to get traceback info from determinization that is
// not seeming to terminate. // not seeming to terminate.
#if !defined(_MSC_VER) && !defined(__APPLE__) #if !defined(_MSC_VER) && !defined(__APPLE__)
signal(SIGUSR1, signal_handler); signal(SIGUSR1, signal_handler);
#endif #endif
// Normal case: just files. // Normal case: just files.
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str); VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str);
ArcSort(fst, ILabelCompare<StdArc>()); // improves speed. ArcSort(fst, ILabelCompare<StdArc>()); // improves speed.
if (use_log) { if (use_log) {
DeterminizeStarInLog(fst, delta, &debug_location, max_states); DeterminizeStarInLog(fst, delta, &debug_location, max_states);
} else { } else {
VectorFst<StdArc> det_fst; VectorFst<StdArc> det_fst;
DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states); DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states);
*fst = det_fst; // will do shallow copy and then det_fst goes *fst = det_fst; // will do shallow copy and then det_fst goes
// out of scope anyway. // out of scope anyway.
}
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
} }
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
} }
...@@ -42,50 +42,51 @@ ...@@ -42,50 +42,51 @@
// though not stochastic because we gave it an absurdly large delta. // though not stochastic because we gave it an absurdly large delta.
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
const char *usage = const char *usage =
"Checks whether an FST is stochastic and exits with success if so.\n" "Checks whether an FST is stochastic and exits with success if "
"Prints out maximum error (in log units).\n" "so.\n"
"\n" "Prints out maximum error (in log units).\n"
"Usage: fstisstochastic [ in.fst ]\n"; "\n"
"Usage: fstisstochastic [ in.fst ]\n";
float delta = 0.01; float delta = 0.01;
bool test_in_log = true; bool test_in_log = true;
ParseOptions po(usage); ParseOptions po(usage);
po.Register("delta", &delta, "Maximum error to accept."); po.Register("delta", &delta, "Maximum error to accept.");
po.Register("test-in-log", &test_in_log, po.Register(
"Test stochasticity in log semiring."); "test-in-log", &test_in_log, "Test stochasticity in log semiring.");
po.Read(argc, argv); po.Read(argc, argv);
if (po.NumArgs() > 1) { if (po.NumArgs() > 1) {
po.PrintUsage(); po.PrintUsage();
exit(1); exit(1);
} }
std::string fst_in_filename = po.GetOptArg(1); std::string fst_in_filename = po.GetOptArg(1);
Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename); Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename);
bool ans; bool ans;
StdArc::Weight min, max; StdArc::Weight min, max;
if (test_in_log) if (test_in_log)
ans = IsStochasticFstInLog(*fst, delta, &min, &max); ans = IsStochasticFstInLog(*fst, delta, &min, &max);
else else
ans = IsStochasticFst(*fst, delta, &min, &max); ans = IsStochasticFst(*fst, delta, &min, &max);
std::cout << min.Value() << " " << max.Value() << '\n'; std::cout << min.Value() << " " << max.Value() << '\n';
delete fst; delete fst;
if (ans) if (ans)
return 0; // success; return 0; // success;
else else
return 1; return 1;
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << e.what(); std::cerr << e.what();
return -1; return -1;
} }
} }
...@@ -33,42 +33,43 @@ ...@@ -33,42 +33,43 @@
*/ */
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
const char *usage = const char *usage =
"Minimizes FST after encoding [similar to fstminimize, but no " "Minimizes FST after encoding [similar to fstminimize, but no "
"weight-pushing]\n" "weight-pushing]\n"
"\n" "\n"
"Usage: fstminimizeencoded [in.fst [out.fst] ]\n"; "Usage: fstminimizeencoded [in.fst [out.fst] ]\n";
float delta = kDelta; float delta = kDelta;
ParseOptions po(usage); ParseOptions po(usage);
po.Register("delta", &delta, po.Register("delta",
"Delta likelihood used for quantization of weights"); &delta,
po.Read(argc, argv); "Delta likelihood used for quantization of weights");
po.Read(argc, argv);
if (po.NumArgs() > 2) { if (po.NumArgs() > 2) {
po.PrintUsage(); po.PrintUsage();
exit(1); exit(1);
} }
std::string fst_in_filename = po.GetOptArg(1), std::string fst_in_filename = po.GetOptArg(1),
fst_out_filename = po.GetOptArg(2); fst_out_filename = po.GetOptArg(2);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename); VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
MinimizeEncoded(fst, delta); MinimizeEncoded(fst, delta);
WriteFstKaldi(*fst, fst_out_filename); WriteFstKaldi(*fst, fst_out_filename);
delete fst; delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0; return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
} }
...@@ -37,97 +37,104 @@ ...@@ -37,97 +37,104 @@
*/ */
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
/* /*
fsttablecompose should always give equivalent results to compose, fsttablecompose should always give equivalent results to compose,
but it is more efficient for certain kinds of inputs. but it is more efficient for certain kinds of inputs.
In particular, it is useful when, say, the left FST has states In particular, it is useful when, say, the left FST has states
that typically either have epsilon olabels, or that typically either have epsilon olabels, or
one transition out for each of the possible symbols (as the one transition out for each of the possible symbols (as the
olabel). The same with the input symbols of the right-hand FST olabel). The same with the input symbols of the right-hand FST
is possible. is possible.
*/ */
const char *usage = const char *usage =
"Composition algorithm [between two FSTs of standard type, in " "Composition algorithm [between two FSTs of standard type, in "
"tropical\n" "tropical\n"
"semiring] that is more efficient for certain cases-- in particular,\n" "semiring] that is more efficient for certain cases-- in "
"where one of the FSTs (the left one, if --match-side=left) has large\n" "particular,\n"
"out-degree\n" "where one of the FSTs (the left one, if --match-side=left) has "
"\n" "large\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) " "out-degree\n"
"(fst2-rxfilename|fst2-rspecifier) [(out-rxfilename|out-rspecifier)]\n"; "\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) "
ParseOptions po(usage); "(fst2-rxfilename|fst2-rspecifier) "
"[(out-rxfilename|out-rspecifier)]\n";
TableComposeOptions opts;
std::string match_side = "left"; ParseOptions po(usage);
std::string compose_filter = "sequence";
TableComposeOptions opts;
po.Register("connect", &opts.connect, "If true, trim FST before output."); std::string match_side = "left";
po.Register("match-side", &match_side, std::string compose_filter = "sequence";
"Side of composition to do table "
"match, one of: \"left\" or \"right\"."); po.Register(
po.Register("compose-filter", &compose_filter, "connect", &opts.connect, "If true, trim FST before output.");
"Composition filter to use, " po.Register("match-side",
"one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\""); &match_side,
"Side of composition to do table "
po.Read(argc, argv); "match, one of: \"left\" or \"right\".");
po.Register(
if (match_side == "left") { "compose-filter",
opts.table_match_type = MATCH_OUTPUT; &compose_filter,
} else if (match_side == "right") { "Composition filter to use, "
opts.table_match_type = MATCH_INPUT; "one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\"");
} else {
KALDI_ERR << "Invalid match-side option: " << match_side; po.Read(argc, argv);
if (match_side == "left") {
opts.table_match_type = MATCH_OUTPUT;
} else if (match_side == "right") {
opts.table_match_type = MATCH_INPUT;
} else {
KALDI_ERR << "Invalid match-side option: " << match_side;
}
if (compose_filter == "alt_sequence") {
opts.filter_type = ALT_SEQUENCE_FILTER;
} else if (compose_filter == "auto") {
opts.filter_type = AUTO_FILTER;
} else if (compose_filter == "match") {
opts.filter_type = MATCH_FILTER;
} else if (compose_filter == "sequence") {
opts.filter_type = SEQUENCE_FILTER;
} else {
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
}
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
fst_out_str = po.GetOptArg(3);
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
KALDI_WARN << "The first FST is not olabel sorted.";
}
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
} }
if (compose_filter == "alt_sequence") {
opts.filter_type = ALT_SEQUENCE_FILTER;
} else if (compose_filter == "auto") {
opts.filter_type = AUTO_FILTER;
} else if (compose_filter == "match") {
opts.filter_type = MATCH_FILTER;
} else if (compose_filter == "sequence") {
opts.filter_type = SEQUENCE_FILTER;
} else {
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
}
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
fst_out_str = po.GetOptArg(3);
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
KALDI_WARN << "The first FST is not olabel sorted.";
}
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
} }
...@@ -24,122 +24,130 @@ ...@@ -24,122 +24,130 @@
#include "util/parse-options.h" #include "util/parse-options.h"
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
try { try {
const char *usage = const char *usage =
"Convert an ARPA format language model into an FST\n" "Convert an ARPA format language model into an FST\n"
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n" "Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table=" " e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
"data/lang/words.txt lm/input.arpa G.fst\n\n" "data/lang/words.txt lm/input.arpa G.fst\n\n"
"Note: When called without switches, the output G.fst will contain\n" "Note: When called without switches, the output G.fst will "
"an embedded symbol table. This is compatible with the way a previous\n" "contain\n"
"version of arpa2fst worked.\n"; "an embedded symbol table. This is compatible with the way a "
"previous\n"
ParseOptions po(usage); "version of arpa2fst worked.\n";
ArpaParseOptions options; ParseOptions po(usage);
options.Register(&po);
ArpaParseOptions options;
// Option flags. options.Register(&po);
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>"; // Option flags.
std::string disambig_symbol; std::string bos_symbol = "<s>";
std::string read_syms_filename; std::string eos_symbol = "</s>";
std::string write_syms_filename; std::string disambig_symbol;
bool keep_symbols = false; std::string read_syms_filename;
bool ilabel_sort = true; std::string write_syms_filename;
bool keep_symbols = false;
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol"); bool ilabel_sort = true;
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
po.Register("disambig-symbol", &disambig_symbol, po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
"Disambiguator. If provided (e. g. #0), used on input side of " po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
"backoff links, and <s> and </s> are replaced with epsilons"); po.Register(
po.Register("read-symbol-table", &read_syms_filename, "disambig-symbol",
"Use existing symbol table"); &disambig_symbol,
po.Register("write-symbol-table", &write_syms_filename, "Disambiguator. If provided (e. g. #0), used on input side of "
"Write generated symbol table to a file"); "backoff links, and <s> and </s> are replaced with epsilons");
po.Register("keep-symbols", &keep_symbols, po.Register("read-symbol-table",
"Store symbol table with FST. Symbols always saved to FST if " &read_syms_filename,
"symbol tables are neither read or written (otherwise symbols " "Use existing symbol table");
"would be lost entirely)"); po.Register("write-symbol-table",
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST"); &write_syms_filename,
"Write generated symbol table to a file");
po.Read(argc, argv); po.Register(
"keep-symbols",
if (po.NumArgs() != 1 && po.NumArgs() != 2) { &keep_symbols,
po.PrintUsage(); "Store symbol table with FST. Symbols always saved to FST if "
exit(1); "symbol tables are neither read or written (otherwise symbols "
"would be lost entirely)");
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
po.Read(argc, argv);
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
} }
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
} }
...@@ -26,9 +26,9 @@ import argparse ...@@ -26,9 +26,9 @@ import argparse
import os import os
import re import re
import subprocess import subprocess
from distutils.util import strtobool
import numpy as np import numpy as np
from distutils.util import strtobool
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License
import re, sys, unicodedata import re, sys, unicodedata
import codecs import codecs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册