From 99b3632d4d904e348e4cf37397538bb0a11bd2a8 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 18 Oct 2022 03:41:09 +0000 Subject: [PATCH] seprate recognizer; NnetBase as base class --- speechx/speechx/CMakeLists.txt | 6 +++ speechx/speechx/decoder/CMakeLists.txt | 14 ++---- speechx/speechx/nnet/decodable.cc | 2 +- speechx/speechx/nnet/decodable.h | 6 +-- speechx/speechx/nnet/ds2_nnet.h | 2 +- speechx/speechx/nnet/nnet_itf.h | 10 +++-- speechx/speechx/nnet/u2_nnet.cc | 2 +- speechx/speechx/nnet/u2_nnet.h | 6 +-- .../speechx/protocol/websocket/CMakeLists.txt | 2 +- .../protocol/websocket/websocket_server.h | 2 +- speechx/speechx/recognizer/CMakeLists.txt | 45 +++++++++++++++++++ .../{decoder => recognizer}/recognizer.cc | 2 +- .../{decoder => recognizer}/recognizer.h | 0 .../recognizer_main.cc | 13 +----- .../{decoder => recognizer}/u2_recognizer.cc | 4 +- .../{decoder => recognizer}/u2_recognizer.h | 0 .../u2_recognizer_main.cc | 2 +- 17 files changed, 78 insertions(+), 40 deletions(-) create mode 100644 speechx/speechx/recognizer/CMakeLists.txt rename speechx/speechx/{decoder => recognizer}/recognizer.cc (97%) rename speechx/speechx/{decoder => recognizer}/recognizer.h (100%) rename speechx/speechx/{decoder => recognizer}/recognizer_main.cc (88%) rename speechx/speechx/{decoder => recognizer}/u2_recognizer.cc (98%) rename speechx/speechx/{decoder => recognizer}/u2_recognizer.h (100%) rename speechx/speechx/{decoder => recognizer}/u2_recognizer_main.cc (99%) diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index c8e21d48..60c18347 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -32,6 +32,12 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder ) add_subdirectory(decoder) +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/recognizer +) +add_subdirectory(recognizer) + include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/protocol diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index d06c3529..5bec24a6 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -1,28 +1,24 @@ -project(decoder) - include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) -set(decoder_src ) +set(srcs) if (USING_DS2) -list(APPEND decoder_src +list(APPEND srcs ctc_decoders/decoder_utils.cpp ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp ctc_beam_search_decoder.cc ctc_tlg_decoder.cc -recognizer.cc ) endif() if (USING_U2) - list(APPEND decoder_src + list(APPEND srcs ctc_prefix_beam_search_decoder.cc - u2_recognizer.cc ) endif() -add_library(decoder STATIC ${decoder_src}) +add_library(decoder STATIC ${srcs}) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings) # test @@ -30,7 +26,6 @@ if (USING_DS2) set(BINS ctc_beam_search_decoder_main nnet_logprob_decoder_main - recognizer_main ctc_tlg_decoder_main ) @@ -45,7 +40,6 @@ endif() if (USING_U2) set(TEST_BINS ctc_prefix_beam_search_decoder_main - u2_recognizer_main ) foreach(bin_name IN LISTS TEST_BINS) diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index dc971e0f..9bad8ed4 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -21,7 +21,7 @@ using kaldi::Matrix; using kaldi::Vector; using std::vector; -Decodable::Decodable(const std::shared_ptr& nnet, +Decodable::Decodable(const std::shared_ptr& nnet, const std::shared_ptr& frontend, kaldi::BaseFloat acoustic_scale) : frontend_(frontend), diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 70a16e2c..dd7b329e 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -24,7 +24,7 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - explicit Decodable(const std::shared_ptr& nnet, + explicit Decodable(const std::shared_ptr& nnet, const std::shared_ptr& frontend, kaldi::BaseFloat acoustic_scale = 1.0); @@ -63,14 +63,14 @@ class Decodable : public kaldi::DecodableInterface { int32 TokenId2NnetId(int32 token_id); - std::shared_ptr Nnet() { return nnet_; } + std::shared_ptr Nnet() { return nnet_; } // for offline test void Acceptlikelihood(const kaldi::Matrix& likelihood); private: std::shared_ptr frontend_; - std::shared_ptr nnet_; + std::shared_ptr nnet_; // nnet outputs' cache kaldi::Matrix nnet_out_cache_; diff --git a/speechx/speechx/nnet/ds2_nnet.h b/speechx/speechx/nnet/ds2_nnet.h index 4aeec32f..d1e3ac8c 100644 --- a/speechx/speechx/nnet/ds2_nnet.h +++ b/speechx/speechx/nnet/ds2_nnet.h @@ -48,7 +48,7 @@ class Tensor { std::vector _data; }; -class PaddleNnet : public NnetInterface { +class PaddleNnet : public NnetBase { public: PaddleNnet(const ModelOptions& opts); diff --git a/speechx/speechx/nnet/nnet_itf.h b/speechx/speechx/nnet/nnet_itf.h index cc737ce0..a504cce5 100644 --- a/speechx/speechx/nnet/nnet_itf.h +++ b/speechx/speechx/nnet/nnet_itf.h @@ -11,8 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - - #pragma once #include "base/basic_types.h" @@ -105,11 +103,15 @@ class NnetInterface { // true, nnet output is logprob; otherwise is prob, virtual bool IsLogProb() = 0; - int SubsamplingRate() const { return subsampling_rate_; } - // using to get encoder outs. e.g. seq2seq with Attention model. virtual void EncoderOuts( std::vector>* encoder_out) const = 0; +}; + + +class NnetBase : public NnetInterface { + public: + int SubsamplingRate() const { return subsampling_rate_; } protected: int subsampling_rate_{1}; diff --git a/speechx/speechx/nnet/u2_nnet.cc b/speechx/speechx/nnet/u2_nnet.cc index 4bafdf83..c92c96aa 100644 --- a/speechx/speechx/nnet/u2_nnet.cc +++ b/speechx/speechx/nnet/u2_nnet.cc @@ -193,7 +193,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) { // ignore inner states } -std::shared_ptr U2Nnet::Copy() const { +std::shared_ptr U2Nnet::Copy() const { auto asr_model = std::make_shared(*this); // reset inner state for new decoding asr_model->Reset(); diff --git a/speechx/speechx/nnet/u2_nnet.h b/speechx/speechx/nnet/u2_nnet.h index 3435bca8..a37a88f2 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/speechx/speechx/nnet/u2_nnet.h @@ -24,7 +24,7 @@ namespace ppspeech { -class U2NnetBase : public NnetInterface { +class U2NnetBase : public NnetBase { public: virtual int context() const { return right_context_ + 1; } virtual int right_context() const { return right_context_; } @@ -41,7 +41,7 @@ class U2NnetBase : public NnetInterface { // start: false, it is the start chunk of one sentence, else true virtual int num_frames_for_chunk(bool start) const; - virtual std::shared_ptr Copy() const = 0; + virtual std::shared_ptr Copy() const = 0; virtual void ForwardEncoderChunk( const std::vector& chunk_feats, @@ -99,7 +99,7 @@ class U2Nnet : public U2NnetBase { std::shared_ptr model() const { return model_; } - std::shared_ptr Copy() const override; + std::shared_ptr Copy() const override; void ForwardEncoderChunkImpl( const std::vector& chunk_feats, diff --git a/speechx/speechx/protocol/websocket/CMakeLists.txt b/speechx/speechx/protocol/websocket/CMakeLists.txt index a171d84d..cafbbec7 100644 --- a/speechx/speechx/protocol/websocket/CMakeLists.txt +++ b/speechx/speechx/protocol/websocket/CMakeLists.txt @@ -2,7 +2,7 @@ add_library(websocket STATIC websocket_server.cc websocket_client.cc ) -target_link_libraries(websocket PUBLIC frontend decoder nnet) +target_link_libraries(websocket PUBLIC frontend nnet decoder recognizer) add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) diff --git a/speechx/speechx/protocol/websocket/websocket_server.h b/speechx/speechx/protocol/websocket/websocket_server.h index 8f3360e4..9b05f868 100644 --- a/speechx/speechx/protocol/websocket/websocket_server.h +++ b/speechx/speechx/protocol/websocket/websocket_server.h @@ -19,7 +19,7 @@ #include "boost/asio/ip/tcp.hpp" #include "boost/beast/core.hpp" #include "boost/beast/websocket.hpp" -#include "decoder/recognizer.h" +#include "recognizer/recognizer.h" #include "frontend/audio/feature_pipeline.h" namespace beast = boost::beast; // from diff --git a/speechx/speechx/recognizer/CMakeLists.txt b/speechx/speechx/recognizer/CMakeLists.txt new file mode 100644 index 00000000..05078873 --- /dev/null +++ b/speechx/speechx/recognizer/CMakeLists.txt @@ -0,0 +1,45 @@ +set(srcs) + +if (USING_DS2) +list(APPEND srcs +recognizer.cc +) +endif() + +if (USING_U2) + list(APPEND srcs + u2_recognizer.cc + ) +endif() + +add_library(recognizer STATIC ${srcs}) +target_link_libraries(recognizer PUBLIC decoder) + +# test +if (USING_DS2) + set(BINS recognizer_main) + + foreach(bin_name IN LISTS BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} PUBLIC recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) + endforeach() +endif() + + +if (USING_U2) + set(TEST_BINS + u2_recognizer_main + ) + + foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + endforeach() + +endif() + diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/recognizer/recognizer.cc similarity index 97% rename from speechx/speechx/decoder/recognizer.cc rename to speechx/speechx/recognizer/recognizer.cc index 870aa40a..c6631813 100644 --- a/speechx/speechx/decoder/recognizer.cc +++ b/speechx/speechx/recognizer/recognizer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/recognizer.h" +#include "recognizer/recognizer.h" namespace ppspeech { diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/recognizer/recognizer.h similarity index 100% rename from speechx/speechx/decoder/recognizer.h rename to speechx/speechx/recognizer/recognizer.h diff --git a/speechx/speechx/decoder/recognizer_main.cc b/speechx/speechx/recognizer/recognizer_main.cc similarity index 88% rename from speechx/speechx/decoder/recognizer_main.cc rename to speechx/speechx/recognizer/recognizer_main.cc index 8e83b188..7c30fe6a 100644 --- a/speechx/speechx/decoder/recognizer_main.cc +++ b/speechx/speechx/recognizer/recognizer_main.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "decoder/param.h" -#include "decoder/recognizer.h" +#include "recognizer/recognizer.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/table-types.h" @@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); -ppspeech::RecognizerResource InitRecognizerResoure() { - ppspeech::RecognizerResource resource; - resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = - ppspeech::FeaturePipelineOptions::InitFromFlags(); - resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); - resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); - return resource; -} int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); @@ -39,7 +30,7 @@ int main(int argc, char* argv[]) { google::InstallFailureSignalHandler(); FLAGS_logtostderr = 1; - ppspeech::RecognizerResource resource = InitRecognizerResoure(); + ppspeech::RecognizerResource resource = ppspeech::RecognizerResource::InitFromFlags(); ppspeech::Recognizer recognizer(resource); kaldi::SequentialTableReader wav_reader( diff --git a/speechx/speechx/decoder/u2_recognizer.cc b/speechx/speechx/recognizer/u2_recognizer.cc similarity index 98% rename from speechx/speechx/decoder/u2_recognizer.cc rename to speechx/speechx/recognizer/u2_recognizer.cc index 04712e7b..75834aa5 100644 --- a/speechx/speechx/decoder/u2_recognizer.cc +++ b/speechx/speechx/recognizer/u2_recognizer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/u2_recognizer.h" +#include "recognizer/u2_recognizer.h" #include "nnet/u2_nnet.h" @@ -30,7 +30,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; feature_pipeline_.reset(new FeaturePipeline(feature_opts)); - std::shared_ptr nnet(new U2Nnet(resource.model_opts)); + std::shared_ptr nnet(new U2Nnet(resource.model_opts)); BaseFloat am_scale = resource.acoustic_scale; decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); diff --git a/speechx/speechx/decoder/u2_recognizer.h b/speechx/speechx/recognizer/u2_recognizer.h similarity index 100% rename from speechx/speechx/decoder/u2_recognizer.h rename to speechx/speechx/recognizer/u2_recognizer.h diff --git a/speechx/speechx/decoder/u2_recognizer_main.cc b/speechx/speechx/recognizer/u2_recognizer_main.cc similarity index 99% rename from speechx/speechx/decoder/u2_recognizer_main.cc rename to speechx/speechx/recognizer/u2_recognizer_main.cc index 9eb0441b..ff848f58 100644 --- a/speechx/speechx/decoder/u2_recognizer_main.cc +++ b/speechx/speechx/recognizer/u2_recognizer_main.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "decoder/param.h" -#include "decoder/u2_recognizer.h" +#include "recognizer/u2_recognizer.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/table-types.h" -- GitLab