From e1fc57deb1454c926c8925fba040ada210183168 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sun, 9 Oct 2022 09:29:37 +0000 Subject: [PATCH] add math and rename ds2 nnet --- speechx/speechx/base/common.h | 5 +++ .../ctc_prefix_beam_search_decoder_main.cc | 2 +- speechx/speechx/decoder/recognizer.h | 2 +- speechx/speechx/decoder/tlg_decoder_main.cc | 2 +- speechx/speechx/nnet/CMakeLists.txt | 8 ++-- .../nnet/{paddle_nnet.cc => ds2_nnet.cc} | 2 +- .../nnet/{paddle_nnet.h => ds2_nnet.h} | 0 ...{nnet_forward_main.cc => ds2_nnet_main.cc} | 2 +- speechx/speechx/utils/math.cc | 37 ++++++++++++------- speechx/speechx/utils/math.h | 9 +++-- 10 files changed, 42 insertions(+), 27 deletions(-) rename speechx/speechx/nnet/{paddle_nnet.cc => ds2_nnet.cc} (99%) rename speechx/speechx/nnet/{paddle_nnet.h => ds2_nnet.h} (100%) rename speechx/speechx/nnet/{nnet_forward_main.cc => ds2_nnet_main.cc} (99%) diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h index 778c06d7..dfb14885 100644 --- a/speechx/speechx/base/common.h +++ b/speechx/speechx/base/common.h @@ -14,19 +14,24 @@ #pragma once +#include #include +#include #include #include +#include #include #include #include #include #include +#include #include #include #include #include #include +#include #include #include #include diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc index 7cfee06c..e4e5c2af 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -20,7 +20,7 @@ #include "frontend/audio/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" -#include "nnet/paddle_nnet.h" +#include "nnet/ds2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h index 35e1e167..e47ca433 100644 --- a/speechx/speechx/decoder/recognizer.h +++ b/speechx/speechx/decoder/recognizer.h @@ -20,7 +20,7 @@ #include "decoder/ctc_tlg_decoder.h" #include "frontend/audio/feature_pipeline.h" #include "nnet/decodable.h" -#include "nnet/paddle_nnet.h" +#include "nnet/ds2_nnet.h" namespace ppspeech { diff --git a/speechx/speechx/decoder/tlg_decoder_main.cc b/speechx/speechx/decoder/tlg_decoder_main.cc index b175ed13..93f84da3 100644 --- a/speechx/speechx/decoder/tlg_decoder_main.cc +++ b/speechx/speechx/decoder/tlg_decoder_main.cc @@ -20,7 +20,7 @@ #include "frontend/audio/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" -#include "nnet/paddle_nnet.h" +#include "nnet/ds2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); diff --git a/speechx/speechx/nnet/CMakeLists.txt b/speechx/speechx/nnet/CMakeLists.txt index c325ce75..565bba3e 100644 --- a/speechx/speechx/nnet/CMakeLists.txt +++ b/speechx/speechx/nnet/CMakeLists.txt @@ -2,13 +2,11 @@ project(nnet) add_library(nnet STATIC decodable.cc - paddle_nnet.cc + ds2_nnet.cc ) target_link_libraries(nnet absl::strings) -set(bin_name nnet_forward_main) +set(bin_name ds2_nnet_main) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet ${DEPS}) - - +target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet ${DEPS}) \ No newline at end of file diff --git a/speechx/speechx/nnet/paddle_nnet.cc b/speechx/speechx/nnet/ds2_nnet.cc similarity index 99% rename from speechx/speechx/nnet/paddle_nnet.cc rename to speechx/speechx/nnet/ds2_nnet.cc index 881a82f5..a89c0f20 100644 --- a/speechx/speechx/nnet/paddle_nnet.cc +++ b/speechx/speechx/nnet/ds2_nnet.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "nnet/paddle_nnet.h" +#include "nnet/ds2_nnet.h" #include "absl/strings/str_split.h" namespace ppspeech { diff --git a/speechx/speechx/nnet/paddle_nnet.h b/speechx/speechx/nnet/ds2_nnet.h similarity index 100% rename from speechx/speechx/nnet/paddle_nnet.h rename to speechx/speechx/nnet/ds2_nnet.h diff --git a/speechx/speechx/nnet/nnet_forward_main.cc b/speechx/speechx/nnet/ds2_nnet_main.cc similarity index 99% rename from speechx/speechx/nnet/nnet_forward_main.cc rename to speechx/speechx/nnet/ds2_nnet_main.cc index 0d4ea8ff..e2904208 100644 --- a/speechx/speechx/nnet/nnet_forward_main.cc +++ b/speechx/speechx/nnet/ds2_nnet_main.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "nnet/ds2_nnet.h" #include "base/flags.h" #include "base/log.h" #include "frontend/audio/assembler.h" #include "frontend/audio/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" -#include "nnet/paddle_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); diff --git a/speechx/speechx/utils/math.cc b/speechx/speechx/utils/math.cc index fe5c7118..7c319295 100644 --- a/speechx/speechx/utils/math.cc +++ b/speechx/speechx/utils/math.cc @@ -1,4 +1,5 @@ +// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng) // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,10 +18,10 @@ #include "base/common.h" -#include #include -#include +#include #include +#include namespace ppspeech { @@ -36,28 +37,36 @@ float LogSumExp(float x, float y) { // greater compare for smallest priority_queue template struct ValGreaterComp { - bool operator()(const std::pair& lhs, const std::pair& rhs) const { - return lhs.first > rhs.first || (lhs.first == rhs.first && lhs.second < rhs.second); + bool operator()(const std::pair& lhs, + const std::pair& rhs) const { + return lhs.first > rhs.first || + (lhs.first == rhs.first && lhs.second < rhs.second); } } -template -void TopK(const std::vector& data, int32_t k, std::vector* values, std::vector* indices) { - int n = data.size(); - int min_k_n = std::min(k, n); +template +void TopK(const std::vector& data, + int32_t k, + std::vector* values, + std::vector* indices) { + int n = data.size(); + int min_k_n = std::min(k, n); // smallest heap, (val, idx) - std::vector> smallest_heap; - for (int i = 0; i < min_k_n; i++){ + std::vector> smallest_heap; + for (int i = 0; i < min_k_n; i++) { smallest_heap.emplace_back(data[i], i); } // smallest priority_queue - std::priority_queue, std::vector>, ValGreaterComp> pq(ValGreaterComp(), std::move(smallest_heap)); + std::priority_queue, + std::vector>, + ValGreaterComp> + pq(ValGreaterComp(), std::move(smallest_heap)); // top k - for (int i = k ; i < n; i++){ - if (pq.top().first < data[i]){ + for (int i = k; i < n; i++) { + if (pq.top().first < data[i]) { pq.pop(); pq.emplace_back(data[i], i); } @@ -68,7 +77,7 @@ void TopK(const std::vector& data, int32_t k, std::vector* values, std::ve // from largest to samllest int cur = values->size() - 1; - while(!pq.empty()){ + while (!pq.empty()) { const auto& item = pq.top(); pq.pop(); diff --git a/speechx/speechx/utils/math.h b/speechx/speechx/utils/math.h index 452bf089..7c863b00 100644 --- a/speechx/speechx/utils/math.h +++ b/speechx/speechx/utils/math.h @@ -14,15 +14,18 @@ #pragma once -#include #include +#include namespace ppspeech { // Sum in log scale float LogSumExp(float x, float y); -template -void TopK(const std::vector& data, int32_t k, std::vector* values, std::vector* indices); +template +void TopK(const std::vector& data, + int32_t k, + std::vector* values, + std::vector* indices); } // namespace ppspeech \ No newline at end of file -- GitLab