未验证 提交 f2796062 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1563 from SmileGoat/rename_author

[speechx]align nnet & decoder
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc) add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
\ No newline at end of file
...@@ -17,50 +17,75 @@ ...@@ -17,50 +17,75 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test nnet prob"); DEFINE_string(feature_respecifier, "", "test feature rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier); FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts); ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data(
new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet)); new ppspeech::Decodable(nnet, raw_data));
// int32 chunk_size = 35; int32 chunk_size = 35;
decoder.InitDecoder(); decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value(); const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature); raw_data->SetDim(feature.NumCols());
decoder.AdvanceDecode(decodable, 8); int32 row_idx = 0;
decodable->InputFinished(); int32 num_chunks = feature.NumRows() / chunk_size;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
row_idx++;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result; std::string result;
result = decoder.GetFinalBestPath(); result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result; KALDI_LOG << " the result of " << utt << " is " << result;
......
...@@ -79,21 +79,19 @@ void CTCBeamSearch::Decode( ...@@ -79,21 +79,19 @@ void CTCBeamSearch::Decode(
return; return;
} }
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; } int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; }
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode( void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable, const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
int max_frames) { while (1) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { vector<BaseFloat> frame_prob;
break; bool flag =
} decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob);
likelihood.push_back( if (flag == false) break;
decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);
max_frames--;
} }
} }
......
...@@ -32,8 +32,8 @@ struct CTCBeamSearchOptions { ...@@ -32,8 +32,8 @@ struct CTCBeamSearchOptions {
int cutoff_top_n; int cutoff_top_n;
int num_proc_bsearch; int num_proc_bsearch;
CTCBeamSearchOptions() CTCBeamSearchOptions()
: dict_file("./model/words.txt"), : dict_file("vocab.txt"),
lm_path("./model/lm.arpa"), lm_path("lm.klm"),
alpha(1.9f), alpha(1.9f),
beta(5.0), beta(5.0),
beam_size(300), beam_size(300),
...@@ -68,8 +68,7 @@ class CTCBeamSearch { ...@@ -68,8 +68,7 @@ class CTCBeamSearch {
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void AdvanceDecode( void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable, const std::shared_ptr<kaldi::DecodableInterface>& decodable);
int max_frames);
void Reset(); void Reset();
private: private:
...@@ -83,7 +82,6 @@ class CTCBeamSearch { ...@@ -83,7 +82,6 @@ class CTCBeamSearch {
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
// std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id; size_t blank_id;
int space_id; int space_id;
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#pragma once
namespace ppspeech { namespace ppspeech {
class RawAudioCache : public FeatureExtractorInterface { class RawAudioCache : public FeatureExtractorInterface {
...@@ -45,13 +47,12 @@ class RawAudioCache : public FeatureExtractorInterface { ...@@ -45,13 +47,12 @@ class RawAudioCache : public FeatureExtractorInterface {
DISALLOW_COPY_AND_ASSIGN(RawAudioCache); DISALLOW_COPY_AND_ASSIGN(RawAudioCache);
}; };
// it is a data source to test different frontend module. // it is a datasource for testing different frontend module.
// it Accepts waves or feats. // it accepts waves or feats.
class RawDataCache: public FeatureExtractorInterface { class RawDataCache : public FeatureExtractorInterface {
public: public:
explicit RawDataCache() { finished_ = false; } explicit RawDataCache() { finished_ = false; }
virtual void Accept( virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
data_ = inputs; data_ = inputs;
} }
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) { virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
...@@ -62,14 +63,15 @@ class RawDataCache: public FeatureExtractorInterface { ...@@ -62,14 +63,15 @@ class RawDataCache: public FeatureExtractorInterface {
data_.Resize(0); data_.Resize(0);
return true; return true;
} }
//the dim is data_ length virtual size_t Dim() const { return dim_; }
virtual size_t Dim() const { return data_.Dim(); }
virtual void SetFinished() { finished_ = true; } virtual void SetFinished() { finished_ = true; }
virtual bool IsFinished() const { return finished_; } virtual bool IsFinished() const { return finished_; }
void SetDim(int32 dim) { dim_ = dim; }
private: private:
kaldi::Vector<kaldi::BaseFloat> data_; kaldi::Vector<kaldi::BaseFloat> data_;
bool finished_; bool finished_;
int32 dim_;
DISALLOW_COPY_AND_ASSIGN(RawDataCache); DISALLOW_COPY_AND_ASSIGN(RawDataCache);
}; };
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h // itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University; // Copyright 2009-2011 Microsoft Corporation; Saarland University;
...@@ -56,10 +42,8 @@ namespace kaldi { ...@@ -56,10 +42,8 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions. important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in There are two ways these are used: the old online-decoding code, in ../online/,
../online/, and the new online-decoding code, in ../online2/. In the old online-decoding
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do: code, the decoder would do:
\code{.cc} \code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
...@@ -68,16 +52,13 @@ namespace kaldi { ...@@ -68,16 +52,13 @@ namespace kaldi {
\endcode \endcode
and the call to IsLastFrame would block if the features had not arrived yet. and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, online-decoding mode is still supported, it is what happens when you call, for
for
example, LatticeFasterDecoder::Decode(). example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for control endpointing. In the "new" decoding code, you don't call (for example)
example) LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this: something like this:
...@@ -87,8 +68,7 @@ namespace kaldi { ...@@ -87,8 +68,7 @@ namespace kaldi {
} }
\endcode \endcode
So the decodable object never has IsLastFrame() called. For decoding where So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function you are starting with a matrix of features, the NumFramesReady() function will
will
always just return the number of frames in the file, and IsLastFrame() will always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame. return true for the last frame.
...@@ -102,39 +82,30 @@ namespace kaldi { ...@@ -102,39 +82,30 @@ namespace kaldi {
class DecodableInterface { class DecodableInterface {
public: public:
/// Returns the log likelihood, which will be negated in the decoder. /// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > /// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// frame
/// before calling this. /// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the /// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file /// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so /// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online /// be careful). Caution: the behavior of this function in an online setting
/// setting
/// is being changed somewhat. In future it may return false in cases where /// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we /// we haven't yet decided to terminate decoding, but later true if we decide
/// decide
/// to terminate decoding. The plan in future is to rely more on /// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false /// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a /// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or /// decoding-from-matrix setting where we want to allow the last delta or LDA
/// LDA
/// features to be flushed out for compatibility with the baseline setup. /// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0; virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently /// The call NumFramesReady() will return the number of frames currently available
/// available /// for this decodable object. This is for use in setups where you don't want the
/// for this decodable object. This is for use in setups where you don't /// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// want the /// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// know when to stop decoding. /// know when to stop decoding.
virtual int32 NumFramesReady() const { virtual int32 NumFramesReady() const {
KALDI_ERR KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
<< "NumFramesReady() not implemented for this decodable type.";
return -1; return -1;
} }
...@@ -143,7 +114,9 @@ class DecodableInterface { ...@@ -143,7 +114,9 @@ class DecodableInterface {
/// this is for compatibility with OpenFst). /// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0; virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0; virtual bool FrameLogLikelihood(int32 frame,
std::vector<kaldi::BaseFloat>* likelihood) = 0;
virtual ~DecodableInterface() {} virtual ~DecodableInterface() {}
}; };
......
...@@ -18,9 +18,16 @@ namespace ppspeech { ...@@ -18,9 +18,16 @@ namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector;
using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet) Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
: frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {} const std::shared_ptr<FeatureExtractorInterface>& frontend)
: frontend_(frontend),
nnet_(nnet),
finished_(false),
frame_offset_(0),
frames_ready_(0) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows(); frames_ready_ += likelihood.NumRows();
...@@ -31,26 +38,46 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { ...@@ -31,26 +38,46 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
bool Decodable::IsLastFrame(int32 frame) const { bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1); return IsInputFinished() && (frame == frames_ready_ - 1);
} }
int32 Decodable::NumIndices() const { return 0; } int32 Decodable::NumIndices() const { return 0; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; } BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols());
return 0;
}
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {
nnet_->FeedForward(features, &nnet_cache_); if (frame >= frames_ready_) {
return AdvanceChunk();
}
return true;
}
bool Decodable::AdvanceChunk() {
Vector<BaseFloat> features;
if (frontend_->Read(&features) == false) {
return false;
}
int32 nnet_dim = 0;
Vector<BaseFloat> inferences;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences);
frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return; return true;
} }
std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) { bool Decodable::FrameLogLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
std::vector<BaseFloat> result; std::vector<BaseFloat> result;
result.reserve(nnet_cache_.NumCols()); if (EnsureFrameHaveComputed(frame) == false) return false;
likelihood->resize(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
result[idx] = nnet_cache_(frame, idx); (*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx);
} }
return result; return true;
} }
void Decodable::Reset() { void Decodable::Reset() {
......
...@@ -24,25 +24,35 @@ struct DecodableOpts; ...@@ -24,25 +24,35 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet); explicit Decodable(
const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend);
// void Init(DecodableOpts config); // void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const; virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const; virtual int32 NumIndices() const;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame); virtual bool FrameLogLikelihood(int32 frame,
void Acceptlikelihood( std::vector<kaldi::BaseFloat>* likelihood);
const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later // for offline test
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
feature); // only for test, todo remove later
void Reset(); void Reset();
void InputFinished() { finished_ = true; } bool IsInputFinished() const { return frontend_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame);
private: private:
bool AdvanceChunk();
std::shared_ptr<FeatureExtractorInterface> frontend_; std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_; kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;
bool finished_; bool finished_;
int32 frame_offset_;
int32 frames_ready_; int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame
// eg: 35 frame features output 8 frame inferences
// so use subsampled_frame
int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_;
}; };
} // namespace ppspeech } // namespace ppspeech
...@@ -23,8 +23,10 @@ namespace ppspeech { ...@@ -23,8 +23,10 @@ namespace ppspeech {
class NnetInterface { class NnetInterface {
public: public:
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) = 0; int32 feature_dim,
kaldi::Vector<kaldi::BaseFloat>* inferences,
int32* inference_dim) = 0;
virtual void Reset() = 0; virtual void Reset() = 0;
virtual ~NnetInterface() {} virtual ~NnetInterface() {}
}; };
......
...@@ -21,6 +21,7 @@ using std::vector; ...@@ -21,6 +21,7 @@ using std::vector;
using std::string; using std::string;
using std::shared_ptr; using std::shared_ptr;
using kaldi::Matrix; using kaldi::Matrix;
using kaldi::Vector;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names; std::vector<std::string> cache_names;
...@@ -143,34 +144,27 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) { ...@@ -143,34 +144,27 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return cache_encouts_[iter->second]; return cache_encouts_[iter->second];
} }
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
Matrix<BaseFloat>* inferences) { int32 feature_dim,
Vector<BaseFloat>* inferences,
int32* inference_dim) {
paddle_infer::Predictor* predictor = GetPredictor(); paddle_infer::Predictor* predictor = GetPredictor();
int row = features.NumRows(); int feat_row = features.Dim() / feature_dim;
int col = features.NumCols();
std::vector<BaseFloat> feed_feature;
// todo refactor feed feature: SmileGoat
feed_feature.reserve(row * col);
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
}
}
std::vector<std::string> input_names = predictor->GetInputNames(); std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames(); std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col= " << col; LOG(INFO) << "feat info: rows, cols: " << feat_row << ", " << feature_dim;
std::unique_ptr<paddle_infer::Tensor> input_tensor = std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]); predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col}; std::vector<int> INPUT_SHAPE = {1, feat_row, feature_dim};
input_tensor->Reshape(INPUT_SHAPE); input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data()); input_tensor->CopyFromCpu(features.Data());
std::unique_ptr<paddle_infer::Tensor> input_len = std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]); predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1}; std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size); input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len; std::vector<int64_t> audio_len;
audio_len.push_back(row); audio_len.push_back(feat_row);
input_len->CopyFromCpu(audio_len.data()); input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box = std::unique_ptr<paddle_infer::Tensor> h_box =
...@@ -203,20 +197,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, ...@@ -203,20 +197,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
std::unique_ptr<paddle_infer::Tensor> output_tensor = std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]); predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape(); std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1]; int32 row = output_shape[1];
col = output_shape[2]; int32 col = output_shape[2];
vector<float> inferences_result; inferences->Resize(row * col);
inferences->Resize(row, col); *inference_dim = col;
inferences_result.resize(row * col); output_tensor->CopyToCpu(inferences->Data());
output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor); ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) =
inferences_result[col * row_idx + col_idx];
}
}
} }
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
...@@ -39,12 +39,8 @@ struct ModelOptions { ...@@ -39,12 +39,8 @@ struct ModelOptions {
bool enable_fc_padding; bool enable_fc_padding;
bool enable_profile; bool enable_profile;
ModelOptions() ModelOptions()
: model_path( : model_path("avg_1.jit.pdmodel"),
"../../../../model/paddle_online_deepspeech/model/" params_path("avg_1.jit.pdiparams"),
"avg_1.jit.pdmodel"),
params_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"),
thread_num(2), thread_num(2),
use_gpu(false), use_gpu(false),
input_names( input_names(
...@@ -107,8 +103,11 @@ class Tensor { ...@@ -107,8 +103,11 @@ class Tensor {
class PaddleNnet : public NnetInterface { class PaddleNnet : public NnetInterface {
public: public:
PaddleNnet(const ModelOptions& opts); PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences); int32 feature_dim,
kaldi::Vector<kaldi::BaseFloat>* inferences,
int32* inference_dim);
void Dim();
virtual void Reset(); virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder( std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name); const std::string& name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册