提交 616fc459 编写于 作者: H Hui Zhang

refactor options

上级 17ea30e7
......@@ -14,7 +14,7 @@ ctc_prefix_beam_search_decoder_main \
--model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--downsampling_rate=4 \
--subsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--result_wspecifier=ark,t:$exp/result.ark
......
......@@ -15,7 +15,7 @@ u2_nnet_main \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--downsampling_rate=4 \
--subsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
......
......@@ -16,7 +16,7 @@ u2_recognizer_main \
--model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--downsampling_rate=4 \
--subsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \
--wav_rspecifier=scp:$data/wav.scp \
--result_wspecifier=ark,t:$exp/result.ark
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder STATIC
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
u2_recognizer.cc
set(decoder_src )
if (USING_DS2)
list(APPEND decoder_src
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
)
endif()
if (USING_U2)
list(APPEND decoder_src
ctc_prefix_beam_search_decoder.cc
u2_recognizer.cc
)
endif()
add_library(decoder STATIC ${decoder_src})
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
# test
set(BINS
ctc_beam_search_decoder_main
nnet_logprob_decoder_main
recognizer_main
ctc_tlg_decoder_main
)
if (USING_DS2)
set(BINS
ctc_beam_search_decoder_main
nnet_logprob_decoder_main
recognizer_main
ctc_tlg_decoder_main
)
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach()
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach()
endif()
# u2
set(TEST_BINS
u2_recognizer_main
ctc_prefix_beam_search_decoder_main
)
if (USING_U2)
set(TEST_BINS
ctc_prefix_beam_search_decoder_main
u2_recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach()
endif()
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach()
\ No newline at end of file
......@@ -31,7 +31,7 @@ DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
DEFINE_int32(subsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
......@@ -81,13 +81,8 @@ int main(int argc, char* argv[]) {
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_path;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
......@@ -95,8 +90,8 @@ int main(int argc, char* argv[]) {
new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
......
......@@ -30,7 +30,7 @@ DEFINE_string(model_path, "", "paddle nnet model");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
DEFINE_int32(subsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
......@@ -81,8 +81,8 @@ int main(int argc, char* argv[]) {
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
......
......@@ -20,15 +20,37 @@
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
DECLARE_string(graph_path);
DECLARE_string(word_symbol_table);
DECLARE_int32(max_active);
DECLARE_double(beam);
DECLARE_double(lattice_beam);
namespace ppspeech {
struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts;
kaldi::LatticeFasterDecoderConfig opts{};
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
std::string word_symbol_table{};
std::string fst_path{};
static TLGDecoderOptions InitFromFlags(){
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
LOG(INFO) << "LatticeFasterDecoder max active: " << decoder_opts.opts.max_active ;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam ;
LOG(INFO) << "LatticeFasterDecoder lattice_beam: " << decoder_opts.opts.lattice_beam ;
return decoder_opts;
}
};
class TLGDecoder : public DecoderInterface {
......
......@@ -19,6 +19,7 @@
#include "frontend/audio/data_cache.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
#include "decoder/param.h"
#include "decoder/ctc_tlg_decoder.h"
#include "kaldi/util/table-types.h"
......@@ -26,30 +27,7 @@
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
using kaldi::BaseFloat;
using kaldi::Matrix;
......@@ -66,32 +44,16 @@ int main(int argc, char* argv[]) {
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string word_symbol_table = FLAGS_word_symbol_table;
std::string graph_path = FLAGS_graph_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "word symbol path: " << word_symbol_table;
LOG(INFO) << "graph path: " << graph_path;
int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts;
opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path;
opts.opts.max_active = FLAGS_max_active;
ppspeech::TLGDecoderOptions opts = ppspeech::TLGDecoderOptions::InitFromFlags();
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
......@@ -99,12 +61,13 @@ int main(int argc, char* argv[]) {
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
......
......@@ -17,8 +17,6 @@
#include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
// feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
......@@ -27,18 +25,18 @@ DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
DEFINE_int32(num_bins, 161, "num bins of mel");
DEFINE_string(cmvn_file, "", "read cmvn");
// feature sliding window
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
DEFINE_int32(subsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet
DEFINE_string(vocab_path, "", "nnet vocab path.");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(
......@@ -52,10 +50,11 @@ DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_string(vocab_path, "", "nnet vocab path.");
// decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_int32(max_active, 7500, "max active");
......@@ -63,37 +62,20 @@ DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.dither = 0.0;
frame_opts.frame_shift_ms = 10;
opts.use_fbank = FLAGS_use_fbank;
LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear");
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opts.fbank_opts.frame_opts = frame_opts;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
opts.assembler_opts.subsampling_rate = FLAGS_downsampling_rate;
opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
return opts;
}
} // namespace ppspeech
// DecodeOptions flags
// DEFINE_int32(chunk_size, -1, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight,
0.5,
"ctc weight when combining ctc score and rescoring score");
DEFINE_double(rescoring_weight,
1.0,
"rescoring weight when combining ctc score and rescoring score");
DEFINE_double(reverse_weight,
0.3,
"used for bitransformer rescoring. it must be 0.0 if decoder is"
"conventional transformer decoder, and only reverse_weight > 0.0"
"dose the right to left decoder will be calculated and used");
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
DEFINE_int32(blank, 0, "blank id in vocab");
......@@ -22,14 +22,26 @@
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
DECLARE_double(acoustic_scale);
namespace ppspeech {
struct RecognizerResource {
kaldi::BaseFloat acoustic_scale{1.0};
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
TLGDecoderOptions tlg_opts{};
// CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale{1.0};
static RecognizerResource InitFromFlags(){
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ModelOptions::InitFromFlags();
resource.tlg_opts = TLGDecoderOptions::InitFromFlags();
return resource;
}
};
class Recognizer {
......
......@@ -25,27 +25,9 @@ DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
model_opts.subsample_rate = FLAGS_downsampling_rate;
resource.model_opts = model_opts;
ppspeech::TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
resource.tlg_opts = decoder_opts;
resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
return resource;
}
......
......@@ -26,15 +26,25 @@
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
namespace ppspeech {
DECLARE_int32(nnet_decoder_chunk);
DECLARE_int32(num_left_chunks);
DECLARE_double(ctc_weight);
DECLARE_double(rescoring_weight);
DECLARE_double(reverse_weight);
DECLARE_int32(nbest);
DECLARE_int32(blank);
DECLARE_double(acoustic_scale);
DECLARE_string(vocab_path);
namespace ppspeech {
struct DecodeOptions {
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4
int chunk_size;
int num_left_chunks;
int chunk_size{16};
int num_left_chunks{-1};
// final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score;
......@@ -46,51 +56,27 @@ struct DecodeOptions {
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods.
float ctc_weight;
float rescoring_weight;
float reverse_weight;
float ctc_weight{0.0};
float rescoring_weight{1.0};
float reverse_weight{0.0};
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions ctc_prefix_search_opts;
DecodeOptions()
: chunk_size(16),
num_left_chunks(-1),
ctc_weight(0.5),
rescoring_weight(1.0),
reverse_weight(0.0) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "DecoderConfig: ";
opts->Register(
"chunk-size",
&chunk_size,
module + "the frame number of one chunk after subsampling.");
opts->Register("num-left-chunks",
&num_left_chunks,
module + "the left history chunks number.");
opts->Register("ctc-weight",
&ctc_weight,
module +
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("rescoring-weight",
&rescoring_weight,
module +
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("reverse-weight",
&reverse_weight,
module +
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight.");
CTCBeamSearchOptions ctc_prefix_search_opts{};
static DecodeOptions InitFromFlags(){
DecodeOptions decoder_opts;
decoder_opts.chunk_size=FLAGS_nnet_decoder_chunk;
decoder_opts.num_left_chunks = FLAGS_num_left_chunks;
decoder_opts.ctc_weight = FLAGS_ctc_weight;
decoder_opts.rescoring_weight = FLAGS_rescoring_weight;
decoder_opts.reverse_weight = FLAGS_reverse_weight;
decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank;
decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
return decoder_opts;
}
};
struct U2RecognizerResource {
kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{};
......@@ -98,7 +84,17 @@ struct U2RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
DecodeOptions decoder_opts{};
// CTCBeamSearchOptions beam_search_opts;
static U2RecognizerResource InitFromFlags() {
U2RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags();
return resource;
}
};
......
......@@ -22,35 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::U2RecognizerResource InitOpts() {
ppspeech::U2RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
LOG(INFO) << "feature!";
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
resource.model_opts = model_opts;
LOG(INFO) << "model!";
ppspeech::DecodeOptions decoder_opts;
decoder_opts.chunk_size=16;
decoder_opts.num_left_chunks = -1;
decoder_opts.ctc_weight = 0.5;
decoder_opts.rescoring_weight = 1.0;
decoder_opts.reverse_weight = 0.3;
decoder_opts.ctc_prefix_search_opts.blank = 0;
decoder_opts.ctc_prefix_search_opts.first_beam_size = 10;
decoder_opts.ctc_prefix_search_opts.second_beam_size = 10;
resource.decoder_opts = decoder_opts;
LOG(INFO) << "decoder!";
return resource;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
......@@ -72,7 +43,7 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource = InitOpts();
ppspeech::U2RecognizerResource resource = ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource);
kaldi::Timer timer;
......
......@@ -25,26 +25,71 @@
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
// feature
DECLARE_bool(use_fbank);
DECLARE_int32(num_bins);
DECLARE_string(cmvn_file);
// feature sliding window
DECLARE_int32(receptive_field_length);
DECLARE_int32(subsampling_rate);
DECLARE_int32(nnet_decoder_chunk);
namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool to_float32; // true, only for linear feature
bool use_fbank;
LinearSpectrogramOptions linear_spectrogram_opts;
kaldi::FbankOptions fbank_opts;
FeatureCacheOptions feature_cache_opts;
AssemblerOptions assembler_opts;
FeaturePipelineOptions()
: cmvn_file(""),
to_float32(false), // true, only for linear feature
use_fbank(true),
linear_spectrogram_opts(),
fbank_opts(),
feature_cache_opts(),
assembler_opts() {}
std::string cmvn_file{};
bool to_float32{false}; // true, only for linear feature
bool use_fbank{true};
LinearSpectrogramOptions linear_spectrogram_opts{};
kaldi::FbankOptions fbank_opts{};
FeatureCacheOptions feature_cache_opts{};
AssemblerOptions assembler_opts{};
static FeaturePipelineOptions InitFromFlags(){
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
LOG(INFO) << "cmvn file: " << opts.cmvn_file;
// frame options
kaldi::FrameExtractionOptions frame_opts;
frame_opts.dither = 0.0;
LOG(INFO) << "dither: " << frame_opts.dither;
frame_opts.frame_shift_ms = 10;
LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms;
opts.use_fbank = FLAGS_use_fbank;
LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear");
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
opts.fbank_opts.frame_opts = frame_opts;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms;
// assembler opts
opts.assembler_opts.subsampling_rate = FLAGS_subsampling_rate;
LOG(INFO) << "subsampling rate: " << opts.assembler_opts.subsampling_rate;
opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length;
LOG(INFO) << "nnet receptive filed length: " << opts.assembler_opts.receptive_filed_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
LOG(INFO) << "nnet chunk size: " << opts.assembler_opts.nnet_decoder_chunk;
return opts;
}
};
class FeaturePipeline : public FrontendInterface {
public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts);
......
......@@ -14,6 +14,7 @@
#include "nnet/ds2_nnet.h"
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
......@@ -21,27 +22,6 @@
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
using kaldi::BaseFloat;
using kaldi::Matrix;
......@@ -64,13 +44,8 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0;
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
......@@ -78,8 +53,8 @@ int main(int argc, char* argv[]) {
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
......
......@@ -20,53 +20,54 @@
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech {
DECLARE_int32(subsampling_rate);
DECLARE_string(model_path);
DECLARE_string(param_path);
DECLARE_string(model_input_names);
DECLARE_string(model_output_names);
DECLARE_string(model_cache_names);
DECLARE_string(model_cache_shapes);
namespace ppspeech {
struct ModelOptions {
// common
int subsample_rate{1};
int thread_num{1}; // predictor thread pool size for ds2;
bool use_gpu{false};
std::string model_path;
std::string param_path;
int thread_num; // predictor thread pool size for ds2;
bool use_gpu;
bool switch_ir_optim;
std::string input_names;
std::string output_names;
std::string cache_names;
std::string cache_shape;
bool enable_fc_padding;
bool enable_profile;
int subsample_rate;
ModelOptions()
: model_path(""),
param_path(""),
thread_num(1),
use_gpu(false),
input_names(""),
output_names(""),
cache_names(""),
cache_shape(""),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false),
subsample_rate(0) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path");
opts->Register("model-param", &param_path, "params model file path");
opts->Register("thread-num", &thread_num, "thread num");
opts->Register("use-gpu", &use_gpu, "if use gpu");
opts->Register("input-names", &input_names, "paddle input names");
opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom",
&switch_ir_optim,
"paddle SwitchIrOptim option");
opts->Register("enable-fc-padding",
&enable_fc_padding,
"paddle EnableFCPadding option");
opts->Register(
"enable-profile", &enable_profile, "paddle EnableProfile option");
// ds2 for inference
std::string input_names{};
std::string output_names{};
std::string cache_names{};
std::string cache_shape{};
bool switch_ir_optim{false};
bool enable_fc_padding{false};
bool enable_profile{false};
static ModelOptions InitFromFlags(){
ModelOptions opts;
opts.subsample_rate = FLAGS_subsampling_rate;
LOG(INFO) << "subsampling rate: " << opts.subsample_rate;
opts.model_path = FLAGS_model_path;
LOG(INFO) << "model path: " << opts.model_path ;
opts.param_path = FLAGS_param_path;
LOG(INFO) << "param path: " << opts.param_path ;
LOG(INFO) << "DS2 param: ";
opts.cache_names = FLAGS_model_cache_names;
LOG(INFO) << " cache names: " << opts.cache_names;
opts.cache_shape = FLAGS_model_cache_shapes;
LOG(INFO) << " cache shape: " << opts.cache_shape;
opts.input_names = FLAGS_model_input_names;
LOG(INFO) << " input names: " << opts.input_names;
opts.output_names = FLAGS_model_output_names;
LOG(INFO) << " output names: " << opts.output_names;
return opts;
}
};
......
......@@ -17,7 +17,6 @@
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "paddle/extension.h"
#include "paddle/jit/all.h"
......
......@@ -12,28 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/u2_nnet.h"
#include "base/common.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "decoder/param.h"
#include "nnet/u2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_string(nnet_encoder_outs_wspecifier, "", "nnet encoder outs wspecifier");
DEFINE_string(model_path, "", "paddle nnet model");
DEFINE_int32(nnet_decoder_chunk, 16, "nnet forward chunk");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
......@@ -58,13 +50,12 @@ int main(int argc, char* argv[]) {
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier);
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
int32 chunk_size =
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate +
FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
......
......@@ -20,27 +20,9 @@ DEFINE_int32(port, 8082, "websocket listening port");
ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
model_opts.subsample_rate = FLAGS_downsampling_rate;
resource.model_opts = model_opts;
ppspeech::TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
resource.tlg_opts = decoder_opts;
resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
return resource;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册