未验证 提交 bb07144c 编写于 作者: Y YangZhou 提交者: GitHub

Merge pull request #1559 from SmileGoat/align_nnet_decoder

align nnet decoder & refactor
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc)
target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
\ No newline at end of file
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
......@@ -17,50 +17,75 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
DEFINE_string(feature_respecifier, "", "test feature rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data(
new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet));
new ppspeech::Decodable(nnet, raw_data));
// int32 chunk_size = 35;
int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
raw_data->SetDim(feature.NumCols());
int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
row_idx++;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
......
......@@ -79,21 +79,19 @@ void CTCBeamSearch::Decode(
return;
}
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; }
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; }
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) {
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(
decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
vector<BaseFloat> frame_prob;
bool flag =
decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
max_frames--;
}
}
......
......@@ -32,8 +32,8 @@ struct CTCBeamSearchOptions {
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions()
: dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"),
: dict_file("vocab.txt"),
lm_path("lm.klm"),
alpha(1.9f),
beta(5.0),
beam_size(300),
......@@ -68,8 +68,7 @@ class CTCBeamSearch {
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private:
......@@ -83,7 +82,6 @@ class CTCBeamSearch {
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
// std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
......
......@@ -18,6 +18,8 @@
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#pragma once
namespace ppspeech {
class RawAudioCache : public FeatureExtractorInterface {
......@@ -45,13 +47,12 @@ class RawAudioCache : public FeatureExtractorInterface {
DISALLOW_COPY_AND_ASSIGN(RawAudioCache);
};
// it is a data source to test different frontend module.
// it Accepts waves or feats.
class RawDataCache: public FeatureExtractorInterface {
// it is a datasource for testing different frontend module.
// it accepts waves or feats.
class RawDataCache : public FeatureExtractorInterface {
public:
explicit RawDataCache() { finished_ = false; }
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
data_ = inputs;
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
......@@ -62,14 +63,15 @@ class RawDataCache: public FeatureExtractorInterface {
data_.Resize(0);
return true;
}
//the dim is data_ length
virtual size_t Dim() const { return data_.Dim(); }
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { finished_ = true; }
virtual bool IsFinished() const { return finished_; }
void SetDim(int32 dim) { dim_ = dim; }
private:
kaldi::Vector<kaldi::BaseFloat> data_;
bool finished_;
int32 dim_;
DISALLOW_COPY_AND_ASSIGN(RawDataCache);
};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
......@@ -56,10 +42,8 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in
../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
......@@ -68,16 +52,13 @@ namespace kaldi {
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call,
for
online-decoding mode is still supported, it is what happens when you call, for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for
example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
......@@ -87,8 +68,7 @@ namespace kaldi {
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function
will
you are starting with a matrix of features, the NumFramesReady() function will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
......@@ -102,39 +82,30 @@ namespace kaldi {
class DecodableInterface {
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online
/// setting
/// be careful). Caution: the behavior of this function in an online setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we
/// decide
/// we haven't yet decided to terminate decoding, but later true if we decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently
/// available
/// for this decodable object. This is for use in setups where you don't
/// want the
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// The call NumFramesReady() will return the number of frames currently available
/// for this decodable object. This is for use in setups where you don't want the
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding.
virtual int32 NumFramesReady() const {
KALDI_ERR
<< "NumFramesReady() not implemented for this decodable type.";
KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
return -1;
}
......@@ -143,7 +114,9 @@ class DecodableInterface {
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0;
virtual bool FrameLogLikelihood(int32 frame,
std::vector<kaldi::BaseFloat>* likelihood) = 0;
virtual ~DecodableInterface() {}
};
......
......@@ -18,9 +18,16 @@ namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet)
: frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {}
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend)
: frontend_(frontend),
nnet_(nnet),
finished_(false),
frame_offset_(0),
frames_ready_(0) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
......@@ -31,26 +38,46 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1);
return IsInputFinished() && (frame == frames_ready_ - 1);
}
int32 Decodable::NumIndices() const { return 0; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols());
return 0;
}
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_);
bool Decodable::EnsureFrameHaveComputed(int32 frame) {
if (frame >= frames_ready_) {
return AdvanceChunk();
}
return true;
}
bool Decodable::AdvanceChunk() {
Vector<BaseFloat> features;
if (frontend_->Read(&features) == false) {
return false;
}
int32 nnet_dim = 0;
Vector<BaseFloat> inferences;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences);
frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows();
return;
return true;
}
std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) {
bool Decodable::FrameLogLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
std::vector<BaseFloat> result;
result.reserve(nnet_cache_.NumCols());
if (EnsureFrameHaveComputed(frame) == false) return false;
likelihood->resize(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
result[idx] = nnet_cache_(frame, idx);
(*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx);
}
return result;
return true;
}
void Decodable::Reset() {
......
......@@ -24,25 +24,35 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
explicit Decodable(
const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend);
// void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Acceptlikelihood(
const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>&
feature); // only for test, todo remove later
virtual bool FrameLogLikelihood(int32 frame,
std::vector<kaldi::BaseFloat>* likelihood);
// for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
void Reset();
void InputFinished() { finished_ = true; }
bool IsInputFinished() const { return frontend_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame);
private:
bool AdvanceChunk();
std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;
bool finished_;
int32 frame_offset_;
int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame
// eg: 35 frame features output 8 frame inferences
// so use subsampled_frame
int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_;
};
} // namespace ppspeech
......@@ -23,8 +23,10 @@ namespace ppspeech {
class NnetInterface {
public:
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) = 0;
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
int32 feature_dim,
kaldi::Vector<kaldi::BaseFloat>* inferences,
int32* inference_dim) = 0;
virtual void Reset() = 0;
virtual ~NnetInterface() {}
};
......
......@@ -21,6 +21,7 @@ using std::vector;
using std::string;
using std::shared_ptr;
using kaldi::Matrix;
using kaldi::Vector;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names;
......@@ -143,34 +144,27 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return cache_encouts_[iter->second];
}
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
Matrix<BaseFloat>* inferences) {
void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
int32 feature_dim,
Vector<BaseFloat>* inferences,
int32* inference_dim) {
paddle_infer::Predictor* predictor = GetPredictor();
int row = features.NumRows();
int col = features.NumCols();
std::vector<BaseFloat> feed_feature;
// todo refactor feed feature: SmileGoat
feed_feature.reserve(row * col);
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
}
}
int feat_row = features.Dim() / feature_dim;
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col= " << col;
LOG(INFO) << "feat info: rows, cols: " << feat_row << ", " << feature_dim;
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
std::vector<int> INPUT_SHAPE = {1, feat_row, feature_dim};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data());
input_tensor->CopyFromCpu(features.Data());
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
audio_len.push_back(feat_row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box =
......@@ -203,20 +197,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
col = output_shape[2];
vector<float> inferences_result;
inferences->Resize(row, col);
inferences_result.resize(row * col);
output_tensor->CopyToCpu(inferences_result.data());
int32 row = output_shape[1];
int32 col = output_shape[2];
inferences->Resize(row * col);
*inference_dim = col;
output_tensor->CopyToCpu(inferences->Data());
ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) =
inferences_result[col * row_idx + col_idx];
}
}
}
} // namespace ppspeech
\ No newline at end of file
......@@ -39,12 +39,8 @@ struct ModelOptions {
bool enable_fc_padding;
bool enable_profile;
ModelOptions()
: model_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdmodel"),
params_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"),
: model_path("avg_1.jit.pdmodel"),
params_path("avg_1.jit.pdiparams"),
thread_num(2),
use_gpu(false),
input_names(
......@@ -107,8 +103,11 @@ class Tensor {
class PaddleNnet : public NnetInterface {
public:
PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
int32 feature_dim,
kaldi::Vector<kaldi::BaseFloat>* inferences,
int32* inference_dim);
void Dim();
virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册