diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index c8e21d4867d615b6005be11e4175e6f6e24aaed1..60c183472baa6f507509d0490d5db768e81901b1 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -32,6 +32,12 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder ) add_subdirectory(decoder) +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/recognizer +) +add_subdirectory(recognizer) + include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/protocol diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index d06c3529ba2a33f4663eb846923b57da28518f1a..5bec24a6138839182b0573220dbf34033774a6f3 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -1,28 +1,24 @@ -project(decoder) - include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) -set(decoder_src ) +set(srcs) if (USING_DS2) -list(APPEND decoder_src +list(APPEND srcs ctc_decoders/decoder_utils.cpp ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp ctc_beam_search_decoder.cc ctc_tlg_decoder.cc -recognizer.cc ) endif() if (USING_U2) - list(APPEND decoder_src + list(APPEND srcs ctc_prefix_beam_search_decoder.cc - u2_recognizer.cc ) endif() -add_library(decoder STATIC ${decoder_src}) +add_library(decoder STATIC ${srcs}) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings) # test @@ -30,7 +26,6 @@ if (USING_DS2) set(BINS ctc_beam_search_decoder_main nnet_logprob_decoder_main - recognizer_main ctc_tlg_decoder_main ) @@ -45,7 +40,6 @@ endif() if (USING_U2) set(TEST_BINS ctc_prefix_beam_search_decoder_main - u2_recognizer_main ) foreach(bin_name IN LISTS TEST_BINS) diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index dc971e0f6a718fc2184b2ef32a0363d3bbad025a..9bad8ed45d5e9c079b4b6914ee04846512af9e27 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -21,7 +21,7 @@ using kaldi::Matrix; using kaldi::Vector; using std::vector; -Decodable::Decodable(const std::shared_ptr& nnet, +Decodable::Decodable(const std::shared_ptr& nnet, const std::shared_ptr& frontend, kaldi::BaseFloat acoustic_scale) : frontend_(frontend), diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 70a16e2c687574eca65ba5dd70b23fa3fcde9943..dd7b329e581a6d5d3e91534fa5356af7cfa3f169 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -24,7 +24,7 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - explicit Decodable(const std::shared_ptr& nnet, + explicit Decodable(const std::shared_ptr& nnet, const std::shared_ptr& frontend, kaldi::BaseFloat acoustic_scale = 1.0); @@ -63,14 +63,14 @@ class Decodable : public kaldi::DecodableInterface { int32 TokenId2NnetId(int32 token_id); - std::shared_ptr Nnet() { return nnet_; } + std::shared_ptr Nnet() { return nnet_; } // for offline test void Acceptlikelihood(const kaldi::Matrix& likelihood); private: std::shared_ptr frontend_; - std::shared_ptr nnet_; + std::shared_ptr nnet_; // nnet outputs' cache kaldi::Matrix nnet_out_cache_; diff --git a/speechx/speechx/nnet/ds2_nnet.h b/speechx/speechx/nnet/ds2_nnet.h index 4aeec32f388164460f47cf6bb733bab7851abae6..d1e3ac8c9286fec15c278944e1762053f4febfd4 100644 --- a/speechx/speechx/nnet/ds2_nnet.h +++ b/speechx/speechx/nnet/ds2_nnet.h @@ -48,7 +48,7 @@ class Tensor { std::vector _data; }; -class PaddleNnet : public NnetInterface { +class PaddleNnet : public NnetBase { public: PaddleNnet(const ModelOptions& opts); diff --git a/speechx/speechx/nnet/nnet_itf.h b/speechx/speechx/nnet/nnet_itf.h index cc737ce054b084259bee23b3f9e8123f34dac7b7..a504cce51704006377da5bf56e4d020f7b14c96c 100644 --- a/speechx/speechx/nnet/nnet_itf.h +++ b/speechx/speechx/nnet/nnet_itf.h @@ -11,8 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - - #pragma once #include "base/basic_types.h" @@ -105,11 +103,15 @@ class NnetInterface { // true, nnet output is logprob; otherwise is prob, virtual bool IsLogProb() = 0; - int SubsamplingRate() const { return subsampling_rate_; } - // using to get encoder outs. e.g. seq2seq with Attention model. virtual void EncoderOuts( std::vector>* encoder_out) const = 0; +}; + + +class NnetBase : public NnetInterface { + public: + int SubsamplingRate() const { return subsampling_rate_; } protected: int subsampling_rate_{1}; diff --git a/speechx/speechx/nnet/u2_nnet.cc b/speechx/speechx/nnet/u2_nnet.cc index 4bafdf83176eeaee9ce71d0ed4547f77e8284d8b..c92c96aaae86af65cd000aae73bea297584a02fd 100644 --- a/speechx/speechx/nnet/u2_nnet.cc +++ b/speechx/speechx/nnet/u2_nnet.cc @@ -193,7 +193,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) { // ignore inner states } -std::shared_ptr U2Nnet::Copy() const { +std::shared_ptr U2Nnet::Copy() const { auto asr_model = std::make_shared(*this); // reset inner state for new decoding asr_model->Reset(); diff --git a/speechx/speechx/nnet/u2_nnet.h b/speechx/speechx/nnet/u2_nnet.h index 3435bca8ba0dfbdeafe915ac8fc4fda18924a83b..a37a88f2fe6fefb391ce262f7acfd88a8b441120 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/speechx/speechx/nnet/u2_nnet.h @@ -24,7 +24,7 @@ namespace ppspeech { -class U2NnetBase : public NnetInterface { +class U2NnetBase : public NnetBase { public: virtual int context() const { return right_context_ + 1; } virtual int right_context() const { return right_context_; } @@ -41,7 +41,7 @@ class U2NnetBase : public NnetInterface { // start: false, it is the start chunk of one sentence, else true virtual int num_frames_for_chunk(bool start) const; - virtual std::shared_ptr Copy() const = 0; + virtual std::shared_ptr Copy() const = 0; virtual void ForwardEncoderChunk( const std::vector& chunk_feats, @@ -99,7 +99,7 @@ class U2Nnet : public U2NnetBase { std::shared_ptr model() const { return model_; } - std::shared_ptr Copy() const override; + std::shared_ptr Copy() const override; void ForwardEncoderChunkImpl( const std::vector& chunk_feats, diff --git a/speechx/speechx/protocol/websocket/CMakeLists.txt b/speechx/speechx/protocol/websocket/CMakeLists.txt index a171d84d0755c7b39b7d5532a803cd7c1c42f630..cafbbec73a5b616ce896a9d8c5a4b838e6b9477d 100644 --- a/speechx/speechx/protocol/websocket/CMakeLists.txt +++ b/speechx/speechx/protocol/websocket/CMakeLists.txt @@ -2,7 +2,7 @@ add_library(websocket STATIC websocket_server.cc websocket_client.cc ) -target_link_libraries(websocket PUBLIC frontend decoder nnet) +target_link_libraries(websocket PUBLIC frontend nnet decoder recognizer) add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) diff --git a/speechx/speechx/protocol/websocket/websocket_server.h b/speechx/speechx/protocol/websocket/websocket_server.h index 8f3360e40ba0b74cc8282f00afc49663cb19bab7..9b05f868e8cf0cc5447e29ea53afcf81fa282968 100644 --- a/speechx/speechx/protocol/websocket/websocket_server.h +++ b/speechx/speechx/protocol/websocket/websocket_server.h @@ -19,7 +19,7 @@ #include "boost/asio/ip/tcp.hpp" #include "boost/beast/core.hpp" #include "boost/beast/websocket.hpp" -#include "decoder/recognizer.h" +#include "recognizer/recognizer.h" #include "frontend/audio/feature_pipeline.h" namespace beast = boost::beast; // from diff --git a/speechx/speechx/recognizer/CMakeLists.txt b/speechx/speechx/recognizer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..05078873952a33ced19c7e723ef8e2aa080d45f9 --- /dev/null +++ b/speechx/speechx/recognizer/CMakeLists.txt @@ -0,0 +1,45 @@ +set(srcs) + +if (USING_DS2) +list(APPEND srcs +recognizer.cc +) +endif() + +if (USING_U2) + list(APPEND srcs + u2_recognizer.cc + ) +endif() + +add_library(recognizer STATIC ${srcs}) +target_link_libraries(recognizer PUBLIC decoder) + +# test +if (USING_DS2) + set(BINS recognizer_main) + + foreach(bin_name IN LISTS BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} PUBLIC recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) + endforeach() +endif() + + +if (USING_U2) + set(TEST_BINS + u2_recognizer_main + ) + + foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + endforeach() + +endif() + diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/recognizer/recognizer.cc similarity index 97% rename from speechx/speechx/decoder/recognizer.cc rename to speechx/speechx/recognizer/recognizer.cc index 870aa40acf4fd295726a43e6c7e45b2b1934d38b..c663181319504e5b8258b4aa06f5994c8a63fd0a 100644 --- a/speechx/speechx/decoder/recognizer.cc +++ b/speechx/speechx/recognizer/recognizer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/recognizer.h" +#include "recognizer/recognizer.h" namespace ppspeech { diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/recognizer/recognizer.h similarity index 100% rename from speechx/speechx/decoder/recognizer.h rename to speechx/speechx/recognizer/recognizer.h diff --git a/speechx/speechx/decoder/recognizer_main.cc b/speechx/speechx/recognizer/recognizer_main.cc similarity index 88% rename from speechx/speechx/decoder/recognizer_main.cc rename to speechx/speechx/recognizer/recognizer_main.cc index 8e83b1888e0340f026715099b7fc7c65acdab343..7c30fe6adc900dd6381cb307b2379ba0cdc5f5dd 100644 --- a/speechx/speechx/decoder/recognizer_main.cc +++ b/speechx/speechx/recognizer/recognizer_main.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "decoder/param.h" -#include "decoder/recognizer.h" +#include "recognizer/recognizer.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/table-types.h" @@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); -ppspeech::RecognizerResource InitRecognizerResoure() { - ppspeech::RecognizerResource resource; - resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = - ppspeech::FeaturePipelineOptions::InitFromFlags(); - resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); - resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); - return resource; -} int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); @@ -39,7 +30,7 @@ int main(int argc, char* argv[]) { google::InstallFailureSignalHandler(); FLAGS_logtostderr = 1; - ppspeech::RecognizerResource resource = InitRecognizerResoure(); + ppspeech::RecognizerResource resource = ppspeech::RecognizerResource::InitFromFlags(); ppspeech::Recognizer recognizer(resource); kaldi::SequentialTableReader wav_reader( diff --git a/speechx/speechx/decoder/u2_recognizer.cc b/speechx/speechx/recognizer/u2_recognizer.cc similarity index 98% rename from speechx/speechx/decoder/u2_recognizer.cc rename to speechx/speechx/recognizer/u2_recognizer.cc index 04712e7b5b2edb7060dbe6ae5282c22edb1d963e..75834aa5de5ca5ed5dac59c7dd9c29f8fe71e247 100644 --- a/speechx/speechx/decoder/u2_recognizer.cc +++ b/speechx/speechx/recognizer/u2_recognizer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/u2_recognizer.h" +#include "recognizer/u2_recognizer.h" #include "nnet/u2_nnet.h" @@ -30,7 +30,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; feature_pipeline_.reset(new FeaturePipeline(feature_opts)); - std::shared_ptr nnet(new U2Nnet(resource.model_opts)); + std::shared_ptr nnet(new U2Nnet(resource.model_opts)); BaseFloat am_scale = resource.acoustic_scale; decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); diff --git a/speechx/speechx/decoder/u2_recognizer.h b/speechx/speechx/recognizer/u2_recognizer.h similarity index 100% rename from speechx/speechx/decoder/u2_recognizer.h rename to speechx/speechx/recognizer/u2_recognizer.h diff --git a/speechx/speechx/decoder/u2_recognizer_main.cc b/speechx/speechx/recognizer/u2_recognizer_main.cc similarity index 99% rename from speechx/speechx/decoder/u2_recognizer_main.cc rename to speechx/speechx/recognizer/u2_recognizer_main.cc index 9eb0441b103eeb321ddbf11964a68b2c28403f25..ff848f5899c9968d62a61ccbd497dc782578d721 100644 --- a/speechx/speechx/decoder/u2_recognizer_main.cc +++ b/speechx/speechx/recognizer/u2_recognizer_main.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "decoder/param.h" -#include "decoder/u2_recognizer.h" +#include "recognizer/u2_recognizer.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/table-types.h"