diff --git a/speechx/examples/ds2_ol/websocket/path.sh b/speechx/examples/ds2_ol/websocket/path.sh index 3ad03203107f4257ced125250c5ffe498e407902..d25e88a2764b1d1cef43675d86f8907248ad2018 100755 --- a/speechx/examples/ds2_ol/websocket/path.sh +++ b/speechx/examples/ds2_ol/websocket/path.sh @@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -SPEECHX_BIN=$SPEECHX_BUILD/websocket +SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/websocket/websocket_server.sh b/speechx/examples/ds2_ol/websocket/websocket_server.sh index fc57e326fb8cc2491d2443738fb8552e052fd033..f798dfd41ac8c6f83fcf1d7847237caac341a463 100755 --- a/speechx/examples/ds2_ol/websocket/websocket_server.sh +++ b/speechx/examples/ds2_ol/websocket/websocket_server.sh @@ -45,7 +45,7 @@ export GLOG_logtostderr=1 # 3. gen cmvn cmvn=$data/cmvn.ark -cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn +cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn wfst=$data/wfst/ diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index a9a8a398d2d6661bab5d3816a58fc775ece7d61e..c8e21d4867d615b6005be11e4175e6f6e24aaed1 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -34,9 +34,9 @@ add_subdirectory(decoder) include_directories( ${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/websocket +${CMAKE_CURRENT_SOURCE_DIR}/protocol ) -add_subdirectory(websocket) +add_subdirectory(protocol) include_directories( ${CMAKE_CURRENT_SOURCE_DIR} diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 02e6431658a453b70c23435fdb22b4ac9fc034d5..3f8bdd5a7e18cefa3a4b1956b85546eca5ec9f18 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -47,6 +47,26 @@ void TLGDecoder::Reset() { return; } +std::string TLGDecoder::GetPartialResult() { + if (frame_decoded_size_ == 0) { + // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call + // BestPathEnd if no frames were decoded.") + return std::string(""); + } + kaldi::Lattice lat; + kaldi::LatticeWeight weight; + std::vector alignment; + std::vector words_id; + decoder_->GetBestPath(&lat, false); + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + std::string words; + for (int32 idx = 0; idx < words_id.size(); ++idx) { + std::string word = word_symbol_table_->Find(words_id[idx]); + words += word; + } + return words; +} + std::string TLGDecoder::GetFinalBestPath() { if (frame_decoded_size_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h index 361c44af5b75d704d91c0083399f543729840a7f..1ac46ac640140a3a38bc790530eb575688406683 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -38,6 +38,7 @@ class TLGDecoder { std::string GetBestPath(); std::vector> GetNBestPath(); std::string GetFinalBestPath(); + std::string GetPartialResult(); int NumFrameDecoded(); int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/decoder/recognizer.cc index 2c90ada99e92ed61d922c640fddb2bbe3a7d0ec4..44c3911c92def5d7f9ba8b79336880fdae4bea81 100644 --- a/speechx/speechx/decoder/recognizer.cc +++ b/speechx/speechx/decoder/recognizer.cc @@ -44,6 +44,10 @@ std::string Recognizer::GetFinalResult() { return decoder_->GetFinalBestPath(); } +std::string Recognizer::GetPartialResult() { + return decoder_->GetPartialResult(); +} + void Recognizer::SetFinished() { feature_pipeline_->SetFinished(); input_finished_ = true; diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h index 9a7e7d11eb39989b2d50fc935f512cbf9a40361b..35e1e1676d1836bbfbbe9599a919a86ada09c613 100644 --- a/speechx/speechx/decoder/recognizer.h +++ b/speechx/speechx/decoder/recognizer.h @@ -43,6 +43,7 @@ class Recognizer { void Accept(const kaldi::Vector& waves); void Decode(); std::string GetFinalResult(); + std::string GetPartialResult(); void SetFinished(); bool IsFinished(); void Reset(); diff --git a/speechx/speechx/protocol/CMakeLists.txt b/speechx/speechx/protocol/CMakeLists.txt index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..98b2f38b43a87e9b548c34d96a1f601e957e0045 100644 --- a/speechx/speechx/protocol/CMakeLists.txt +++ b/speechx/speechx/protocol/CMakeLists.txt @@ -0,0 +1,3 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_subdirectory(websocket) diff --git a/speechx/speechx/websocket/CMakeLists.txt b/speechx/speechx/protocol/websocket/CMakeLists.txt similarity index 100% rename from speechx/speechx/websocket/CMakeLists.txt rename to speechx/speechx/protocol/websocket/CMakeLists.txt diff --git a/speechx/speechx/websocket/websocket_client.cc b/speechx/speechx/protocol/websocket/websocket_client.cc similarity index 96% rename from speechx/speechx/websocket/websocket_client.cc rename to speechx/speechx/protocol/websocket/websocket_client.cc index 6bd930b858aa10d15ac24397e2e29f33eeb22ebb..60e06db638d2121ad2ead24f9e5deaaae73160de 100644 --- a/speechx/speechx/websocket/websocket_client.cc +++ b/speechx/speechx/protocol/websocket/websocket_client.cc @@ -67,6 +67,9 @@ void WebSocketClient::ReadLoopFunc() { if (obj["type"] == "final_result") { result_ = obj["result"].as_string().c_str(); } + if (obj["type"] == "partial_result") { + partial_result_ = obj["result"].as_string().c_str(); + } if (obj["type"] == "speech_end") { done_ = true; break; diff --git a/speechx/speechx/websocket/websocket_client.h b/speechx/speechx/protocol/websocket/websocket_client.h similarity index 91% rename from speechx/speechx/websocket/websocket_client.h rename to speechx/speechx/protocol/websocket/websocket_client.h index ac0aed310bd1f017550e3663a8589c740f769294..8635501a8e6a9d029d104083e51675a2820b232d 100644 --- a/speechx/speechx/websocket/websocket_client.h +++ b/speechx/speechx/protocol/websocket/websocket_client.h @@ -40,12 +40,14 @@ class WebSocketClient { void SendEndSignal(); void SendDataEnd(); bool Done() const { return done_; } - std::string GetResult() { return result_; } + std::string GetResult() const { return result_; } + std::string GetPartialResult() const { return partial_result_;} private: void Connect(); std::string host_; std::string result_; + std::string partial_result_; int port_; bool done_ = false; asio::io_context ioc_; diff --git a/speechx/speechx/websocket/websocket_client_main.cc b/speechx/speechx/protocol/websocket/websocket_client_main.cc similarity index 99% rename from speechx/speechx/websocket/websocket_client_main.cc rename to speechx/speechx/protocol/websocket/websocket_client_main.cc index df658b0a2218b72af99c23d7a986aa4867732c6d..7ad36e3a563c58e40918666beca7b185115eb1cb 100644 --- a/speechx/speechx/websocket/websocket_client_main.cc +++ b/speechx/speechx/protocol/websocket/websocket_client_main.cc @@ -59,7 +59,6 @@ int main(int argc, char* argv[]) { client.SendBinaryData(wav_chunk.data(), wav_chunk.size() * sizeof(int16)); - sample_offset += cur_chunk_size; LOG(INFO) << "Send " << cur_chunk_size << " samples"; std::this_thread::sleep_for( diff --git a/speechx/speechx/websocket/websocket_server.cc b/speechx/speechx/protocol/websocket/websocket_server.cc similarity index 98% rename from speechx/speechx/websocket/websocket_server.cc rename to speechx/speechx/protocol/websocket/websocket_server.cc index 28c9eca4ee7776e8f1c4606dff60fa13b1a284bd..a1abd98e66b8f1bac00ff5276506241c5702d2e9 100644 --- a/speechx/speechx/websocket/websocket_server.cc +++ b/speechx/speechx/protocol/websocket/websocket_server.cc @@ -75,9 +75,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) { CHECK(recognizer_ != nullptr); recognizer_->Accept(pcm_data); - // TODO: return lpartial result + std::string partial_result = recognizer_->GetPartialResult(); + json::value rv = { - {"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}}; + {"status", "ok"}, {"type", "partial_result"}, {"result", partial_result}}; ws_.text(true); ws_.write(asio::buffer(json::serialize(rv))); } diff --git a/speechx/speechx/websocket/websocket_server.h b/speechx/speechx/protocol/websocket/websocket_server.h similarity index 98% rename from speechx/speechx/websocket/websocket_server.h rename to speechx/speechx/protocol/websocket/websocket_server.h index 9ea88282ec5e60682daeabcccf0d4f2c09c114fb..009fc42ed827fed1258f241a3e48936bf71a7daf 100644 --- a/speechx/speechx/websocket/websocket_server.h +++ b/speechx/speechx/protocol/websocket/websocket_server.h @@ -44,7 +44,6 @@ class ConnectionHandler { void OnFinish(); void OnSpeechData(const beast::flat_buffer& buffer); void OnError(const std::string& message); - void OnPartialResult(const std::string& result); void OnFinalResult(const std::string& result); void DecodeThreadFunc(); std::string SerializeResult(bool finish); diff --git a/speechx/speechx/websocket/websocket_server_main.cc b/speechx/speechx/protocol/websocket/websocket_server_main.cc similarity index 100% rename from speechx/speechx/websocket/websocket_server_main.cc rename to speechx/speechx/protocol/websocket/websocket_server_main.cc