提交 cb39777a 编写于 作者: H Hui Zhang

format code

上级 760e5d44
......@@ -108,7 +108,12 @@ class SpeechSegment(AudioSegment):
token_ids)
@classmethod
def from_pcm(cls, samples, sample_rate, transcript, tokens=None, token_ids=None):
def from_pcm(cls,
samples,
sample_rate,
transcript,
tokens=None,
token_ids=None):
"""Create speech segment from pcm on online mode
Args:
samples (numpy.ndarray): Audio samples [num_samples x num_channels].
......
......@@ -18,8 +18,8 @@ from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.ws.api import setup_router as setup_ws_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
......
......@@ -11,29 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import time
from typing import Optional
import pickle
import numpy as np
from numpy import float32
import soundfile
import numpy as np
import paddle
from numpy import float32
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['ASREngine']
......@@ -141,10 +135,10 @@ class ASRServerExecutor(ASRExecutor):
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
cfg = self.config.decode
decode_batch_size = 1 # for online
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
......@@ -182,10 +176,11 @@ class ASRServerExecutor(ASRExecutor):
Returns:
[type]: [description]
"""
if "deepspeech2online" in model_type :
if "deepspeech2online" in model_type:
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
audio_len_handle = self.am_predictor.get_input_handle(
input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
......@@ -203,7 +198,8 @@ class ASRServerExecutor(ASRExecutor):
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(output_names[1])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
......@@ -341,7 +337,8 @@ class ASREngine(BaseEngine):
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type)
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens,
self.config.model_type)
def postprocess(self):
"""postprocess
......
......@@ -43,10 +43,10 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio
self.remained_audio = b''
n = int(self.sample_rate *
(self.frame_duration_ms / 1000.0) * self.sample_width)
shift_n = int(self.sample_rate *
(self.shift_ms / 1000.0) * self.sample_width)
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0
timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册