提交 e23b173c 编写于 作者: Y Yang Zhou

add partial result

上级 1d01c5b5
......@@ -45,7 +45,7 @@ export GLOG_logtostderr=1
# 3. gen cmvn
cmvn=$data/cmvn.ark
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
wfst=$data/wfst/
......
......@@ -47,6 +47,26 @@ void TLGDecoder::Reset() {
return;
}
std::string TLGDecoder::GetPartialResult() {
if (frame_decoded_size_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, false);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
}
std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
......
......@@ -38,6 +38,7 @@ class TLGDecoder {
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
std::string GetPartialResult();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
......
......@@ -44,6 +44,10 @@ std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
}
std::string Recognizer::GetPartialResult() {
return decoder_->GetPartialResult();
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
......
......@@ -43,6 +43,7 @@ class Recognizer {
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode();
std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished();
bool IsFinished();
void Reset();
......
......@@ -67,6 +67,9 @@ void WebSocketClient::ReadLoopFunc() {
if (obj["type"] == "final_result") {
result_ = obj["result"].as_string().c_str();
}
if (obj["type"] == "partial_result") {
partial_result_ = obj["partial_result"].as_string().c_str();
}
if (obj["type"] == "speech_end") {
done_ = true;
break;
......
......@@ -41,11 +41,13 @@ class WebSocketClient {
void SendDataEnd();
bool Done() const { return done_; }
std::string GetResult() { return result_; }
std::string GetPartialResult() { return partial_result_; }
private:
void Connect();
std::string host_;
std::string result_;
std::string partial_result_;
int port_;
bool done_ = false;
asio::io_context ioc_;
......
......@@ -59,7 +59,6 @@ int main(int argc, char* argv[]) {
client.SendBinaryData(wav_chunk.data(),
wav_chunk.size() * sizeof(int16));
sample_offset += cur_chunk_size;
LOG(INFO) << "Send " << cur_chunk_size << " samples";
std::this_thread::sleep_for(
......
......@@ -75,9 +75,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) {
CHECK(recognizer_ != nullptr);
recognizer_->Accept(pcm_data);
// TODO: return lpartial result
std::string partial_result = recognizer_->GetPartialResult();
json::value rv = {
{"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}};
{"status", "ok"}, {"type", "partial_result"}, {"partial_result", partial_result}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
}
......
......@@ -44,7 +44,6 @@ class ConnectionHandler {
void OnFinish();
void OnSpeechData(const beast::flat_buffer& buffer);
void OnError(const std::string& message);
void OnPartialResult(const std::string& result);
void OnFinalResult(const std::string& result);
void DecodeThreadFunc();
std::string SerializeResult(bool finish);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册