提交 c15278ed 编写于 作者: H Hui Zhang

format

上级 94327238
......@@ -26,9 +26,8 @@ def get_audios(path):
"""
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [
item
for sublist in [[os.path.join(dir, file) for file in files]
for dir, _, files in list(os.walk(path))]
item for sublist in [[os.path.join(dir, file) for file in files]
for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats
]
......
......@@ -62,4 +62,4 @@ I0513 10:58:13.884493 41768 feature_cache.h:52] set finished
I0513 10:58:24.247171 41768 paddle_nnet.h:76] Tensor neml: 10240
I0513 10:58:24.247249 41768 paddle_nnet.h:76] Tensor neml: 10240
LOG ([5.5.544~2-f21d7]:main():decoder/recognizer_test_main.cc:90) the result of case_10 is 五月十二日二十二点三十六分加班打车回家四十一元
```
\ No newline at end of file
```
......@@ -13,9 +13,7 @@
# limitations under the License.
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
import argparse
import asyncio
import codecs
......
......@@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
## 4. 快速开始
关于如果使用 PP-ASR,可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单****中等****困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。
......@@ -24,11 +24,11 @@ from typing import Any
from typing import Dict
import paddle
import paddleaudio
import requests
import yaml
from paddle.framework import load
import paddleaudio
from . import download
from .entry import commands
try:
......
......@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -153,8 +153,7 @@ class PaddleASRConnectionHanddler:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
feat = self.collate_fn_test.augmentation.transform_feature(
spectrum)
feat = self.collate_fn_test.augmentation.transform_feature(spectrum)
# audio_len is frame num
frame_num = feat.shape[0]
......@@ -189,14 +188,16 @@ class PaddleASRConnectionHanddler:
assert samples.ndim == 1
self.num_samples += samples.shape[0]
logger.info(f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}")
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1 # (T,)
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
......@@ -216,8 +217,8 @@ class PaddleASRConnectionHanddler:
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
......@@ -234,18 +235,16 @@ class PaddleASRConnectionHanddler:
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
else:
raise ValueError(f"not supported: {self.model_type}")
raise ValueError(f"not supported: {self.model_type}")
def reset(self):
if "deepspeech2" in self.model_type:
......@@ -263,12 +262,11 @@ class PaddleASRConnectionHanddler:
# global sample and frame step
self.num_samples = 0
self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
# partial/ending decoding results
self.result_transcripts = ['']
......@@ -280,17 +278,16 @@ class PaddleASRConnectionHanddler:
self.conformer_cnn_cache = None
self.encoder_out = None
# conformer decoding state
self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 # global offset in decoding frame unit
self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 # global offset in decoding frame unit
self.hyps = []
# token timestamp result
self.word_time_stamp = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
def decode(self, is_finished=False):
"""advance decoding
......@@ -307,7 +304,7 @@ class PaddleASRConnectionHanddler:
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit
cached_feature_num = context - subsampling
# decoding window for model, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context
......@@ -373,7 +370,6 @@ class PaddleASRConnectionHanddler:
else:
raise Exception("invalid model name")
@paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens):
"""forward one chunk frames
......@@ -425,10 +421,11 @@ class PaddleASRConnectionHanddler:
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0]
@paddle.no_grad()
def advance_decoding(self, is_finished=False):
logger.info("Conformer/Transformer: start to decode with advanced_decoding method")
logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit
......@@ -563,7 +560,6 @@ class PaddleASRConnectionHanddler:
"""
return self.word_time_stamp
@paddle.no_grad()
def rescoring(self):
"""Second-Pass Decoding,
......@@ -572,9 +568,11 @@ class PaddleASRConnectionHanddler:
if "deepspeech2" in self.model_type:
logger.info("deepspeech2 not support rescoring decoding.")
return
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
logger.info(f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring")
logger.info(
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
)
return
logger.info("rescoring the final result")
......@@ -605,7 +603,8 @@ class PaddleASRConnectionHanddler:
hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_pad = pad_sequence(
hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,)
......@@ -689,12 +688,11 @@ class PaddleASRConnectionHanddler:
"ed": end
})
# logger.info(f"{word_time_stamp[-1]}")
self.word_time_stamp = word_time_stamp
logger.info(f"word time stamp: {self.word_time_stamp}")
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
......@@ -741,7 +739,7 @@ class ASRServerExecutor(ASRExecutor):
self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.am_model)
......@@ -855,7 +853,7 @@ class ASRServerExecutor(ASRExecutor):
self.transformer_decode_reset()
else:
raise ValueError(f"Not support: {model_type}")
return True
......
......@@ -17,12 +17,12 @@ from typing import List
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.restful.acs_api import router as acs_router
from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router
from paddlespeech.server.restful.vector_api import router as vec_router
from paddlespeech.server.restful.acs_api import router as acs_router
_router = APIRouter()
......
......@@ -248,7 +248,7 @@ class ASRHttpHandler:
}
res = requests.post(url=self.url, data=json.dumps(data))
return res.json()
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
class Frame(object):
"""Represents a "frame" of audio data."""
......@@ -45,7 +46,7 @@ class ChunkBuffer(object):
self.shift_ms = shift_ms
self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4
self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
......@@ -77,8 +78,8 @@ class ChunkBuffer(object):
offset = 0
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], self.timestamp,
self.window_sec)
yield Frame(audio[offset:offset + self.window_bytes],
self.timestamp, self.window_sec)
self.timestamp += self.shift_sec
offset += self.shift_bytes
......
......@@ -176,7 +176,10 @@ def main():
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.")
parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.")
"--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
args, _ = parser.parse_known_args()
......
......@@ -188,7 +188,10 @@ def main():
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.")
"--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu")
......
......@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
MultiSequential: Repeated model instance.
"""
return MultiSequential(*[fn(n) for n in range(N)])
return MultiSequential(* [fn(n) for n in range(N)])
......@@ -98,7 +98,6 @@ requirements = {
}
def check_call(cmd: str, shell=False, executable=None):
try:
sp.check_call(
......@@ -112,12 +111,13 @@ def check_call(cmd: str, shell=False, executable=None):
file=sys.stderr)
raise e
def check_output(cmd: str, shell=False):
try:
out_bytes = sp.check_output(cmd.split())
except sp.CalledProcessError as e:
out_bytes = e.output # Output generated before error
code = e.returncode # Return code
out_bytes = e.output # Output generated before error
code = e.returncode # Return code
print(
f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:",
out_bytes,
......@@ -146,6 +146,7 @@ def _remove(files: str):
for f in files:
f.unlink()
################################# Install ##################################
......@@ -308,6 +309,5 @@ setup_info = dict(
]
})
with version_info():
setup(**setup_info)
......@@ -20,7 +20,6 @@ of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
from pathlib import Path
......@@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix):
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(audio_path)
reference_lines.append(str(total_num+1) + "\t" + text)
reference_lines.append(str(total_num + 1) + "\t" + text)
total_sec += duration
total_text += len(text)
......@@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
manifest_dir = os.path.dirname(manifest_path_prefix)
def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册