未验证 提交 36df70cb 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1638 from zh794390558/spx_refactor

[speechx] refactor audio/data/feature cache
.DS_Store
*.pyc
.vscode
*log
*.log
*.wav
*.pdmodel
*.pdiparams*
......
......@@ -52,7 +52,7 @@ pull_request_rules:
add: ["T2S"]
- name: "auto add label=Audio"
conditions:
- files~=^audio/
- files~=^paddleaudio/
actions:
label:
add: ["Audio"]
......
......@@ -108,7 +108,12 @@ class SpeechSegment(AudioSegment):
token_ids)
@classmethod
def from_pcm(cls, samples, sample_rate, transcript, tokens=None, token_ids=None):
def from_pcm(cls,
samples,
sample_rate,
transcript,
tokens=None,
token_ids=None):
"""Create speech segment from pcm on online mode
Args:
samples (numpy.ndarray): Audio samples [num_samples x num_channels].
......
......@@ -18,8 +18,8 @@ from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.ws.api import setup_router as setup_ws_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
......
......@@ -11,29 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import time
from typing import Optional
import pickle
import numpy as np
from numpy import float32
import soundfile
import numpy as np
import paddle
from numpy import float32
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['ASREngine']
......@@ -141,10 +135,10 @@ class ASRServerExecutor(ASRExecutor):
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
cfg = self.config.decode
decode_batch_size = 1 # for online
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
......@@ -182,10 +176,11 @@ class ASRServerExecutor(ASRExecutor):
Returns:
[type]: [description]
"""
if "deepspeech2online" in model_type :
if "deepspeech2online" in model_type:
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
audio_len_handle = self.am_predictor.get_input_handle(
input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
......@@ -203,7 +198,8 @@ class ASRServerExecutor(ASRExecutor):
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(output_names[1])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
......@@ -341,7 +337,8 @@ class ASREngine(BaseEngine):
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type)
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens,
self.config.model_type)
def postprocess(self):
"""postprocess
......
......@@ -43,10 +43,10 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio
self.remained_audio = b''
n = int(self.sample_rate *
(self.frame_duration_ms / 1000.0) * self.sample_width)
shift_n = int(self.sample_rate *
(self.shift_ms / 1000.0) * self.sample_width)
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0
timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
......
......@@ -24,11 +24,11 @@ DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding nnet posterior probability
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
......@@ -37,6 +37,8 @@ int main(int argc, char* argv[]) {
FLAGS_nnet_prob_respecifier);
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0;
......@@ -53,6 +55,9 @@ int main(int argc, char* argv[]) {
for (; !likelihood_reader.Done(); likelihood_reader.Next()) {
string utt = likelihood_reader.Key();
const kaldi::Matrix<BaseFloat> likelihood = likelihood_reader.Value();
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << likelihood.NumRows();
LOG(INFO) << "cols: " << likelihood.NumCols();
decodable->Acceptlikelihood(likelihood);
decoder.AdvanceDecode(decodable);
std::string result;
......
......@@ -17,7 +17,7 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
......@@ -34,6 +34,7 @@ using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding speech feature, deprecated.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
......@@ -59,8 +60,7 @@ int main(int argc, char* argv[]) {
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data(
new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
LOG(INFO) << "Init decodeable.";
......
......@@ -17,7 +17,7 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
......@@ -27,12 +27,19 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
......@@ -43,6 +50,11 @@ int main(int argc, char* argv[]) {
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0;
......@@ -57,34 +69,44 @@ int main(int argc, char* argv[]) {
model_opts.cache_shape = "5-1-1024,5-1-1024";
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data(
new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = 7;
int32 chunk_stride = 4;
int32 receptive_field_length = 7;
int32 chunk_size = FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ( (feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len = chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len, feature.NumCols(), kaldi::kCopyData);
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if ( ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(ori_feature_len - chunk_idx * chunk_stride, chunk_size);
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
......
......@@ -17,10 +17,11 @@
#include "frontend/linear_spectrogram.h"
#include "base/flags.h"
#include "base/log.h"
#include "frontend/audio_cache.h"
#include "frontend/data_cache.h"
#include "frontend/feature_cache.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/normalizer.h"
#include "frontend/raw_audio.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
......@@ -170,9 +171,9 @@ int main(int argc, char* argv[]) {
// window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
// ppspeech::RawDataCache());
// ppspeech::DataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
new ppspeech::RawAudioCache());
new ppspeech::AudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> db_norm(
......
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(glog_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_test.cc)
target_link_libraries(glog_test glog)
add_executable(glog_logtostderr_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_test.cc)
target_link_libraries(glog_logtostderr_test glog)
\ No newline at end of file
# [GLOG](https://rpg.ifi.uzh.ch/docs/glog.html)
Unless otherwise specified, glog writes to the filename `/tmp/<program name>.<hostname>.<user name>.log.<severity level>.<date>.<time>.<pid>` (e.g., "/tmp/hello_world.example.com.hamaji.log.INFO.20080709-222411.10474"). By default, glog copies the log messages of severity level ERROR or FATAL to standard error (stderr) in addition to log files.
Several flags influence glog's output behavior. If the Google gflags library is installed on your machine, the configure script (see the INSTALL file in the package for detail of this script) will automatically detect and use it, allowing you to pass flags on the command line. For example, if you want to turn the flag --logtostderr on, you can start your application with the following command line:
`./your_application --logtostderr=1`
If the Google gflags library isn't installed, you set flags via environment variables, prefixing the flag name with "GLOG_", e.g.
`GLOG_logtostderr=1 ./your_application`
You can also modify flag values in your program by modifying global variables `FLAGS_*` . Most settings start working immediately after you update `FLAGS_*` . The exceptions are the flags related to destination files. For example, you might want to set `FLAGS_log_dir` before calling `google::InitGoogleLogging` . Here is an example:
∂∂
```c++
LOG(INFO) << "file";
// Most flags work immediately after updating values.
FLAGS_logtostderr = 1;
LOG(INFO) << "stderr";
FLAGS_logtostderr = 0;
// This won't change the log destination. If you want to set this
// value, you should do this before google::InitGoogleLogging .
FLAGS_log_dir = "/some/log/directory";
LOG(INFO) << "the same file";
```
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
int main(int argc, char* argv[]) {
// Initialize Google’s logging library.
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = 1;
LOG(INFO) << "Found " << 10 << " cookies";
LOG(ERROR) << "Found " << 10 << " error";
}
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
int main(int argc, char* argv[]) {
// Initialize Google’s logging library.
google::InitGoogleLogging(argv[0]);
LOG(INFO) << "Found " << 10 << " cookies";
LOG(ERROR) << "Found " << 10 << " error";
}
\ No newline at end of file
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/glog
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. run
glog_test
echo "------"
export FLAGS_logtostderr=1
glog_test
echo "------"
glog_logtostderr_test
......@@ -3,8 +3,8 @@ project(frontend)
add_library(frontend STATIC
normalizer.cc
linear_spectrogram.cc
raw_audio.cc
audio_cache.cc
feature_cache.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)
target_link_libraries(frontend PUBLIC kaldi-matrix)
\ No newline at end of file
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/raw_audio.h"
#include "frontend/audio_cache.h"
#include "kaldi/base/timer.h"
namespace ppspeech {
......@@ -21,38 +21,43 @@ using kaldi::BaseFloat;
using kaldi::VectorBase;
using kaldi::Vector;
RawAudioCache::RawAudioCache(int buffer_size)
: finished_(false), data_length_(0), start_(0), timeout_(1) {
ring_buffer_.resize(buffer_size);
AudioCache::AudioCache(int buffer_size)
: finished_(false),
capacity_(buffer_size),
size_(0),
offset_(0),
timeout_(1) {
ring_buffer_.resize(capacity_);
}
void RawAudioCache::Accept(const VectorBase<BaseFloat>& waves) {
void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
std::unique_lock<std::mutex> lock(mutex_);
while (data_length_ + waves.Dim() > ring_buffer_.size()) {
while (size_ + waves.Dim() > ring_buffer_.size()) {
ready_feed_condition_.wait(lock);
}
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + start_) % ring_buffer_.size();
int32 buffer_idx = (idx + offset_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
}
data_length_ += waves.Dim();
size_ += waves.Dim();
}
bool RawAudioCache::Read(Vector<BaseFloat>* waves) {
bool AudioCache::Read(Vector<BaseFloat>* waves) {
size_t chunk_size = waves->Dim();
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (chunk_size > data_length_) {
while (chunk_size > size_) {
// when audio is empty and no more data feed
// ready_read_condition will block in dead lock. so replace with
// timeout_
// ready_read_condition will block in dead lock,
// so replace with timeout_
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
if (elapsed > timeout_) {
if (finished_ == true) { // read last chunk data
if (finished_ == true) {
// read last chunk data
break;
}
if (chunk_size > data_length_) {
if (chunk_size > size_) {
return false;
}
}
......@@ -60,17 +65,17 @@ bool RawAudioCache::Read(Vector<BaseFloat>* waves) {
}
// read last chunk data
if (chunk_size > data_length_) {
chunk_size = data_length_;
if (chunk_size > size_) {
chunk_size = size_;
waves->Resize(chunk_size);
}
for (size_t idx = 0; idx < chunk_size; ++idx) {
int buff_idx = (start_ + idx) % ring_buffer_.size();
int buff_idx = (offset_ + idx) % ring_buffer_.size();
waves->Data()[idx] = ring_buffer_[buff_idx];
}
data_length_ -= chunk_size;
start_ = (start_ + chunk_size) % ring_buffer_.size();
size_ -= chunk_size;
offset_ = (offset_ + chunk_size) % ring_buffer_.size();
ready_feed_condition_.notify_one();
return true;
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
namespace ppspeech {
// waves cache
class AudioCache : public FeatureExtractorInterface {
public:
explicit AudioCache(int buffer_size = kint16max);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// the audio dim is 1, one sample
virtual size_t Dim() const { return 1; }
virtual void SetFinished() {
std::lock_guard<std::mutex> lock(mutex_);
finished_ = true;
}
virtual bool IsFinished() const { return finished_; }
virtual void Reset() {
offset_ = 0;
size_ = 0;
finished_ = false;
}
private:
std::vector<kaldi::BaseFloat> ring_buffer_;
size_t offset_; // offset in ring_buffer_
size_t size_; // samples in ring_buffer_ now
size_t capacity_; // capacity of ring_buffer_
bool finished_; // reach audio end
mutable std::mutex mutex_;
std::condition_variable ready_feed_condition_;
kaldi::int32 timeout_; // millisecond
DISALLOW_COPY_AND_ASSIGN(AudioCache);
};
} // namespace ppspeech
......@@ -15,51 +15,22 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#pragma once
namespace ppspeech {
class RawAudioCache : public FeatureExtractorInterface {
// A data source for testing different frontend module.
// It accepts waves or feats.
class DataCache : public FeatureExtractorInterface {
public:
explicit RawAudioCache(int buffer_size = kint16max);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// the audio dim is 1
virtual size_t Dim() const { return 1; }
virtual void SetFinished() {
std::lock_guard<std::mutex> lock(mutex_);
finished_ = true;
}
virtual bool IsFinished() const { return finished_; }
virtual void Reset() {
start_ = 0;
data_length_ = 0;
finished_ = false;
}
private:
std::vector<kaldi::BaseFloat> ring_buffer_;
size_t start_;
size_t data_length_;
bool finished_;
mutable std::mutex mutex_;
std::condition_variable ready_feed_condition_;
kaldi::int32 timeout_;
DISALLOW_COPY_AND_ASSIGN(RawAudioCache);
};
explicit DataCache() { finished_ = false; }
// it is a datasource for testing different frontend module.
// it accepts waves or feats.
class RawDataCache : public FeatureExtractorInterface {
public:
explicit RawDataCache() { finished_ = false; }
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
data_ = inputs;
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
if (data_.Dim() == 0) {
return false;
......@@ -80,7 +51,6 @@ class RawDataCache : public FeatureExtractorInterface {
bool finished_;
int32 dim_;
DISALLOW_COPY_AND_ASSIGN(RawDataCache);
DISALLOW_COPY_AND_ASSIGN(DataCache);
};
} // namespace ppspeech
}
\ No newline at end of file
......@@ -82,7 +82,7 @@ void Decodable::Reset() {
if (nnet_ != nullptr) nnet_->Reset();
frame_offset_ = 0;
frames_ready_ = 0;
nnet_cache_.Resize(0,0);
nnet_cache_.Resize(0, 0);
}
} // namespace ppspeech
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册