Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d14ee800
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d14ee800
编写于
2月 17, 2022
作者:
S
SmileGoat
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add decodable & ctc_beam_search_deocder
上级
e57efcb3
变更
20
显示空白变更内容
内联
并排
Showing
20 changed file
with
7416 addition
and
1 deletion
+7416
-1
speechx/speechx/decoder/common.h
speechx/speechx/decoder/common.h
+7
-0
speechx/speechx/decoder/ctc_beam_search_decoder.cc
speechx/speechx/decoder/ctc_beam_search_decoder.cc
+264
-0
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+74
-0
speechx/speechx/decoder/ctc_decoders
speechx/speechx/decoder/ctc_decoders
+1
-0
speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc
speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc
+1020
-0
speechx/speechx/kaldi/decoder/lattice-faster-decoder.h
speechx/speechx/kaldi/decoder/lattice-faster-decoder.h
+549
-0
speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc
...hx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc
+285
-0
speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h
...chx/speechx/kaldi/decoder/lattice-faster-online-decoder.h
+147
-0
speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc
speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc
+147
-0
speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc
speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc
+1541
-0
speechx/speechx/kaldi/lat/determinize-lattice-pruned.h
speechx/speechx/kaldi/lat/determinize-lattice-pruned.h
+296
-0
speechx/speechx/kaldi/lat/kaldi-lattice.cc
speechx/speechx/kaldi/lat/kaldi-lattice.cc
+506
-0
speechx/speechx/kaldi/lat/kaldi-lattice.h
speechx/speechx/kaldi/lat/kaldi-lattice.h
+156
-0
speechx/speechx/kaldi/lat/lattice-functions.cc
speechx/speechx/kaldi/lat/lattice-functions.cc
+1880
-0
speechx/speechx/kaldi/lat/lattice-functions.h
speechx/speechx/kaldi/lat/lattice-functions.h
+402
-0
speechx/speechx/nnet/ctc_decodable.h
speechx/speechx/nnet/ctc_decodable.h
+0
-0
speechx/speechx/nnet/decodable-itf.h
speechx/speechx/nnet/decodable-itf.h
+122
-0
speechx/speechx/nnet/decodable.h
speechx/speechx/nnet/decodable.h
+18
-0
speechx/speechx/nnet/dnn_decodable.h
speechx/speechx/nnet/dnn_decodable.h
+0
-0
speechx/speechx/nnet/nnet_interface.h
speechx/speechx/nnet/nnet_interface.h
+1
-1
未找到文件。
speechx/speechx/decoder/common.h
0 → 100644
浏览文件 @
d14ee800
#include "base/basic_types.h"
struct
DecoderResult
{
BaseFloat
acoustic_score
;
std
::
vector
<
int32
>
words_idx
;
std
::
vector
<
pair
<
int32
,
int32
>>
time_stamp
;
};
speechx/speechx/decoder/ctc_beam_search_decoder.cc
0 → 100644
浏览文件 @
d14ee800
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
#include "decoder/ctc_decoders/decoder_utils.h"
namespace
ppspeech
{
using
std
::
vector
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
CTCBeamSearch
::
CTCBeamSearch
(
std
::
shared_ptr
<
CTCBeamSearchOptions
>
opts
)
:
opts_
(
opts
),
vocabulary_
(
nullptr
),
init_ext_scorer_
(
nullptr
),
blank_id
(
-
1
),
space_id
(
-
1
),
root
(
nullptr
)
{
LOG
(
INFO
)
<<
"dict path: "
<<
_opts
.
dict_file
;
vocabulary_
=
std
::
make_shared
<
vector
<
string
>>
();
if
(
!
basr
::
ReadDictToVector
(
_opts
.
dict_file
,
*
vocabulary_
))
{
LOG
(
INFO
)
<<
"load the dict failed"
;
}
LOG
(
INFO
)
<<
"read the vocabulary success, dict size: "
<<
vocabulary_
->
size
();
LOG
(
INFO
)
<<
"language model path: "
<<
_opts
.
lm_path
;
init_ext_scorer_
=
std
::
make_shared
<
Scorer
>
(
_opts
.
alpha
,
_opts
.
beta
,
_opts
.
lm_path
,
*
vocabulary_
);
}
void
CTCBeamSearch
::
InitDecoder
()
{
blank_id
=
0
;
auto
it
=
std
::
find
(
vocabulary_
->
begin
(),
vocabulary_
->
end
(),
" "
);
space_id
=
it
-
vocabulary_
->
begin
();
// if no space in vocabulary
if
((
size_t
)
space_id
>=
vocabulary_
->
size
())
{
space_id
=
-
2
;
}
clear_prefixes
();
root
=
std
::
make_shared
<
PathTrie
>
();
root
->
score
=
root
->
log_prob_b_prev
=
0.0
;
prefixes
.
push_back
(
root
.
get
());
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
init_ext_scorer_
->
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
root
->
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
root
->
set_matcher
(
matcher
);
}
}
void
CTCBeamSearch
::
ResetPrefixes
()
{
for
(
size_t
i
=
0
;
i
<
prefixes
.
size
();
i
++
)
{
if
(
prefixes
[
i
]
!=
nullptr
)
{
delete
prefixes
[
i
];
prefixes
[
i
]
=
nullptr
;
}
}
}
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
vector
<
string
>&
nbest_words
)
{
std
::
thread
::
id
this_id
=
std
::
this_thread
::
get_id
();
Timer
timer
;
vector
<
vector
<
double
>>
double_probs
(
probs
.
size
(),
vector
<
double
>
(
probs
[
0
].
size
(),
0
));
int
row
=
probs
.
size
();
int
col
=
probs
[
0
].
size
();
for
(
int
i
=
0
;
i
<
row
;
i
++
)
{
for
(
int
j
=
0
;
j
<
col
;
j
++
){
double_probs
[
i
][
j
]
=
static_cast
<
double
>
(
probs
[
i
][
j
]);
}
}
timer
.
Reset
();
vector
<
std
::
pair
<
double
,
string
>>
results
=
AdvanceDecoding
(
double_probs
);
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
for
(
const
auto
&
item
:
results
)
{
nbest_words
.
push_back
(
item
.
second
);
}
return
0
;
}
vector
<
std
::
pair
<
double
,
string
>>
CTCBeamSearch
::
AdvanceDecoding
(
const
vector
<
vector
<
double
>>&
probs_seq
)
{
size_t
num_time_steps
=
probs_seq
.
size
();
size_t
beam_size
=
_opts
.
beam_size
;
double
cutoff_prob
=
_opts
.
cutoff_prob
;
size_t
cutoff_top_n
=
_opts
.
cutoff_top_n
;
for
(
size_t
time_step
=
0
;
time_step
<
num_time_steps
;
time_step
++
)
{
const
auto
&
prob
=
probs_seq
[
time_step
];
float
min_cutoff
=
-
NUM_FLT_INF
;
bool
full_beam
=
false
;
if
(
init_ext_scorer_
!=
nullptr
)
{
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
if
(
num_prefixes
==
0
)
{
continue
;
}
min_cutoff
=
prefixes
[
num_prefixes
-
1
]
->
score
+
std
::
log
(
prob
[
blank_id
])
-
std
::
max
(
0.0
,
init_ext_scorer_
->
beta
);
full_beam
=
(
num_prefixes
==
beam_size
);
}
vector
<
std
::
pair
<
size_t
,
float
>>
log_prob_idx
=
get_pruned_log_probs
(
prob
,
cutoff_prob
,
cutoff_top_n
);
// loop over chars
size_t
log_prob_idx_len
=
log_prob_idx
.
size
();
for
(
size_t
index
=
0
;
index
<
log_prob_idx_len
;
index
++
)
{
SearchOneChar
(
full_beam
,
log_prob_idx
[
index
],
min_cutoff
);
prefixes
.
clear
();
// update log probs
root
->
iterate_to_vec
(
prefixes
);
// only preserve top beam_size prefixes
if
(
prefixes
.
size
()
>=
beam_size
)
{
std
::
nth_element
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
beam_size
,
prefixes
.
end
(),
prefix_compare
);
for
(
size_t
i
=
beam_size
;
i
<
prefixes
.
size
();
++
i
)
{
prefixes
[
i
]
->
remove
();
}
}
// if
}
// for probs_seq
// score the last word of each prefix that doesn't end with space
LMRescore
();
CalculateApproxScore
();
return
get_beam_search_result
(
prefixes
,
*
vocabulary_
,
beam_size
);
}
int
CTCBeamSearch
::
SearchOneChar
(
const
bool
&
full_beam
,
const
std
::
pair
<
size_t
,
float
>&
log_prob_idx
,
const
float
&
min_cutoff
)
{
size_t
beam_size
=
_opts
.
beam_size
;
const
auto
&
c
=
log_prob_idx
.
first
;
const
auto
&
log_prob_c
=
log_prob_idx
.
second
;
size_t
prefixes_len
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
for
(
size_t
i
=
0
;
i
<
prefixes_len
;
++
i
)
{
auto
prefix
=
prefixes
[
i
];
if
(
full_beam
&&
log_prob_c
+
prefix
->
score
<
min_cutoff
)
{
break
;
}
if
(
c
==
blank_id
)
{
prefix
->
log_prob_b_cur
=
log_sum_exp
(
prefix
->
log_prob_b_cur
,
log_prob_c
+
prefix
->
score
);
continue
;
}
// repeated character
if
(
c
==
prefix
->
character
)
{
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix
->
log_prob_nb_cur
=
log_sum_exp
(
prefix
->
log_prob_nb_cur
,
log_prob_c
+
prefix
->
log_prob_nb_prev
);
}
// get new prefix
auto
prefix_new
=
prefix
->
get_path_trie
(
c
);
if
(
prefix_new
!=
nullptr
)
{
float
log_p
=
-
NUM_FLT_INF
;
if
(
c
==
prefix
->
character
&&
prefix
->
log_prob_b_prev
>
-
NUM_FLT_INF
)
{
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p
=
log_prob_c
+
prefix
->
log_prob_b_prev
;
}
else
if
(
c
!=
prefix
->
character
)
{
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p
=
log_prob_c
+
prefix
->
score
;
}
// language model scoring
if
(
init_ext_scorer_
!=
nullptr
&&
(
c
==
space_id
||
init_ext_scorer_
->
is_character_based
()))
{
PathTrie
*
prefix_to_score
=
nullptr
;
// skip scoring the space
if
(
init_ext_scorer_
->
is_character_based
())
{
prefix_to_score
=
prefix_new
;
}
else
{
prefix_to_score
=
prefix
;
}
float
score
=
0.0
;
vector
<
string
>
ngram
;
ngram
=
init_ext_scorer_
->
make_ngram
(
prefix_to_score
);
// lm score: p_{lm}(W)^{\alpha} + \beta
score
=
init_ext_scorer_
->
get_log_cond_prob
(
ngram
)
*
init_ext_scorer_
->
alpha
;
log_p
+=
score
;
log_p
+=
init_ext_scorer_
->
beta
;
}
// p_{nb}(l;x_{1:t})
prefix_new
->
log_prob_nb_cur
=
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_p
);
}
}
// end of loop over prefix
return
0
;
}
void
CTCBeamSearch
::
CalculateApproxScore
()
{
size_t
beam_size
=
_opts
.
beam_size
;
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
double
approx_ctc
=
prefixes
[
i
]
->
score
;
if
(
init_ext_scorer_
!=
nullptr
)
{
vector
<
int
>
output
;
prefixes
[
i
]
->
get_path_vec
(
output
);
auto
prefix_length
=
output
.
size
();
auto
words
=
init_ext_scorer_
->
split_labels
(
output
);
// remove word insert
approx_ctc
=
approx_ctc
-
prefix_length
*
init_ext_scorer_
->
beta
;
// remove language model weight:
approx_ctc
-=
(
init_ext_scorer_
->
get_sent_log_prob
(
words
))
*
init_ext_scorer_
->
alpha
;
}
prefixes
[
i
]
->
approx_ctc
=
approx_ctc
;
}
}
void
CTCBeamSearch
::
LMRescore
()
{
size_t
beam_size
=
_opts
.
beam_size
;
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
auto
prefix
=
prefixes
[
i
];
if
(
!
prefix
->
is_empty
()
&&
prefix
->
character
!=
space_id
)
{
float
score
=
0.0
;
vector
<
string
>
ngram
=
init_ext_scorer_
->
make_ngram
(
prefix
);
score
=
init_ext_scorer_
->
get_log_cond_prob
(
ngram
)
*
init_ext_scorer_
->
alpha
;
score
+=
init_ext_scorer_
->
beta
;
prefix
->
score
+=
score
;
}
}
}
}
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/ctc_beam_search_decoder.h
0 → 100644
浏览文件 @
d14ee800
#include "base/basic_types.h"
#pragma once
namespace
ppspeech
{
struct
CTCBeamSearchOptions
{
std
::
string
dict_file
;
std
::
string
lm_path
;
BaseFloat
alpha
;
BaseFloat
beta
;
BaseFloat
cutoff_prob
;
int
beam_size
;
int
cutoff_top_n
;
int
num_proc_bsearch
;
CTCBeamSearchOptions
()
:
dict_file
(
"./model/words.txt"
),
lm_path
(
"./model/lm.arpa"
),
alpha
(
1.9
f
),
beta
(
5.0
),
beam_size
(
300
),
cutoff_prob
(
0.99
f
),
cutoff_top_n
(
40
),
num_proc_bsearch
(
0
)
{
}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"dict"
,
&
dict_file
,
"dict file "
);
opts
->
Register
(
"lm-path"
,
&
lm_path
,
"language model file"
);
opts
->
Register
(
"alpha"
,
&
alpha
,
"alpha"
);
opts
->
Register
(
"beta"
,
&
beta
,
"beta"
);
opts
->
Register
(
"beam-size"
,
&
beam_size
,
"beam size for beam search method"
);
opts
->
Register
(
"cutoff-prob"
,
&
cutoff_prob
,
"cutoff probs"
);
opts
->
Register
(
"cutoff-top-n"
,
&
cutoff_top_n
,
"cutoff top n"
);
opts
->
Register
(
"num-proc-bsearch"
,
&
num_proc_bsearch
,
"num proc bsearch"
);
}
};
class
CTCBeamSearch
{
public:
CTCBeamSearch
(
std
::
shared_ptr
<
CTCBeamSearchOptions
>
opts
);
~
CTCBeamSearch
()
{
}
bool
InitDecoder
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
std
::
vector
<
DecodeResult
>&
GetDecodeResult
()
{
return
decoder_results_
;
}
private:
void
ResetPrefixes
();
int32
SearchOneChar
(
const
bool
&
full_beam
,
const
std
::
pair
<
size_t
,
BaseFloat
>&
log_prob_idx
,
const
BaseFloat
&
min_cutoff
);
void
CalculateApproxScore
();
void
LMRescore
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
AdvanceDecoding
(
const
std
::
vector
<
std
::
vector
<
double
>>&
probs_seq
);
CTCBeamSearchOptions
opts_
;
std
::
shared_ptr
<
Scorer
>
init_ext_scorer_
;
// todo separate later
std
::
vector
<
DecodeResult
>
decoder_results_
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
vocabulary_
;
// todo remove later
size_t
blank_id
;
int
space_id
;
std
::
shared_ptr
<
PathTrie
>
root
;
std
::
vector
<
PathTrie
*>
prefixes
;
};
}
// namespace basr
\ No newline at end of file
speechx/speechx/decoder/ctc_decoders
0 → 120000
浏览文件 @
d14ee800
../../../third_party/ctc_decoders
\ No newline at end of file
speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc
0 → 100644
浏览文件 @
d14ee800
// decoder/lattice-faster-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2018 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "decoder/lattice-faster-decoder.h"
#include "lat/lattice-functions.h"
namespace
kaldi
{
// instantiate this class once for each thing you have to decode.
template
<
typename
FST
,
typename
Token
>
LatticeFasterDecoderTpl
<
FST
,
Token
>::
LatticeFasterDecoderTpl
(
const
FST
&
fst
,
const
LatticeFasterDecoderConfig
&
config
)
:
fst_
(
&
fst
),
delete_fst_
(
false
),
config_
(
config
),
num_toks_
(
0
),
token_pool_
(
config
.
memory_pool_tokens_block_size
),
forward_link_pool_
(
config
.
memory_pool_links_block_size
)
{
config
.
Check
();
toks_
.
SetSize
(
1000
);
// just so on the first frame we do something reasonable.
}
template
<
typename
FST
,
typename
Token
>
LatticeFasterDecoderTpl
<
FST
,
Token
>::
LatticeFasterDecoderTpl
(
const
LatticeFasterDecoderConfig
&
config
,
FST
*
fst
)
:
fst_
(
fst
),
delete_fst_
(
true
),
config_
(
config
),
num_toks_
(
0
),
token_pool_
(
config
.
memory_pool_tokens_block_size
),
forward_link_pool_
(
config
.
memory_pool_links_block_size
)
{
config
.
Check
();
toks_
.
SetSize
(
1000
);
// just so on the first frame we do something reasonable.
}
template
<
typename
FST
,
typename
Token
>
LatticeFasterDecoderTpl
<
FST
,
Token
>::~
LatticeFasterDecoderTpl
()
{
DeleteElems
(
toks_
.
Clear
());
ClearActiveTokens
();
if
(
delete_fst_
)
delete
fst_
;
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
InitDecoding
()
{
// clean up from last time:
DeleteElems
(
toks_
.
Clear
());
cost_offsets_
.
clear
();
ClearActiveTokens
();
warned_
=
false
;
num_toks_
=
0
;
decoding_finalized_
=
false
;
final_costs_
.
clear
();
StateId
start_state
=
fst_
->
Start
();
KALDI_ASSERT
(
start_state
!=
fst
::
kNoStateId
);
active_toks_
.
resize
(
1
);
Token
*
start_tok
=
new
(
token_pool_
.
Allocate
())
Token
(
0.0
,
0.0
,
NULL
,
NULL
,
NULL
);
active_toks_
[
0
].
toks
=
start_tok
;
toks_
.
Insert
(
start_state
,
start_tok
);
num_toks_
++
;
ProcessNonemitting
(
config_
.
beam
);
}
// Returns true if any kind of traceback is available (not necessarily from
// a final state). It should only very rarely return false; this indicates
// an unusual search error.
template
<
typename
FST
,
typename
Token
>
bool
LatticeFasterDecoderTpl
<
FST
,
Token
>::
Decode
(
DecodableInterface
*
decodable
)
{
InitDecoding
();
// We use 1-based indexing for frames in this decoder (if you view it in
// terms of features), but note that the decodable object uses zero-based
// numbering, which we have to correct for when we call it.
AdvanceDecoding
(
decodable
);
FinalizeDecoding
();
// Returns true if we have any kind of traceback available (not necessarily
// to the end state; query ReachedFinal() for that).
return
!
active_toks_
.
empty
()
&&
active_toks_
.
back
().
toks
!=
NULL
;
}
// Outputs an FST corresponding to the single best path through the lattice.
template
<
typename
FST
,
typename
Token
>
bool
LatticeFasterDecoderTpl
<
FST
,
Token
>::
GetBestPath
(
Lattice
*
olat
,
bool
use_final_probs
)
const
{
Lattice
raw_lat
;
GetRawLattice
(
&
raw_lat
,
use_final_probs
);
ShortestPath
(
raw_lat
,
olat
);
return
(
olat
->
NumStates
()
!=
0
);
}
// Outputs an FST corresponding to the raw, state-level lattice
template
<
typename
FST
,
typename
Token
>
bool
LatticeFasterDecoderTpl
<
FST
,
Token
>::
GetRawLattice
(
Lattice
*
ofst
,
bool
use_final_probs
)
const
{
typedef
LatticeArc
Arc
;
typedef
Arc
::
StateId
StateId
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
Label
Label
;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if
(
decoding_finalized_
&&
!
use_final_probs
)
KALDI_ERR
<<
"You cannot call FinalizeDecoding() and then call "
<<
"GetRawLattice() with use_final_probs == false"
;
unordered_map
<
Token
*
,
BaseFloat
>
final_costs_local
;
const
unordered_map
<
Token
*
,
BaseFloat
>
&
final_costs
=
(
decoding_finalized_
?
final_costs_
:
final_costs_local
);
if
(
!
decoding_finalized_
&&
use_final_probs
)
ComputeFinalCosts
(
&
final_costs_local
,
NULL
,
NULL
);
ofst
->
DeleteStates
();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32
num_frames
=
active_toks_
.
size
()
-
1
;
KALDI_ASSERT
(
num_frames
>
0
);
const
int32
bucket_count
=
num_toks_
/
2
+
3
;
unordered_map
<
Token
*
,
StateId
>
tok_map
(
bucket_count
);
// First create all states.
std
::
vector
<
Token
*>
token_list
;
for
(
int32
f
=
0
;
f
<=
num_frames
;
f
++
)
{
if
(
active_toks_
[
f
].
toks
==
NULL
)
{
KALDI_WARN
<<
"GetRawLattice: no tokens active on frame "
<<
f
<<
": not producing lattice.
\n
"
;
return
false
;
}
TopSortTokens
(
active_toks_
[
f
].
toks
,
&
token_list
);
for
(
size_t
i
=
0
;
i
<
token_list
.
size
();
i
++
)
if
(
token_list
[
i
]
!=
NULL
)
tok_map
[
token_list
[
i
]]
=
ofst
->
AddState
();
}
// The next statement sets the start state of the output FST. Because we
// topologically sorted the tokens, state zero must be the start-state.
ofst
->
SetStart
(
0
);
KALDI_VLOG
(
4
)
<<
"init:"
<<
num_toks_
/
2
+
3
<<
" buckets:"
<<
tok_map
.
bucket_count
()
<<
" load:"
<<
tok_map
.
load_factor
()
<<
" max:"
<<
tok_map
.
max_load_factor
();
// Now create all arcs.
for
(
int32
f
=
0
;
f
<=
num_frames
;
f
++
)
{
for
(
Token
*
tok
=
active_toks_
[
f
].
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
StateId
cur_state
=
tok_map
[
tok
];
for
(
ForwardLinkT
*
l
=
tok
->
links
;
l
!=
NULL
;
l
=
l
->
next
)
{
typename
unordered_map
<
Token
*
,
StateId
>::
const_iterator
iter
=
tok_map
.
find
(
l
->
next_tok
);
StateId
nextstate
=
iter
->
second
;
KALDI_ASSERT
(
iter
!=
tok_map
.
end
());
BaseFloat
cost_offset
=
0.0
;
if
(
l
->
ilabel
!=
0
)
{
// emitting..
KALDI_ASSERT
(
f
>=
0
&&
f
<
cost_offsets_
.
size
());
cost_offset
=
cost_offsets_
[
f
];
}
Arc
arc
(
l
->
ilabel
,
l
->
olabel
,
Weight
(
l
->
graph_cost
,
l
->
acoustic_cost
-
cost_offset
),
nextstate
);
ofst
->
AddArc
(
cur_state
,
arc
);
}
if
(
f
==
num_frames
)
{
if
(
use_final_probs
&&
!
final_costs
.
empty
())
{
typename
unordered_map
<
Token
*
,
BaseFloat
>::
const_iterator
iter
=
final_costs
.
find
(
tok
);
if
(
iter
!=
final_costs
.
end
())
ofst
->
SetFinal
(
cur_state
,
LatticeWeight
(
iter
->
second
,
0
));
}
else
{
ofst
->
SetFinal
(
cur_state
,
LatticeWeight
::
One
());
}
}
}
}
return
(
ofst
->
NumStates
()
>
0
);
}
// This function is now deprecated, since now we do determinization from outside
// the LatticeFasterDecoder class. Outputs an FST corresponding to the
// lattice-determinized lattice (one path per word sequence).
template
<
typename
FST
,
typename
Token
>
bool
LatticeFasterDecoderTpl
<
FST
,
Token
>::
GetLattice
(
CompactLattice
*
ofst
,
bool
use_final_probs
)
const
{
Lattice
raw_fst
;
GetRawLattice
(
&
raw_fst
,
use_final_probs
);
Invert
(
&
raw_fst
);
// make it so word labels are on the input.
// (in phase where we get backward-costs).
fst
::
ILabelCompare
<
LatticeArc
>
ilabel_comp
;
ArcSort
(
&
raw_fst
,
ilabel_comp
);
// sort on ilabel; makes
// lattice-determinization more efficient.
fst
::
DeterminizeLatticePrunedOptions
lat_opts
;
lat_opts
.
max_mem
=
config_
.
det_opts
.
max_mem
;
DeterminizeLatticePruned
(
raw_fst
,
config_
.
lattice_beam
,
ofst
,
lat_opts
);
raw_fst
.
DeleteStates
();
// Free memory-- raw_fst no longer needed.
Connect
(
ofst
);
// Remove unreachable states... there might be
// a small number of these, in some cases.
// Note: if something went wrong and the raw lattice was empty,
// we should still get to this point in the code without warnings or failures.
return
(
ofst
->
NumStates
()
!=
0
);
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PossiblyResizeHash
(
size_t
num_toks
)
{
size_t
new_sz
=
static_cast
<
size_t
>
(
static_cast
<
BaseFloat
>
(
num_toks
)
*
config_
.
hash_ratio
);
if
(
new_sz
>
toks_
.
Size
())
{
toks_
.
SetSize
(
new_sz
);
}
}
/*
A note on the definition of extra_cost.
extra_cost is used in pruning tokens, to save memory.
extra_cost can be thought of as a beta (backward) cost assuming
we had set the betas on currently-active tokens to all be the negative
of the alphas for those tokens. (So all currently active tokens would
be on (tied) best paths).
We can use the extra_cost to accurately prune away tokens that we know will
never appear in the lattice. If the extra_cost is greater than the desired
lattice beam, the token would provably never appear in the lattice, so we can
prune away the token.
(Note: we don't update all the extra_costs every time we update a frame; we
only do it every 'config_.prune_interval' frames).
*/
// FindOrAddToken either locates a token in hash of toks_,
// or if necessary inserts a new, empty token (i.e. with no forward links)
// for the current frame. [note: it's inserted if necessary into hash toks_
// and also into the singly linked list of tokens active on this frame
// (whose head is at active_toks_[frame]).
template
<
typename
FST
,
typename
Token
>
inline
typename
LatticeFasterDecoderTpl
<
FST
,
Token
>::
Elem
*
LatticeFasterDecoderTpl
<
FST
,
Token
>::
FindOrAddToken
(
StateId
state
,
int32
frame_plus_one
,
BaseFloat
tot_cost
,
Token
*
backpointer
,
bool
*
changed
)
{
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT
(
frame_plus_one
<
active_toks_
.
size
());
Token
*&
toks
=
active_toks_
[
frame_plus_one
].
toks
;
Elem
*
e_found
=
toks_
.
Insert
(
state
,
NULL
);
if
(
e_found
->
val
==
NULL
)
{
// no such token presently.
const
BaseFloat
extra_cost
=
0.0
;
// tokens on the currently final frame have zero extra_cost
// as any of them could end up
// on the winning path.
Token
*
new_tok
=
new
(
token_pool_
.
Allocate
())
Token
(
tot_cost
,
extra_cost
,
NULL
,
toks
,
backpointer
);
// NULL: no forward links yet
toks
=
new_tok
;
num_toks_
++
;
e_found
->
val
=
new_tok
;
if
(
changed
)
*
changed
=
true
;
return
e_found
;
}
else
{
Token
*
tok
=
e_found
->
val
;
// There is an existing Token for this state.
if
(
tok
->
tot_cost
>
tot_cost
)
{
// replace old token
tok
->
tot_cost
=
tot_cost
;
// SetBackpointer() just does tok->backpointer = backpointer in
// the case where Token == BackpointerToken, else nothing.
tok
->
SetBackpointer
(
backpointer
);
// we don't allocate a new token, the old stays linked in active_toks_
// we only replace the tot_cost
// in the current frame, there are no forward links (and no extra_cost)
// only in ProcessNonemitting we have to delete forward links
// in case we visit a state for the second time
// those forward links, that lead to this replaced token before:
// they remain and will hopefully be pruned later (PruneForwardLinks...)
if
(
changed
)
*
changed
=
true
;
}
else
{
if
(
changed
)
*
changed
=
false
;
}
return
e_found
;
}
}
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PruneForwardLinks
(
int32
frame_plus_one
,
bool
*
extra_costs_changed
,
bool
*
links_pruned
,
BaseFloat
delta
)
{
// delta is the amount by which the extra_costs must change
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
*
extra_costs_changed
=
false
;
*
links_pruned
=
false
;
KALDI_ASSERT
(
frame_plus_one
>=
0
&&
frame_plus_one
<
active_toks_
.
size
());
if
(
active_toks_
[
frame_plus_one
].
toks
==
NULL
)
{
// empty list; should not happen.
if
(
!
warned_
)
{
KALDI_WARN
<<
"No tokens alive [doing pruning].. warning first "
"time only for each utterance
\n
"
;
warned_
=
true
;
}
}
// We have to iterate until there is no more change, because the links
// are not guaranteed to be in topological order.
bool
changed
=
true
;
// difference new minus old extra cost >= delta ?
while
(
changed
)
{
changed
=
false
;
for
(
Token
*
tok
=
active_toks_
[
frame_plus_one
].
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
ForwardLinkT
*
link
,
*
prev_link
=
NULL
;
// will recompute tok_extra_cost for tok.
BaseFloat
tok_extra_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// tok_extra_cost is the best (min) of link_extra_cost of outgoing links
for
(
link
=
tok
->
links
;
link
!=
NULL
;
)
{
// See if we need to excise this link...
Token
*
next_tok
=
link
->
next_tok
;
BaseFloat
link_extra_cost
=
next_tok
->
extra_cost
+
((
tok
->
tot_cost
+
link
->
acoustic_cost
+
link
->
graph_cost
)
-
next_tok
->
tot_cost
);
// difference in brackets is >= 0
// link_exta_cost is the difference in score between the best paths
// through link source state and through link destination state
KALDI_ASSERT
(
link_extra_cost
==
link_extra_cost
);
// check for NaN
if
(
link_extra_cost
>
config_
.
lattice_beam
)
{
// excise link
ForwardLinkT
*
next_link
=
link
->
next
;
if
(
prev_link
!=
NULL
)
prev_link
->
next
=
next_link
;
else
tok
->
links
=
next_link
;
forward_link_pool_
.
Free
(
link
);
link
=
next_link
;
// advance link but leave prev_link the same.
*
links_pruned
=
true
;
}
else
{
// keep the link and update the tok_extra_cost if needed.
if
(
link_extra_cost
<
0.0
)
{
// this is just a precaution.
if
(
link_extra_cost
<
-
0.01
)
KALDI_WARN
<<
"Negative extra_cost: "
<<
link_extra_cost
;
link_extra_cost
=
0.0
;
}
if
(
link_extra_cost
<
tok_extra_cost
)
tok_extra_cost
=
link_extra_cost
;
prev_link
=
link
;
// move to next link
link
=
link
->
next
;
}
}
// for all outgoing links
if
(
fabs
(
tok_extra_cost
-
tok
->
extra_cost
)
>
delta
)
changed
=
true
;
// difference new minus old is bigger than delta
tok
->
extra_cost
=
tok_extra_cost
;
// will be +infinity or <= lattice_beam_.
// infinity indicates, that no forward link survived pruning
}
// for all Token on active_toks_[frame]
if
(
changed
)
*
extra_costs_changed
=
true
;
// Note: it's theoretically possible that aggressive compiler
// optimizations could cause an infinite loop here for small delta and
// high-dynamic-range scores.
}
// while changed
}
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PruneForwardLinksFinal
()
{
KALDI_ASSERT
(
!
active_toks_
.
empty
());
int32
frame_plus_one
=
active_toks_
.
size
()
-
1
;
if
(
active_toks_
[
frame_plus_one
].
toks
==
NULL
)
// empty list; should not happen.
KALDI_WARN
<<
"No tokens alive at end of file"
;
typedef
typename
unordered_map
<
Token
*
,
BaseFloat
>::
const_iterator
IterType
;
ComputeFinalCosts
(
&
final_costs_
,
&
final_relative_cost_
,
&
final_best_cost_
);
decoding_finalized_
=
true
;
// We call DeleteElems() as a nicety, not because it's really necessary;
// otherwise there would be a time, after calling PruneTokensForFrame() on the
// final frame, when toks_.GetList() or toks_.Clear() would contain pointers
// to nonexistent tokens.
DeleteElems
(
toks_
.
Clear
());
// Now go through tokens on this frame, pruning forward links... may have to
// iterate a few times until there is no more change, because the list is not
// in topological order. This is a modified version of the code in
// PruneForwardLinks, but here we also take account of the final-probs.
bool
changed
=
true
;
BaseFloat
delta
=
1.0e-05
;
while
(
changed
)
{
changed
=
false
;
for
(
Token
*
tok
=
active_toks_
[
frame_plus_one
].
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
ForwardLinkT
*
link
,
*
prev_link
=
NULL
;
// will recompute tok_extra_cost. It has a term in it that corresponds
// to the "final-prob", so instead of initializing tok_extra_cost to infinity
// below we set it to the difference between the (score+final_prob) of this token,
// and the best such (score+final_prob).
BaseFloat
final_cost
;
if
(
final_costs_
.
empty
())
{
final_cost
=
0.0
;
}
else
{
IterType
iter
=
final_costs_
.
find
(
tok
);
if
(
iter
!=
final_costs_
.
end
())
final_cost
=
iter
->
second
;
else
final_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
}
BaseFloat
tok_extra_cost
=
tok
->
tot_cost
+
final_cost
-
final_best_cost_
;
// tok_extra_cost will be a "min" over either directly being final, or
// being indirectly final through other links, and the loop below may
// decrease its value:
for
(
link
=
tok
->
links
;
link
!=
NULL
;
)
{
// See if we need to excise this link...
Token
*
next_tok
=
link
->
next_tok
;
BaseFloat
link_extra_cost
=
next_tok
->
extra_cost
+
((
tok
->
tot_cost
+
link
->
acoustic_cost
+
link
->
graph_cost
)
-
next_tok
->
tot_cost
);
if
(
link_extra_cost
>
config_
.
lattice_beam
)
{
// excise link
ForwardLinkT
*
next_link
=
link
->
next
;
if
(
prev_link
!=
NULL
)
prev_link
->
next
=
next_link
;
else
tok
->
links
=
next_link
;
forward_link_pool_
.
Free
(
link
);
link
=
next_link
;
// advance link but leave prev_link the same.
}
else
{
// keep the link and update the tok_extra_cost if needed.
if
(
link_extra_cost
<
0.0
)
{
// this is just a precaution.
if
(
link_extra_cost
<
-
0.01
)
KALDI_WARN
<<
"Negative extra_cost: "
<<
link_extra_cost
;
link_extra_cost
=
0.0
;
}
if
(
link_extra_cost
<
tok_extra_cost
)
tok_extra_cost
=
link_extra_cost
;
prev_link
=
link
;
link
=
link
->
next
;
}
}
// prune away tokens worse than lattice_beam above best path. This step
// was not necessary in the non-final case because then, this case
// showed up as having no forward links. Here, the tok_extra_cost has
// an extra component relating to the final-prob.
if
(
tok_extra_cost
>
config_
.
lattice_beam
)
tok_extra_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// to be pruned in PruneTokensForFrame
if
(
!
ApproxEqual
(
tok
->
extra_cost
,
tok_extra_cost
,
delta
))
changed
=
true
;
tok
->
extra_cost
=
tok_extra_cost
;
// will be +infinity or <= lattice_beam_.
}
}
// while changed
}
template
<
typename
FST
,
typename
Token
>
BaseFloat
LatticeFasterDecoderTpl
<
FST
,
Token
>::
FinalRelativeCost
()
const
{
if
(
!
decoding_finalized_
)
{
BaseFloat
relative_cost
;
ComputeFinalCosts
(
NULL
,
&
relative_cost
,
NULL
);
return
relative_cost
;
}
else
{
// we're not allowed to call that function if FinalizeDecoding() has
// been called; return a cached value.
return
final_relative_cost_
;
}
}
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PruneTokensForFrame
(
int32
frame_plus_one
)
{
KALDI_ASSERT
(
frame_plus_one
>=
0
&&
frame_plus_one
<
active_toks_
.
size
());
Token
*&
toks
=
active_toks_
[
frame_plus_one
].
toks
;
if
(
toks
==
NULL
)
KALDI_WARN
<<
"No tokens alive [doing pruning]"
;
Token
*
tok
,
*
next_tok
,
*
prev_tok
=
NULL
;
for
(
tok
=
toks
;
tok
!=
NULL
;
tok
=
next_tok
)
{
next_tok
=
tok
->
next
;
if
(
tok
->
extra_cost
==
std
::
numeric_limits
<
BaseFloat
>::
infinity
())
{
// token is unreachable from end of graph; (no forward links survived)
// excise tok from list and delete tok.
if
(
prev_tok
!=
NULL
)
prev_tok
->
next
=
tok
->
next
;
else
toks
=
tok
->
next
;
token_pool_
.
Free
(
tok
);
num_toks_
--
;
}
else
{
// fetch next Token
prev_tok
=
tok
;
}
}
}
// Go backwards through still-alive tokens, pruning them, starting not from
// the current frame (where we want to keep all tokens) but from the frame before
// that. We go backwards through the frames and stop when we reach a point
// where the delta-costs are not changing (and the delta controls when we consider
// a cost to have "not changed").
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PruneActiveTokens
(
BaseFloat
delta
)
{
int32
cur_frame_plus_one
=
NumFramesDecoded
();
int32
num_toks_begin
=
num_toks_
;
// The index "f" below represents a "frame plus one", i.e. you'd have to subtract
// one to get the corresponding index for the decodable object.
for
(
int32
f
=
cur_frame_plus_one
-
1
;
f
>=
0
;
f
--
)
{
// Reason why we need to prune forward links in this situation:
// (1) we have never pruned them (new TokenList)
// (2) we have not yet pruned the forward links to the next f,
// after any of those tokens have changed their extra_cost.
if
(
active_toks_
[
f
].
must_prune_forward_links
)
{
bool
extra_costs_changed
=
false
,
links_pruned
=
false
;
PruneForwardLinks
(
f
,
&
extra_costs_changed
,
&
links_pruned
,
delta
);
if
(
extra_costs_changed
&&
f
>
0
)
// any token has changed extra_cost
active_toks_
[
f
-
1
].
must_prune_forward_links
=
true
;
if
(
links_pruned
)
// any link was pruned
active_toks_
[
f
].
must_prune_tokens
=
true
;
active_toks_
[
f
].
must_prune_forward_links
=
false
;
// job done
}
if
(
f
+
1
<
cur_frame_plus_one
&&
// except for last f (no forward links)
active_toks_
[
f
+
1
].
must_prune_tokens
)
{
PruneTokensForFrame
(
f
+
1
);
active_toks_
[
f
+
1
].
must_prune_tokens
=
false
;
}
}
KALDI_VLOG
(
4
)
<<
"PruneActiveTokens: pruned tokens from "
<<
num_toks_begin
<<
" to "
<<
num_toks_
;
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
ComputeFinalCosts
(
unordered_map
<
Token
*
,
BaseFloat
>
*
final_costs
,
BaseFloat
*
final_relative_cost
,
BaseFloat
*
final_best_cost
)
const
{
KALDI_ASSERT
(
!
decoding_finalized_
);
if
(
final_costs
!=
NULL
)
final_costs
->
clear
();
const
Elem
*
final_toks
=
toks_
.
GetList
();
BaseFloat
infinity
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
BaseFloat
best_cost
=
infinity
,
best_cost_with_final
=
infinity
;
while
(
final_toks
!=
NULL
)
{
StateId
state
=
final_toks
->
key
;
Token
*
tok
=
final_toks
->
val
;
const
Elem
*
next
=
final_toks
->
tail
;
BaseFloat
final_cost
=
fst_
->
Final
(
state
).
Value
();
BaseFloat
cost
=
tok
->
tot_cost
,
cost_with_final
=
cost
+
final_cost
;
best_cost
=
std
::
min
(
cost
,
best_cost
);
best_cost_with_final
=
std
::
min
(
cost_with_final
,
best_cost_with_final
);
if
(
final_costs
!=
NULL
&&
final_cost
!=
infinity
)
(
*
final_costs
)[
tok
]
=
final_cost
;
final_toks
=
next
;
}
if
(
final_relative_cost
!=
NULL
)
{
if
(
best_cost
==
infinity
&&
best_cost_with_final
==
infinity
)
{
// Likely this will only happen if there are no tokens surviving.
// This seems the least bad way to handle it.
*
final_relative_cost
=
infinity
;
}
else
{
*
final_relative_cost
=
best_cost_with_final
-
best_cost
;
}
}
if
(
final_best_cost
!=
NULL
)
{
if
(
best_cost_with_final
!=
infinity
)
{
// final-state exists.
*
final_best_cost
=
best_cost_with_final
;
}
else
{
// no final-state exists.
*
final_best_cost
=
best_cost
;
}
}
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
AdvanceDecoding
(
DecodableInterface
*
decodable
,
int32
max_num_frames
)
{
if
(
std
::
is_same
<
FST
,
fst
::
Fst
<
fst
::
StdArc
>
>::
value
)
{
// if the type 'FST' is the FST base-class, then see if the FST type of fst_
// is actually VectorFst or ConstFst. If so, call the AdvanceDecoding()
// function after casting *this to the more specific type.
if
(
fst_
->
Type
()
==
"const"
)
{
LatticeFasterDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>
,
Token
>
*
this_cast
=
reinterpret_cast
<
LatticeFasterDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>
,
Token
>*
>
(
this
);
this_cast
->
AdvanceDecoding
(
decodable
,
max_num_frames
);
return
;
}
else
if
(
fst_
->
Type
()
==
"vector"
)
{
LatticeFasterDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>
,
Token
>
*
this_cast
=
reinterpret_cast
<
LatticeFasterDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>
,
Token
>*
>
(
this
);
this_cast
->
AdvanceDecoding
(
decodable
,
max_num_frames
);
return
;
}
}
KALDI_ASSERT
(
!
active_toks_
.
empty
()
&&
!
decoding_finalized_
&&
"You must call InitDecoding() before AdvanceDecoding"
);
int32
num_frames_ready
=
decodable
->
NumFramesReady
();
// num_frames_ready must be >= num_frames_decoded, or else
// the number of frames ready must have decreased (which doesn't
// make sense) or the decodable object changed between calls
// (which isn't allowed).
KALDI_ASSERT
(
num_frames_ready
>=
NumFramesDecoded
());
int32
target_frames_decoded
=
num_frames_ready
;
if
(
max_num_frames
>=
0
)
target_frames_decoded
=
std
::
min
(
target_frames_decoded
,
NumFramesDecoded
()
+
max_num_frames
);
while
(
NumFramesDecoded
()
<
target_frames_decoded
)
{
if
(
NumFramesDecoded
()
%
config_
.
prune_interval
==
0
)
{
PruneActiveTokens
(
config_
.
lattice_beam
*
config_
.
prune_scale
);
}
BaseFloat
cost_cutoff
=
ProcessEmitting
(
decodable
);
ProcessNonemitting
(
cost_cutoff
);
}
}
// FinalizeDecoding() is a version of PruneActiveTokens that we call
// (optionally) on the final frame. Takes into account the final-prob of
// tokens. This function used to be called PruneActiveTokensFinal().
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
FinalizeDecoding
()
{
int32
final_frame_plus_one
=
NumFramesDecoded
();
int32
num_toks_begin
=
num_toks_
;
// PruneForwardLinksFinal() prunes final frame (with final-probs), and
// sets decoding_finalized_.
PruneForwardLinksFinal
();
for
(
int32
f
=
final_frame_plus_one
-
1
;
f
>=
0
;
f
--
)
{
bool
b1
,
b2
;
// values not used.
BaseFloat
dontcare
=
0.0
;
// delta of zero means we must always update
PruneForwardLinks
(
f
,
&
b1
,
&
b2
,
dontcare
);
PruneTokensForFrame
(
f
+
1
);
}
PruneTokensForFrame
(
0
);
KALDI_VLOG
(
4
)
<<
"pruned tokens from "
<<
num_toks_begin
<<
" to "
<<
num_toks_
;
}
/// Gets the weight cutoff. Also counts the active tokens.
template
<
typename
FST
,
typename
Token
>
BaseFloat
LatticeFasterDecoderTpl
<
FST
,
Token
>::
GetCutoff
(
Elem
*
list_head
,
size_t
*
tok_count
,
BaseFloat
*
adaptive_beam
,
Elem
**
best_elem
)
{
BaseFloat
best_weight
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// positive == high cost == bad.
size_t
count
=
0
;
if
(
config_
.
max_active
==
std
::
numeric_limits
<
int32
>::
max
()
&&
config_
.
min_active
==
0
)
{
for
(
Elem
*
e
=
list_head
;
e
!=
NULL
;
e
=
e
->
tail
,
count
++
)
{
BaseFloat
w
=
static_cast
<
BaseFloat
>
(
e
->
val
->
tot_cost
);
if
(
w
<
best_weight
)
{
best_weight
=
w
;
if
(
best_elem
)
*
best_elem
=
e
;
}
}
if
(
tok_count
!=
NULL
)
*
tok_count
=
count
;
if
(
adaptive_beam
!=
NULL
)
*
adaptive_beam
=
config_
.
beam
;
return
best_weight
+
config_
.
beam
;
}
else
{
tmp_array_
.
clear
();
for
(
Elem
*
e
=
list_head
;
e
!=
NULL
;
e
=
e
->
tail
,
count
++
)
{
BaseFloat
w
=
e
->
val
->
tot_cost
;
tmp_array_
.
push_back
(
w
);
if
(
w
<
best_weight
)
{
best_weight
=
w
;
if
(
best_elem
)
*
best_elem
=
e
;
}
}
if
(
tok_count
!=
NULL
)
*
tok_count
=
count
;
BaseFloat
beam_cutoff
=
best_weight
+
config_
.
beam
,
min_active_cutoff
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
(),
max_active_cutoff
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
KALDI_VLOG
(
6
)
<<
"Number of tokens active on frame "
<<
NumFramesDecoded
()
<<
" is "
<<
tmp_array_
.
size
();
if
(
tmp_array_
.
size
()
>
static_cast
<
size_t
>
(
config_
.
max_active
))
{
std
::
nth_element
(
tmp_array_
.
begin
(),
tmp_array_
.
begin
()
+
config_
.
max_active
,
tmp_array_
.
end
());
max_active_cutoff
=
tmp_array_
[
config_
.
max_active
];
}
if
(
max_active_cutoff
<
beam_cutoff
)
{
// max_active is tighter than beam.
if
(
adaptive_beam
)
*
adaptive_beam
=
max_active_cutoff
-
best_weight
+
config_
.
beam_delta
;
return
max_active_cutoff
;
}
if
(
tmp_array_
.
size
()
>
static_cast
<
size_t
>
(
config_
.
min_active
))
{
if
(
config_
.
min_active
==
0
)
min_active_cutoff
=
best_weight
;
else
{
std
::
nth_element
(
tmp_array_
.
begin
(),
tmp_array_
.
begin
()
+
config_
.
min_active
,
tmp_array_
.
size
()
>
static_cast
<
size_t
>
(
config_
.
max_active
)
?
tmp_array_
.
begin
()
+
config_
.
max_active
:
tmp_array_
.
end
());
min_active_cutoff
=
tmp_array_
[
config_
.
min_active
];
}
}
if
(
min_active_cutoff
>
beam_cutoff
)
{
// min_active is looser than beam.
if
(
adaptive_beam
)
*
adaptive_beam
=
min_active_cutoff
-
best_weight
+
config_
.
beam_delta
;
return
min_active_cutoff
;
}
else
{
*
adaptive_beam
=
config_
.
beam
;
return
beam_cutoff
;
}
}
}
template
<
typename
FST
,
typename
Token
>
BaseFloat
LatticeFasterDecoderTpl
<
FST
,
Token
>::
ProcessEmitting
(
DecodableInterface
*
decodable
)
{
KALDI_ASSERT
(
active_toks_
.
size
()
>
0
);
int32
frame
=
active_toks_
.
size
()
-
1
;
// frame is the frame-index
// (zero-based) used to get likelihoods
// from the decodable object.
active_toks_
.
resize
(
active_toks_
.
size
()
+
1
);
Elem
*
final_toks
=
toks_
.
Clear
();
// analogous to swapping prev_toks_ / cur_toks_
// in simple-decoder.h. Removes the Elems from
// being indexed in the hash in toks_.
Elem
*
best_elem
=
NULL
;
BaseFloat
adaptive_beam
;
size_t
tok_cnt
;
BaseFloat
cur_cutoff
=
GetCutoff
(
final_toks
,
&
tok_cnt
,
&
adaptive_beam
,
&
best_elem
);
KALDI_VLOG
(
6
)
<<
"Adaptive beam on frame "
<<
NumFramesDecoded
()
<<
" is "
<<
adaptive_beam
;
PossiblyResizeHash
(
tok_cnt
);
// This makes sure the hash is always big enough.
BaseFloat
next_cutoff
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// pruning "online" before having seen all tokens
BaseFloat
cost_offset
=
0.0
;
// Used to keep probabilities in a good
// dynamic range.
// First process the best token to get a hopefully
// reasonably tight bound on the next cutoff. The only
// products of the next block are "next_cutoff" and "cost_offset".
if
(
best_elem
)
{
StateId
state
=
best_elem
->
key
;
Token
*
tok
=
best_elem
->
val
;
cost_offset
=
-
tok
->
tot_cost
;
for
(
fst
::
ArcIterator
<
FST
>
aiter
(
*
fst_
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
)
{
// propagate..
BaseFloat
new_weight
=
arc
.
weight
.
Value
()
+
cost_offset
-
decodable
->
LogLikelihood
(
frame
,
arc
.
ilabel
)
+
tok
->
tot_cost
;
if
(
new_weight
+
adaptive_beam
<
next_cutoff
)
next_cutoff
=
new_weight
+
adaptive_beam
;
}
}
}
// Store the offset on the acoustic likelihoods that we're applying.
// Could just do cost_offsets_.push_back(cost_offset), but we
// do it this way as it's more robust to future code changes.
cost_offsets_
.
resize
(
frame
+
1
,
0.0
);
cost_offsets_
[
frame
]
=
cost_offset
;
// the tokens are now owned here, in final_toks, and the hash is empty.
// 'owned' is a complex thing here; the point is we need to call DeleteElem
// on each elem 'e' to let toks_ know we're done with them.
for
(
Elem
*
e
=
final_toks
,
*
e_tail
;
e
!=
NULL
;
e
=
e_tail
)
{
// loop this way because we delete "e" as we go.
StateId
state
=
e
->
key
;
Token
*
tok
=
e
->
val
;
if
(
tok
->
tot_cost
<=
cur_cutoff
)
{
for
(
fst
::
ArcIterator
<
FST
>
aiter
(
*
fst_
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
)
{
// propagate..
BaseFloat
ac_cost
=
cost_offset
-
decodable
->
LogLikelihood
(
frame
,
arc
.
ilabel
),
graph_cost
=
arc
.
weight
.
Value
(),
cur_cost
=
tok
->
tot_cost
,
tot_cost
=
cur_cost
+
ac_cost
+
graph_cost
;
if
(
tot_cost
>=
next_cutoff
)
continue
;
else
if
(
tot_cost
+
adaptive_beam
<
next_cutoff
)
next_cutoff
=
tot_cost
+
adaptive_beam
;
// prune by best current token
// Note: the frame indexes into active_toks_ are one-based,
// hence the + 1.
Elem
*
e_next
=
FindOrAddToken
(
arc
.
nextstate
,
frame
+
1
,
tot_cost
,
tok
,
NULL
);
// NULL: no change indicator needed
// Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok
->
links
=
new
(
forward_link_pool_
.
Allocate
())
ForwardLinkT
(
e_next
->
val
,
arc
.
ilabel
,
arc
.
olabel
,
graph_cost
,
ac_cost
,
tok
->
links
);
}
}
// for all arcs
}
e_tail
=
e
->
tail
;
toks_
.
Delete
(
e
);
// delete Elem
}
return
next_cutoff
;
}
// static inline
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
DeleteForwardLinks
(
Token
*
tok
)
{
ForwardLinkT
*
l
=
tok
->
links
,
*
m
;
while
(
l
!=
NULL
)
{
m
=
l
->
next
;
forward_link_pool_
.
Free
(
l
);
l
=
m
;
}
tok
->
links
=
NULL
;
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
ProcessNonemitting
(
BaseFloat
cutoff
)
{
KALDI_ASSERT
(
!
active_toks_
.
empty
());
int32
frame
=
static_cast
<
int32
>
(
active_toks_
.
size
())
-
2
;
// Note: "frame" is the time-index we just processed, or -1 if
// we are processing the nonemitting transitions before the
// first frame (called from InitDecoding()).
// Processes nonemitting arcs for one frame. Propagates within toks_.
// Note-- this queue structure is not very optimal as
// it may cause us to process states unnecessarily (e.g. more than once),
// but in the baseline code, turning this vector into a set to fix this
// problem did not improve overall speed.
KALDI_ASSERT
(
queue_
.
empty
());
if
(
toks_
.
GetList
()
==
NULL
)
{
if
(
!
warned_
)
{
KALDI_WARN
<<
"Error, no surviving tokens: frame is "
<<
frame
;
warned_
=
true
;
}
}
for
(
const
Elem
*
e
=
toks_
.
GetList
();
e
!=
NULL
;
e
=
e
->
tail
)
{
StateId
state
=
e
->
key
;
if
(
fst_
->
NumInputEpsilons
(
state
)
!=
0
)
queue_
.
push_back
(
e
);
}
while
(
!
queue_
.
empty
())
{
const
Elem
*
e
=
queue_
.
back
();
queue_
.
pop_back
();
StateId
state
=
e
->
key
;
Token
*
tok
=
e
->
val
;
// would segfault if e is a NULL pointer but this can't happen.
BaseFloat
cur_cost
=
tok
->
tot_cost
;
if
(
cur_cost
>=
cutoff
)
// Don't bother processing successors.
continue
;
// If "tok" has any existing forward links, delete them,
// because we're about to regenerate them. This is a kind
// of non-optimality (remember, this is the simple decoder),
// but since most states are emitting it's not a huge issue.
DeleteForwardLinks
(
tok
);
// necessary when re-visiting
tok
->
links
=
NULL
;
for
(
fst
::
ArcIterator
<
FST
>
aiter
(
*
fst_
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
==
0
)
{
// propagate nonemitting only...
BaseFloat
graph_cost
=
arc
.
weight
.
Value
(),
tot_cost
=
cur_cost
+
graph_cost
;
if
(
tot_cost
<
cutoff
)
{
bool
changed
;
Elem
*
e_new
=
FindOrAddToken
(
arc
.
nextstate
,
frame
+
1
,
tot_cost
,
tok
,
&
changed
);
tok
->
links
=
new
(
forward_link_pool_
.
Allocate
())
ForwardLinkT
(
e_new
->
val
,
0
,
arc
.
olabel
,
graph_cost
,
0
,
tok
->
links
);
// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if
(
changed
&&
fst_
->
NumInputEpsilons
(
arc
.
nextstate
)
!=
0
)
queue_
.
push_back
(
e_new
);
}
}
}
// for all arcs
}
// while queue not empty
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
DeleteElems
(
Elem
*
list
)
{
for
(
Elem
*
e
=
list
,
*
e_tail
;
e
!=
NULL
;
e
=
e_tail
)
{
e_tail
=
e
->
tail
;
toks_
.
Delete
(
e
);
}
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
ClearActiveTokens
()
{
// a cleanup routine, at utt end/begin
for
(
size_t
i
=
0
;
i
<
active_toks_
.
size
();
i
++
)
{
// Delete all tokens alive on this frame, and any forward
// links they may have.
for
(
Token
*
tok
=
active_toks_
[
i
].
toks
;
tok
!=
NULL
;
)
{
DeleteForwardLinks
(
tok
);
Token
*
next_tok
=
tok
->
next
;
token_pool_
.
Free
(
tok
);
num_toks_
--
;
tok
=
next_tok
;
}
}
active_toks_
.
clear
();
KALDI_ASSERT
(
num_toks_
==
0
);
}
// static
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
TopSortTokens
(
Token
*
tok_list
,
std
::
vector
<
Token
*>
*
topsorted_list
)
{
unordered_map
<
Token
*
,
int32
>
token2pos
;
typedef
typename
unordered_map
<
Token
*
,
int32
>::
iterator
IterType
;
int32
num_toks
=
0
;
for
(
Token
*
tok
=
tok_list
;
tok
!=
NULL
;
tok
=
tok
->
next
)
num_toks
++
;
int32
cur_pos
=
0
;
// We assign the tokens numbers num_toks - 1, ... , 2, 1, 0.
// This is likely to be in closer to topological order than
// if we had given them ascending order, because of the way
// new tokens are put at the front of the list.
for
(
Token
*
tok
=
tok_list
;
tok
!=
NULL
;
tok
=
tok
->
next
)
token2pos
[
tok
]
=
num_toks
-
++
cur_pos
;
unordered_set
<
Token
*>
reprocess
;
for
(
IterType
iter
=
token2pos
.
begin
();
iter
!=
token2pos
.
end
();
++
iter
)
{
Token
*
tok
=
iter
->
first
;
int32
pos
=
iter
->
second
;
for
(
ForwardLinkT
*
link
=
tok
->
links
;
link
!=
NULL
;
link
=
link
->
next
)
{
if
(
link
->
ilabel
==
0
)
{
// We only need to consider epsilon links, since non-epsilon links
// transition between frames and this function only needs to sort a list
// of tokens from a single frame.
IterType
following_iter
=
token2pos
.
find
(
link
->
next_tok
);
if
(
following_iter
!=
token2pos
.
end
())
{
// another token on this frame,
// so must consider it.
int32
next_pos
=
following_iter
->
second
;
if
(
next_pos
<
pos
)
{
// reassign the position of the next Token.
following_iter
->
second
=
cur_pos
++
;
reprocess
.
insert
(
link
->
next_tok
);
}
}
}
}
// In case we had previously assigned this token to be reprocessed, we can
// erase it from that set because it's "happy now" (we just processed it).
reprocess
.
erase
(
tok
);
}
size_t
max_loop
=
1000000
,
loop_count
;
// max_loop is to detect epsilon cycles.
for
(
loop_count
=
0
;
!
reprocess
.
empty
()
&&
loop_count
<
max_loop
;
++
loop_count
)
{
std
::
vector
<
Token
*>
reprocess_vec
;
for
(
typename
unordered_set
<
Token
*>::
iterator
iter
=
reprocess
.
begin
();
iter
!=
reprocess
.
end
();
++
iter
)
reprocess_vec
.
push_back
(
*
iter
);
reprocess
.
clear
();
for
(
typename
std
::
vector
<
Token
*>::
iterator
iter
=
reprocess_vec
.
begin
();
iter
!=
reprocess_vec
.
end
();
++
iter
)
{
Token
*
tok
=
*
iter
;
int32
pos
=
token2pos
[
tok
];
// Repeat the processing we did above (for comments, see above).
for
(
ForwardLinkT
*
link
=
tok
->
links
;
link
!=
NULL
;
link
=
link
->
next
)
{
if
(
link
->
ilabel
==
0
)
{
IterType
following_iter
=
token2pos
.
find
(
link
->
next_tok
);
if
(
following_iter
!=
token2pos
.
end
())
{
int32
next_pos
=
following_iter
->
second
;
if
(
next_pos
<
pos
)
{
following_iter
->
second
=
cur_pos
++
;
reprocess
.
insert
(
link
->
next_tok
);
}
}
}
}
}
}
KALDI_ASSERT
(
loop_count
<
max_loop
&&
"Epsilon loops exist in your decoding "
"graph (this is not allowed!)"
);
topsorted_list
->
clear
();
topsorted_list
->
resize
(
cur_pos
,
NULL
);
// create a list with NULLs in between.
for
(
IterType
iter
=
token2pos
.
begin
();
iter
!=
token2pos
.
end
();
++
iter
)
(
*
topsorted_list
)[
iter
->
second
]
=
iter
->
first
;
}
// Instantiate the template for the combination of token types and FST types
// that we'll need.
template
class
LatticeFasterDecoderTpl
<
fst
::
Fst
<
fst
::
StdArc
>,
decoder
::
StdToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>,
decoder
::
StdToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>,
decoder
::
StdToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
ConstGrammarFst
,
decoder
::
StdToken
>;
template
class
LatticeFasterDecoderTpl
<
fst
::
VectorGrammarFst
,
decoder
::
StdToken
>;
template
class
LatticeFasterDecoderTpl
<
fst
::
Fst
<
fst
::
StdArc
>
,
decoder
::
BackpointerToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>,
decoder
::
BackpointerToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>,
decoder
::
BackpointerToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
ConstGrammarFst
,
decoder
::
BackpointerToken
>;
template
class
LatticeFasterDecoderTpl
<
fst
::
VectorGrammarFst
,
decoder
::
BackpointerToken
>;
}
// end namespace kaldi.
speechx/speechx/kaldi/decoder/lattice-faster-decoder.h
0 → 100644
浏览文件 @
d14ee800
// decoder/lattice-faster-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#include "decoder/grammar-fst.h"
#include "fst/fstlib.h"
#include "fst/memory.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "util/hash-list.h"
#include "util/stl-utils.h"
namespace
kaldi
{
struct
LatticeFasterDecoderConfig
{
BaseFloat
beam
;
int32
max_active
;
int32
min_active
;
BaseFloat
lattice_beam
;
int32
prune_interval
;
bool
determinize_lattice
;
// not inspected by this class... used in
// command-line program.
BaseFloat
beam_delta
;
BaseFloat
hash_ratio
;
// Note: we don't make prune_scale configurable on the command line, it's not
// a very important parameter. It affects the algorithm that prunes the
// tokens as we go.
BaseFloat
prune_scale
;
// Number of elements in the block for Token and ForwardLink memory
// pool allocation.
int32
memory_pool_tokens_block_size
;
int32
memory_pool_links_block_size
;
// Most of the options inside det_opts are not actually queried by the
// LatticeFasterDecoder class itself, but by the code that calls it, for
// example in the function DecodeUtteranceLatticeFaster.
fst
::
DeterminizeLatticePhonePrunedOptions
det_opts
;
LatticeFasterDecoderConfig
()
:
beam
(
16.0
),
max_active
(
std
::
numeric_limits
<
int32
>::
max
()),
min_active
(
200
),
lattice_beam
(
10.0
),
prune_interval
(
25
),
determinize_lattice
(
true
),
beam_delta
(
0.5
),
hash_ratio
(
2.0
),
prune_scale
(
0.1
),
memory_pool_tokens_block_size
(
1
<<
8
),
memory_pool_links_block_size
(
1
<<
8
)
{}
void
Register
(
OptionsItf
*
opts
)
{
det_opts
.
Register
(
opts
);
opts
->
Register
(
"beam"
,
&
beam
,
"Decoding beam. Larger->slower, more accurate."
);
opts
->
Register
(
"max-active"
,
&
max_active
,
"Decoder max active states. Larger->slower; "
"more accurate"
);
opts
->
Register
(
"min-active"
,
&
min_active
,
"Decoder minimum #active states."
);
opts
->
Register
(
"lattice-beam"
,
&
lattice_beam
,
"Lattice generation beam. Larger->slower, "
"and deeper lattices"
);
opts
->
Register
(
"prune-interval"
,
&
prune_interval
,
"Interval (in frames) at "
"which to prune tokens"
);
opts
->
Register
(
"determinize-lattice"
,
&
determinize_lattice
,
"If true, "
"determinize the lattice (lattice-determinization, keeping only "
"best pdf-sequence for each word-sequence)."
);
opts
->
Register
(
"beam-delta"
,
&
beam_delta
,
"Increment used in decoding-- this "
"parameter is obscure and relates to a speedup in the way the "
"max-active constraint is applied. Larger is more accurate."
);
opts
->
Register
(
"hash-ratio"
,
&
hash_ratio
,
"Setting used in decoder to "
"control hash behavior"
);
opts
->
Register
(
"memory-pool-tokens-block-size"
,
&
memory_pool_tokens_block_size
,
"Memory pool block size suggestion for storing tokens (in elements). "
"Smaller uses less memory but increases cache misses."
);
opts
->
Register
(
"memory-pool-links-block-size"
,
&
memory_pool_links_block_size
,
"Memory pool block size suggestion for storing links (in elements). "
"Smaller uses less memory but increases cache misses."
);
}
void
Check
()
const
{
KALDI_ASSERT
(
beam
>
0.0
&&
max_active
>
1
&&
lattice_beam
>
0.0
&&
min_active
<=
max_active
&&
prune_interval
>
0
&&
beam_delta
>
0.0
&&
hash_ratio
>=
1.0
&&
prune_scale
>
0.0
&&
prune_scale
<
1.0
);
}
};
namespace
decoder
{
// We will template the decoder on the token type as well as the FST type; this
// is a mechanism so that we can use the same underlying decoder code for
// versions of the decoder that support quickly getting the best path
// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
// those that do not (LatticeFasterDecoder).
// ForwardLinks are the links from a token to a token on the next frame.
// or sometimes on the current frame (for input-epsilon links).
template
<
typename
Token
>
struct
ForwardLink
{
using
Label
=
fst
::
StdArc
::
Label
;
Token
*
next_tok
;
// the next token [or NULL if represents final-state]
Label
ilabel
;
// ilabel on arc
Label
olabel
;
// olabel on arc
BaseFloat
graph_cost
;
// graph cost of traversing arc (contains LM, etc.)
BaseFloat
acoustic_cost
;
// acoustic cost (pre-scaled) of traversing arc
ForwardLink
*
next
;
// next in singly-linked list of forward arcs (arcs
// in the state-level lattice) from a token.
inline
ForwardLink
(
Token
*
next_tok
,
Label
ilabel
,
Label
olabel
,
BaseFloat
graph_cost
,
BaseFloat
acoustic_cost
,
ForwardLink
*
next
)
:
next_tok
(
next_tok
),
ilabel
(
ilabel
),
olabel
(
olabel
),
graph_cost
(
graph_cost
),
acoustic_cost
(
acoustic_cost
),
next
(
next
)
{
}
};
struct
StdToken
{
using
ForwardLinkT
=
ForwardLink
<
StdToken
>
;
using
Token
=
StdToken
;
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat
tot_cost
;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals the
// minimum difference between the cost of the best path that this link is a
// part of, and the cost of the absolute best path, under the assumption that
// any of the currently active states at the decoding front may eventually
// succeed (e.g. if you were to take the currently active states one by one
// and compute this difference, and then take the minimum).
BaseFloat
extra_cost
;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT
*
links
;
//'next' is the next in the singly-linked list of tokens for this frame.
Token
*
next
;
// This function does nothing and should be optimized out; it's needed
// so we can share the regular LatticeFasterDecoderTpl code and the code
// for LatticeFasterOnlineDecoder that supports fast traceback.
inline
void
SetBackpointer
(
Token
*
backpointer
)
{
}
// This constructor just ignores the 'backpointer' argument. That argument is
// needed so that we can use the same decoder code for LatticeFasterDecoderTpl
// and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
// fast way to obtain the best path).
inline
StdToken
(
BaseFloat
tot_cost
,
BaseFloat
extra_cost
,
ForwardLinkT
*
links
,
Token
*
next
,
Token
*
backpointer
)
:
tot_cost
(
tot_cost
),
extra_cost
(
extra_cost
),
links
(
links
),
next
(
next
)
{
}
};
struct
BackpointerToken
{
using
ForwardLinkT
=
ForwardLink
<
BackpointerToken
>
;
using
Token
=
BackpointerToken
;
// BackpointerToken is like Token but also
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat
tot_cost
;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals
// the minimum difference between the cost of the best path, and the cost of
// this is on, and the cost of the absolute best path, under the assumption
// that any of the currently active states at the decoding front may
// eventually succeed (e.g. if you were to take the currently active states
// one by one and compute this difference, and then take the minimum).
BaseFloat
extra_cost
;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT
*
links
;
//'next' is the next in the singly-linked list of tokens for this frame.
BackpointerToken
*
next
;
// Best preceding BackpointerToken (could be a on this frame, connected to
// this via an epsilon transition, or on a previous frame). This is only
// required for an efficient GetBestPath function in
// LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
// (the "links" list is what stores the forward links, for that).
Token
*
backpointer
;
inline
void
SetBackpointer
(
Token
*
backpointer
)
{
this
->
backpointer
=
backpointer
;
}
inline
BackpointerToken
(
BaseFloat
tot_cost
,
BaseFloat
extra_cost
,
ForwardLinkT
*
links
,
Token
*
next
,
Token
*
backpointer
)
:
tot_cost
(
tot_cost
),
extra_cost
(
extra_cost
),
links
(
links
),
next
(
next
),
backpointer
(
backpointer
)
{
}
};
}
// namespace decoder
/** This is the "normal" lattice-generating decoder.
See \ref lattices_generation \ref decoders_faster and \ref decoders_simple
for more information.
The decoder is templated on the FST type and the token type. The token type
will normally be StdToken, but also may be BackpointerToken which is to support
quick lookup of the current best path (see lattice-faster-online-decoder.h)
The FST you invoke this decoder which is expected to equal
Fst::Fst<fst::StdArc>, a.k.a. StdFst, or GrammarFst. If you invoke it with
FST == StdFst and it notices that the actual FST type is
fst::VectorFst<fst::StdArc> or fst::ConstFst<fst::StdArc>, the decoder object
will internally cast itself to one that is templated on those more specific
types; this is an optimization for speed.
*/
template
<
typename
FST
,
typename
Token
=
decoder
::
StdToken
>
class
LatticeFasterDecoderTpl
{
public:
using
Arc
=
typename
FST
::
Arc
;
using
Label
=
typename
Arc
::
Label
;
using
StateId
=
typename
Arc
::
StateId
;
using
Weight
=
typename
Arc
::
Weight
;
using
ForwardLinkT
=
decoder
::
ForwardLink
<
Token
>
;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterDecoderTpl
(
const
FST
&
fst
,
const
LatticeFasterDecoderConfig
&
config
);
// This version of the constructor takes ownership of the fst, and will delete
// it when this object is destroyed.
LatticeFasterDecoderTpl
(
const
LatticeFasterDecoderConfig
&
config
,
FST
*
fst
);
void
SetOptions
(
const
LatticeFasterDecoderConfig
&
config
)
{
config_
=
config
;
}
const
LatticeFasterDecoderConfig
&
GetOptions
()
const
{
return
config_
;
}
~
LatticeFasterDecoderTpl
();
/// Decodes until there are no more frames left in the "decodable" object..
/// note, this may block waiting for input if the "decodable" object blocks.
/// Returns true if any kind of traceback is available (not necessarily from a
/// final state).
bool
Decode
(
DecodableInterface
*
decodable
);
/// says whether a final-state was active on the last frame. If it was not, the
/// lattice (or traceback) will end with states that are not final-states.
bool
ReachedFinal
()
const
{
return
FinalRelativeCost
()
!=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
}
/// Outputs an FST corresponding to the single best path through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one. Note: this just calls GetRawLattice()
/// and figures out the shortest path.
bool
GetBestPath
(
Lattice
*
ofst
,
bool
use_final_probs
=
true
)
const
;
/// Outputs an FST corresponding to the raw, state-level
/// tracebacks. Returns true if result is nonempty.
/// If "use_final_probs" is true AND we reached the final-state
/// of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
/// The raw lattice will be topologically sorted.
///
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
/// which also supports a pruning beam, in case for some reason
/// you want it pruned tighter than the regular lattice beam.
/// We could put that here in future needed.
bool
GetRawLattice
(
Lattice
*
ofst
,
bool
use_final_probs
=
true
)
const
;
/// [Deprecated, users should now use GetRawLattice and determinize it
/// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper].
/// Outputs an FST corresponding to the lattice-determinized
/// lattice (one path per word sequence). Returns true if result is nonempty.
/// If "use_final_probs" is true AND we reached the final-state of the graph
/// then it will include those as final-probs, else it will treat all
/// final-probs as one.
bool
GetLattice
(
CompactLattice
*
ofst
,
bool
use_final_probs
=
true
)
const
;
/// InitDecoding initializes the decoding, and should only be used if you
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
/// call this. You can also call InitDecoding if you have already decoded an
/// utterance and want to start with a new utterance.
void
InitDecoding
();
/// This will decode until there are no more frames ready in the decodable
/// object. You can keep calling it each time more frames become available.
/// If max_num_frames is specified, it specifies the maximum number of frames
/// the function will decode before returning.
void
AdvanceDecoding
(
DecodableInterface
*
decodable
,
int32
max_num_frames
=
-
1
);
/// This function may be optionally called after AdvanceDecoding(), when you
/// do not plan to decode any further. It does an extra pruning step that
/// will help to prune the lattices output by GetLattice and (particularly)
/// GetRawLattice more completely, particularly toward the end of the
/// utterance. If you call this, you cannot call AdvanceDecoding again (it
/// will fail), and you cannot call GetLattice() and related functions with
/// use_final_probs = false. Used to be called PruneActiveTokensFinal().
void
FinalizeDecoding
();
/// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
/// more information. It returns the difference between the best (final-cost
/// plus cost) of any token on the final frame, and the best cost of any token
/// on the final frame. If it is infinity it means no final-states were
/// present on the final frame. It will usually be nonnegative. If it not
/// too positive (e.g. < 5 is my first guess, but this is not tested) you can
/// take it as a good indication that we reached the final-state with
/// reasonable likelihood.
BaseFloat
FinalRelativeCost
()
const
;
// Returns the number of frames decoded so far. The value returned changes
// whenever we call ProcessEmitting().
inline
int32
NumFramesDecoded
()
const
{
return
active_toks_
.
size
()
-
1
;
}
protected:
// we make things protected instead of private, as code in
// LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
// internals.
// Deletes the elements of the singly linked list tok->links.
void
DeleteForwardLinks
(
Token
*
tok
);
// head of per-frame list of Tokens (list is in topological order),
// and something saying whether we ever pruned it using PruneForwardLinks.
struct
TokenList
{
Token
*
toks
;
bool
must_prune_forward_links
;
bool
must_prune_tokens
;
TokenList
()
:
toks
(
NULL
),
must_prune_forward_links
(
true
),
must_prune_tokens
(
true
)
{
}
};
using
Elem
=
typename
HashList
<
StateId
,
Token
*>::
Elem
;
// Equivalent to:
// struct Elem {
// StateId key;
// Token *val;
// Elem *tail;
// };
void
PossiblyResizeHash
(
size_t
num_toks
);
// FindOrAddToken either locates a token in hash of toks_, or if necessary
// inserts a new, empty token (i.e. with no forward links) for the current
// frame. [note: it's inserted if necessary into hash toks_ and also into the
// singly linked list of tokens active on this frame (whose head is at
// active_toks_[frame]). The frame_plus_one argument is the acoustic frame
// index plus one, which is used to index into the active_toks_ array.
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
// token was newly created or the cost changed.
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
// hopefully be optimized out).
inline
Elem
*
FindOrAddToken
(
StateId
state
,
int32
frame_plus_one
,
BaseFloat
tot_cost
,
Token
*
backpointer
,
bool
*
changed
);
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
// delta is the amount by which the extra_costs must change
// before we set *extra_costs_changed = true.
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
void
PruneForwardLinks
(
int32
frame_plus_one
,
bool
*
extra_costs_changed
,
bool
*
links_pruned
,
BaseFloat
delta
);
// This function computes the final-costs for tokens active on the final
// frame. It outputs to final-costs, if non-NULL, a map from the Token*
// pointer to the final-prob of the corresponding state, for all Tokens
// that correspond to states that have final-probs. This map will be
// empty if there were no final-probs. It outputs to
// final_relative_cost, if non-NULL, the difference between the best
// forward-cost including the final-prob cost, and the best forward-cost
// without including the final-prob cost (this will usually be positive), or
// infinity if there were no final-probs. [c.f. FinalRelativeCost(), which
// outputs this quanitity]. It outputs to final_best_cost, if
// non-NULL, the lowest for any token t active on the final frame, of
// forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in
// the graph of the state corresponding to token t, or the best of
// forward-cost[t] if there were no final-probs active on the final frame.
// You cannot call this after FinalizeDecoding() has been called; in that
// case you should get the answer from class-member variables.
void
ComputeFinalCosts
(
unordered_map
<
Token
*
,
BaseFloat
>
*
final_costs
,
BaseFloat
*
final_relative_cost
,
BaseFloat
*
final_best_cost
)
const
;
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
void
PruneForwardLinksFinal
();
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
void
PruneTokensForFrame
(
int32
frame_plus_one
);
// Go backwards through still-alive tokens, pruning them if the
// forward+backward cost is more than lat_beam away from the best path. It's
// possible to prove that this is "correct" in the sense that we won't lose
// anything outside of lat_beam, regardless of what happens in the future.
// delta controls when it considers a cost to have changed enough to continue
// going backward and propagating the change. larger delta -> will recurse
// less far.
void
PruneActiveTokens
(
BaseFloat
delta
);
/// Gets the weight cutoff. Also counts the active tokens.
BaseFloat
GetCutoff
(
Elem
*
list_head
,
size_t
*
tok_count
,
BaseFloat
*
adaptive_beam
,
Elem
**
best_elem
);
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to
/// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to
/// use.
BaseFloat
ProcessEmitting
(
DecodableInterface
*
decodable
);
/// Processes nonemitting (epsilon) arcs for one frame. Called after
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
/// preceding ProcessEmitting().
void
ProcessNonemitting
(
BaseFloat
cost_cutoff
);
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
// more than one list (e.g. for current and previous frames), but only one of
// them at a time can be indexed by StateId. It is indexed by frame-index
// plus one, where the frame-index is zero-based, as used in decodable object.
// That is, the emitting probs of frame t are accounted for in tokens at
// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
// the graph.
HashList
<
StateId
,
Token
*>
toks_
;
std
::
vector
<
TokenList
>
active_toks_
;
// Lists of tokens, indexed by
// frame (members of TokenList are toks, must_prune_forward_links,
// must_prune_tokens).
std
::
vector
<
const
Elem
*
>
queue_
;
// temp variable used in ProcessNonemitting,
std
::
vector
<
BaseFloat
>
tmp_array_
;
// used in GetCutoff.
// fst_ is a pointer to the FST we are decoding from.
const
FST
*
fst_
;
// delete_fst_ is true if the pointer fst_ needs to be deleted when this
// object is destroyed.
bool
delete_fst_
;
std
::
vector
<
BaseFloat
>
cost_offsets_
;
// This contains, for each
// frame, an offset that was added to the acoustic log-likelihoods on that
// frame in order to keep everything in a nice dynamic range i.e. close to
// zero, to reduce roundoff errors.
LatticeFasterDecoderConfig
config_
;
int32
num_toks_
;
// current total #toks allocated...
bool
warned_
;
/// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
/// calling this is optional]. If true, it's forbidden to decode more. Also,
/// if this is set, then the output of ComputeFinalCosts() is in the next
/// three variables. The reason we need to do this is that after
/// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
/// of the tokens on the last frame are freed, so we free the list from toks_
/// to avoid having dangling pointers hanging around.
bool
decoding_finalized_
;
/// For the meaning of the next 3 variables, see the comment for
/// decoding_finalized_ above., and ComputeFinalCosts().
unordered_map
<
Token
*
,
BaseFloat
>
final_costs_
;
BaseFloat
final_relative_cost_
;
BaseFloat
final_best_cost_
;
// Memory pools for storing tokens and forward links.
// We use it to decrease the work put on allocator and to move some of data
// together. Too small block sizes will result in more work to allocator but
// bigger ones increase the memory usage.
fst
::
MemoryPool
<
Token
>
token_pool_
;
fst
::
MemoryPool
<
ForwardLinkT
>
forward_link_pool_
;
// There are various cleanup tasks... the toks_ structure contains
// singly linked lists of Token pointers, where Elem is the list type.
// It also indexes them in a hash, indexed by state (this hash is only
// maintained for the most recent frame). toks_.Clear()
// deletes them from the hash and returns the list of Elems. The
// function DeleteElems calls toks_.Delete(elem) for each elem in
// the list, which returns ownership of the Elem to the toks_ structure
// for reuse, but does not delete the Token pointer. The Token pointers
// are reference-counted and are ultimately deleted in PruneTokensForFrame,
// but are also linked together on each frame by their own linked-list,
// using the "next" pointer. We delete them manually.
void
DeleteElems
(
Elem
*
list
);
// This function takes a singly linked list of tokens for a single frame, and
// outputs a list of them in topological order (it will crash if no such order
// can be found, which will typically be due to decoding graphs with epsilon
// cycles, which are not allowed). Note: the output list may contain NULLs,
// which the caller should pass over; it just happens to be more efficient for
// the algorithm to output a list that contains NULLs.
static
void
TopSortTokens
(
Token
*
tok_list
,
std
::
vector
<
Token
*>
*
topsorted_list
);
void
ClearActiveTokens
();
KALDI_DISALLOW_COPY_AND_ASSIGN
(
LatticeFasterDecoderTpl
);
};
typedef
LatticeFasterDecoderTpl
<
fst
::
StdFst
,
decoder
::
StdToken
>
LatticeFasterDecoder
;
}
// end namespace kaldi.
#endif
speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc
0 → 100644
浏览文件 @
d14ee800
// decoder/lattice-faster-online-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
// file in sync with lattice-faster-decoder.cc
#include "decoder/lattice-faster-online-decoder.h"
#include "lat/lattice-functions.h"
namespace
kaldi
{
template
<
typename
FST
>
bool
LatticeFasterOnlineDecoderTpl
<
FST
>::
TestGetBestPath
(
bool
use_final_probs
)
const
{
Lattice
lat1
;
{
Lattice
raw_lat
;
this
->
GetRawLattice
(
&
raw_lat
,
use_final_probs
);
ShortestPath
(
raw_lat
,
&
lat1
);
}
Lattice
lat2
;
GetBestPath
(
&
lat2
,
use_final_probs
);
BaseFloat
delta
=
0.1
;
int32
num_paths
=
1
;
if
(
!
fst
::
RandEquivalent
(
lat1
,
lat2
,
num_paths
,
delta
,
rand
()))
{
KALDI_WARN
<<
"Best-path test failed"
;
return
false
;
}
else
{
return
true
;
}
}
// Outputs an FST corresponding to the single best path through the lattice.
template
<
typename
FST
>
bool
LatticeFasterOnlineDecoderTpl
<
FST
>::
GetBestPath
(
Lattice
*
olat
,
bool
use_final_probs
)
const
{
olat
->
DeleteStates
();
BaseFloat
final_graph_cost
;
BestPathIterator
iter
=
BestPathEnd
(
use_final_probs
,
&
final_graph_cost
);
if
(
iter
.
Done
())
return
false
;
// would have printed warning.
StateId
state
=
olat
->
AddState
();
olat
->
SetFinal
(
state
,
LatticeWeight
(
final_graph_cost
,
0.0
));
while
(
!
iter
.
Done
())
{
LatticeArc
arc
;
iter
=
TraceBackBestPath
(
iter
,
&
arc
);
arc
.
nextstate
=
state
;
StateId
new_state
=
olat
->
AddState
();
olat
->
AddArc
(
new_state
,
arc
);
state
=
new_state
;
}
olat
->
SetStart
(
state
);
return
true
;
}
template
<
typename
FST
>
typename
LatticeFasterOnlineDecoderTpl
<
FST
>::
BestPathIterator
LatticeFasterOnlineDecoderTpl
<
FST
>::
BestPathEnd
(
bool
use_final_probs
,
BaseFloat
*
final_cost_out
)
const
{
if
(
this
->
decoding_finalized_
&&
!
use_final_probs
)
KALDI_ERR
<<
"You cannot call FinalizeDecoding() and then call "
<<
"BestPathEnd() with use_final_probs == false"
;
KALDI_ASSERT
(
this
->
NumFramesDecoded
()
>
0
&&
"You cannot call BestPathEnd if no frames were decoded."
);
unordered_map
<
Token
*
,
BaseFloat
>
final_costs_local
;
const
unordered_map
<
Token
*
,
BaseFloat
>
&
final_costs
=
(
this
->
decoding_finalized_
?
this
->
final_costs_
:
final_costs_local
);
if
(
!
this
->
decoding_finalized_
&&
use_final_probs
)
this
->
ComputeFinalCosts
(
&
final_costs_local
,
NULL
,
NULL
);
// Singly linked list of tokens on last frame (access list through "next"
// pointer).
BaseFloat
best_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
BaseFloat
best_final_cost
=
0
;
Token
*
best_tok
=
NULL
;
for
(
Token
*
tok
=
this
->
active_toks_
.
back
().
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
BaseFloat
cost
=
tok
->
tot_cost
,
final_cost
=
0.0
;
if
(
use_final_probs
&&
!
final_costs
.
empty
())
{
// if we are instructed to use final-probs, and any final tokens were
// active on final frame, include the final-prob in the cost of the token.
typename
unordered_map
<
Token
*
,
BaseFloat
>::
const_iterator
iter
=
final_costs
.
find
(
tok
);
if
(
iter
!=
final_costs
.
end
())
{
final_cost
=
iter
->
second
;
cost
+=
final_cost
;
}
else
{
cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
}
}
if
(
cost
<
best_cost
)
{
best_cost
=
cost
;
best_tok
=
tok
;
best_final_cost
=
final_cost
;
}
}
if
(
best_tok
==
NULL
)
{
// this should not happen, and is likely a code error or
// caused by infinities in likelihoods, but I'm not making
// it a fatal error for now.
KALDI_WARN
<<
"No final token found."
;
}
if
(
final_cost_out
)
*
final_cost_out
=
best_final_cost
;
return
BestPathIterator
(
best_tok
,
this
->
NumFramesDecoded
()
-
1
);
}
template
<
typename
FST
>
typename
LatticeFasterOnlineDecoderTpl
<
FST
>::
BestPathIterator
LatticeFasterOnlineDecoderTpl
<
FST
>::
TraceBackBestPath
(
BestPathIterator
iter
,
LatticeArc
*
oarc
)
const
{
KALDI_ASSERT
(
!
iter
.
Done
()
&&
oarc
!=
NULL
);
Token
*
tok
=
static_cast
<
Token
*>
(
iter
.
tok
);
int32
cur_t
=
iter
.
frame
,
step_t
=
0
;
if
(
tok
->
backpointer
!=
NULL
)
{
// retrieve the correct forward link(with the best link cost)
BaseFloat
best_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
ForwardLinkT
*
link
;
for
(
link
=
tok
->
backpointer
->
links
;
link
!=
NULL
;
link
=
link
->
next
)
{
if
(
link
->
next_tok
==
tok
)
{
// this is a link to "tok"
BaseFloat
graph_cost
=
link
->
graph_cost
,
acoustic_cost
=
link
->
acoustic_cost
;
BaseFloat
cost
=
graph_cost
+
acoustic_cost
;
if
(
cost
<
best_cost
)
{
oarc
->
ilabel
=
link
->
ilabel
;
oarc
->
olabel
=
link
->
olabel
;
if
(
link
->
ilabel
!=
0
)
{
KALDI_ASSERT
(
static_cast
<
size_t
>
(
cur_t
)
<
this
->
cost_offsets_
.
size
());
acoustic_cost
-=
this
->
cost_offsets_
[
cur_t
];
step_t
=
-
1
;
}
else
{
step_t
=
0
;
}
oarc
->
weight
=
LatticeWeight
(
graph_cost
,
acoustic_cost
);
best_cost
=
cost
;
}
}
}
if
(
link
==
NULL
&&
best_cost
==
std
::
numeric_limits
<
BaseFloat
>::
infinity
())
{
// Did not find correct link.
KALDI_ERR
<<
"Error tracing best-path back (likely "
<<
"bug in token-pruning algorithm)"
;
}
}
else
{
oarc
->
ilabel
=
0
;
oarc
->
olabel
=
0
;
oarc
->
weight
=
LatticeWeight
::
One
();
// zero costs.
}
return
BestPathIterator
(
tok
->
backpointer
,
cur_t
+
step_t
);
}
template
<
typename
FST
>
bool
LatticeFasterOnlineDecoderTpl
<
FST
>::
GetRawLatticePruned
(
Lattice
*
ofst
,
bool
use_final_probs
,
BaseFloat
beam
)
const
{
typedef
LatticeArc
Arc
;
typedef
Arc
::
StateId
StateId
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
Label
Label
;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if
(
this
->
decoding_finalized_
&&
!
use_final_probs
)
KALDI_ERR
<<
"You cannot call FinalizeDecoding() and then call "
<<
"GetRawLattice() with use_final_probs == false"
;
unordered_map
<
Token
*
,
BaseFloat
>
final_costs_local
;
const
unordered_map
<
Token
*
,
BaseFloat
>
&
final_costs
=
(
this
->
decoding_finalized_
?
this
->
final_costs_
:
final_costs_local
);
if
(
!
this
->
decoding_finalized_
&&
use_final_probs
)
this
->
ComputeFinalCosts
(
&
final_costs_local
,
NULL
,
NULL
);
ofst
->
DeleteStates
();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32
num_frames
=
this
->
active_toks_
.
size
()
-
1
;
KALDI_ASSERT
(
num_frames
>
0
);
for
(
int32
f
=
0
;
f
<=
num_frames
;
f
++
)
{
if
(
this
->
active_toks_
[
f
].
toks
==
NULL
)
{
KALDI_WARN
<<
"No tokens active on frame "
<<
f
<<
": not producing lattice.
\n
"
;
return
false
;
}
}
unordered_map
<
Token
*
,
StateId
>
tok_map
;
std
::
queue
<
std
::
pair
<
Token
*
,
int32
>
>
tok_queue
;
// First initialize the queue and states. Put the initial state on the queue;
// this is the last token in the list active_toks_[0].toks.
for
(
Token
*
tok
=
this
->
active_toks_
[
0
].
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
if
(
tok
->
next
==
NULL
)
{
tok_map
[
tok
]
=
ofst
->
AddState
();
ofst
->
SetStart
(
tok_map
[
tok
]);
std
::
pair
<
Token
*
,
int32
>
tok_pair
(
tok
,
0
);
// #frame = 0
tok_queue
.
push
(
tok_pair
);
}
}
// Next create states for "good" tokens
while
(
!
tok_queue
.
empty
())
{
std
::
pair
<
Token
*
,
int32
>
cur_tok_pair
=
tok_queue
.
front
();
tok_queue
.
pop
();
Token
*
cur_tok
=
cur_tok_pair
.
first
;
int32
cur_frame
=
cur_tok_pair
.
second
;
KALDI_ASSERT
(
cur_frame
>=
0
&&
cur_frame
<=
this
->
cost_offsets_
.
size
());
typename
unordered_map
<
Token
*
,
StateId
>::
const_iterator
iter
=
tok_map
.
find
(
cur_tok
);
KALDI_ASSERT
(
iter
!=
tok_map
.
end
());
StateId
cur_state
=
iter
->
second
;
for
(
ForwardLinkT
*
l
=
cur_tok
->
links
;
l
!=
NULL
;
l
=
l
->
next
)
{
Token
*
next_tok
=
l
->
next_tok
;
if
(
next_tok
->
extra_cost
<
beam
)
{
// so both the current and the next token are good; create the arc
int32
next_frame
=
l
->
ilabel
==
0
?
cur_frame
:
cur_frame
+
1
;
StateId
nextstate
;
if
(
tok_map
.
find
(
next_tok
)
==
tok_map
.
end
())
{
nextstate
=
tok_map
[
next_tok
]
=
ofst
->
AddState
();
tok_queue
.
push
(
std
::
pair
<
Token
*
,
int32
>
(
next_tok
,
next_frame
));
}
else
{
nextstate
=
tok_map
[
next_tok
];
}
BaseFloat
cost_offset
=
(
l
->
ilabel
!=
0
?
this
->
cost_offsets_
[
cur_frame
]
:
0
);
Arc
arc
(
l
->
ilabel
,
l
->
olabel
,
Weight
(
l
->
graph_cost
,
l
->
acoustic_cost
-
cost_offset
),
nextstate
);
ofst
->
AddArc
(
cur_state
,
arc
);
}
}
if
(
cur_frame
==
num_frames
)
{
if
(
use_final_probs
&&
!
final_costs
.
empty
())
{
typename
unordered_map
<
Token
*
,
BaseFloat
>::
const_iterator
iter
=
final_costs
.
find
(
cur_tok
);
if
(
iter
!=
final_costs
.
end
())
ofst
->
SetFinal
(
cur_state
,
LatticeWeight
(
iter
->
second
,
0
));
}
else
{
ofst
->
SetFinal
(
cur_state
,
LatticeWeight
::
One
());
}
}
}
return
(
ofst
->
NumStates
()
!=
0
);
}
// Instantiate the template for the FST types that we'll need.
template
class
LatticeFasterOnlineDecoderTpl
<
fst
::
Fst
<
fst
::
StdArc
>
>
;
template
class
LatticeFasterOnlineDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>
>
;
template
class
LatticeFasterOnlineDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>
>
;
template
class
LatticeFasterOnlineDecoderTpl
<
fst
::
ConstGrammarFst
>;
template
class
LatticeFasterOnlineDecoderTpl
<
fst
::
VectorGrammarFst
>;
}
// end namespace kaldi.
speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h
0 → 100644
浏览文件 @
d14ee800
// decoder/lattice-faster-online-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.h, about how to maintain this
// file in sync with lattice-faster-decoder.h
#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#include "util/stl-utils.h"
#include "util/hash-list.h"
#include "fst/fstlib.h"
#include "itf/decodable-itf.h"
#include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "decoder/lattice-faster-decoder.h"
namespace
kaldi
{
/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also
supports an efficient way to get the best path (see the function
BestPathEnd()), which is useful in endpointing and in situations where you
might want to frequently access the best path.
This is only templated on the FST type, since the Token type is required to
be BackpointerToken. Actually it only makes sense to instantiate
LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via
this child class.
*/
template
<
typename
FST
>
class
LatticeFasterOnlineDecoderTpl
:
public
LatticeFasterDecoderTpl
<
FST
,
decoder
::
BackpointerToken
>
{
public:
using
Arc
=
typename
FST
::
Arc
;
using
Label
=
typename
Arc
::
Label
;
using
StateId
=
typename
Arc
::
StateId
;
using
Weight
=
typename
Arc
::
Weight
;
using
Token
=
decoder
::
BackpointerToken
;
using
ForwardLinkT
=
decoder
::
ForwardLink
<
Token
>
;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterOnlineDecoderTpl
(
const
FST
&
fst
,
const
LatticeFasterDecoderConfig
&
config
)
:
LatticeFasterDecoderTpl
<
FST
,
Token
>
(
fst
,
config
)
{
}
// This version of the initializer takes ownership of 'fst', and will delete
// it when this object is destroyed.
LatticeFasterOnlineDecoderTpl
(
const
LatticeFasterDecoderConfig
&
config
,
FST
*
fst
)
:
LatticeFasterDecoderTpl
<
FST
,
Token
>
(
config
,
fst
)
{
}
struct
BestPathIterator
{
void
*
tok
;
int32
frame
;
// note, "frame" is the frame-index of the frame you'll get the
// transition-id for next time, if you call TraceBackBestPath on this
// iterator (assuming it's not an epsilon transition). Note that this
// is one less than you might reasonably expect, e.g. it's -1 for
// the nonemitting transitions before the first frame.
BestPathIterator
(
void
*
t
,
int32
f
)
:
tok
(
t
),
frame
(
f
)
{
}
bool
Done
()
const
{
return
tok
==
NULL
;
}
};
/// Outputs an FST corresponding to the single best path through the lattice.
/// This is quite efficient because it doesn't get the entire raw lattice and find
/// the best path through it; instead, it uses the BestPathEnd and BestPathIterator
/// so it basically traces it back through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
bool
GetBestPath
(
Lattice
*
ofst
,
bool
use_final_probs
=
true
)
const
;
/// This function does a self-test of GetBestPath(). Returns true on
/// success; returns false and prints a warning on failure.
bool
TestGetBestPath
(
bool
use_final_probs
=
true
)
const
;
/// This function returns an iterator that can be used to trace back
/// the best path. If use_final_probs == true and at least one final state
/// survived till the end, it will use the final-probs in working out the best
/// final Token, and will output the final cost to *final_cost (if non-NULL),
/// else it will use only the forward likelihood, and will put zero in
/// *final_cost (if non-NULL).
/// Requires that NumFramesDecoded() > 0.
BestPathIterator
BestPathEnd
(
bool
use_final_probs
,
BaseFloat
*
final_cost
=
NULL
)
const
;
/// This function can be used in conjunction with BestPathEnd() to trace back
/// the best path one link at a time (e.g. this can be useful in endpoint
/// detection). By "link" we mean a link in the graph; not all links cross
/// frame boundaries, but each time you see a nonzero ilabel you can interpret
/// that as a frame. The return value is the updated iterator. It outputs
/// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer,
/// while leaving its "nextstate" variable unchanged.
BestPathIterator
TraceBackBestPath
(
BestPathIterator
iter
,
LatticeArc
*
arc
)
const
;
/// Behaves the same as GetRawLattice but only processes tokens whose
/// extra_cost is smaller than the best-cost plus the specified beam.
/// It is only worthwhile to call this function if beam is less than
/// the lattice_beam specified in the config; otherwise, it would
/// return essentially the same thing as GetRawLattice, but more slowly.
bool
GetRawLatticePruned
(
Lattice
*
ofst
,
bool
use_final_probs
,
BaseFloat
beam
)
const
;
KALDI_DISALLOW_COPY_AND_ASSIGN
(
LatticeFasterOnlineDecoderTpl
);
};
typedef
LatticeFasterOnlineDecoderTpl
<
fst
::
StdFst
>
LatticeFasterOnlineDecoder
;
}
// end namespace kaldi.
#endif
speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc
0 → 100644
浏览文件 @
d14ee800
// lat/determinize-lattice-pruned-test.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/determinize-lattice-pruned.h"
#include "fstext/lattice-utils.h"
#include "fstext/fst-test-utils.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
namespace
fst
{
// Caution: these tests are not as generic as you might think from all the
// templates in the code. They are basically only valid for LatticeArc.
// This is partly due to the fact that certain templates need to be instantiated
// in other .cc files in this directory.
// test that determinization proceeds correctly on general
// FSTs (not guaranteed determinzable, but we use the
// max-states option to stop it getting out of control).
template
<
class
Arc
>
void
TestDeterminizeLatticePruned
()
{
typedef
kaldi
::
int32
Int
;
typedef
typename
Arc
::
Weight
Weight
;
typedef
ArcTpl
<
CompactLatticeWeightTpl
<
Weight
,
Int
>
>
CompactArc
;
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
RandFstOptions
opts
;
opts
.
n_states
=
4
;
opts
.
n_arcs
=
10
;
opts
.
n_final
=
2
;
opts
.
allow_empty
=
false
;
opts
.
weight_multiplier
=
0.5
;
// impt for the randomly generated weights
opts
.
acyclic
=
true
;
// to be exactly representable in float,
// or this test fails because numerical differences can cause symmetry in
// weights to be broken, which causes the wrong path to be chosen as far
// as the string part is concerned.
VectorFst
<
Arc
>
*
fst
=
RandPairFst
<
Arc
>
(
opts
);
bool
sorted
=
TopSort
(
fst
);
KALDI_ASSERT
(
sorted
);
ILabelCompare
<
Arc
>
ilabel_comp
;
if
(
kaldi
::
Rand
()
%
2
==
0
)
ArcSort
(
fst
,
ilabel_comp
);
std
::
cout
<<
"FST before lattice-determinizing is:
\n
"
;
{
FstPrinter
<
Arc
>
fstprinter
(
*
fst
,
NULL
,
NULL
,
NULL
,
false
,
true
,
"
\t
"
);
fstprinter
.
Print
(
&
std
::
cout
,
"standard output"
);
}
VectorFst
<
Arc
>
det_fst
;
try
{
DeterminizeLatticePrunedOptions
lat_opts
;
lat_opts
.
max_mem
=
((
kaldi
::
Rand
()
%
2
==
0
)
?
100
:
1000
);
lat_opts
.
max_states
=
((
kaldi
::
Rand
()
%
2
==
0
)
?
-
1
:
20
);
lat_opts
.
max_arcs
=
((
kaldi
::
Rand
()
%
2
==
0
)
?
-
1
:
30
);
bool
ans
=
DeterminizeLatticePruned
<
Weight
>
(
*
fst
,
10.0
,
&
det_fst
,
lat_opts
);
std
::
cout
<<
"FST after lattice-determinizing is:
\n
"
;
{
FstPrinter
<
Arc
>
fstprinter
(
det_fst
,
NULL
,
NULL
,
NULL
,
false
,
true
,
"
\t
"
);
fstprinter
.
Print
(
&
std
::
cout
,
"standard output"
);
}
KALDI_ASSERT
(
det_fst
.
Properties
(
kIDeterministic
,
true
)
&
kIDeterministic
);
// OK, now determinize it a different way and check equivalence.
// [note: it's not normal determinization, it's taking the best path
// for any input-symbol sequence....
VectorFst
<
Arc
>
pruned_fst
(
*
fst
);
if
(
pruned_fst
.
NumStates
()
!=
0
)
kaldi
::
PruneLattice
(
10.0
,
&
pruned_fst
);
VectorFst
<
CompactArc
>
compact_pruned_fst
,
compact_pruned_det_fst
;
ConvertLattice
<
Weight
,
Int
>
(
pruned_fst
,
&
compact_pruned_fst
,
false
);
std
::
cout
<<
"Compact pruned FST is:
\n
"
;
{
FstPrinter
<
CompactArc
>
fstprinter
(
compact_pruned_fst
,
NULL
,
NULL
,
NULL
,
false
,
true
,
"
\t
"
);
fstprinter
.
Print
(
&
std
::
cout
,
"standard output"
);
}
ConvertLattice
<
Weight
,
Int
>
(
det_fst
,
&
compact_pruned_det_fst
,
false
);
std
::
cout
<<
"Compact version of determinized FST is:
\n
"
;
{
FstPrinter
<
CompactArc
>
fstprinter
(
compact_pruned_det_fst
,
NULL
,
NULL
,
NULL
,
false
,
true
,
"
\t
"
);
fstprinter
.
Print
(
&
std
::
cout
,
"standard output"
);
}
if
(
ans
)
KALDI_ASSERT
(
RandEquivalent
(
compact_pruned_det_fst
,
compact_pruned_fst
,
5
/*paths*/
,
0.01
/*delta*/
,
kaldi
::
Rand
()
/*seed*/
,
100
/*path length, max*/
));
}
catch
(...)
{
std
::
cout
<<
"Failed to lattice-determinize this FST (probably not determinizable)
\n
"
;
}
delete
fst
;
}
}
// test that determinization proceeds without crash on acyclic FSTs
// (guaranteed determinizable in this sense).
template
<
class
Arc
>
void
TestDeterminizeLatticePruned2
()
{
typedef
typename
Arc
::
Weight
Weight
;
RandFstOptions
opts
;
opts
.
acyclic
=
true
;
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
VectorFst
<
Arc
>
*
fst
=
RandPairFst
<
Arc
>
(
opts
);
std
::
cout
<<
"FST before lattice-determinizing is:
\n
"
;
{
FstPrinter
<
Arc
>
fstprinter
(
*
fst
,
NULL
,
NULL
,
NULL
,
false
,
true
,
"
\t
"
);
fstprinter
.
Print
(
&
std
::
cout
,
"standard output"
);
}
VectorFst
<
Arc
>
ofst
;
DeterminizeLatticePruned
<
Weight
>
(
*
fst
,
10.0
,
&
ofst
);
std
::
cout
<<
"FST after lattice-determinizing is:
\n
"
;
{
FstPrinter
<
Arc
>
fstprinter
(
ofst
,
NULL
,
NULL
,
NULL
,
false
,
true
,
"
\t
"
);
fstprinter
.
Print
(
&
std
::
cout
,
"standard output"
);
}
delete
fst
;
}
}
}
// end namespace fst
int
main
()
{
using
namespace
fst
;
TestDeterminizeLatticePruned
<
kaldi
::
LatticeArc
>
();
TestDeterminizeLatticePruned2
<
kaldi
::
LatticeArc
>
();
std
::
cout
<<
"Tests succeeded
\n
"
;
}
speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc
0 → 100644
浏览文件 @
d14ee800
// lat/determinize-lattice-pruned.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <climits>
#include "fstext/determinize-lattice.h" // for LatticeStringRepository
#include "fstext/fstext-utils.h"
#include "lat/lattice-functions.h" // for PruneLattice
#include "lat/minimize-lattice.h" // for minimization
#include "lat/push-lattice.h" // for minimization
#include "lat/determinize-lattice-pruned.h"
namespace
fst
{
using
std
::
vector
;
using
std
::
pair
;
using
std
::
greater
;
// class LatticeDeterminizerPruned is templated on the same types that
// CompactLatticeWeight is templated on: the base weight (Weight), typically
// LatticeWeightTpl<float> etc. but could also be e.g. TropicalWeight, and the
// IntType, typically int32, used for the output symbols in the compact
// representation of strings [note: the output symbols would usually be
// p.d.f. id's in the anticipated use of this code] It has a special requirement
// on the Weight type: that there should be a Compare function on the weights
// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 >
// w2. This requires that there be a total order on the weights.
template
<
class
Weight
,
class
IntType
>
class
LatticeDeterminizerPruned
{
public:
// Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 correspondence
// between our states and the states in ofst. If destroy == true, release memory as we go
// (but we cannot output again).
typedef
CompactLatticeWeightTpl
<
Weight
,
IntType
>
CompactWeight
;
typedef
ArcTpl
<
CompactWeight
>
CompactArc
;
// arc in compact, acceptor form of lattice
typedef
ArcTpl
<
Weight
>
Arc
;
// arc in non-compact version of lattice
// Output to standard FST with CompactWeightTpl<Weight> as its weight type (the
// weight stores the original output-symbol strings). If destroy == true,
// release memory as we go (but we cannot output again).
void
Output
(
MutableFst
<
CompactArc
>
*
ofst
,
bool
destroy
=
true
)
{
KALDI_ASSERT
(
determinized_
);
typedef
typename
Arc
::
StateId
StateId
;
StateId
nStates
=
static_cast
<
StateId
>
(
output_states_
.
size
());
if
(
destroy
)
FreeMostMemory
();
ofst
->
DeleteStates
();
ofst
->
SetStart
(
kNoStateId
);
if
(
nStates
==
0
)
{
return
;
}
for
(
StateId
s
=
0
;
s
<
nStates
;
s
++
)
{
OutputStateId
news
=
ofst
->
AddState
();
KALDI_ASSERT
(
news
==
s
);
}
ofst
->
SetStart
(
0
);
// now process transitions.
for
(
StateId
this_state_id
=
0
;
this_state_id
<
nStates
;
this_state_id
++
)
{
OutputState
&
this_state
=
*
(
output_states_
[
this_state_id
]);
vector
<
TempArc
>
&
this_vec
(
this_state
.
arcs
);
typename
vector
<
TempArc
>::
const_iterator
iter
=
this_vec
.
begin
(),
end
=
this_vec
.
end
();
for
(;
iter
!=
end
;
++
iter
)
{
const
TempArc
&
temp_arc
(
*
iter
);
CompactArc
new_arc
;
vector
<
Label
>
olabel_seq
;
repository_
.
ConvertToVector
(
temp_arc
.
string
,
&
olabel_seq
);
CompactWeight
weight
(
temp_arc
.
weight
,
olabel_seq
);
if
(
temp_arc
.
nextstate
==
kNoStateId
)
{
// is really final weight.
ofst
->
SetFinal
(
this_state_id
,
weight
);
}
else
{
// is really an arc.
new_arc
.
nextstate
=
temp_arc
.
nextstate
;
new_arc
.
ilabel
=
temp_arc
.
ilabel
;
new_arc
.
olabel
=
temp_arc
.
ilabel
;
// acceptor. input == output.
new_arc
.
weight
=
weight
;
// includes string and weight.
ofst
->
AddArc
(
this_state_id
,
new_arc
);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating memory,
// and we want to reduce the maximum amount ever allocated.
if
(
destroy
)
{
vector
<
TempArc
>
temp
;
temp
.
swap
(
this_vec
);
}
}
if
(
destroy
)
{
FreeOutputStates
();
repository_
.
Destroy
();
}
}
// Output to standard FST with Weight as its weight type. We will create extra
// states to handle sequences of symbols on the output. If destroy == true,
// release memory as we go (but we cannot output again).
void
Output
(
MutableFst
<
Arc
>
*
ofst
,
bool
destroy
=
true
)
{
// Outputs to standard fst.
OutputStateId
nStates
=
static_cast
<
OutputStateId
>
(
output_states_
.
size
());
ofst
->
DeleteStates
();
if
(
nStates
==
0
)
{
ofst
->
SetStart
(
kNoStateId
);
return
;
}
if
(
destroy
)
FreeMostMemory
();
// Add basic states-- but we will add extra ones to account for strings on output.
for
(
OutputStateId
s
=
0
;
s
<
nStates
;
s
++
)
{
OutputStateId
news
=
ofst
->
AddState
();
KALDI_ASSERT
(
news
==
s
);
}
ofst
->
SetStart
(
0
);
for
(
OutputStateId
this_state_id
=
0
;
this_state_id
<
nStates
;
this_state_id
++
)
{
OutputState
&
this_state
=
*
(
output_states_
[
this_state_id
]);
vector
<
TempArc
>
&
this_vec
(
this_state
.
arcs
);
typename
vector
<
TempArc
>::
const_iterator
iter
=
this_vec
.
begin
(),
end
=
this_vec
.
end
();
for
(;
iter
!=
end
;
++
iter
)
{
const
TempArc
&
temp_arc
(
*
iter
);
vector
<
Label
>
seq
;
repository_
.
ConvertToVector
(
temp_arc
.
string
,
&
seq
);
if
(
temp_arc
.
nextstate
==
kNoStateId
)
{
// Really a final weight.
// Make a sequence of states going to a final state, with the strings
// as labels. Put the weight on the first arc.
OutputStateId
cur_state
=
this_state_id
;
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
i
++
)
{
OutputStateId
next_state
=
ofst
->
AddState
();
Arc
arc
;
arc
.
nextstate
=
next_state
;
arc
.
weight
=
(
i
==
0
?
temp_arc
.
weight
:
Weight
::
One
());
arc
.
ilabel
=
0
;
// epsilon.
arc
.
olabel
=
seq
[
i
];
ofst
->
AddArc
(
cur_state
,
arc
);
cur_state
=
next_state
;
}
ofst
->
SetFinal
(
cur_state
,
(
seq
.
size
()
==
0
?
temp_arc
.
weight
:
Weight
::
One
()));
}
else
{
// Really an arc.
OutputStateId
cur_state
=
this_state_id
;
// Have to be careful with this integer comparison (i+1 < seq.size()) because unsigned.
// i < seq.size()-1 could fail for zero-length sequences.
for
(
size_t
i
=
0
;
i
+
1
<
seq
.
size
();
i
++
)
{
// for all but the last element of seq, create new state.
OutputStateId
next_state
=
ofst
->
AddState
();
Arc
arc
;
arc
.
nextstate
=
next_state
;
arc
.
weight
=
(
i
==
0
?
temp_arc
.
weight
:
Weight
::
One
());
arc
.
ilabel
=
(
i
==
0
?
temp_arc
.
ilabel
:
0
);
// put ilabel on first element of seq.
arc
.
olabel
=
seq
[
i
];
ofst
->
AddArc
(
cur_state
,
arc
);
cur_state
=
next_state
;
}
// Add the final arc in the sequence.
Arc
arc
;
arc
.
nextstate
=
temp_arc
.
nextstate
;
arc
.
weight
=
(
seq
.
size
()
<=
1
?
temp_arc
.
weight
:
Weight
::
One
());
arc
.
ilabel
=
(
seq
.
size
()
<=
1
?
temp_arc
.
ilabel
:
0
);
arc
.
olabel
=
(
seq
.
size
()
>
0
?
seq
.
back
()
:
0
);
ofst
->
AddArc
(
cur_state
,
arc
);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating memory
if
(
destroy
)
{
vector
<
TempArc
>
temp
;
temp
.
swap
(
this_vec
);
}
}
if
(
destroy
)
{
FreeOutputStates
();
repository_
.
Destroy
();
}
}
// Initializer. After initializing the object you will typically
// call Determinize() and then call one of the Output functions.
// Note: ifst.Copy() will generally do a
// shallow copy. We do it like this for memory safety, rather than
// keeping a reference or pointer to ifst_.
LatticeDeterminizerPruned
(
const
ExpandedFst
<
Arc
>
&
ifst
,
double
beam
,
DeterminizeLatticePrunedOptions
opts
)
:
num_arcs_
(
0
),
num_elems_
(
0
),
ifst_
(
ifst
.
Copy
()),
beam_
(
beam
),
opts_
(
opts
),
equal_
(
opts_
.
delta
),
determinized_
(
false
),
minimal_hash_
(
3
,
hasher_
,
equal_
),
initial_hash_
(
3
,
hasher_
,
equal_
)
{
KALDI_ASSERT
(
Weight
::
Properties
()
&
kIdempotent
);
// this algorithm won't
// work correctly otherwise.
}
void
FreeOutputStates
()
{
for
(
size_t
i
=
0
;
i
<
output_states_
.
size
();
i
++
)
delete
output_states_
[
i
];
vector
<
OutputState
*>
temp
;
temp
.
swap
(
output_states_
);
}
// frees all memory except the info (in output_states_[ ]->arcs)
// that we need to output the FST.
void
FreeMostMemory
()
{
if
(
ifst_
)
{
delete
ifst_
;
ifst_
=
NULL
;
}
{
MinimalSubsetHash
tmp
;
tmp
.
swap
(
minimal_hash_
);
}
for
(
size_t
i
=
0
;
i
<
output_states_
.
size
();
i
++
)
{
vector
<
Element
>
empty_subset
;
empty_subset
.
swap
(
output_states_
[
i
]
->
minimal_subset
);
}
for
(
typename
InitialSubsetHash
::
iterator
iter
=
initial_hash_
.
begin
();
iter
!=
initial_hash_
.
end
();
++
iter
)
delete
iter
->
first
;
{
InitialSubsetHash
tmp
;
tmp
.
swap
(
initial_hash_
);
}
{
vector
<
char
>
tmp
;
tmp
.
swap
(
isymbol_or_final_
);
}
{
// Free up the queue. I'm not sure how to make sure all
// the memory is really freed (no swap() function)... doesn't really
// matter much though.
while
(
!
queue_
.
empty
())
{
Task
*
t
=
queue_
.
top
();
delete
t
;
queue_
.
pop
();
}
}
{
vector
<
pair
<
Label
,
Element
>
>
tmp
;
tmp
.
swap
(
all_elems_tmp_
);
}
}
~
LatticeDeterminizerPruned
()
{
FreeMostMemory
();
FreeOutputStates
();
// rest is deleted by destructors.
}
void
RebuildRepository
()
{
// rebuild the string repository,
// freeing stuff we don't need.. we call this when memory usage
// passes a supplied threshold. We need to accumulate all the
// strings we need the repository to "remember", then tell it
// to clean the repository.
std
::
vector
<
StringId
>
needed_strings
;
for
(
size_t
i
=
0
;
i
<
output_states_
.
size
();
i
++
)
{
AddStrings
(
output_states_
[
i
]
->
minimal_subset
,
&
needed_strings
);
for
(
size_t
j
=
0
;
j
<
output_states_
[
i
]
->
arcs
.
size
();
j
++
)
needed_strings
.
push_back
(
output_states_
[
i
]
->
arcs
[
j
].
string
);
}
{
// the queue doesn't allow us access to the underlying vector,
// so we have to resort to a temporary collection.
std
::
vector
<
Task
*>
tasks
;
while
(
!
queue_
.
empty
())
{
Task
*
task
=
queue_
.
top
();
queue_
.
pop
();
tasks
.
push_back
(
task
);
AddStrings
(
task
->
subset
,
&
needed_strings
);
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
queue_
.
push
(
tasks
[
i
]);
}
// the following loop covers strings present in initial_hash_.
for
(
typename
InitialSubsetHash
::
const_iterator
iter
=
initial_hash_
.
begin
();
iter
!=
initial_hash_
.
end
();
++
iter
)
{
const
vector
<
Element
>
&
vec
=
*
(
iter
->
first
);
Element
elem
=
iter
->
second
;
AddStrings
(
vec
,
&
needed_strings
);
needed_strings
.
push_back
(
elem
.
string
);
}
std
::
sort
(
needed_strings
.
begin
(),
needed_strings
.
end
());
needed_strings
.
erase
(
std
::
unique
(
needed_strings
.
begin
(),
needed_strings
.
end
()),
needed_strings
.
end
());
// uniq the strings.
KALDI_LOG
<<
"Rebuilding repository."
;
repository_
.
Rebuild
(
needed_strings
);
}
bool
CheckMemoryUsage
()
{
int32
repo_size
=
repository_
.
MemSize
(),
arcs_size
=
num_arcs_
*
sizeof
(
TempArc
),
elems_size
=
num_elems_
*
sizeof
(
Element
),
total_size
=
repo_size
+
arcs_size
+
elems_size
;
if
(
opts_
.
max_mem
>
0
&&
total_size
>
opts_
.
max_mem
)
{
// We passed the memory threshold.
// This is usually due to the repository getting large, so we
// clean this out.
RebuildRepository
();
int32
new_repo_size
=
repository_
.
MemSize
(),
new_total_size
=
new_repo_size
+
arcs_size
+
elems_size
;
KALDI_VLOG
(
2
)
<<
"Rebuilt repository in determinize-lattice: repository shrank from "
<<
repo_size
<<
" to "
<<
new_repo_size
<<
" bytes (approximately)"
;
if
(
new_total_size
>
static_cast
<
int32
>
(
opts_
.
max_mem
*
0.8
))
{
// Rebuilding didn't help enough-- we need a margin to stop
// having to rebuild too often. We'll just return to the user at
// this point, with a partial lattice that's pruned tighter than
// the specified beam. Here we figure out what the effective
// beam was.
double
effective_beam
=
beam_
;
if
(
!
queue_
.
empty
())
{
// Note: queue should probably not be empty; we're
// just being paranoid here.
Task
*
task
=
queue_
.
top
();
double
total_weight
=
backward_costs_
[
ifst_
->
Start
()];
// best weight of FST.
effective_beam
=
task
->
priority_cost
-
total_weight
;
}
KALDI_WARN
<<
"Did not reach requested beam in determinize-lattice: "
<<
"size exceeds maximum "
<<
opts_
.
max_mem
<<
" bytes; (repo,arcs,elems) = ("
<<
repo_size
<<
","
<<
arcs_size
<<
","
<<
elems_size
<<
"), after rebuilding, repo size was "
<<
new_repo_size
<<
", effective beam was "
<<
effective_beam
<<
" vs. requested beam "
<<
beam_
;
return
false
;
}
}
return
true
;
}
bool
Determinize
(
double
*
effective_beam
)
{
KALDI_ASSERT
(
!
determinized_
);
// This determinizes the input fst but leaves it in the "special format"
// in "output_arcs_". Must be called after Initialize(). To get the
// output, call one of the Output routines.
InitializeDeterminization
();
// some start-up tasks.
while
(
!
queue_
.
empty
())
{
Task
*
task
=
queue_
.
top
();
// Note: the queue contains only tasks that are "within the beam".
// We also have to check whether we have reached one of the user-specified
// maximums, of estimated memory, arcs, or states. The condition for
// ending is:
// num-states is more than user specified, OR
// num-arcs is more than user specified, OR
// memory passed a user-specified threshold and cleanup failed
// to get it below that threshold.
size_t
num_states
=
output_states_
.
size
();
if
((
opts_
.
max_states
>
0
&&
num_states
>
opts_
.
max_states
)
||
(
opts_
.
max_arcs
>
0
&&
num_arcs_
>
opts_
.
max_arcs
)
||
(
num_states
%
10
==
0
&&
!
CheckMemoryUsage
()))
{
// note: at some point
// it was num_states % 100, not num_states % 10, but I encountered an example
// where memory was exhausted before we reached state #100.
KALDI_VLOG
(
1
)
<<
"Lattice determinization terminated but not "
<<
" because of lattice-beam. (#states, #arcs) is ( "
<<
output_states_
.
size
()
<<
", "
<<
num_arcs_
<<
" ), versus limits ( "
<<
opts_
.
max_states
<<
", "
<<
opts_
.
max_arcs
<<
" ) (else, may be memory limit)."
;
break
;
// we terminate the determinization here-- whatever we already expanded is
// what we'll return... because we expanded stuff in order of total
// (forward-backward) weight, the stuff we returned first is the most
// important.
}
queue_
.
pop
();
ProcessTransition
(
task
->
state
,
task
->
label
,
&
(
task
->
subset
));
delete
task
;
}
determinized_
=
true
;
if
(
effective_beam
!=
NULL
)
{
if
(
queue_
.
empty
())
*
effective_beam
=
beam_
;
else
*
effective_beam
=
queue_
.
top
()
->
priority_cost
-
backward_costs_
[
ifst_
->
Start
()];
}
return
(
queue_
.
empty
());
// return success if queue was empty, i.e. we processed
// all tasks and did not break out of the loop early due to reaching a memory,
// arc or state limit.
}
private:
typedef
typename
Arc
::
Label
Label
;
typedef
typename
Arc
::
StateId
StateId
;
// use this when we don't know if it's input or output.
typedef
typename
Arc
::
StateId
InputStateId
;
// state in the input FST.
typedef
typename
Arc
::
StateId
OutputStateId
;
// same as above but distinguish
// states in output Fst.
typedef
LatticeStringRepository
<
IntType
>
StringRepositoryType
;
typedef
const
typename
StringRepositoryType
::
Entry
*
StringId
;
// Element of a subset [of original states]
struct
Element
{
StateId
state
;
// use StateId as this is usually InputStateId but in one case
// OutputStateId.
StringId
string
;
Weight
weight
;
bool
operator
!=
(
const
Element
&
other
)
const
{
return
(
state
!=
other
.
state
||
string
!=
other
.
string
||
weight
!=
other
.
weight
);
}
// This operator is only intended for the priority_queue in the function
// EpsilonClosure().
bool
operator
>
(
const
Element
&
other
)
const
{
return
state
>
other
.
state
;
}
// This operator is only intended to support sorting in EpsilonClosure()
bool
operator
<
(
const
Element
&
other
)
const
{
return
state
<
other
.
state
;
}
};
// Arcs in the format we temporarily create in this class (a representation, essentially of
// a Gallic Fst).
struct
TempArc
{
Label
ilabel
;
StringId
string
;
// Look it up in the StringRepository, it's a sequence of Labels.
OutputStateId
nextstate
;
// or kNoState for final weights.
Weight
weight
;
};
// Hashing function used in hash of subsets.
// A subset is a pointer to vector<Element>.
// The Elements are in sorted order on state id, and without repeated states.
// Because the order of Elements is fixed, we can use a hashing function that is
// order-dependent. However the weights are not included in the hashing function--
// we hash subsets that differ only in weight to the same key. This is not optimal
// in terms of the O(N) performance but typically if we have a lot of determinized
// states that differ only in weight then the input probably was pathological in some way,
// or even non-determinizable.
// We don't quantize the weights, in order to avoid inexactness in simple cases.
// Instead we apply the delta when comparing subsets for equality, and allow a small
// difference.
class
SubsetKey
{
public:
size_t
operator
()(
const
vector
<
Element
>
*
subset
)
const
{
// hashes only the state and string.
size_t
hash
=
0
,
factor
=
1
;
for
(
typename
vector
<
Element
>::
const_iterator
iter
=
subset
->
begin
();
iter
!=
subset
->
end
();
++
iter
)
{
hash
*=
factor
;
hash
+=
iter
->
state
+
reinterpret_cast
<
size_t
>
(
iter
->
string
);
factor
*=
23531
;
// these numbers are primes.
}
return
hash
;
}
};
// This is the equality operator on subsets. It checks for exact match on state-id
// and string, and approximate match on weights.
class
SubsetEqual
{
public:
bool
operator
()(
const
vector
<
Element
>
*
s1
,
const
vector
<
Element
>
*
s2
)
const
{
size_t
sz
=
s1
->
size
();
KALDI_ASSERT
(
sz
>=
0
);
if
(
sz
!=
s2
->
size
())
return
false
;
typename
vector
<
Element
>::
const_iterator
iter1
=
s1
->
begin
(),
iter1_end
=
s1
->
end
(),
iter2
=
s2
->
begin
();
for
(;
iter1
<
iter1_end
;
++
iter1
,
++
iter2
)
{
if
(
iter1
->
state
!=
iter2
->
state
||
iter1
->
string
!=
iter2
->
string
||
!
ApproxEqual
(
iter1
->
weight
,
iter2
->
weight
,
delta_
))
return
false
;
}
return
true
;
}
float
delta_
;
SubsetEqual
(
float
delta
)
:
delta_
(
delta
)
{}
SubsetEqual
()
:
delta_
(
kDelta
)
{}
};
// Operator that says whether two Elements have the same states.
// Used only for debug.
class
SubsetEqualStates
{
public:
bool
operator
()(
const
vector
<
Element
>
*
s1
,
const
vector
<
Element
>
*
s2
)
const
{
size_t
sz
=
s1
->
size
();
KALDI_ASSERT
(
sz
>=
0
);
if
(
sz
!=
s2
->
size
())
return
false
;
typename
vector
<
Element
>::
const_iterator
iter1
=
s1
->
begin
(),
iter1_end
=
s1
->
end
(),
iter2
=
s2
->
begin
();
for
(;
iter1
<
iter1_end
;
++
iter1
,
++
iter2
)
{
if
(
iter1
->
state
!=
iter2
->
state
)
return
false
;
}
return
true
;
}
};
// Define the hash type we use to map subsets (in minimal
// representation) to OutputStateId.
typedef
unordered_map
<
const
vector
<
Element
>*
,
OutputStateId
,
SubsetKey
,
SubsetEqual
>
MinimalSubsetHash
;
// Define the hash type we use to map subsets (in initial
// representation) to OutputStateId, together with an
// extra weight. [note: we interpret the Element.state in here
// as an OutputStateId even though it's declared as InputStateId;
// these types are the same anyway].
typedef
unordered_map
<
const
vector
<
Element
>*
,
Element
,
SubsetKey
,
SubsetEqual
>
InitialSubsetHash
;
// converts the representation of the subset from canonical (all states) to
// minimal (only states with output symbols on arcs leaving them, and final
// states). Output is not necessarily normalized, even if input_subset was.
void
ConvertToMinimal
(
vector
<
Element
>
*
subset
)
{
KALDI_ASSERT
(
!
subset
->
empty
());
typename
vector
<
Element
>::
iterator
cur_in
=
subset
->
begin
(),
cur_out
=
subset
->
begin
(),
end
=
subset
->
end
();
while
(
cur_in
!=
end
)
{
if
(
IsIsymbolOrFinal
(
cur_in
->
state
))
{
// keep it...
*
cur_out
=
*
cur_in
;
cur_out
++
;
}
cur_in
++
;
}
subset
->
resize
(
cur_out
-
subset
->
begin
());
}
// Takes a minimal, normalized subset, and converts it to an OutputStateId.
// Involves a hash lookup, and possibly adding a new OutputStateId.
// If it creates a new OutputStateId, it creates a new record for it, works
// out its final-weight, and puts stuff on the queue relating to its
// transitions.
OutputStateId
MinimalToStateId
(
const
vector
<
Element
>
&
subset
,
const
double
forward_cost
)
{
typename
MinimalSubsetHash
::
const_iterator
iter
=
minimal_hash_
.
find
(
&
subset
);
if
(
iter
!=
minimal_hash_
.
end
())
{
// Found a matching subset.
OutputStateId
state_id
=
iter
->
second
;
const
OutputState
&
state
=
*
(
output_states_
[
state_id
]);
// Below is just a check that the algorithm is working...
if
(
forward_cost
<
state
.
forward_cost
-
0.1
)
{
// for large weights, this check could fail due to roundoff.
KALDI_WARN
<<
"New cost is less (check the difference is small) "
<<
forward_cost
<<
", "
<<
state
.
forward_cost
;
}
return
state_id
;
}
OutputStateId
state_id
=
static_cast
<
OutputStateId
>
(
output_states_
.
size
());
OutputState
*
new_state
=
new
OutputState
(
subset
,
forward_cost
);
minimal_hash_
[
&
(
new_state
->
minimal_subset
)]
=
state_id
;
output_states_
.
push_back
(
new_state
);
num_elems_
+=
subset
.
size
();
// Note: in the previous algorithm, we pushed the new state-id onto the queue
// at this point. Here, the queue happens elsewhere, and we directly process
// the state (which result in stuff getting added to the queue).
ProcessFinal
(
state_id
);
// will work out the final-prob.
ProcessTransitions
(
state_id
);
// will process transitions and add stuff to the queue.
return
state_id
;
}
// Given a normalized initial subset of elements (i.e. before epsilon closure),
// compute the corresponding output-state.
OutputStateId
InitialToStateId
(
const
vector
<
Element
>
&
subset_in
,
double
forward_cost
,
Weight
*
remaining_weight
,
StringId
*
common_prefix
)
{
typename
InitialSubsetHash
::
const_iterator
iter
=
initial_hash_
.
find
(
&
subset_in
);
if
(
iter
!=
initial_hash_
.
end
())
{
// Found a matching subset.
const
Element
&
elem
=
iter
->
second
;
*
remaining_weight
=
elem
.
weight
;
*
common_prefix
=
elem
.
string
;
if
(
elem
.
weight
==
Weight
::
Zero
())
KALDI_WARN
<<
"Zero weight!"
;
return
elem
.
state
;
}
// else no matching subset-- have to work it out.
vector
<
Element
>
subset
(
subset_in
);
// Follow through epsilons. Will add no duplicate states. note: after
// EpsilonClosure, it is the same as "canonical" subset, except not
// normalized (actually we never compute the normalized canonical subset,
// only the normalized minimal one).
EpsilonClosure
(
&
subset
);
// follow epsilons.
ConvertToMinimal
(
&
subset
);
// remove all but emitting and final states.
Element
elem
;
// will be used to store remaining weight and string, and
// OutputStateId, in initial_hash_;
NormalizeSubset
(
&
subset
,
&
elem
.
weight
,
&
elem
.
string
);
// normalize subset; put
// common string and weight in "elem". The subset is now a minimal,
// normalized subset.
forward_cost
+=
ConvertToCost
(
elem
.
weight
);
OutputStateId
ans
=
MinimalToStateId
(
subset
,
forward_cost
);
*
remaining_weight
=
elem
.
weight
;
*
common_prefix
=
elem
.
string
;
if
(
elem
.
weight
==
Weight
::
Zero
())
KALDI_WARN
<<
"Zero weight!"
;
// Before returning "ans", add the initial subset to the hash,
// so that we can bypass the epsilon-closure etc., next time
// we process the same initial subset.
vector
<
Element
>
*
initial_subset_ptr
=
new
vector
<
Element
>
(
subset_in
);
elem
.
state
=
ans
;
initial_hash_
[
initial_subset_ptr
]
=
elem
;
num_elems_
+=
initial_subset_ptr
->
size
();
// keep track of memory usage.
return
ans
;
}
// returns the Compare value (-1 if a < b, 0 if a == b, 1 if a > b) according
// to the ordering we defined on strings for the CompactLatticeWeightTpl.
// see function
// inline int Compare (const CompactLatticeWeightTpl<WeightType,IntType> &w1,
// const CompactLatticeWeightTpl<WeightType,IntType> &w2)
// in lattice-weight.h.
// this is the same as that, but optimized for our data structures.
inline
int
Compare
(
const
Weight
&
a_w
,
StringId
a_str
,
const
Weight
&
b_w
,
StringId
b_str
)
const
{
int
weight_comp
=
fst
::
Compare
(
a_w
,
b_w
);
if
(
weight_comp
!=
0
)
return
weight_comp
;
// now comparing strings.
if
(
a_str
==
b_str
)
return
0
;
vector
<
IntType
>
a_vec
,
b_vec
;
repository_
.
ConvertToVector
(
a_str
,
&
a_vec
);
repository_
.
ConvertToVector
(
b_str
,
&
b_vec
);
// First compare their lengths.
int
a_len
=
a_vec
.
size
(),
b_len
=
b_vec
.
size
();
// use opposite order on the string lengths (c.f. Compare in
// lattice-weight.h)
if
(
a_len
>
b_len
)
return
-
1
;
else
if
(
a_len
<
b_len
)
return
1
;
for
(
int
i
=
0
;
i
<
a_len
;
i
++
)
{
if
(
a_vec
[
i
]
<
b_vec
[
i
])
return
-
1
;
else
if
(
a_vec
[
i
]
>
b_vec
[
i
])
return
1
;
}
KALDI_ASSERT
(
0
);
// because we checked if a_str == b_str above, shouldn't reach here
return
0
;
}
// This function computes epsilon closure of subset of states by following epsilon links.
// Called by InitialToStateId and Initialize.
// Has no side effects except on the string repository. The "output_subset" is not
// necessarily normalized (in the sense of there being no common substring), unless
// input_subset was.
void
EpsilonClosure
(
vector
<
Element
>
*
subset
)
{
// at input, subset must have only one example of each StateId. [will still
// be so at output]. This function follows input-epsilons, and augments the
// subset accordingly.
std
::
priority_queue
<
Element
,
vector
<
Element
>
,
greater
<
Element
>
>
queue
;
unordered_map
<
InputStateId
,
Element
>
cur_subset
;
typedef
typename
unordered_map
<
InputStateId
,
Element
>::
iterator
MapIter
;
typedef
typename
vector
<
Element
>::
const_iterator
VecIter
;
for
(
VecIter
iter
=
subset
->
begin
();
iter
!=
subset
->
end
();
++
iter
)
{
queue
.
push
(
*
iter
);
cur_subset
[
iter
->
state
]
=
*
iter
;
}
// find whether input fst is known to be sorted on input label.
bool
sorted
=
((
ifst_
->
Properties
(
kILabelSorted
,
false
)
&
kILabelSorted
)
!=
0
);
bool
replaced_elems
=
false
;
// relates to an optimization, see below.
int
counter
=
0
;
// stops infinite loops here for non-lattice-determinizable input
// (e.g. input with negative-cost epsilon loops); useful in testing.
while
(
queue
.
size
()
!=
0
)
{
Element
elem
=
queue
.
top
();
queue
.
pop
();
// The next if-statement is a kind of optimization. It's to prevent us
// unnecessarily repeating the processing of a state. "cur_subset" always
// contains only one Element with a particular state. The issue is that
// whenever we modify the Element corresponding to that state in "cur_subset",
// both the new (optimal) and old (less-optimal) Element will still be in
// "queue". The next if-statement stops us from wasting compute by
// processing the old Element.
if
(
replaced_elems
&&
cur_subset
[
elem
.
state
]
!=
elem
)
continue
;
if
(
opts_
.
max_loop
>
0
&&
counter
++
>
opts_
.
max_loop
)
{
KALDI_ERR
<<
"Lattice determinization aborted since looped more than "
<<
opts_
.
max_loop
<<
" times during epsilon closure."
;
}
for
(
ArcIterator
<
ExpandedFst
<
Arc
>
>
aiter
(
*
ifst_
,
elem
.
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
sorted
&&
arc
.
ilabel
!=
0
)
break
;
// Break from the loop: due to sorting there will be no
// more transitions with epsilons as input labels.
if
(
arc
.
ilabel
==
0
&&
arc
.
weight
!=
Weight
::
Zero
())
{
// Epsilon transition.
Element
next_elem
;
next_elem
.
state
=
arc
.
nextstate
;
next_elem
.
weight
=
Times
(
elem
.
weight
,
arc
.
weight
);
// next_elem.string is not set up yet... create it only
// when we know we need it (this is an optimization)
MapIter
iter
=
cur_subset
.
find
(
next_elem
.
state
);
if
(
iter
==
cur_subset
.
end
())
{
// was no such StateId: insert and add to queue.
next_elem
.
string
=
(
arc
.
olabel
==
0
?
elem
.
string
:
repository_
.
Successor
(
elem
.
string
,
arc
.
olabel
));
cur_subset
[
next_elem
.
state
]
=
next_elem
;
queue
.
push
(
next_elem
);
}
else
{
// was not inserted because one already there. In normal
// determinization we'd add the weights. Here, we find which one
// has the better weight, and keep its corresponding string.
int
comp
=
fst
::
Compare
(
next_elem
.
weight
,
iter
->
second
.
weight
);
if
(
comp
==
0
)
{
// A tie on weights. This should be a rare case;
// we don't optimize for it.
next_elem
.
string
=
(
arc
.
olabel
==
0
?
elem
.
string
:
repository_
.
Successor
(
elem
.
string
,
arc
.
olabel
));
comp
=
Compare
(
next_elem
.
weight
,
next_elem
.
string
,
iter
->
second
.
weight
,
iter
->
second
.
string
);
}
if
(
comp
==
1
)
{
// next_elem is better, so use its (weight, string)
next_elem
.
string
=
(
arc
.
olabel
==
0
?
elem
.
string
:
repository_
.
Successor
(
elem
.
string
,
arc
.
olabel
));
iter
->
second
.
string
=
next_elem
.
string
;
iter
->
second
.
weight
=
next_elem
.
weight
;
queue
.
push
(
next_elem
);
replaced_elems
=
true
;
}
// else it is the same or worse, so use original one.
}
}
}
}
{
// copy cur_subset to subset.
subset
->
clear
();
subset
->
reserve
(
cur_subset
.
size
());
MapIter
iter
=
cur_subset
.
begin
(),
end
=
cur_subset
.
end
();
for
(;
iter
!=
end
;
++
iter
)
subset
->
push_back
(
iter
->
second
);
// sort by state ID, because the subset hash function is order-dependent(see SubsetKey)
std
::
sort
(
subset
->
begin
(),
subset
->
end
());
}
}
// This function works out the final-weight of the determinized state.
// called by ProcessSubset.
// Has no side effects except on the variable repository_, and
// output_states_[output_state_id].arcs
void
ProcessFinal
(
OutputStateId
output_state_id
)
{
OutputState
&
state
=
*
(
output_states_
[
output_state_id
]);
const
vector
<
Element
>
&
minimal_subset
=
state
.
minimal_subset
;
// processes final-weights for this subset. state.minimal_subset_ may be
// empty if the graphs is not connected/trimmed, I think, do don't check
// that it's nonempty.
StringId
final_string
=
repository_
.
EmptyString
();
// set it to keep the
// compiler happy; if it doesn't get set in the loop, we won't use the value anyway.
Weight
final_weight
=
Weight
::
Zero
();
bool
is_final
=
false
;
typename
vector
<
Element
>::
const_iterator
iter
=
minimal_subset
.
begin
(),
end
=
minimal_subset
.
end
();
for
(;
iter
!=
end
;
++
iter
)
{
const
Element
&
elem
=
*
iter
;
Weight
this_final_weight
=
Times
(
elem
.
weight
,
ifst_
->
Final
(
elem
.
state
));
StringId
this_final_string
=
elem
.
string
;
if
(
this_final_weight
!=
Weight
::
Zero
()
&&
(
!
is_final
||
Compare
(
this_final_weight
,
this_final_string
,
final_weight
,
final_string
)
==
1
))
{
// the new
// (weight, string) pair is more in semiring than our current
// one.
is_final
=
true
;
final_weight
=
this_final_weight
;
final_string
=
this_final_string
;
}
}
if
(
is_final
&&
ConvertToCost
(
final_weight
)
+
state
.
forward_cost
<=
cutoff_
)
{
// store final weights in TempArc structure, just like a transition.
// Note: we only store the final-weight if it's inside the pruning beam, hence
// the stuff with Compare.
TempArc
temp_arc
;
temp_arc
.
ilabel
=
0
;
temp_arc
.
nextstate
=
kNoStateId
;
// special marker meaning "final weight".
temp_arc
.
string
=
final_string
;
temp_arc
.
weight
=
final_weight
;
state
.
arcs
.
push_back
(
temp_arc
);
num_arcs_
++
;
}
}
// NormalizeSubset normalizes the subset "elems" by
// removing any common string prefix (putting it in common_str),
// and dividing by the total weight (putting it in tot_weight).
void
NormalizeSubset
(
vector
<
Element
>
*
elems
,
Weight
*
tot_weight
,
StringId
*
common_str
)
{
if
(
elems
->
empty
())
{
// just set common_str, tot_weight
// to defaults and return...
KALDI_WARN
<<
"empty subset"
;
*
common_str
=
repository_
.
EmptyString
();
*
tot_weight
=
Weight
::
Zero
();
return
;
}
size_t
size
=
elems
->
size
();
vector
<
IntType
>
common_prefix
;
repository_
.
ConvertToVector
((
*
elems
)[
0
].
string
,
&
common_prefix
);
Weight
weight
=
(
*
elems
)[
0
].
weight
;
for
(
size_t
i
=
1
;
i
<
size
;
i
++
)
{
weight
=
Plus
(
weight
,
(
*
elems
)[
i
].
weight
);
repository_
.
ReduceToCommonPrefix
((
*
elems
)[
i
].
string
,
&
common_prefix
);
}
KALDI_ASSERT
(
weight
!=
Weight
::
Zero
());
// we made sure to ignore arcs with zero
// weights on them, so we shouldn't have zero here.
size_t
prefix_len
=
common_prefix
.
size
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
(
*
elems
)[
i
].
weight
=
Divide
((
*
elems
)[
i
].
weight
,
weight
,
DIVIDE_LEFT
);
(
*
elems
)[
i
].
string
=
repository_
.
RemovePrefix
((
*
elems
)[
i
].
string
,
prefix_len
);
}
*
common_str
=
repository_
.
ConvertFromVector
(
common_prefix
);
*
tot_weight
=
weight
;
}
// Take a subset of Elements that is sorted on state, and
// merge any Elements that have the same state (taking the best
// (weight, string) pair in the semiring).
void
MakeSubsetUnique
(
vector
<
Element
>
*
subset
)
{
typedef
typename
vector
<
Element
>::
iterator
IterType
;
// This KALDI_ASSERT is designed to fail (usually) if the subset is not sorted on
// state.
KALDI_ASSERT
(
subset
->
size
()
<
2
||
(
*
subset
)[
0
].
state
<=
(
*
subset
)[
1
].
state
);
IterType
cur_in
=
subset
->
begin
(),
cur_out
=
cur_in
,
end
=
subset
->
end
();
size_t
num_out
=
0
;
// Merge elements with same state-id
while
(
cur_in
!=
end
)
{
// while we have more elements to process.
// At this point, cur_out points to location of next place we want to put an element,
// cur_in points to location of next element we want to process.
if
(
cur_in
!=
cur_out
)
*
cur_out
=
*
cur_in
;
cur_in
++
;
while
(
cur_in
!=
end
&&
cur_in
->
state
==
cur_out
->
state
)
{
if
(
Compare
(
cur_in
->
weight
,
cur_in
->
string
,
cur_out
->
weight
,
cur_out
->
string
)
==
1
)
{
// if *cur_in > *cur_out in semiring, then take *cur_in.
cur_out
->
string
=
cur_in
->
string
;
cur_out
->
weight
=
cur_in
->
weight
;
}
cur_in
++
;
}
cur_out
++
;
num_out
++
;
}
subset
->
resize
(
num_out
);
}
// ProcessTransition was called from "ProcessTransitions" in the non-pruned
// code, but now we in effect put the calls to ProcessTransition on a priority
// queue, and it now gets called directly from Determinize(). This function
// processes a transition from state "ostate_id". The set "subset" of Elements
// represents a set of next-states with associated weights and strings, each
// one arising from an arc from some state in a determinized-state; the
// next-states are unique (there is only one Entry assocated with each)
void
ProcessTransition
(
OutputStateId
ostate_id
,
Label
ilabel
,
vector
<
Element
>
*
subset
)
{
double
forward_cost
=
output_states_
[
ostate_id
]
->
forward_cost
;
StringId
common_str
;
Weight
tot_weight
;
NormalizeSubset
(
subset
,
&
tot_weight
,
&
common_str
);
forward_cost
+=
ConvertToCost
(
tot_weight
);
OutputStateId
nextstate
;
{
Weight
next_tot_weight
;
StringId
next_common_str
;
nextstate
=
InitialToStateId
(
*
subset
,
forward_cost
,
&
next_tot_weight
,
&
next_common_str
);
common_str
=
repository_
.
Concatenate
(
common_str
,
next_common_str
);
tot_weight
=
Times
(
tot_weight
,
next_tot_weight
);
}
// Now add an arc to the next state (would have been created if necessary by
// InitialToStateId).
TempArc
temp_arc
;
temp_arc
.
ilabel
=
ilabel
;
temp_arc
.
nextstate
=
nextstate
;
temp_arc
.
string
=
common_str
;
temp_arc
.
weight
=
tot_weight
;
output_states_
[
ostate_id
]
->
arcs
.
push_back
(
temp_arc
);
// record the arc.
num_arcs_
++
;
}
// "less than" operator for pair<Label, Element>. Used in ProcessTransitions.
// Lexicographical order, which only compares the state when ordering the
// "Element" member of the pair.
class
PairComparator
{
public:
inline
bool
operator
()
(
const
pair
<
Label
,
Element
>
&
p1
,
const
pair
<
Label
,
Element
>
&
p2
)
{
if
(
p1
.
first
<
p2
.
first
)
return
true
;
else
if
(
p1
.
first
>
p2
.
first
)
return
false
;
else
{
return
p1
.
second
.
state
<
p2
.
second
.
state
;
}
}
};
// ProcessTransitions processes emitting transitions (transitions with
// ilabels) out of this subset of states. It actualy only creates records
// ("Task") that get added to the queue. The transitions will be processed in
// priority order from Determinize(). This function soes not consider final
// states. Partitions the emitting transitions up by ilabel (by sorting on
// ilabel), and for each unique ilabel, it creates a Task record that contains
// the information we need to process the transition.
void
ProcessTransitions
(
OutputStateId
output_state_id
)
{
const
vector
<
Element
>
&
minimal_subset
=
output_states_
[
output_state_id
]
->
minimal_subset
;
// it's possible that minimal_subset could be empty if there are
// unreachable parts of the graph, so don't check that it's nonempty.
vector
<
pair
<
Label
,
Element
>
>
&
all_elems
(
all_elems_tmp_
);
// use class member
// to avoid memory allocation/deallocation.
{
// Push back into "all_elems", elements corresponding to all
// non-epsilon-input transitions out of all states in "minimal_subset".
typename
vector
<
Element
>::
const_iterator
iter
=
minimal_subset
.
begin
(),
end
=
minimal_subset
.
end
();
for
(;
iter
!=
end
;
++
iter
)
{
const
Element
&
elem
=
*
iter
;
for
(
ArcIterator
<
ExpandedFst
<
Arc
>
>
aiter
(
*
ifst_
,
elem
.
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
&&
arc
.
weight
!=
Weight
::
Zero
())
{
// Non-epsilon transition -- ignore epsilons here.
pair
<
Label
,
Element
>
this_pr
;
this_pr
.
first
=
arc
.
ilabel
;
Element
&
next_elem
(
this_pr
.
second
);
next_elem
.
state
=
arc
.
nextstate
;
next_elem
.
weight
=
Times
(
elem
.
weight
,
arc
.
weight
);
if
(
arc
.
olabel
==
0
)
// output epsilon
next_elem
.
string
=
elem
.
string
;
else
next_elem
.
string
=
repository_
.
Successor
(
elem
.
string
,
arc
.
olabel
);
all_elems
.
push_back
(
this_pr
);
}
}
}
}
PairComparator
pc
;
std
::
sort
(
all_elems
.
begin
(),
all_elems
.
end
(),
pc
);
// now sorted first on input label, then on state.
typedef
typename
vector
<
pair
<
Label
,
Element
>
>::
const_iterator
PairIter
;
PairIter
cur
=
all_elems
.
begin
(),
end
=
all_elems
.
end
();
while
(
cur
!=
end
)
{
// The old code (non-pruned) called ProcessTransition; here, instead,
// we'll put the calls into a priority queue.
Task
*
task
=
new
Task
;
// Process ranges that share the same input symbol.
Label
ilabel
=
cur
->
first
;
task
->
state
=
output_state_id
;
task
->
priority_cost
=
std
::
numeric_limits
<
double
>::
infinity
();
task
->
label
=
ilabel
;
while
(
cur
!=
end
&&
cur
->
first
==
ilabel
)
{
task
->
subset
.
push_back
(
cur
->
second
);
const
Element
&
element
=
cur
->
second
;
// Note: we'll later include the term "forward_cost" in the
// priority_cost.
task
->
priority_cost
=
std
::
min
(
task
->
priority_cost
,
ConvertToCost
(
element
.
weight
)
+
backward_costs_
[
element
.
state
]);
cur
++
;
}
// After the command below, the "priority_cost" is a value comparable to
// the total-weight of the input FST, like a total-path weight... of
// course, it will typically be less (in the semiring) than that.
// note: we represent it just as a double.
task
->
priority_cost
+=
output_states_
[
output_state_id
]
->
forward_cost
;
if
(
task
->
priority_cost
>
cutoff_
)
{
// This task would never get done as it's past the pruning cutoff.
delete
task
;
}
else
{
MakeSubsetUnique
(
&
(
task
->
subset
));
// remove duplicate Elements with the same state.
queue_
.
push
(
task
);
// Push the task onto the queue. The queue keeps it
// in prioritized order, so we always process the one with the "best"
// weight (highest in the semiring).
{
// this is a check.
double
best_cost
=
backward_costs_
[
ifst_
->
Start
()],
tolerance
=
0.01
+
1.0e-04
*
std
::
abs
(
best_cost
);
if
(
task
->
priority_cost
<
best_cost
-
tolerance
)
{
KALDI_WARN
<<
"Cost below best cost was encountered:"
<<
task
->
priority_cost
<<
" < "
<<
best_cost
;
}
}
}
}
all_elems
.
clear
();
// as it's a reference to a class variable; we want it to stay
// empty.
}
bool
IsIsymbolOrFinal
(
InputStateId
state
)
{
// returns true if this state
// of the input FST either is final or has an osymbol on an arc out of it.
// Uses the vector isymbol_or_final_ as a cache for this info.
KALDI_ASSERT
(
state
>=
0
);
if
(
isymbol_or_final_
.
size
()
<=
state
)
isymbol_or_final_
.
resize
(
state
+
1
,
static_cast
<
char
>
(
OSF_UNKNOWN
));
if
(
isymbol_or_final_
[
state
]
==
static_cast
<
char
>
(
OSF_NO
))
return
false
;
else
if
(
isymbol_or_final_
[
state
]
==
static_cast
<
char
>
(
OSF_YES
))
return
true
;
// else work it out...
isymbol_or_final_
[
state
]
=
static_cast
<
char
>
(
OSF_NO
);
if
(
ifst_
->
Final
(
state
)
!=
Weight
::
Zero
())
isymbol_or_final_
[
state
]
=
static_cast
<
char
>
(
OSF_YES
);
for
(
ArcIterator
<
ExpandedFst
<
Arc
>
>
aiter
(
*
ifst_
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
&&
arc
.
weight
!=
Weight
::
Zero
())
{
isymbol_or_final_
[
state
]
=
static_cast
<
char
>
(
OSF_YES
);
return
true
;
}
}
return
IsIsymbolOrFinal
(
state
);
// will only recurse once.
}
void
ComputeBackwardWeight
()
{
// Sets up the backward_costs_ array, and the cutoff_ variable.
KALDI_ASSERT
(
beam_
>
0
);
// Only handle the toplogically sorted case.
backward_costs_
.
resize
(
ifst_
->
NumStates
());
for
(
StateId
s
=
ifst_
->
NumStates
()
-
1
;
s
>=
0
;
s
--
)
{
double
&
cost
=
backward_costs_
[
s
];
cost
=
ConvertToCost
(
ifst_
->
Final
(
s
));
for
(
ArcIterator
<
ExpandedFst
<
Arc
>
>
aiter
(
*
ifst_
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
cost
=
std
::
min
(
cost
,
ConvertToCost
(
arc
.
weight
)
+
backward_costs_
[
arc
.
nextstate
]);
}
}
if
(
ifst_
->
Start
()
==
kNoStateId
)
return
;
// we'll be returning
// an empty FST.
double
best_cost
=
backward_costs_
[
ifst_
->
Start
()];
if
(
best_cost
==
std
::
numeric_limits
<
double
>::
infinity
())
KALDI_WARN
<<
"Total weight of input lattice is zero."
;
cutoff_
=
best_cost
+
beam_
;
}
void
InitializeDeterminization
()
{
// We insist that the input lattice be topologically sorted. This is not a
// fundamental limitation of the algorithm (which in principle should be
// applicable to even cyclic FSTs), but it helps us more efficiently
// compute the backward_costs_ array. There may be some other reason we
// require this, that escapes me at the moment.
KALDI_ASSERT
(
ifst_
->
Properties
(
kTopSorted
,
true
)
!=
0
);
ComputeBackwardWeight
();
#if !(__GNUC__ == 4 && __GNUC_MINOR__ == 0)
if
(
ifst_
->
Properties
(
kExpanded
,
false
)
!=
0
)
{
// if we know the number of
// states in ifst_, it might be a bit more efficient
// to pre-size the hashes so we're not constantly rebuilding them.
StateId
num_states
=
down_cast
<
const
ExpandedFst
<
Arc
>*
,
const
Fst
<
Arc
>
>
(
ifst_
)
->
NumStates
();
minimal_hash_
.
rehash
(
num_states
/
2
+
3
);
initial_hash_
.
rehash
(
num_states
/
2
+
3
);
}
#endif
InputStateId
start_id
=
ifst_
->
Start
();
if
(
start_id
!=
kNoStateId
)
{
/* Create determinized-state corresponding to the start state....
Unlike all the other states, we don't "normalize" the representation
of this determinized-state before we put it into minimal_hash_. This is actually
what we want, as otherwise we'd have problems dealing with any extra weight
and string and might have to create a "super-initial" state which would make
the output nondeterministic. Normalization is only needed to make the
determinized output more minimal anyway, it's not needed for correctness.
Note, we don't put anything in the initial_hash_. The initial_hash_ is only
a lookaside buffer anyway, so this isn't a problem-- it will get populated
later if it needs to be.
*/
vector
<
Element
>
subset
(
1
);
subset
[
0
].
state
=
start_id
;
subset
[
0
].
weight
=
Weight
::
One
();
subset
[
0
].
string
=
repository_
.
EmptyString
();
// Id of empty sequence.
EpsilonClosure
(
&
subset
);
// follow through epsilon-input links
ConvertToMinimal
(
&
subset
);
// remove all but final states and
// states with input-labels on arcs out of them.
// Weight::One() is the "forward-weight" of this determinized state...
// i.e. the minimal cost from the start of the determinized FST to this
// state [One() because it's the start state].
OutputState
*
initial_state
=
new
OutputState
(
subset
,
0
);
KALDI_ASSERT
(
output_states_
.
empty
());
output_states_
.
push_back
(
initial_state
);
num_elems_
+=
subset
.
size
();
OutputStateId
initial_state_id
=
0
;
minimal_hash_
[
&
(
initial_state
->
minimal_subset
)]
=
initial_state_id
;
ProcessFinal
(
initial_state_id
);
ProcessTransitions
(
initial_state_id
);
// this will add tasks to
// the queue, which we'll start processing in Determinize().
}
}
KALDI_DISALLOW_COPY_AND_ASSIGN
(
LatticeDeterminizerPruned
);
struct
OutputState
{
vector
<
Element
>
minimal_subset
;
vector
<
TempArc
>
arcs
;
// arcs out of the state-- those that have been processed.
// Note: the final-weight is included here with kNoStateId as the state id. We
// always process the final-weight regardless of the beam; when producing the
// output we may have to ignore some of these.
double
forward_cost
;
// Represents minimal cost from start-state
// to this state. Used in prioritization of tasks, and pruning.
// Note: we know this minimal cost from when we first create the OutputState;
// this is because of the priority-queue we use, that ensures that the
// "best" path into the state will be expanded first.
OutputState
(
const
vector
<
Element
>
&
minimal_subset
,
double
forward_cost
)
:
minimal_subset
(
minimal_subset
),
forward_cost
(
forward_cost
)
{
}
};
vector
<
OutputState
*>
output_states_
;
// All the info about the output states.
int
num_arcs_
;
// keep track of memory usage: number of arcs in output_states_[ ]->arcs
int
num_elems_
;
// keep track of memory usage: number of elems in output_states_ and
// the keys of initial_hash_
const
ExpandedFst
<
Arc
>
*
ifst_
;
std
::
vector
<
double
>
backward_costs_
;
// This vector stores, for every state in ifst_,
// the minimal cost to the end-state (i.e. the sum of weights; they are guaranteed to
// have "take-the-minimum" semantics). We get the double from the ConvertToCost()
// function on the lattice weights.
double
beam_
;
double
cutoff_
;
// beam plus total-weight of input (and note, the weight is
// guaranteed to be "tropical-like" so the sum does represent a min-cost.
DeterminizeLatticePrunedOptions
opts_
;
SubsetKey
hasher_
;
// object that computes keys-- has no data members.
SubsetEqual
equal_
;
// object that compares subsets-- only data member is delta_.
bool
determinized_
;
// set to true when user called Determinize(); used to make
// sure this object is used correctly.
MinimalSubsetHash
minimal_hash_
;
// hash from Subset to OutputStateId. Subset is "minimal
// representation" (only include final and states and states with
// nonzero ilabel on arc out of them. Owns the pointers
// in its keys.
InitialSubsetHash
initial_hash_
;
// hash from Subset to Element, which
// represents the OutputStateId together
// with an extra weight and string. Subset
// is "initial representation". The extra
// weight and string is needed because after
// we convert to minimal representation and
// normalize, there may be an extra weight
// and string. Owns the pointers
// in its keys.
struct
Task
{
OutputStateId
state
;
// State from which we're processing the transition.
Label
label
;
// Label on the transition we're processing out of this state.
vector
<
Element
>
subset
;
// Weighted subset of states (with strings)-- not normalized.
double
priority_cost
;
// Cost used in deciding priority of tasks. Note:
// we assume there is a ConvertToCost() function that converts the semiring to double.
};
struct
TaskCompare
{
inline
int
operator
()
(
const
Task
*
t1
,
const
Task
*
t2
)
{
// view this like operator <, which is the default template parameter
// to std::priority_queue.
// returns true if t1 is worse than t2.
return
(
t1
->
priority_cost
>
t2
->
priority_cost
);
}
};
// This priority queue contains "Task"s to be processed; these correspond
// to transitions out of determinized states. We process these in priority
// order according to the best weight of any path passing through these
// determinized states... it's possible to work this out.
std
::
priority_queue
<
Task
*
,
vector
<
Task
*>
,
TaskCompare
>
queue_
;
vector
<
pair
<
Label
,
Element
>
>
all_elems_tmp_
;
// temporary vector used in ProcessTransitions.
enum
IsymbolOrFinal
{
OSF_UNKNOWN
=
0
,
OSF_NO
=
1
,
OSF_YES
=
2
};
vector
<
char
>
isymbol_or_final_
;
// A kind of cache; it says whether
// each state is (emitting or final) where emitting means it has at least one
// non-epsilon output arc. Only accessed by IsIsymbolOrFinal()
LatticeStringRepository
<
IntType
>
repository_
;
// defines a compact and fast way of
// storing sequences of labels.
void
AddStrings
(
const
vector
<
Element
>
&
vec
,
vector
<
StringId
>
*
needed_strings
)
{
for
(
typename
std
::
vector
<
Element
>::
const_iterator
iter
=
vec
.
begin
();
iter
!=
vec
.
end
();
++
iter
)
needed_strings
->
push_back
(
iter
->
string
);
}
};
// normally Weight would be LatticeWeight<float> (which has two floats),
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
// Caution: there are two versions of the function DeterminizeLatticePruned,
// with identical code but different output FST types.
template
<
class
Weight
,
class
IntType
>
bool
DeterminizeLatticePruned
(
const
ExpandedFst
<
ArcTpl
<
Weight
>
>&
ifst
,
double
beam
,
MutableFst
<
ArcTpl
<
CompactLatticeWeightTpl
<
Weight
,
IntType
>
>
>*
ofst
,
DeterminizeLatticePrunedOptions
opts
)
{
ofst
->
SetInputSymbols
(
ifst
.
InputSymbols
());
ofst
->
SetOutputSymbols
(
ifst
.
OutputSymbols
());
if
(
ifst
.
NumStates
()
==
0
)
{
ofst
->
DeleteStates
();
return
true
;
}
KALDI_ASSERT
(
opts
.
retry_cutoff
>=
0.0
&&
opts
.
retry_cutoff
<
1.0
);
int32
max_num_iters
=
10
;
// avoid the potential for infinite loops if
// retrying.
VectorFst
<
ArcTpl
<
Weight
>
>
temp_fst
;
for
(
int32
iter
=
0
;
iter
<
max_num_iters
;
iter
++
)
{
LatticeDeterminizerPruned
<
Weight
,
IntType
>
det
(
iter
==
0
?
ifst
:
temp_fst
,
beam
,
opts
);
double
effective_beam
;
bool
ans
=
det
.
Determinize
(
&
effective_beam
);
// if it returns false it will typically still produce reasonable output,
// just with a narrower beam than "beam". If the user specifies an infinite
// beam we don't do this beam-narrowing.
if
(
effective_beam
>=
beam
*
opts
.
retry_cutoff
||
beam
==
std
::
numeric_limits
<
double
>::
infinity
()
||
iter
+
1
==
max_num_iters
)
{
det
.
Output
(
ofst
);
return
ans
;
}
else
{
// The code below to set "beam" is a heuristic.
// If effective_beam is very small, we want to reduce by a lot.
// But never change the beam by more than a factor of two.
if
(
effective_beam
<
0.0
)
effective_beam
=
0.0
;
double
new_beam
=
beam
*
sqrt
(
effective_beam
/
beam
);
if
(
new_beam
<
0.5
*
beam
)
new_beam
=
0.5
*
beam
;
beam
=
new_beam
;
if
(
iter
==
0
)
temp_fst
=
ifst
;
kaldi
::
PruneLattice
(
beam
,
&
temp_fst
);
KALDI_LOG
<<
"Pruned state-level lattice with beam "
<<
beam
<<
" and retrying determinization with that beam."
;
}
}
return
false
;
// Suppress compiler warning; this code is unreachable.
}
// normally Weight would be LatticeWeight<float> (which has two floats),
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
// Caution: there are two versions of the function DeterminizeLatticePruned,
// with identical code but different output FST types.
template
<
class
Weight
>
bool
DeterminizeLatticePruned
(
const
ExpandedFst
<
ArcTpl
<
Weight
>
>
&
ifst
,
double
beam
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
ofst
,
DeterminizeLatticePrunedOptions
opts
)
{
typedef
int32
IntType
;
ofst
->
SetInputSymbols
(
ifst
.
InputSymbols
());
ofst
->
SetOutputSymbols
(
ifst
.
OutputSymbols
());
KALDI_ASSERT
(
opts
.
retry_cutoff
>=
0.0
&&
opts
.
retry_cutoff
<
1.0
);
if
(
ifst
.
NumStates
()
==
0
)
{
ofst
->
DeleteStates
();
return
true
;
}
int32
max_num_iters
=
10
;
// avoid the potential for infinite loops if
// retrying.
VectorFst
<
ArcTpl
<
Weight
>
>
temp_fst
;
for
(
int32
iter
=
0
;
iter
<
max_num_iters
;
iter
++
)
{
LatticeDeterminizerPruned
<
Weight
,
IntType
>
det
(
iter
==
0
?
ifst
:
temp_fst
,
beam
,
opts
);
double
effective_beam
;
bool
ans
=
det
.
Determinize
(
&
effective_beam
);
// if it returns false it will typically still
// produce reasonable output, just with a
// narrower beam than "beam".
if
(
effective_beam
>=
beam
*
opts
.
retry_cutoff
||
iter
+
1
==
max_num_iters
)
{
det
.
Output
(
ofst
);
return
ans
;
}
else
{
// The code below to set "beam" is a heuristic.
// If effective_beam is very small, we want to reduce by a lot.
// But never change the beam by more than a factor of two.
if
(
effective_beam
<
0
)
effective_beam
=
0
;
double
new_beam
=
beam
*
sqrt
(
effective_beam
/
beam
);
if
(
new_beam
<
0.5
*
beam
)
new_beam
=
0.5
*
beam
;
KALDI_WARN
<<
"Effective beam "
<<
effective_beam
<<
" was less than beam "
<<
beam
<<
" * cutoff "
<<
opts
.
retry_cutoff
<<
", pruning raw "
<<
"lattice with new beam "
<<
new_beam
<<
" and retrying."
;
beam
=
new_beam
;
if
(
iter
==
0
)
temp_fst
=
ifst
;
kaldi
::
PruneLattice
(
beam
,
&
temp_fst
);
}
}
return
false
;
// Suppress compiler warning; this code is unreachable.
}
template
<
class
Weight
>
typename
ArcTpl
<
Weight
>::
Label
DeterminizeLatticeInsertPhones
(
const
kaldi
::
TransitionInformation
&
trans_model
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
fst
)
{
// Define some types.
typedef
ArcTpl
<
Weight
>
Arc
;
typedef
typename
Arc
::
StateId
StateId
;
typedef
typename
Arc
::
Label
Label
;
// Work out the first phone symbol. This is more related to the phone
// insertion function, so we put it here and make it the returning value of
// DeterminizeLatticeInsertPhones().
Label
first_phone_label
=
HighestNumberedInputSymbol
(
*
fst
)
+
1
;
// Insert phones here.
for
(
StateIterator
<
MutableFst
<
Arc
>
>
siter
(
*
fst
);
!
siter
.
Done
();
siter
.
Next
())
{
StateId
state
=
siter
.
Value
();
if
(
state
==
fst
->
Start
())
continue
;
for
(
MutableArcIterator
<
MutableFst
<
Arc
>
>
aiter
(
fst
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
Arc
arc
=
aiter
.
Value
();
// Note: the words are on the input symbol side and transition-id's are on
// the output symbol side.
if
((
arc
.
olabel
!=
0
)
&&
(
trans_model
.
TransitionIdIsStartOfPhone
(
arc
.
olabel
))
&&
(
!
trans_model
.
IsSelfLoop
(
arc
.
olabel
)))
{
Label
phone
=
static_cast
<
Label
>
(
trans_model
.
TransitionIdToPhone
(
arc
.
olabel
));
// Skips <eps>.
KALDI_ASSERT
(
phone
!=
0
);
if
(
arc
.
ilabel
==
0
)
{
// If there is no word on the arc, insert the phone directly.
arc
.
ilabel
=
first_phone_label
+
phone
;
}
else
{
// Otherwise, add an additional arc.
StateId
additional_state
=
fst
->
AddState
();
StateId
next_state
=
arc
.
nextstate
;
arc
.
nextstate
=
additional_state
;
fst
->
AddArc
(
additional_state
,
Arc
(
first_phone_label
+
phone
,
0
,
Weight
::
One
(),
next_state
));
}
}
aiter
.
SetValue
(
arc
);
}
}
return
first_phone_label
;
}
template
<
class
Weight
>
void
DeterminizeLatticeDeletePhones
(
typename
ArcTpl
<
Weight
>::
Label
first_phone_label
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
fst
)
{
// Define some types.
typedef
ArcTpl
<
Weight
>
Arc
;
typedef
typename
Arc
::
StateId
StateId
;
typedef
typename
Arc
::
Label
Label
;
// Delete phones here.
for
(
StateIterator
<
MutableFst
<
Arc
>
>
siter
(
*
fst
);
!
siter
.
Done
();
siter
.
Next
())
{
StateId
state
=
siter
.
Value
();
for
(
MutableArcIterator
<
MutableFst
<
Arc
>
>
aiter
(
fst
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
Arc
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
>=
first_phone_label
)
arc
.
ilabel
=
0
;
aiter
.
SetValue
(
arc
);
}
}
}
// instantiate for type LatticeWeight
template
void
DeterminizeLatticeDeletePhones
(
ArcTpl
<
kaldi
::
LatticeWeight
>
::
Label
first_phone_label
,
MutableFst
<
ArcTpl
<
kaldi
::
LatticeWeight
>
>
*
fst
);
/** This function does a first pass determinization with phone symbols inserted
at phone boundary. It uses a transition model to work out the transition-id
to phone map. First, phones will be inserted into the word level lattice.
Second, determinization will be applied on top of the phone + word lattice.
Finally, the inserted phones will be removed, converting the lattice back to
a word level lattice. The output lattice of this pass is not deterministic,
since we remove the phone symbols as a last step. It is supposed to be
followed by another pass of determinization at the word level. It could also
be useful for some other applications such as fMLLR estimation, confidence
estimation, discriminative training, etc.
*/
template
<
class
Weight
,
class
IntType
>
bool
DeterminizeLatticePhonePrunedFirstPass
(
const
kaldi
::
TransitionInformation
&
trans_model
,
double
beam
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
fst
,
const
DeterminizeLatticePrunedOptions
&
opts
)
{
// First, insert the phones.
typename
ArcTpl
<
Weight
>::
Label
first_phone_label
=
DeterminizeLatticeInsertPhones
(
trans_model
,
fst
);
TopSort
(
fst
);
// Second, do determinization with phone inserted.
bool
ans
=
DeterminizeLatticePruned
<
Weight
>
(
*
fst
,
beam
,
fst
,
opts
);
// Finally, remove the inserted phones.
DeterminizeLatticeDeletePhones
(
first_phone_label
,
fst
);
TopSort
(
fst
);
return
ans
;
}
// "Destructive" version of DeterminizeLatticePhonePruned() where the input
// lattice might be modified.
template
<
class
Weight
,
class
IntType
>
bool
DeterminizeLatticePhonePruned
(
const
kaldi
::
TransitionInformation
&
trans_model
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
ifst
,
double
beam
,
MutableFst
<
ArcTpl
<
CompactLatticeWeightTpl
<
Weight
,
IntType
>
>
>
*
ofst
,
DeterminizeLatticePhonePrunedOptions
opts
)
{
// Returning status.
bool
ans
=
true
;
// Make sure at least one of opts.phone_determinize and opts.word_determinize
// is not false, otherwise calling this function doesn't make any sense.
if
((
opts
.
phone_determinize
||
opts
.
word_determinize
)
==
false
)
{
KALDI_WARN
<<
"Both --phone-determinize and --word-determinize are set to "
<<
"false, copying lattice without determinization."
;
// We are expecting the words on the input side.
ConvertLattice
<
Weight
,
IntType
>
(
*
ifst
,
ofst
,
false
);
return
ans
;
}
// Determinization options.
DeterminizeLatticePrunedOptions
det_opts
;
det_opts
.
delta
=
opts
.
delta
;
det_opts
.
max_mem
=
opts
.
max_mem
;
// If --phone-determinize is true, do the determinization on phone + word
// lattices.
if
(
opts
.
phone_determinize
)
{
KALDI_VLOG
(
3
)
<<
"Doing first pass of determinization on phone + word "
<<
"lattices."
;
ans
=
DeterminizeLatticePhonePrunedFirstPass
<
Weight
,
IntType
>
(
trans_model
,
beam
,
ifst
,
det_opts
)
&&
ans
;
// If --word-determinize is false, we've finished the job and return here.
if
(
!
opts
.
word_determinize
)
{
// We are expecting the words on the input side.
ConvertLattice
<
Weight
,
IntType
>
(
*
ifst
,
ofst
,
false
);
return
ans
;
}
}
// If --word-determinize is true, do the determinization on word lattices.
if
(
opts
.
word_determinize
)
{
KALDI_VLOG
(
3
)
<<
"Doing second pass of determinization on word lattices."
;
ans
=
DeterminizeLatticePruned
<
Weight
,
IntType
>
(
*
ifst
,
beam
,
ofst
,
det_opts
)
&&
ans
;
}
// If --minimize is true, push and minimize after determinization.
if
(
opts
.
minimize
)
{
KALDI_VLOG
(
3
)
<<
"Pushing and minimizing on word lattices."
;
ans
=
PushCompactLatticeStrings
<
Weight
,
IntType
>
(
ofst
)
&&
ans
;
ans
=
PushCompactLatticeWeights
<
Weight
,
IntType
>
(
ofst
)
&&
ans
;
ans
=
MinimizeCompactLattice
<
Weight
,
IntType
>
(
ofst
)
&&
ans
;
}
return
ans
;
}
// Normal verson of DeterminizeLatticePhonePruned(), where the input lattice
// will be kept as unchanged.
template
<
class
Weight
,
class
IntType
>
bool
DeterminizeLatticePhonePruned
(
const
kaldi
::
TransitionInformation
&
trans_model
,
const
ExpandedFst
<
ArcTpl
<
Weight
>
>
&
ifst
,
double
beam
,
MutableFst
<
ArcTpl
<
CompactLatticeWeightTpl
<
Weight
,
IntType
>
>
>
*
ofst
,
DeterminizeLatticePhonePrunedOptions
opts
)
{
VectorFst
<
ArcTpl
<
Weight
>
>
temp_fst
(
ifst
);
return
DeterminizeLatticePhonePruned
(
trans_model
,
&
temp_fst
,
beam
,
ofst
,
opts
);
}
bool
DeterminizeLatticePhonePrunedWrapper
(
const
kaldi
::
TransitionInformation
&
trans_model
,
MutableFst
<
kaldi
::
LatticeArc
>
*
ifst
,
double
beam
,
MutableFst
<
kaldi
::
CompactLatticeArc
>
*
ofst
,
DeterminizeLatticePhonePrunedOptions
opts
)
{
bool
ans
=
true
;
Invert
(
ifst
);
if
(
ifst
->
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
if
(
!
TopSort
(
ifst
))
{
// Cannot topologically sort the lattice -- determinization will fail.
KALDI_ERR
<<
"Topological sorting of state-level lattice failed (probably"
<<
" your lexicon has empty words or your LM has epsilon cycles"
<<
")."
;
}
}
ILabelCompare
<
kaldi
::
LatticeArc
>
ilabel_comp
;
ArcSort
(
ifst
,
ilabel_comp
);
ans
=
DeterminizeLatticePhonePruned
<
kaldi
::
LatticeWeight
,
kaldi
::
int32
>
(
trans_model
,
ifst
,
beam
,
ofst
,
opts
);
Connect
(
ofst
);
return
ans
;
}
// Instantiate the templates for the types we might need.
// Note: there are actually four templates, each of which
// we instantiate for a single type.
template
bool
DeterminizeLatticePruned
<
kaldi
::
LatticeWeight
>(
const
ExpandedFst
<
kaldi
::
LatticeArc
>
&
ifst
,
double
prune
,
MutableFst
<
kaldi
::
CompactLatticeArc
>
*
ofst
,
DeterminizeLatticePrunedOptions
opts
);
template
bool
DeterminizeLatticePruned
<
kaldi
::
LatticeWeight
>(
const
ExpandedFst
<
kaldi
::
LatticeArc
>
&
ifst
,
double
prune
,
MutableFst
<
kaldi
::
LatticeArc
>
*
ofst
,
DeterminizeLatticePrunedOptions
opts
);
template
bool
DeterminizeLatticePhonePruned
<
kaldi
::
LatticeWeight
,
kaldi
::
int32
>(
const
kaldi
::
TransitionInformation
&
trans_model
,
const
ExpandedFst
<
kaldi
::
LatticeArc
>
&
ifst
,
double
prune
,
MutableFst
<
kaldi
::
CompactLatticeArc
>
*
ofst
,
DeterminizeLatticePhonePrunedOptions
opts
);
template
bool
DeterminizeLatticePhonePruned
<
kaldi
::
LatticeWeight
,
kaldi
::
int32
>(
const
kaldi
::
TransitionInformation
&
trans_model
,
MutableFst
<
kaldi
::
LatticeArc
>
*
ifst
,
double
prune
,
MutableFst
<
kaldi
::
CompactLatticeArc
>
*
ofst
,
DeterminizeLatticePhonePrunedOptions
opts
);
}
speechx/speechx/kaldi/lat/determinize-lattice-pruned.h
0 → 100644
浏览文件 @
d14ee800
// lat/determinize-lattice-pruned.h
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#define KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#include <fst/fstlib.h>
#include <fst/fst-decl.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/lattice-weight.h"
#include "itf/transition-information.h"
#include "itf/options-itf.h"
#include "lat/kaldi-lattice.h"
namespace
fst
{
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-lattice-pruned.cc
/*
DeterminizeLatticePruned implements a special form of determinization with
epsilon removal, optimized for a phase of lattice generation. This algorithm
also does pruning at the same time-- the combination is more efficient as it
somtimes prevents us from creating a lot of states that would later be pruned
away. This allows us to increase the lattice-beam and not have the algorithm
blow up. Also, because our algorithm processes states in order from those
that appear on high-scoring paths down to those that appear on low-scoring
paths, we can easily terminate the algorithm after a certain specified number
of states or arcs.
The input is an FST with weight-type BaseWeightType (usually a pair of floats,
with a lexicographical type of order, such as LatticeWeightTpl<float>).
Typically this would be a state-level lattice, with input symbols equal to
words, and output-symbols equal to p.d.f's (so like the inverse of HCLG). Imagine representing this as an
acceptor of type CompactLatticeWeightTpl<float>, in which the input/output
symbols are words, and the weights contain the original weights together with
strings (with zero or one symbol in them) containing the original output labels
(the p.d.f.'s). We determinize this using acceptor determinization with
epsilon removal. Remember (from lattice-weight.h) that
CompactLatticeWeightTpl has a special kind of semiring where we always take
the string corresponding to the best cost (of type BaseWeightType), and
discard the other. This corresponds to taking the best output-label sequence
(of p.d.f.'s) for each input-label sequence (of words). We couldn't use the
Gallic weight for this, or it would die as soon as it detected that the input
FST was non-functional. In our case, any acyclic FST (and many cyclic ones)
can be determinized.
We assume that there is a function
Compare(const BaseWeightType &a, const BaseWeightType &b)
that returns (-1, 0, 1) according to whether (a < b, a == b, a > b) in the
total order on the BaseWeightType... this information should be the
same as NaturalLess would give, but it's more efficient to do it this way.
You can define this for things like TropicalWeight if you need to instantiate
this class for that weight type.
We implement this determinization in a special way to make it efficient for
the types of FSTs that we will apply it to. One issue is that if we
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
type vector<IntType>, the algorithm takes time quadratic in the length of
words (in states), because propagating each arc involves copying a whole
vector (of integers representing p.d.f.'s). Instead we use a hash structure
where each string is a pointer (Entry*), and uses a hash from (Entry*,
IntType), to the successor string (and a way to get the latest IntType and the
ancestor Entry*). [this is the class LatticeStringRepository].
Another issue is that rather than representing a determinized-state as a
collection of (state, weight), we represent it in a couple of reduced forms.
Suppose a determinized-state is a collection of (state, weight) pairs; call
this the "canonical representation". Note: these collections are always
normalized to remove any common weight and string part. Define end-states as
the subset of states that have an arc out of them with a label on, or are
final. If we represent a determinized-state a the set of just its (end-state,
weight) pairs, this will be a valid and more compact representation, and will
lead to a smaller set of determinized states (like early minimization). Call
this collection of (end-state, weight) pairs the "minimal representation". As
a mechanism to reduce compute, we can also consider another representation.
In the determinization algorithm, we start off with a set of (begin-state,
weight) pairs (where the "begin-states" are initial or have a label on the
transition into them), and the "canonical representation" consists of the
epsilon-closure of this set (i.e. follow epsilons). Call this set of
(begin-state, weight) pairs, appropriately normalized, the "initial
representation". If two initial representations are the same, the "canonical
representation" and hence the "minimal representation" will be the same. We
can use this to reduce compute. Note that if two initial representations are
different, this does not preclude the other representations from being the same.
*/
struct
DeterminizeLatticePrunedOptions
{
float
delta
;
// A small offset used to measure equality of weights.
int
max_mem
;
// If >0, determinization will fail and return false
// when the algorithm's (approximate) memory consumption crosses this threshold.
int
max_loop
;
// If >0, can be used to detect non-determinizable input
// (a case that wouldn't be caught by max_mem).
int
max_states
;
int
max_arcs
;
float
retry_cutoff
;
DeterminizeLatticePrunedOptions
()
:
delta
(
kDelta
),
max_mem
(
-
1
),
max_loop
(
-
1
),
max_states
(
-
1
),
max_arcs
(
-
1
),
retry_cutoff
(
0.5
)
{
}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"delta"
,
&
delta
,
"Tolerance used in determinization"
);
opts
->
Register
(
"max-mem"
,
&
max_mem
,
"Maximum approximate memory usage in "
"determinization (real usage might be many times this)"
);
opts
->
Register
(
"max-arcs"
,
&
max_arcs
,
"Maximum number of arcs in "
"output FST (total, not per state"
);
opts
->
Register
(
"max-states"
,
&
max_states
,
"Maximum number of arcs in output "
"FST (total, not per state"
);
opts
->
Register
(
"max-loop"
,
&
max_loop
,
"Option used to detect a particular "
"type of determinization failure, typically due to invalid input "
"(e.g., negative-cost loops)"
);
opts
->
Register
(
"retry-cutoff"
,
&
retry_cutoff
,
"Controls pruning un-determinized "
"lattice and retrying determinization: if effective-beam < "
"retry-cutoff * beam, we prune the raw lattice and retry. Avoids "
"ever getting empty output for long segments."
);
}
};
struct
DeterminizeLatticePhonePrunedOptions
{
// delta: a small offset used to measure equality of weights.
float
delta
;
// max_mem: if > 0, determinization will fail and return false when the
// algorithm's (approximate) memory consumption crosses this threshold.
int
max_mem
;
// phone_determinize: if true, do a first pass determinization on both phones
// and words.
bool
phone_determinize
;
// word_determinize: if true, do a second pass determinization on words only.
bool
word_determinize
;
// minimize: if true, push and minimize after determinization.
bool
minimize
;
DeterminizeLatticePhonePrunedOptions
()
:
delta
(
kDelta
),
max_mem
(
50000000
),
phone_determinize
(
true
),
word_determinize
(
true
),
minimize
(
false
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"delta"
,
&
delta
,
"Tolerance used in determinization"
);
opts
->
Register
(
"max-mem"
,
&
max_mem
,
"Maximum approximate memory usage in "
"determinization (real usage might be many times this)."
);
opts
->
Register
(
"phone-determinize"
,
&
phone_determinize
,
"If true, do an "
"initial pass of determinization on both phones and words (see"
" also --word-determinize)"
);
opts
->
Register
(
"word-determinize"
,
&
word_determinize
,
"If true, do a second "
"pass of determinization on words only (see also "
"--phone-determinize)"
);
opts
->
Register
(
"minimize"
,
&
minimize
,
"If true, push and minimize after "
"determinization."
);
}
};
/**
This function implements the normal version of DeterminizeLattice, in which the
output strings are represented using sequences of arcs, where all but the
first one has an epsilon on the input side. It also prunes using the beam
in the "prune" parameter. The input FST must be topologically sorted in order
for the algorithm to work. For efficiency it is recommended to sort ilabel as well.
Returns true on success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: you may want to use the version below which outputs to CompactLattice.
*/
template
<
class
Weight
>
bool
DeterminizeLatticePruned
(
const
ExpandedFst
<
ArcTpl
<
Weight
>
>
&
ifst
,
double
prune
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
ofst
,
DeterminizeLatticePrunedOptions
opts
=
DeterminizeLatticePrunedOptions
());
/* This is a version of DeterminizeLattice with a slightly more "natural" output format,
where the output sequences are encoded using the CompactLatticeArcTpl template
(i.e. the sequences of output symbols are represented directly as strings The input
FST must be topologically sorted in order for the algorithm to work. For efficiency
it is recommended to sort the ilabel for the input FST as well.
Returns true on normal success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: if Lattice is the input, you need to Invert() before calling this,
so words are on the input side.
*/
template
<
class
Weight
,
class
IntType
>
bool
DeterminizeLatticePruned
(
const
ExpandedFst
<
ArcTpl
<
Weight
>
>&
ifst
,
double
prune
,
MutableFst
<
ArcTpl
<
CompactLatticeWeightTpl
<
Weight
,
IntType
>
>
>
*
ofst
,
DeterminizeLatticePrunedOptions
opts
=
DeterminizeLatticePrunedOptions
());
/** This function takes in lattices and inserts phones at phone boundaries. It
uses the transition model to work out the transition_id to phone map. The
returning value is the starting index of the phone label. Typically we pick
(maximum_output_label_index + 1) as this value. The inserted phones are then
mapped to (returning_value + original_phone_label) in the new lattice. The
returning value will be used by DeterminizeLatticeDeletePhones() where it
works out the phones according to this value.
*/
template
<
class
Weight
>
typename
ArcTpl
<
Weight
>::
Label
DeterminizeLatticeInsertPhones
(
const
kaldi
::
TransitionInformation
&
trans_model
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
fst
);
/** This function takes in lattices and deletes "phones" from them. The "phones"
here are actually any label that is larger than first_phone_label because
when we insert phones into the lattice, we map the original phone label to
(first_phone_label + original_phone_label). It is supposed to be used
together with DeterminizeLatticeInsertPhones()
*/
template
<
class
Weight
>
void
DeterminizeLatticeDeletePhones
(
typename
ArcTpl
<
Weight
>::
Label
first_phone_label
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
fst
);
/** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
determinization on the phone + word lattices. If --word-determinize is set
true, it then does a second pass of determinization on the word lattices by
calling DeterminizeLatticePruned(). If both are set to false, then it gives
a warning and copying the lattices without determinization.
Note: the point of doing first a phone-level determinization pass and then
a word-level determinization pass is that it allows us to determinize
deeper lattices without "failing early" and returning a too-small lattice
due to the max-mem constraint. The result should be the same as word-level
determinization in general, but for deeper lattices it is a bit faster,
despite the fact that we now have two passes of determinization by default.
*/
template
<
class
Weight
,
class
IntType
>
bool
DeterminizeLatticePhonePruned
(
const
kaldi
::
TransitionInformation
&
trans_model
,
const
ExpandedFst
<
ArcTpl
<
Weight
>
>
&
ifst
,
double
prune
,
MutableFst
<
ArcTpl
<
CompactLatticeWeightTpl
<
Weight
,
IntType
>
>
>
*
ofst
,
DeterminizeLatticePhonePrunedOptions
opts
=
DeterminizeLatticePhonePrunedOptions
());
/** "Destructive" version of DeterminizeLatticePhonePruned() where the input
lattice might be changed.
*/
template
<
class
Weight
,
class
IntType
>
bool
DeterminizeLatticePhonePruned
(
const
kaldi
::
TransitionInformation
&
trans_model
,
MutableFst
<
ArcTpl
<
Weight
>
>
*
ifst
,
double
prune
,
MutableFst
<
ArcTpl
<
CompactLatticeWeightTpl
<
Weight
,
IntType
>
>
>
*
ofst
,
DeterminizeLatticePhonePrunedOptions
opts
=
DeterminizeLatticePhonePrunedOptions
());
/** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
Lattice type FSTs. It simplifies the calling process by calling
TopSort() Invert() and ArcSort() for you.
Unlike other determinization routines, the function
requires "ifst" to have transition-id's on the input side and words on the
output side.
This function can be used as the top-level interface to all the determinization
code.
*/
bool
DeterminizeLatticePhonePrunedWrapper
(
const
kaldi
::
TransitionInformation
&
trans_model
,
MutableFst
<
kaldi
::
LatticeArc
>
*
ifst
,
double
prune
,
MutableFst
<
kaldi
::
CompactLatticeArc
>
*
ofst
,
DeterminizeLatticePhonePrunedOptions
opts
=
DeterminizeLatticePhonePrunedOptions
());
/// @} end "addtogroup fst_extensions"
}
// end namespace fst
#endif
speechx/speechx/kaldi/lat/kaldi-lattice.cc
0 → 100644
浏览文件 @
d14ee800
// lat/kaldi-lattice.cc
// Copyright 2009-2011 Microsoft Corporation
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/kaldi-lattice.h"
#include "fst/script/print-impl.h"
namespace
kaldi
{
/// Converts lattice types if necessary, deleting its input.
template
<
class
OrigWeightType
>
CompactLattice
*
ConvertToCompactLattice
(
fst
::
VectorFst
<
OrigWeightType
>
*
ifst
)
{
if
(
!
ifst
)
return
NULL
;
CompactLattice
*
ofst
=
new
CompactLattice
();
ConvertLattice
(
*
ifst
,
ofst
);
delete
ifst
;
return
ofst
;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template
<
>
CompactLattice
*
ConvertToCompactLattice
(
CompactLattice
*
ifst
)
{
return
ifst
;
}
/// Converts lattice types if necessary, deleting its input.
template
<
class
OrigWeightType
>
Lattice
*
ConvertToLattice
(
fst
::
VectorFst
<
OrigWeightType
>
*
ifst
)
{
if
(
!
ifst
)
return
NULL
;
Lattice
*
ofst
=
new
Lattice
();
ConvertLattice
(
*
ifst
,
ofst
);
delete
ifst
;
return
ofst
;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template
<
>
Lattice
*
ConvertToLattice
(
Lattice
*
ifst
)
{
return
ifst
;
}
bool
WriteCompactLattice
(
std
::
ostream
&
os
,
bool
binary
,
const
CompactLattice
&
t
)
{
if
(
binary
)
{
fst
::
FstWriteOptions
opts
;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to if we had them).
return
t
.
Write
(
os
,
opts
);
}
else
{
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os
<<
'\n'
;
bool
acceptor
=
true
,
write_one
=
false
;
fst
::
FstPrinter
<
CompactLatticeArc
>
printer
(
t
,
t
.
InputSymbols
(),
t
.
OutputSymbols
(),
NULL
,
acceptor
,
write_one
,
"
\t
"
);
printer
.
Print
(
&
os
,
"<unknown>"
);
if
(
os
.
fail
())
KALDI_WARN
<<
"Stream failure detected."
;
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os
<<
'\n'
;
return
os
.
good
();
}
}
/// LatticeReader provides (static) functions for reading both Lattice
/// and CompactLattice, in text form.
class
LatticeReader
{
typedef
LatticeArc
Arc
;
typedef
LatticeWeight
Weight
;
typedef
CompactLatticeArc
CArc
;
typedef
CompactLatticeWeight
CWeight
;
typedef
Arc
::
Label
Label
;
typedef
Arc
::
StateId
StateId
;
public:
// everything is static in this class.
/** This function reads from the FST text format; it does not know in advance
whether it's a Lattice or CompactLattice in the stream so it tries to
read both formats until it becomes clear which is the correct one.
*/
static
std
::
pair
<
Lattice
*
,
CompactLattice
*>
ReadText
(
std
::
istream
&
is
)
{
typedef
std
::
pair
<
Lattice
*
,
CompactLattice
*>
PairT
;
using
std
::
string
;
using
std
::
vector
;
Lattice
*
fst
=
new
Lattice
();
CompactLattice
*
cfst
=
new
CompactLattice
();
string
line
;
size_t
nline
=
0
;
string
separator
=
FLAGS_fst_field_separator
+
"
\r\n
"
;
while
(
std
::
getline
(
is
,
line
))
{
nline
++
;
vector
<
string
>
col
;
// on Windows we'll write in text and read in binary mode.
SplitStringToVector
(
line
,
separator
.
c_str
(),
true
,
&
col
);
if
(
col
.
size
()
==
0
)
break
;
// Empty line is a signal to stop, in our
// archive format.
if
(
col
.
size
()
>
5
)
{
KALDI_WARN
<<
"Reading lattice: bad line in FST: "
<<
line
;
delete
fst
;
delete
cfst
;
return
PairT
(
static_cast
<
Lattice
*>
(
NULL
),
static_cast
<
CompactLattice
*>
(
NULL
));
}
StateId
s
;
if
(
!
ConvertStringToInteger
(
col
[
0
],
&
s
))
{
KALDI_WARN
<<
"FstCompiler: bad line in FST: "
<<
line
;
delete
fst
;
delete
cfst
;
return
PairT
(
static_cast
<
Lattice
*>
(
NULL
),
static_cast
<
CompactLattice
*>
(
NULL
));
}
if
(
fst
)
while
(
s
>=
fst
->
NumStates
())
fst
->
AddState
();
if
(
cfst
)
while
(
s
>=
cfst
->
NumStates
())
cfst
->
AddState
();
if
(
nline
==
1
)
{
if
(
fst
)
fst
->
SetStart
(
s
);
if
(
cfst
)
cfst
->
SetStart
(
s
);
}
if
(
fst
)
{
// we still have fst; try to read that arc.
bool
ok
=
true
;
Arc
arc
;
Weight
w
;
StateId
d
=
s
;
switch
(
col
.
size
())
{
case
1
:
fst
->
SetFinal
(
s
,
Weight
::
One
());
break
;
case
2
:
if
(
!
StrToWeight
(
col
[
1
],
true
,
&
w
))
ok
=
false
;
else
fst
->
SetFinal
(
s
,
w
);
break
;
case
3
:
// 3 columns not ok for Lattice format; it's not an acceptor.
ok
=
false
;
break
;
case
4
:
ok
=
ConvertStringToInteger
(
col
[
1
],
&
arc
.
nextstate
)
&&
ConvertStringToInteger
(
col
[
2
],
&
arc
.
ilabel
)
&&
ConvertStringToInteger
(
col
[
3
],
&
arc
.
olabel
);
if
(
ok
)
{
d
=
arc
.
nextstate
;
arc
.
weight
=
Weight
::
One
();
fst
->
AddArc
(
s
,
arc
);
}
break
;
case
5
:
ok
=
ConvertStringToInteger
(
col
[
1
],
&
arc
.
nextstate
)
&&
ConvertStringToInteger
(
col
[
2
],
&
arc
.
ilabel
)
&&
ConvertStringToInteger
(
col
[
3
],
&
arc
.
olabel
)
&&
StrToWeight
(
col
[
4
],
false
,
&
arc
.
weight
);
if
(
ok
)
{
d
=
arc
.
nextstate
;
fst
->
AddArc
(
s
,
arc
);
}
break
;
default:
ok
=
false
;
}
while
(
d
>=
fst
->
NumStates
())
fst
->
AddState
();
if
(
!
ok
)
{
delete
fst
;
fst
=
NULL
;
}
}
if
(
cfst
)
{
bool
ok
=
true
;
CArc
arc
;
CWeight
w
;
StateId
d
=
s
;
switch
(
col
.
size
())
{
case
1
:
cfst
->
SetFinal
(
s
,
CWeight
::
One
());
break
;
case
2
:
if
(
!
StrToCWeight
(
col
[
1
],
true
,
&
w
))
ok
=
false
;
else
cfst
->
SetFinal
(
s
,
w
);
break
;
case
3
:
// compact-lattice is acceptor format: state, next-state, label.
ok
=
ConvertStringToInteger
(
col
[
1
],
&
arc
.
nextstate
)
&&
ConvertStringToInteger
(
col
[
2
],
&
arc
.
ilabel
);
if
(
ok
)
{
d
=
arc
.
nextstate
;
arc
.
olabel
=
arc
.
ilabel
;
arc
.
weight
=
CWeight
::
One
();
cfst
->
AddArc
(
s
,
arc
);
}
break
;
case
4
:
ok
=
ConvertStringToInteger
(
col
[
1
],
&
arc
.
nextstate
)
&&
ConvertStringToInteger
(
col
[
2
],
&
arc
.
ilabel
)
&&
StrToCWeight
(
col
[
3
],
false
,
&
arc
.
weight
);
if
(
ok
)
{
d
=
arc
.
nextstate
;
arc
.
olabel
=
arc
.
ilabel
;
cfst
->
AddArc
(
s
,
arc
);
}
break
;
case
5
:
default
:
ok
=
false
;
}
while
(
d
>=
cfst
->
NumStates
())
cfst
->
AddState
();
if
(
!
ok
)
{
delete
cfst
;
cfst
=
NULL
;
}
}
if
(
!
fst
&&
!
cfst
)
{
KALDI_WARN
<<
"Bad line in lattice text format: "
<<
line
;
// read until we get an empty line, so at least we
// have a chance to read the next one (although this might
// be a bit futile since the calling code will get unhappy
// about failing to read this one.
while
(
std
::
getline
(
is
,
line
))
{
SplitStringToVector
(
line
,
separator
.
c_str
(),
true
,
&
col
);
if
(
col
.
empty
())
break
;
}
return
PairT
(
static_cast
<
Lattice
*>
(
NULL
),
static_cast
<
CompactLattice
*>
(
NULL
));
}
}
return
PairT
(
fst
,
cfst
);
}
static
bool
StrToWeight
(
const
std
::
string
&
s
,
bool
allow_zero
,
Weight
*
w
)
{
std
::
istringstream
strm
(
s
);
strm
>>
*
w
;
if
(
!
strm
||
(
!
allow_zero
&&
*
w
==
Weight
::
Zero
()))
{
return
false
;
}
return
true
;
}
static
bool
StrToCWeight
(
const
std
::
string
&
s
,
bool
allow_zero
,
CWeight
*
w
)
{
std
::
istringstream
strm
(
s
);
strm
>>
*
w
;
if
(
!
strm
||
(
!
allow_zero
&&
*
w
==
CWeight
::
Zero
()))
{
return
false
;
}
return
true
;
}
};
CompactLattice
*
ReadCompactLatticeText
(
std
::
istream
&
is
)
{
std
::
pair
<
Lattice
*
,
CompactLattice
*>
lat_pair
=
LatticeReader
::
ReadText
(
is
);
if
(
lat_pair
.
second
!=
NULL
)
{
delete
lat_pair
.
first
;
return
lat_pair
.
second
;
}
else
if
(
lat_pair
.
first
!=
NULL
)
{
// note: ConvertToCompactLattice frees its input.
return
ConvertToCompactLattice
(
lat_pair
.
first
);
}
else
{
return
NULL
;
}
}
Lattice
*
ReadLatticeText
(
std
::
istream
&
is
)
{
std
::
pair
<
Lattice
*
,
CompactLattice
*>
lat_pair
=
LatticeReader
::
ReadText
(
is
);
if
(
lat_pair
.
first
!=
NULL
)
{
delete
lat_pair
.
second
;
return
lat_pair
.
first
;
}
else
if
(
lat_pair
.
second
!=
NULL
)
{
// note: ConvertToLattice frees its input.
return
ConvertToLattice
(
lat_pair
.
second
);
}
else
{
return
NULL
;
}
}
bool
ReadCompactLattice
(
std
::
istream
&
is
,
bool
binary
,
CompactLattice
**
clat
)
{
KALDI_ASSERT
(
*
clat
==
NULL
);
if
(
binary
)
{
fst
::
FstHeader
hdr
;
if
(
!
hdr
.
Read
(
is
,
"<unknown>"
))
{
KALDI_WARN
<<
"Reading compact lattice: error reading FST header."
;
return
false
;
}
if
(
hdr
.
FstType
()
!=
"vector"
)
{
KALDI_WARN
<<
"Reading compact lattice: unsupported FST type: "
<<
hdr
.
FstType
();
return
false
;
}
fst
::
FstReadOptions
ropts
(
"<unspecified>"
,
&
hdr
);
typedef
fst
::
CompactLatticeWeightTpl
<
fst
::
LatticeWeightTpl
<
float
>
,
int32
>
T1
;
typedef
fst
::
CompactLatticeWeightTpl
<
fst
::
LatticeWeightTpl
<
double
>
,
int32
>
T2
;
typedef
fst
::
LatticeWeightTpl
<
float
>
T3
;
typedef
fst
::
LatticeWeightTpl
<
double
>
T4
;
typedef
fst
::
VectorFst
<
fst
::
ArcTpl
<
T1
>
>
F1
;
typedef
fst
::
VectorFst
<
fst
::
ArcTpl
<
T2
>
>
F2
;
typedef
fst
::
VectorFst
<
fst
::
ArcTpl
<
T3
>
>
F3
;
typedef
fst
::
VectorFst
<
fst
::
ArcTpl
<
T4
>
>
F4
;
CompactLattice
*
ans
=
NULL
;
if
(
hdr
.
ArcType
()
==
T1
::
Type
())
{
ans
=
ConvertToCompactLattice
(
F1
::
Read
(
is
,
ropts
));
}
else
if
(
hdr
.
ArcType
()
==
T2
::
Type
())
{
ans
=
ConvertToCompactLattice
(
F2
::
Read
(
is
,
ropts
));
}
else
if
(
hdr
.
ArcType
()
==
T3
::
Type
())
{
ans
=
ConvertToCompactLattice
(
F3
::
Read
(
is
,
ropts
));
}
else
if
(
hdr
.
ArcType
()
==
T4
::
Type
())
{
ans
=
ConvertToCompactLattice
(
F4
::
Read
(
is
,
ropts
));
}
else
{
KALDI_WARN
<<
"FST with arc type "
<<
hdr
.
ArcType
()
<<
" cannot be converted to CompactLattice.
\n
"
;
return
false
;
}
if
(
ans
==
NULL
)
{
KALDI_WARN
<<
"Error reading compact lattice (after reading header)."
;
return
false
;
}
*
clat
=
ans
;
return
true
;
}
else
{
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while
(
std
::
isspace
(
is
.
peek
())
&&
is
.
peek
()
!=
'\n'
)
is
.
get
();
if
(
is
.
peek
()
==
'\n'
)
is
.
get
();
// consume the newline.
else
{
// saw spaces but no newline.. this is not expected.
KALDI_WARN
<<
"Reading compact lattice: unexpected sequence of spaces "
<<
" at file position "
<<
is
.
tellg
();
return
false
;
}
*
clat
=
ReadCompactLatticeText
(
is
);
// that routine will warn on error.
return
(
*
clat
!=
NULL
);
}
}
bool
CompactLatticeHolder
::
Read
(
std
::
istream
&
is
)
{
Clear
();
// in case anything currently stored.
int
c
=
is
.
peek
();
if
(
c
==
-
1
)
{
KALDI_WARN
<<
"End of stream detected reading CompactLattice."
;
return
false
;
}
else
if
(
isspace
(
c
))
{
// The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return
ReadCompactLattice
(
is
,
false
,
&
t_
);
}
else
if
(
c
!=
214
)
{
// 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN
<<
"Reading compact lattice: does not appear to be an FST "
<<
" [non-space but no magic number detected], file pos is "
<<
is
.
tellg
();
return
false
;
}
else
{
return
ReadCompactLattice
(
is
,
true
,
&
t_
);
}
}
bool
WriteLattice
(
std
::
ostream
&
os
,
bool
binary
,
const
Lattice
&
t
)
{
if
(
binary
)
{
fst
::
FstWriteOptions
opts
;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to do if we had them).
return
t
.
Write
(
os
,
opts
);
}
else
{
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os
<<
'\n'
;
bool
acceptor
=
false
,
write_one
=
false
;
fst
::
FstPrinter
<
LatticeArc
>
printer
(
t
,
t
.
InputSymbols
(),
t
.
OutputSymbols
(),
NULL
,
acceptor
,
write_one
,
"
\t
"
);
printer
.
Print
(
&
os
,
"<unknown>"
);
if
(
os
.
fail
())
KALDI_WARN
<<
"Stream failure detected."
;
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os
<<
'\n'
;
return
os
.
good
();
}
}
bool
ReadLattice
(
std
::
istream
&
is
,
bool
binary
,
Lattice
**
lat
)
{
KALDI_ASSERT
(
*
lat
==
NULL
);
if
(
binary
)
{
fst
::
FstHeader
hdr
;
if
(
!
hdr
.
Read
(
is
,
"<unknown>"
))
{
KALDI_WARN
<<
"Reading lattice: error reading FST header."
;
return
false
;
}
if
(
hdr
.
FstType
()
!=
"vector"
)
{
KALDI_WARN
<<
"Reading lattice: unsupported FST type: "
<<
hdr
.
FstType
();
return
false
;
}
fst
::
FstReadOptions
ropts
(
"<unspecified>"
,
&
hdr
);
typedef
fst
::
CompactLatticeWeightTpl
<
fst
::
LatticeWeightTpl
<
float
>
,
int32
>
T1
;
typedef
fst
::
CompactLatticeWeightTpl
<
fst
::
LatticeWeightTpl
<
double
>
,
int32
>
T2
;
typedef
fst
::
LatticeWeightTpl
<
float
>
T3
;
typedef
fst
::
LatticeWeightTpl
<
double
>
T4
;
typedef
fst
::
VectorFst
<
fst
::
ArcTpl
<
T1
>
>
F1
;
typedef
fst
::
VectorFst
<
fst
::
ArcTpl
<
T2
>
>
F2
;
typedef
fst
::
VectorFst
<
fst
::
ArcTpl
<
T3
>
>
F3
;
typedef
fst
::
VectorFst
<
fst
::
ArcTpl
<
T4
>
>
F4
;
Lattice
*
ans
=
NULL
;
if
(
hdr
.
ArcType
()
==
T1
::
Type
())
{
ans
=
ConvertToLattice
(
F1
::
Read
(
is
,
ropts
));
}
else
if
(
hdr
.
ArcType
()
==
T2
::
Type
())
{
ans
=
ConvertToLattice
(
F2
::
Read
(
is
,
ropts
));
}
else
if
(
hdr
.
ArcType
()
==
T3
::
Type
())
{
ans
=
ConvertToLattice
(
F3
::
Read
(
is
,
ropts
));
}
else
if
(
hdr
.
ArcType
()
==
T4
::
Type
())
{
ans
=
ConvertToLattice
(
F4
::
Read
(
is
,
ropts
));
}
else
{
KALDI_WARN
<<
"FST with arc type "
<<
hdr
.
ArcType
()
<<
" cannot be converted to Lattice.
\n
"
;
return
false
;
}
if
(
ans
==
NULL
)
{
KALDI_WARN
<<
"Error reading lattice (after reading header)."
;
return
false
;
}
*
lat
=
ans
;
return
true
;
}
else
{
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while
(
std
::
isspace
(
is
.
peek
())
&&
is
.
peek
()
!=
'\n'
)
is
.
get
();
if
(
is
.
peek
()
==
'\n'
)
is
.
get
();
// consume the newline.
else
{
// saw spaces but no newline.. this is not expected.
KALDI_WARN
<<
"Reading compact lattice: unexpected sequence of spaces "
<<
" at file position "
<<
is
.
tellg
();
return
false
;
}
*
lat
=
ReadLatticeText
(
is
);
// that routine will warn on error.
return
(
*
lat
!=
NULL
);
}
}
/* Since we don't write the binary headers for this type of holder,
we use a different method to work out whether we're in binary mode.
*/
bool
LatticeHolder
::
Read
(
std
::
istream
&
is
)
{
Clear
();
// in case anything currently stored.
int
c
=
is
.
peek
();
if
(
c
==
-
1
)
{
KALDI_WARN
<<
"End of stream detected reading Lattice."
;
return
false
;
}
else
if
(
isspace
(
c
))
{
// The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return
ReadLattice
(
is
,
false
,
&
t_
);
}
else
if
(
c
!=
214
)
{
// 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN
<<
"Reading compact lattice: does not appear to be an FST "
<<
" [non-space but no magic number detected], file pos is "
<<
is
.
tellg
();
return
false
;
}
else
{
return
ReadLattice
(
is
,
true
,
&
t_
);
}
}
}
// end namespace kaldi
speechx/speechx/kaldi/lat/kaldi-lattice.h
0 → 100644
浏览文件 @
d14ee800
// lat/kaldi-lattice.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_KALDI_LATTICE_H_
#define KALDI_LAT_KALDI_LATTICE_H_
#include "fstext/fstext-lib.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
namespace
kaldi
{
// will import some things above...
typedef
fst
::
LatticeWeightTpl
<
BaseFloat
>
LatticeWeight
;
// careful: kaldi::int32 is not always the same C type as fst::int32
typedef
fst
::
CompactLatticeWeightTpl
<
LatticeWeight
,
int32
>
CompactLatticeWeight
;
typedef
fst
::
CompactLatticeWeightCommonDivisorTpl
<
LatticeWeight
,
int32
>
CompactLatticeWeightCommonDivisor
;
typedef
fst
::
ArcTpl
<
LatticeWeight
>
LatticeArc
;
typedef
fst
::
ArcTpl
<
CompactLatticeWeight
>
CompactLatticeArc
;
typedef
fst
::
VectorFst
<
LatticeArc
>
Lattice
;
typedef
fst
::
VectorFst
<
CompactLatticeArc
>
CompactLattice
;
// The following functions for writing and reading lattices in binary or text
// form are provided here in case you need to include lattices in larger,
// Kaldi-type objects with their own Read and Write functions. Caution: these
// functions return false on stream failure rather than throwing an exception as
// most similar Kaldi functions would do.
bool
WriteCompactLattice
(
std
::
ostream
&
os
,
bool
binary
,
const
CompactLattice
&
clat
);
bool
WriteLattice
(
std
::
ostream
&
os
,
bool
binary
,
const
Lattice
&
lat
);
// the following function requires that *clat be
// NULL when called.
bool
ReadCompactLattice
(
std
::
istream
&
is
,
bool
binary
,
CompactLattice
**
clat
);
// the following function requires that *lat be
// NULL when called.
bool
ReadLattice
(
std
::
istream
&
is
,
bool
binary
,
Lattice
**
lat
);
class
CompactLatticeHolder
{
public:
typedef
CompactLattice
T
;
CompactLatticeHolder
()
{
t_
=
NULL
;
}
static
bool
Write
(
std
::
ostream
&
os
,
bool
binary
,
const
T
&
t
)
{
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return
WriteCompactLattice
(
os
,
binary
,
t
);
}
bool
Read
(
std
::
istream
&
is
);
static
bool
IsReadInBinary
()
{
return
true
;
}
T
&
Value
()
{
KALDI_ASSERT
(
t_
!=
NULL
&&
"Called Value() on empty CompactLatticeHolder"
);
return
*
t_
;
}
void
Clear
()
{
delete
t_
;
t_
=
NULL
;
}
void
Swap
(
CompactLatticeHolder
*
other
)
{
std
::
swap
(
t_
,
other
->
t_
);
}
bool
ExtractRange
(
const
CompactLatticeHolder
&
other
,
const
std
::
string
&
range
)
{
KALDI_ERR
<<
"ExtractRange is not defined for this type of holder."
;
return
false
;
}
~
CompactLatticeHolder
()
{
Clear
();
}
private:
T
*
t_
;
};
class
LatticeHolder
{
public:
typedef
Lattice
T
;
LatticeHolder
()
{
t_
=
NULL
;
}
static
bool
Write
(
std
::
ostream
&
os
,
bool
binary
,
const
T
&
t
)
{
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return
WriteLattice
(
os
,
binary
,
t
);
}
bool
Read
(
std
::
istream
&
is
);
static
bool
IsReadInBinary
()
{
return
true
;
}
T
&
Value
()
{
KALDI_ASSERT
(
t_
!=
NULL
&&
"Called Value() on empty LatticeHolder"
);
return
*
t_
;
}
void
Clear
()
{
delete
t_
;
t_
=
NULL
;
}
void
Swap
(
LatticeHolder
*
other
)
{
std
::
swap
(
t_
,
other
->
t_
);
}
bool
ExtractRange
(
const
LatticeHolder
&
other
,
const
std
::
string
&
range
)
{
KALDI_ERR
<<
"ExtractRange is not defined for this type of holder."
;
return
false
;
}
~
LatticeHolder
()
{
Clear
();
}
private:
T
*
t_
;
};
typedef
TableWriter
<
LatticeHolder
>
LatticeWriter
;
typedef
SequentialTableReader
<
LatticeHolder
>
SequentialLatticeReader
;
typedef
RandomAccessTableReader
<
LatticeHolder
>
RandomAccessLatticeReader
;
typedef
TableWriter
<
CompactLatticeHolder
>
CompactLatticeWriter
;
typedef
SequentialTableReader
<
CompactLatticeHolder
>
SequentialCompactLatticeReader
;
typedef
RandomAccessTableReader
<
CompactLatticeHolder
>
RandomAccessCompactLatticeReader
;
}
// namespace kaldi
#endif // KALDI_LAT_KALDI_LATTICE_H_
speechx/speechx/kaldi/lat/lattice-functions.cc
0 → 100644
浏览文件 @
d14ee800
// lat/lattice-functions.cc
// Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng;
// Bagher BabaAli
// 2013 Cisco Systems (author: Neha Agrawal) [code modified
// from original code in ../gmmbin/gmm-rescore-lattice.cc]
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-math.h"
#include "lat/lattice-functions.h"
namespace
kaldi
{
using
std
::
map
;
using
std
::
vector
;
void
GetPerFrameAcousticCosts
(
const
Lattice
&
nbest
,
Vector
<
BaseFloat
>
*
per_frame_loglikes
)
{
using
namespace
fst
;
typedef
Lattice
::
Arc
::
Weight
Weight
;
vector
<
BaseFloat
>
loglikes
;
int32
cur_state
=
nbest
.
Start
();
int32
prev_frame
=
-
1
;
BaseFloat
eps_acwt
=
0.0
;
while
(
1
)
{
Weight
w
=
nbest
.
Final
(
cur_state
);
if
(
w
!=
Weight
::
Zero
())
{
KALDI_ASSERT
(
nbest
.
NumArcs
(
cur_state
)
==
0
);
if
(
per_frame_loglikes
!=
NULL
)
{
SubVector
<
BaseFloat
>
subvec
(
&
(
loglikes
[
0
]),
loglikes
.
size
());
Vector
<
BaseFloat
>
vec
(
subvec
);
*
per_frame_loglikes
=
vec
;
}
break
;
}
else
{
KALDI_ASSERT
(
nbest
.
NumArcs
(
cur_state
)
==
1
);
fst
::
ArcIterator
<
Lattice
>
iter
(
nbest
,
cur_state
);
const
Lattice
::
Arc
&
arc
=
iter
.
Value
();
BaseFloat
acwt
=
arc
.
weight
.
Value2
();
if
(
arc
.
ilabel
!=
0
)
{
if
(
eps_acwt
>
0
)
{
acwt
+=
eps_acwt
;
eps_acwt
=
0.0
;
}
loglikes
.
push_back
(
acwt
);
prev_frame
++
;
}
else
if
(
acwt
==
acwt
){
if
(
prev_frame
>
-
1
)
{
loglikes
[
prev_frame
]
+=
acwt
;
}
else
{
eps_acwt
+=
acwt
;
}
}
cur_state
=
arc
.
nextstate
;
}
}
}
int32
LatticeStateTimes
(
const
Lattice
&
lat
,
vector
<
int32
>
*
times
)
{
if
(
!
lat
.
Properties
(
fst
::
kTopSorted
,
true
))
KALDI_ERR
<<
"Input lattice must be topologically sorted."
;
KALDI_ASSERT
(
lat
.
Start
()
==
0
);
int32
num_states
=
lat
.
NumStates
();
times
->
clear
();
times
->
resize
(
num_states
,
-
1
);
(
*
times
)[
0
]
=
0
;
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
int32
cur_time
=
(
*
times
)[
state
];
for
(
fst
::
ArcIterator
<
Lattice
>
aiter
(
lat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
LatticeArc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
)
{
// Non-epsilon input label on arc
// next time instance
if
((
*
times
)[
arc
.
nextstate
]
==
-
1
)
{
(
*
times
)[
arc
.
nextstate
]
=
cur_time
+
1
;
}
else
{
KALDI_ASSERT
((
*
times
)[
arc
.
nextstate
]
==
cur_time
+
1
);
}
}
else
{
// epsilon input label on arc
// Same time instance
if
((
*
times
)[
arc
.
nextstate
]
==
-
1
)
(
*
times
)[
arc
.
nextstate
]
=
cur_time
;
else
KALDI_ASSERT
((
*
times
)[
arc
.
nextstate
]
==
cur_time
);
}
}
}
return
(
*
std
::
max_element
(
times
->
begin
(),
times
->
end
()));
}
int32
CompactLatticeStateTimes
(
const
CompactLattice
&
lat
,
vector
<
int32
>
*
times
)
{
if
(
!
lat
.
Properties
(
fst
::
kTopSorted
,
true
))
KALDI_ERR
<<
"Input lattice must be topologically sorted."
;
KALDI_ASSERT
(
lat
.
Start
()
==
0
);
int32
num_states
=
lat
.
NumStates
();
times
->
clear
();
times
->
resize
(
num_states
,
-
1
);
(
*
times
)[
0
]
=
0
;
int32
utt_len
=
-
1
;
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
int32
cur_time
=
(
*
times
)[
state
];
for
(
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
lat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
CompactLatticeArc
&
arc
=
aiter
.
Value
();
int32
arc_len
=
static_cast
<
int32
>
(
arc
.
weight
.
String
().
size
());
if
((
*
times
)[
arc
.
nextstate
]
==
-
1
)
(
*
times
)[
arc
.
nextstate
]
=
cur_time
+
arc_len
;
else
KALDI_ASSERT
((
*
times
)[
arc
.
nextstate
]
==
cur_time
+
arc_len
);
}
if
(
lat
.
Final
(
state
)
!=
CompactLatticeWeight
::
Zero
())
{
int32
this_utt_len
=
(
*
times
)[
state
]
+
lat
.
Final
(
state
).
String
().
size
();
if
(
utt_len
==
-
1
)
utt_len
=
this_utt_len
;
else
{
if
(
this_utt_len
!=
utt_len
)
{
KALDI_WARN
<<
"Utterance does not "
"seem to have a consistent length."
;
utt_len
=
std
::
max
(
utt_len
,
this_utt_len
);
}
}
}
}
if
(
utt_len
==
-
1
)
{
KALDI_WARN
<<
"Utterance does not have a final-state."
;
return
0
;
}
return
utt_len
;
}
bool
ComputeCompactLatticeAlphas
(
const
CompactLattice
&
clat
,
vector
<
double
>
*
alpha
)
{
using
namespace
fst
;
// typedef the arc, weight types
typedef
CompactLattice
::
Arc
Arc
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
StateId
StateId
;
//Make sure the lattice is topologically sorted.
if
(
clat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
KALDI_WARN
<<
"Input lattice must be topologically sorted."
;
return
false
;
}
if
(
clat
.
Start
()
!=
0
)
{
KALDI_WARN
<<
"Input lattice must start from state 0."
;
return
false
;
}
int32
num_states
=
clat
.
NumStates
();
(
*
alpha
).
resize
(
0
);
(
*
alpha
).
resize
(
num_states
,
kLogZeroDouble
);
// Now propagate alphas forward. Note that we don't acount the weight of the
// final state to alpha[final_state] -- we acount it to beta[final_state];
(
*
alpha
)[
0
]
=
0.0
;
for
(
StateId
s
=
0
;
s
<
num_states
;
s
++
)
{
double
this_alpha
=
(
*
alpha
)[
s
];
for
(
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
(
arc
.
weight
.
Weight
().
Value1
()
+
arc
.
weight
.
Weight
().
Value2
());
(
*
alpha
)[
arc
.
nextstate
]
=
LogAdd
((
*
alpha
)[
arc
.
nextstate
],
this_alpha
+
arc_like
);
}
}
return
true
;
}
bool
ComputeCompactLatticeBetas
(
const
CompactLattice
&
clat
,
vector
<
double
>
*
beta
)
{
using
namespace
fst
;
// typedef the arc, weight types
typedef
CompactLattice
::
Arc
Arc
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
StateId
StateId
;
// Make sure the lattice is topologically sorted.
if
(
clat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
KALDI_WARN
<<
"Input lattice must be topologically sorted."
;
return
false
;
}
if
(
clat
.
Start
()
!=
0
)
{
KALDI_WARN
<<
"Input lattice must start from state 0."
;
return
false
;
}
int32
num_states
=
clat
.
NumStates
();
(
*
beta
).
resize
(
0
);
(
*
beta
).
resize
(
num_states
,
kLogZeroDouble
);
// Now propagate betas backward. Note that beta[final_state] contains the
// weight of the final state in the lattice -- compare that with alpha.
for
(
StateId
s
=
num_states
-
1
;
s
>=
0
;
s
--
)
{
Weight
f
=
clat
.
Final
(
s
);
double
this_beta
=
-
(
f
.
Weight
().
Value1
()
+
f
.
Weight
().
Value2
());
for
(
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
(
arc
.
weight
.
Weight
().
Value1
()
+
arc
.
weight
.
Weight
().
Value2
());
double
arc_beta
=
(
*
beta
)[
arc
.
nextstate
]
+
arc_like
;
this_beta
=
LogAdd
(
this_beta
,
arc_beta
);
}
(
*
beta
)[
s
]
=
this_beta
;
}
return
true
;
}
template
<
class
LatType
>
// could be Lattice or CompactLattice
bool
PruneLattice
(
BaseFloat
beam
,
LatType
*
lat
)
{
typedef
typename
LatType
::
Arc
Arc
;
typedef
typename
Arc
::
Weight
Weight
;
typedef
typename
Arc
::
StateId
StateId
;
KALDI_ASSERT
(
beam
>
0.0
);
if
(
!
lat
->
Properties
(
fst
::
kTopSorted
,
true
))
{
if
(
fst
::
TopSort
(
lat
)
==
false
)
{
KALDI_WARN
<<
"Cycles detected in lattice"
;
return
false
;
}
}
// We assume states before "start" are not reachable, since
// the lattice is topologically sorted.
int32
start
=
lat
->
Start
();
int32
num_states
=
lat
->
NumStates
();
if
(
num_states
==
0
)
return
false
;
std
::
vector
<
double
>
forward_cost
(
num_states
,
std
::
numeric_limits
<
double
>::
infinity
());
// viterbi forward.
forward_cost
[
start
]
=
0.0
;
// lattice can't have cycles so couldn't be
// less than this.
double
best_final_cost
=
std
::
numeric_limits
<
double
>::
infinity
();
// Update the forward probs.
// Thanks to Jing Zheng for finding a bug here.
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
double
this_forward_cost
=
forward_cost
[
state
];
for
(
fst
::
ArcIterator
<
LatType
>
aiter
(
*
lat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
(
aiter
.
Value
());
StateId
nextstate
=
arc
.
nextstate
;
KALDI_ASSERT
(
nextstate
>
state
&&
nextstate
<
num_states
);
double
next_forward_cost
=
this_forward_cost
+
ConvertToCost
(
arc
.
weight
);
if
(
forward_cost
[
nextstate
]
>
next_forward_cost
)
forward_cost
[
nextstate
]
=
next_forward_cost
;
}
Weight
final_weight
=
lat
->
Final
(
state
);
double
this_final_cost
=
this_forward_cost
+
ConvertToCost
(
final_weight
);
if
(
this_final_cost
<
best_final_cost
)
best_final_cost
=
this_final_cost
;
}
int32
bad_state
=
lat
->
AddState
();
// this state is not final.
double
cutoff
=
best_final_cost
+
beam
;
// Go backwards updating the backward probs (which share memory with the
// forward probs), and pruning arcs and deleting final-probs. We prune arcs
// by making them point to the non-final state "bad_state". We'll then use
// Trim() to remove unnecessary arcs and states. [this is just easier than
// doing it ourselves.]
std
::
vector
<
double
>
&
backward_cost
(
forward_cost
);
for
(
int32
state
=
num_states
-
1
;
state
>=
0
;
state
--
)
{
double
this_forward_cost
=
forward_cost
[
state
];
double
this_backward_cost
=
ConvertToCost
(
lat
->
Final
(
state
));
if
(
this_backward_cost
+
this_forward_cost
>
cutoff
&&
this_backward_cost
!=
std
::
numeric_limits
<
double
>::
infinity
())
lat
->
SetFinal
(
state
,
Weight
::
Zero
());
for
(
fst
::
MutableArcIterator
<
LatType
>
aiter
(
lat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
Arc
arc
(
aiter
.
Value
());
StateId
nextstate
=
arc
.
nextstate
;
KALDI_ASSERT
(
nextstate
>
state
&&
nextstate
<
num_states
);
double
arc_cost
=
ConvertToCost
(
arc
.
weight
),
arc_backward_cost
=
arc_cost
+
backward_cost
[
nextstate
],
this_fb_cost
=
this_forward_cost
+
arc_backward_cost
;
if
(
arc_backward_cost
<
this_backward_cost
)
this_backward_cost
=
arc_backward_cost
;
if
(
this_fb_cost
>
cutoff
)
{
// Prune the arc.
arc
.
nextstate
=
bad_state
;
aiter
.
SetValue
(
arc
);
}
}
backward_cost
[
state
]
=
this_backward_cost
;
}
fst
::
Connect
(
lat
);
return
(
lat
->
NumStates
()
>
0
);
}
// instantiate the template for lattice and CompactLattice.
template
bool
PruneLattice
(
BaseFloat
beam
,
Lattice
*
lat
);
template
bool
PruneLattice
(
BaseFloat
beam
,
CompactLattice
*
lat
);
BaseFloat
LatticeForwardBackward
(
const
Lattice
&
lat
,
Posterior
*
post
,
double
*
acoustic_like_sum
)
{
// Note, Posterior is defined as follows: Indexed [frame], then a list
// of (transition-id, posterior-probability) pairs.
// typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
using
namespace
fst
;
typedef
Lattice
::
Arc
Arc
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
StateId
StateId
;
if
(
acoustic_like_sum
)
*
acoustic_like_sum
=
0.0
;
// Make sure the lattice is topologically sorted.
if
(
lat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
KALDI_ERR
<<
"Input lattice must be topologically sorted."
;
KALDI_ASSERT
(
lat
.
Start
()
==
0
);
int32
num_states
=
lat
.
NumStates
();
vector
<
int32
>
state_times
;
int32
max_time
=
LatticeStateTimes
(
lat
,
&
state_times
);
std
::
vector
<
double
>
alpha
(
num_states
,
kLogZeroDouble
);
std
::
vector
<
double
>
&
beta
(
alpha
);
// we re-use the same memory for
// this, but it's semantically distinct so we name it differently.
double
tot_forward_prob
=
kLogZeroDouble
;
post
->
clear
();
post
->
resize
(
max_time
);
alpha
[
0
]
=
0.0
;
// Propagate alphas forward.
for
(
StateId
s
=
0
;
s
<
num_states
;
s
++
)
{
double
this_alpha
=
alpha
[
s
];
for
(
ArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
);
alpha
[
arc
.
nextstate
]
=
LogAdd
(
alpha
[
arc
.
nextstate
],
this_alpha
+
arc_like
);
}
Weight
f
=
lat
.
Final
(
s
);
if
(
f
!=
Weight
::
Zero
())
{
double
final_like
=
this_alpha
-
(
f
.
Value1
()
+
f
.
Value2
());
tot_forward_prob
=
LogAdd
(
tot_forward_prob
,
final_like
);
KALDI_ASSERT
(
state_times
[
s
]
==
max_time
&&
"Lattice is inconsistent (final-prob not at max_time)"
);
}
}
for
(
StateId
s
=
num_states
-
1
;
s
>=
0
;
s
--
)
{
Weight
f
=
lat
.
Final
(
s
);
double
this_beta
=
-
(
f
.
Value1
()
+
f
.
Value2
());
for
(
ArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
),
arc_beta
=
beta
[
arc
.
nextstate
]
+
arc_like
;
this_beta
=
LogAdd
(
this_beta
,
arc_beta
);
int32
transition_id
=
arc
.
ilabel
;
// The following "if" is an optimization to avoid un-needed exp().
if
(
transition_id
!=
0
||
acoustic_like_sum
!=
NULL
)
{
double
posterior
=
Exp
(
alpha
[
s
]
+
arc_beta
-
tot_forward_prob
);
if
(
transition_id
!=
0
)
// Arc has a transition-id on it [not epsilon]
(
*
post
)[
state_times
[
s
]].
push_back
(
std
::
make_pair
(
transition_id
,
static_cast
<
kaldi
::
BaseFloat
>
(
posterior
)));
if
(
acoustic_like_sum
!=
NULL
)
*
acoustic_like_sum
-=
posterior
*
arc
.
weight
.
Value2
();
}
}
if
(
acoustic_like_sum
!=
NULL
&&
f
!=
Weight
::
Zero
())
{
double
final_logprob
=
-
ConvertToCost
(
f
),
posterior
=
Exp
(
alpha
[
s
]
+
final_logprob
-
tot_forward_prob
);
*
acoustic_like_sum
-=
posterior
*
f
.
Value2
();
}
beta
[
s
]
=
this_beta
;
}
double
tot_backward_prob
=
beta
[
0
];
if
(
!
ApproxEqual
(
tot_forward_prob
,
tot_backward_prob
,
1e-8
))
{
KALDI_WARN
<<
"Total forward probability over lattice = "
<<
tot_forward_prob
<<
", while total backward probability = "
<<
tot_backward_prob
;
}
// Now combine any posteriors with the same transition-id.
for
(
int32
t
=
0
;
t
<
max_time
;
t
++
)
MergePairVectorSumming
(
&
((
*
post
)[
t
]));
return
tot_backward_prob
;
}
void
LatticeActivePhones
(
const
Lattice
&
lat
,
const
TransitionInformation
&
trans
,
const
vector
<
int32
>
&
silence_phones
,
vector
<
std
::
set
<
int32
>
>
*
active_phones
)
{
KALDI_ASSERT
(
IsSortedAndUniq
(
silence_phones
));
vector
<
int32
>
state_times
;
int32
num_states
=
lat
.
NumStates
();
int32
max_time
=
LatticeStateTimes
(
lat
,
&
state_times
);
active_phones
->
clear
();
active_phones
->
resize
(
max_time
);
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
int32
cur_time
=
state_times
[
state
];
for
(
fst
::
ArcIterator
<
Lattice
>
aiter
(
lat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
LatticeArc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
)
{
// Non-epsilon arc
int32
phone
=
trans
.
TransitionIdToPhone
(
arc
.
ilabel
);
if
(
!
std
::
binary_search
(
silence_phones
.
begin
(),
silence_phones
.
end
(),
phone
))
(
*
active_phones
)[
cur_time
].
insert
(
phone
);
}
}
// end looping over arcs
}
// end looping over states
}
void
ConvertLatticeToPhones
(
const
TransitionInformation
&
trans
,
Lattice
*
lat
)
{
typedef
LatticeArc
Arc
;
int32
num_states
=
lat
->
NumStates
();
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
for
(
fst
::
MutableArcIterator
<
Lattice
>
aiter
(
lat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
Arc
arc
(
aiter
.
Value
());
arc
.
olabel
=
0
;
// remove any word.
if
((
arc
.
ilabel
!=
0
)
// has a transition-id on input..
&&
(
trans
.
TransitionIdIsStartOfPhone
(
arc
.
ilabel
))
&&
(
!
trans
.
IsSelfLoop
(
arc
.
ilabel
)))
{
// && trans.IsFinal(arc.ilabel)) // there is one of these per phone...
arc
.
olabel
=
trans
.
TransitionIdToPhone
(
arc
.
ilabel
);
}
aiter
.
SetValue
(
arc
);
}
// end looping over arcs
}
// end looping over states
}
static
inline
double
LogAddOrMax
(
bool
viterbi
,
double
a
,
double
b
)
{
if
(
viterbi
)
return
std
::
max
(
a
,
b
);
else
return
LogAdd
(
a
,
b
);
}
template
<
typename
LatticeType
>
double
ComputeLatticeAlphasAndBetas
(
const
LatticeType
&
lat
,
bool
viterbi
,
vector
<
double
>
*
alpha
,
vector
<
double
>
*
beta
)
{
typedef
typename
LatticeType
::
Arc
Arc
;
typedef
typename
Arc
::
Weight
Weight
;
typedef
typename
Arc
::
StateId
StateId
;
StateId
num_states
=
lat
.
NumStates
();
KALDI_ASSERT
(
lat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
fst
::
kTopSorted
);
KALDI_ASSERT
(
lat
.
Start
()
==
0
);
alpha
->
clear
();
beta
->
clear
();
alpha
->
resize
(
num_states
,
kLogZeroDouble
);
beta
->
resize
(
num_states
,
kLogZeroDouble
);
double
tot_forward_prob
=
kLogZeroDouble
;
(
*
alpha
)[
0
]
=
0.0
;
// Propagate alphas forward.
for
(
StateId
s
=
0
;
s
<
num_states
;
s
++
)
{
double
this_alpha
=
(
*
alpha
)[
s
];
for
(
fst
::
ArcIterator
<
LatticeType
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
);
(
*
alpha
)[
arc
.
nextstate
]
=
LogAddOrMax
(
viterbi
,
(
*
alpha
)[
arc
.
nextstate
],
this_alpha
+
arc_like
);
}
Weight
f
=
lat
.
Final
(
s
);
if
(
f
!=
Weight
::
Zero
())
{
double
final_like
=
this_alpha
-
ConvertToCost
(
f
);
tot_forward_prob
=
LogAddOrMax
(
viterbi
,
tot_forward_prob
,
final_like
);
}
}
for
(
StateId
s
=
num_states
-
1
;
s
>=
0
;
s
--
)
{
// it's guaranteed signed.
double
this_beta
=
-
ConvertToCost
(
lat
.
Final
(
s
));
for
(
fst
::
ArcIterator
<
LatticeType
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
),
arc_beta
=
(
*
beta
)[
arc
.
nextstate
]
+
arc_like
;
this_beta
=
LogAddOrMax
(
viterbi
,
this_beta
,
arc_beta
);
}
(
*
beta
)[
s
]
=
this_beta
;
}
double
tot_backward_prob
=
(
*
beta
)[
lat
.
Start
()];
if
(
!
ApproxEqual
(
tot_forward_prob
,
tot_backward_prob
,
1e-8
))
{
KALDI_WARN
<<
"Total forward probability over lattice = "
<<
tot_forward_prob
<<
", while total backward probability = "
<<
tot_backward_prob
;
}
// Split the difference when returning... they should be the same.
return
0.5
*
(
tot_backward_prob
+
tot_forward_prob
);
}
// instantiate the template for Lattice and CompactLattice
template
double
ComputeLatticeAlphasAndBetas
(
const
Lattice
&
lat
,
bool
viterbi
,
vector
<
double
>
*
alpha
,
vector
<
double
>
*
beta
);
template
double
ComputeLatticeAlphasAndBetas
(
const
CompactLattice
&
lat
,
bool
viterbi
,
vector
<
double
>
*
alpha
,
vector
<
double
>
*
beta
);
/// This is used in CompactLatticeLimitDepth.
struct
LatticeArcRecord
{
BaseFloat
logprob
;
// logprob <= 0 is the best Viterbi logprob of this arc,
// minus the overall best-cost of the lattice.
CompactLatticeArc
::
StateId
state
;
// state in the lattice.
size_t
arc
;
// arc index within the state.
bool
operator
<
(
const
LatticeArcRecord
&
other
)
const
{
return
logprob
<
other
.
logprob
;
}
};
void
CompactLatticeLimitDepth
(
int32
max_depth_per_frame
,
CompactLattice
*
clat
)
{
typedef
CompactLatticeArc
Arc
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
StateId
StateId
;
if
(
clat
->
Start
()
==
fst
::
kNoStateId
)
{
KALDI_WARN
<<
"Limiting depth of empty lattice."
;
return
;
}
if
(
clat
->
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
if
(
!
TopSort
(
clat
))
KALDI_ERR
<<
"Topological sorting of lattice failed."
;
}
vector
<
int32
>
state_times
;
int32
T
=
CompactLatticeStateTimes
(
*
clat
,
&
state_times
);
// The alpha and beta quantities here are "viterbi" alphas and beta.
std
::
vector
<
double
>
alpha
;
std
::
vector
<
double
>
beta
;
bool
viterbi
=
true
;
double
best_prob
=
ComputeLatticeAlphasAndBetas
(
*
clat
,
viterbi
,
&
alpha
,
&
beta
);
std
::
vector
<
std
::
vector
<
LatticeArcRecord
>
>
arc_records
(
T
);
StateId
num_states
=
clat
->
NumStates
();
for
(
StateId
s
=
0
;
s
<
num_states
;
s
++
)
{
for
(
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
*
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
LatticeArcRecord
arc_record
;
arc_record
.
state
=
s
;
arc_record
.
arc
=
aiter
.
Position
();
arc_record
.
logprob
=
(
alpha
[
s
]
+
beta
[
arc
.
nextstate
]
-
ConvertToCost
(
arc
.
weight
))
-
best_prob
;
KALDI_ASSERT
(
arc_record
.
logprob
<
0.1
);
// Should be zero or negative.
int32
num_frames
=
arc
.
weight
.
String
().
size
(),
start_t
=
state_times
[
s
];
for
(
int32
t
=
start_t
;
t
<
start_t
+
num_frames
;
t
++
)
{
KALDI_ASSERT
(
t
<
T
);
arc_records
[
t
].
push_back
(
arc_record
);
}
}
}
StateId
dead_state
=
clat
->
AddState
();
// A non-coaccesible state which we use
// to remove arcs (make them end
// there).
size_t
max_depth
=
max_depth_per_frame
;
for
(
int32
t
=
0
;
t
<
T
;
t
++
)
{
size_t
size
=
arc_records
[
t
].
size
();
if
(
size
>
max_depth
)
{
// we sort from worst to best, so we keep the later-numbered ones,
// and delete the lower-numbered ones.
size_t
cutoff
=
size
-
max_depth
;
std
::
nth_element
(
arc_records
[
t
].
begin
(),
arc_records
[
t
].
begin
()
+
cutoff
,
arc_records
[
t
].
end
());
for
(
size_t
index
=
0
;
index
<
cutoff
;
index
++
)
{
LatticeArcRecord
record
(
arc_records
[
t
][
index
]);
fst
::
MutableArcIterator
<
CompactLattice
>
aiter
(
clat
,
record
.
state
);
aiter
.
Seek
(
record
.
arc
);
Arc
arc
=
aiter
.
Value
();
if
(
arc
.
nextstate
!=
dead_state
)
{
// not already killed.
arc
.
nextstate
=
dead_state
;
aiter
.
SetValue
(
arc
);
}
}
}
}
Connect
(
clat
);
TopSortCompactLatticeIfNeeded
(
clat
);
}
void
TopSortCompactLatticeIfNeeded
(
CompactLattice
*
clat
)
{
if
(
clat
->
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
if
(
fst
::
TopSort
(
clat
)
==
false
)
{
KALDI_ERR
<<
"Topological sorting failed"
;
}
}
}
void
TopSortLatticeIfNeeded
(
Lattice
*
lat
)
{
if
(
lat
->
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
if
(
fst
::
TopSort
(
lat
)
==
false
)
{
KALDI_ERR
<<
"Topological sorting failed"
;
}
}
}
/// Returns the depth of the lattice, defined as the average number of
/// arcs crossing any given frame. Returns 1 for empty lattices.
/// Requires that input is topologically sorted.
BaseFloat
CompactLatticeDepth
(
const
CompactLattice
&
clat
,
int32
*
num_frames
)
{
typedef
CompactLattice
::
Arc
::
StateId
StateId
;
if
(
clat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
KALDI_ERR
<<
"Lattice input to CompactLatticeDepth was not topologically "
<<
"sorted."
;
}
if
(
clat
.
Start
()
==
fst
::
kNoStateId
)
{
*
num_frames
=
0
;
return
1.0
;
}
size_t
num_arc_frames
=
0
;
int32
t
;
{
vector
<
int32
>
state_times
;
t
=
CompactLatticeStateTimes
(
clat
,
&
state_times
);
}
if
(
num_frames
!=
NULL
)
*
num_frames
=
t
;
for
(
StateId
s
=
0
;
s
<
clat
.
NumStates
();
s
++
)
{
for
(
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
CompactLatticeArc
&
arc
=
aiter
.
Value
();
num_arc_frames
+=
arc
.
weight
.
String
().
size
();
}
num_arc_frames
+=
clat
.
Final
(
s
).
String
().
size
();
}
return
num_arc_frames
/
static_cast
<
BaseFloat
>
(
t
);
}
void
CompactLatticeDepthPerFrame
(
const
CompactLattice
&
clat
,
std
::
vector
<
int32
>
*
depth_per_frame
)
{
typedef
CompactLattice
::
Arc
::
StateId
StateId
;
if
(
clat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
KALDI_ERR
<<
"Lattice input to CompactLatticeDepthPerFrame was not "
<<
"topologically sorted."
;
}
if
(
clat
.
Start
()
==
fst
::
kNoStateId
)
{
depth_per_frame
->
clear
();
return
;
}
vector
<
int32
>
state_times
;
int32
T
=
CompactLatticeStateTimes
(
clat
,
&
state_times
);
depth_per_frame
->
clear
();
if
(
T
<=
0
)
{
return
;
}
else
{
depth_per_frame
->
resize
(
T
,
0
);
for
(
StateId
s
=
0
;
s
<
clat
.
NumStates
();
s
++
)
{
int32
start_time
=
state_times
[
s
];
for
(
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
CompactLatticeArc
&
arc
=
aiter
.
Value
();
int32
len
=
arc
.
weight
.
String
().
size
();
for
(
int32
t
=
start_time
;
t
<
start_time
+
len
;
t
++
)
{
KALDI_ASSERT
(
t
<
T
);
(
*
depth_per_frame
)[
t
]
++
;
}
}
int32
final_len
=
clat
.
Final
(
s
).
String
().
size
();
for
(
int32
t
=
start_time
;
t
<
start_time
+
final_len
;
t
++
)
{
KALDI_ASSERT
(
t
<
T
);
(
*
depth_per_frame
)[
t
]
++
;
}
}
}
}
void
ConvertCompactLatticeToPhones
(
const
TransitionInformation
&
trans
,
CompactLattice
*
clat
)
{
typedef
CompactLatticeArc
Arc
;
typedef
Arc
::
Weight
Weight
;
int32
num_states
=
clat
->
NumStates
();
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
for
(
fst
::
MutableArcIterator
<
CompactLattice
>
aiter
(
clat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
Arc
arc
(
aiter
.
Value
());
std
::
vector
<
int32
>
phone_seq
;
const
std
::
vector
<
int32
>
&
tid_seq
=
arc
.
weight
.
String
();
for
(
std
::
vector
<
int32
>::
const_iterator
iter
=
tid_seq
.
begin
();
iter
!=
tid_seq
.
end
();
++
iter
)
{
if
(
trans
.
IsFinal
(
*
iter
))
// note: there is one of these per phone...
phone_seq
.
push_back
(
trans
.
TransitionIdToPhone
(
*
iter
));
}
arc
.
weight
.
SetString
(
phone_seq
);
aiter
.
SetValue
(
arc
);
}
// end looping over arcs
Weight
f
=
clat
->
Final
(
state
);
if
(
f
!=
Weight
::
Zero
())
{
std
::
vector
<
int32
>
phone_seq
;
const
std
::
vector
<
int32
>
&
tid_seq
=
f
.
String
();
for
(
std
::
vector
<
int32
>::
const_iterator
iter
=
tid_seq
.
begin
();
iter
!=
tid_seq
.
end
();
++
iter
)
{
if
(
trans
.
IsFinal
(
*
iter
))
// note: there is one of these per phone...
phone_seq
.
push_back
(
trans
.
TransitionIdToPhone
(
*
iter
));
}
f
.
SetString
(
phone_seq
);
clat
->
SetFinal
(
state
,
f
);
}
}
// end looping over states
}
bool
LatticeBoost
(
const
TransitionInformation
&
trans
,
const
std
::
vector
<
int32
>
&
alignment
,
const
std
::
vector
<
int32
>
&
silence_phones
,
BaseFloat
b
,
BaseFloat
max_silence_error
,
Lattice
*
lat
)
{
TopSortLatticeIfNeeded
(
lat
);
// get all stored properties (test==false means don't test if not known).
uint64
props
=
lat
->
Properties
(
fst
::
kFstProperties
,
false
);
KALDI_ASSERT
(
IsSortedAndUniq
(
silence_phones
));
KALDI_ASSERT
(
max_silence_error
>=
0.0
&&
max_silence_error
<=
1.0
);
vector
<
int32
>
state_times
;
int32
num_states
=
lat
->
NumStates
();
int32
num_frames
=
LatticeStateTimes
(
*
lat
,
&
state_times
);
KALDI_ASSERT
(
num_frames
==
static_cast
<
int32
>
(
alignment
.
size
()));
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
int32
cur_time
=
state_times
[
state
];
for
(
fst
::
MutableArcIterator
<
Lattice
>
aiter
(
lat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
LatticeArc
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
)
{
// Non-epsilon arc
if
(
arc
.
ilabel
<
0
||
arc
.
ilabel
>
trans
.
NumTransitionIds
())
{
KALDI_WARN
<<
"Lattice has out-of-range transition-ids: "
<<
"lattice/model mismatch?"
;
return
false
;
}
int32
phone
=
trans
.
TransitionIdToPhone
(
arc
.
ilabel
),
ref_phone
=
trans
.
TransitionIdToPhone
(
alignment
[
cur_time
]);
BaseFloat
frame_error
;
if
(
phone
==
ref_phone
)
{
frame_error
=
0.0
;
}
else
{
// an error...
if
(
std
::
binary_search
(
silence_phones
.
begin
(),
silence_phones
.
end
(),
phone
))
frame_error
=
max_silence_error
;
else
frame_error
=
1.0
;
}
BaseFloat
delta_cost
=
-
b
*
frame_error
;
// negative cost if
// frame is wrong, to boost likelihood of arcs with errors on them.
// Add this cost to the graph part.
arc
.
weight
.
SetValue1
(
arc
.
weight
.
Value1
()
+
delta_cost
);
aiter
.
SetValue
(
arc
);
}
}
}
// All we changed is the weights, so any properties that were
// known before, are still known, except for whether or not the
// lattice was weighted.
lat
->
SetProperties
(
props
,
~
(
fst
::
kWeighted
|
fst
::
kUnweighted
));
return
true
;
}
BaseFloat
LatticeForwardBackwardMpeVariants
(
const
TransitionInformation
&
trans
,
const
std
::
vector
<
int32
>
&
silence_phones
,
const
Lattice
&
lat
,
const
std
::
vector
<
int32
>
&
num_ali
,
std
::
string
criterion
,
bool
one_silence_class
,
Posterior
*
post
)
{
using
namespace
fst
;
typedef
Lattice
::
Arc
Arc
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
StateId
StateId
;
KALDI_ASSERT
(
criterion
==
"mpfe"
||
criterion
==
"smbr"
);
bool
is_mpfe
=
(
criterion
==
"mpfe"
);
if
(
lat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
KALDI_ERR
<<
"Input lattice must be topologically sorted."
;
KALDI_ASSERT
(
lat
.
Start
()
==
0
);
int32
num_states
=
lat
.
NumStates
();
vector
<
int32
>
state_times
;
int32
max_time
=
LatticeStateTimes
(
lat
,
&
state_times
);
KALDI_ASSERT
(
max_time
==
static_cast
<
int32
>
(
num_ali
.
size
()));
std
::
vector
<
double
>
alpha
(
num_states
,
kLogZeroDouble
),
alpha_smbr
(
num_states
,
0
),
//forward variable for sMBR
beta
(
num_states
,
kLogZeroDouble
),
beta_smbr
(
num_states
,
0
);
//backward variable for sMBR
double
tot_forward_prob
=
kLogZeroDouble
;
double
tot_forward_score
=
0
;
post
->
clear
();
post
->
resize
(
max_time
);
alpha
[
0
]
=
0.0
;
// First Pass Forward,
for
(
StateId
s
=
0
;
s
<
num_states
;
s
++
)
{
double
this_alpha
=
alpha
[
s
];
for
(
ArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
);
alpha
[
arc
.
nextstate
]
=
LogAdd
(
alpha
[
arc
.
nextstate
],
this_alpha
+
arc_like
);
}
Weight
f
=
lat
.
Final
(
s
);
if
(
f
!=
Weight
::
Zero
())
{
double
final_like
=
this_alpha
-
(
f
.
Value1
()
+
f
.
Value2
());
tot_forward_prob
=
LogAdd
(
tot_forward_prob
,
final_like
);
KALDI_ASSERT
(
state_times
[
s
]
==
max_time
&&
"Lattice is inconsistent (final-prob not at max_time)"
);
}
}
// First Pass Backward,
for
(
StateId
s
=
num_states
-
1
;
s
>=
0
;
s
--
)
{
Weight
f
=
lat
.
Final
(
s
);
double
this_beta
=
-
(
f
.
Value1
()
+
f
.
Value2
());
for
(
ArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
),
arc_beta
=
beta
[
arc
.
nextstate
]
+
arc_like
;
this_beta
=
LogAdd
(
this_beta
,
arc_beta
);
}
beta
[
s
]
=
this_beta
;
}
// First Pass Forward-Backward Check
double
tot_backward_prob
=
beta
[
0
];
// may loose the condition somehow here 1e-6 (was 1e-8)
if
(
!
ApproxEqual
(
tot_forward_prob
,
tot_backward_prob
,
1e-6
))
{
KALDI_ERR
<<
"Total forward probability over lattice = "
<<
tot_forward_prob
<<
", while total backward probability = "
<<
tot_backward_prob
;
}
alpha_smbr
[
0
]
=
0.0
;
// Second Pass Forward, calculate forward for MPFE/SMBR
for
(
StateId
s
=
0
;
s
<
num_states
;
s
++
)
{
double
this_alpha
=
alpha
[
s
];
for
(
ArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
);
double
frame_acc
=
0.0
;
if
(
arc
.
ilabel
!=
0
)
{
int32
cur_time
=
state_times
[
s
];
int32
phone
=
trans
.
TransitionIdToPhone
(
arc
.
ilabel
),
ref_phone
=
trans
.
TransitionIdToPhone
(
num_ali
[
cur_time
]);
bool
phone_is_sil
=
std
::
binary_search
(
silence_phones
.
begin
(),
silence_phones
.
end
(),
phone
),
ref_phone_is_sil
=
std
::
binary_search
(
silence_phones
.
begin
(),
silence_phones
.
end
(),
ref_phone
),
both_sil
=
phone_is_sil
&&
ref_phone_is_sil
;
if
(
!
is_mpfe
)
{
// smbr.
int32
pdf
=
trans
.
TransitionIdToPdf
(
arc
.
ilabel
),
ref_pdf
=
trans
.
TransitionIdToPdf
(
num_ali
[
cur_time
]);
if
(
!
one_silence_class
)
// old behavior
frame_acc
=
(
pdf
==
ref_pdf
&&
!
phone_is_sil
)
?
1.0
:
0.0
;
else
frame_acc
=
(
pdf
==
ref_pdf
||
both_sil
)
?
1.0
:
0.0
;
}
else
{
if
(
!
one_silence_class
)
// old behavior
frame_acc
=
(
phone
==
ref_phone
&&
!
phone_is_sil
)
?
1.0
:
0.0
;
else
frame_acc
=
(
phone
==
ref_phone
||
both_sil
)
?
1.0
:
0.0
;
}
}
double
arc_scale
=
Exp
(
alpha
[
s
]
+
arc_like
-
alpha
[
arc
.
nextstate
]);
alpha_smbr
[
arc
.
nextstate
]
+=
arc_scale
*
(
alpha_smbr
[
s
]
+
frame_acc
);
}
Weight
f
=
lat
.
Final
(
s
);
if
(
f
!=
Weight
::
Zero
())
{
double
final_like
=
this_alpha
-
(
f
.
Value1
()
+
f
.
Value2
());
double
arc_scale
=
Exp
(
final_like
-
tot_forward_prob
);
tot_forward_score
+=
arc_scale
*
alpha_smbr
[
s
];
KALDI_ASSERT
(
state_times
[
s
]
==
max_time
&&
"Lattice is inconsistent (final-prob not at max_time)"
);
}
}
// Second Pass Backward, collect Mpe style posteriors
for
(
StateId
s
=
num_states
-
1
;
s
>=
0
;
s
--
)
{
for
(
ArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
),
arc_beta
=
beta
[
arc
.
nextstate
]
+
arc_like
;
double
frame_acc
=
0.0
;
int32
transition_id
=
arc
.
ilabel
;
if
(
arc
.
ilabel
!=
0
)
{
int32
cur_time
=
state_times
[
s
];
int32
phone
=
trans
.
TransitionIdToPhone
(
arc
.
ilabel
),
ref_phone
=
trans
.
TransitionIdToPhone
(
num_ali
[
cur_time
]);
bool
phone_is_sil
=
std
::
binary_search
(
silence_phones
.
begin
(),
silence_phones
.
end
(),
phone
),
ref_phone_is_sil
=
std
::
binary_search
(
silence_phones
.
begin
(),
silence_phones
.
end
(),
ref_phone
),
both_sil
=
phone_is_sil
&&
ref_phone_is_sil
;
if
(
!
is_mpfe
)
{
// smbr.
int32
pdf
=
trans
.
TransitionIdToPdf
(
arc
.
ilabel
),
ref_pdf
=
trans
.
TransitionIdToPdf
(
num_ali
[
cur_time
]);
if
(
!
one_silence_class
)
// old behavior
frame_acc
=
(
pdf
==
ref_pdf
&&
!
phone_is_sil
)
?
1.0
:
0.0
;
else
frame_acc
=
(
pdf
==
ref_pdf
||
both_sil
)
?
1.0
:
0.0
;
}
else
{
if
(
!
one_silence_class
)
// old behavior
frame_acc
=
(
phone
==
ref_phone
&&
!
phone_is_sil
)
?
1.0
:
0.0
;
else
frame_acc
=
(
phone
==
ref_phone
||
both_sil
)
?
1.0
:
0.0
;
}
}
double
arc_scale
=
Exp
(
beta
[
arc
.
nextstate
]
+
arc_like
-
beta
[
s
]);
// check arc_scale NAN,
// this is to prevent partial paths in Lattices
// i.e., paths don't survive to the final state
if
(
KALDI_ISNAN
(
arc_scale
))
arc_scale
=
0
;
beta_smbr
[
s
]
+=
arc_scale
*
(
beta_smbr
[
arc
.
nextstate
]
+
frame_acc
);
if
(
transition_id
!=
0
)
{
// Arc has a transition-id on it [not epsilon]
double
posterior
=
Exp
(
alpha
[
s
]
+
arc_beta
-
tot_forward_prob
);
double
acc_diff
=
alpha_smbr
[
s
]
+
frame_acc
+
beta_smbr
[
arc
.
nextstate
]
-
tot_forward_score
;
double
posterior_smbr
=
posterior
*
acc_diff
;
(
*
post
)[
state_times
[
s
]].
push_back
(
std
::
make_pair
(
transition_id
,
static_cast
<
BaseFloat
>
(
posterior_smbr
)));
}
}
}
//Second Pass Forward Backward check
double
tot_backward_score
=
beta_smbr
[
0
];
// Initial state id == 0
// may loose the condition somehow here 1e-5/1e-4
if
(
!
ApproxEqual
(
tot_forward_score
,
tot_backward_score
,
1e-4
))
{
KALDI_ERR
<<
"Total forward score over lattice = "
<<
tot_forward_score
<<
", while total backward score = "
<<
tot_backward_score
;
}
// Output the computed posteriors
for
(
int32
t
=
0
;
t
<
max_time
;
t
++
)
MergePairVectorSumming
(
&
((
*
post
)[
t
]));
return
tot_forward_score
;
}
bool
CompactLatticeToWordAlignment
(
const
CompactLattice
&
clat
,
std
::
vector
<
int32
>
*
words
,
std
::
vector
<
int32
>
*
begin_times
,
std
::
vector
<
int32
>
*
lengths
)
{
words
->
clear
();
begin_times
->
clear
();
lengths
->
clear
();
typedef
CompactLattice
::
Arc
Arc
;
typedef
Arc
::
Label
Label
;
typedef
CompactLattice
::
StateId
StateId
;
typedef
CompactLattice
::
Weight
Weight
;
using
namespace
fst
;
StateId
state
=
clat
.
Start
();
int32
cur_time
=
0
;
if
(
state
==
kNoStateId
)
{
KALDI_WARN
<<
"Empty lattice."
;
return
false
;
}
while
(
1
)
{
Weight
final
=
clat
.
Final
(
state
);
size_t
num_arcs
=
clat
.
NumArcs
(
state
);
if
(
final
!=
Weight
::
Zero
())
{
if
(
num_arcs
!=
0
)
{
KALDI_WARN
<<
"Lattice is not linear."
;
return
false
;
}
if
(
!
final
.
String
().
empty
())
{
KALDI_WARN
<<
"Lattice has alignments on final-weight: probably "
"was not word-aligned (alignments will be approximate)"
;
}
return
true
;
}
else
{
if
(
num_arcs
!=
1
)
{
KALDI_WARN
<<
"Lattice is not linear: num-arcs = "
<<
num_arcs
;
return
false
;
}
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
state
);
const
Arc
&
arc
=
aiter
.
Value
();
Label
word_id
=
arc
.
ilabel
;
// Note: ilabel==olabel, since acceptor.
// Also note: word_id may be zero; we output it anyway.
int32
length
=
arc
.
weight
.
String
().
size
();
words
->
push_back
(
word_id
);
begin_times
->
push_back
(
cur_time
);
lengths
->
push_back
(
length
);
cur_time
+=
length
;
state
=
arc
.
nextstate
;
}
}
}
void
CompactLatticeShortestPath
(
const
CompactLattice
&
clat
,
CompactLattice
*
shortest_path
)
{
using
namespace
fst
;
if
(
clat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
CompactLattice
clat_copy
(
clat
);
if
(
!
TopSort
(
&
clat_copy
))
KALDI_ERR
<<
"Was not able to topologically sort lattice (cycles found?)"
;
CompactLatticeShortestPath
(
clat_copy
,
shortest_path
);
return
;
}
// Now we can assume it's topologically sorted.
shortest_path
->
DeleteStates
();
if
(
clat
.
Start
()
==
kNoStateId
)
return
;
typedef
CompactLatticeArc
Arc
;
typedef
Arc
::
StateId
StateId
;
typedef
CompactLatticeWeight
Weight
;
vector
<
std
::
pair
<
double
,
StateId
>
>
best_cost_and_pred
(
clat
.
NumStates
()
+
1
);
StateId
superfinal
=
clat
.
NumStates
();
for
(
StateId
s
=
0
;
s
<=
clat
.
NumStates
();
s
++
)
{
best_cost_and_pred
[
s
].
first
=
std
::
numeric_limits
<
double
>::
infinity
();
best_cost_and_pred
[
s
].
second
=
fst
::
kNoStateId
;
}
best_cost_and_pred
[
clat
.
Start
()].
first
=
0
;
for
(
StateId
s
=
0
;
s
<
clat
.
NumStates
();
s
++
)
{
double
my_cost
=
best_cost_and_pred
[
s
].
first
;
for
(
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
arc_cost
=
ConvertToCost
(
arc
.
weight
),
next_cost
=
my_cost
+
arc_cost
;
if
(
next_cost
<
best_cost_and_pred
[
arc
.
nextstate
].
first
)
{
best_cost_and_pred
[
arc
.
nextstate
].
first
=
next_cost
;
best_cost_and_pred
[
arc
.
nextstate
].
second
=
s
;
}
}
double
final_cost
=
ConvertToCost
(
clat
.
Final
(
s
)),
tot_final
=
my_cost
+
final_cost
;
if
(
tot_final
<
best_cost_and_pred
[
superfinal
].
first
)
{
best_cost_and_pred
[
superfinal
].
first
=
tot_final
;
best_cost_and_pred
[
superfinal
].
second
=
s
;
}
}
std
::
vector
<
StateId
>
states
;
// states on best path.
StateId
cur_state
=
superfinal
,
start_state
=
clat
.
Start
();
while
(
cur_state
!=
start_state
)
{
StateId
prev_state
=
best_cost_and_pred
[
cur_state
].
second
;
if
(
prev_state
==
kNoStateId
)
{
KALDI_WARN
<<
"Failure in best-path algorithm for lattice (infinite costs?)"
;
return
;
// return empty best-path.
}
states
.
push_back
(
prev_state
);
KALDI_ASSERT
(
cur_state
!=
prev_state
&&
"Lattice with cycles"
);
cur_state
=
prev_state
;
}
std
::
reverse
(
states
.
begin
(),
states
.
end
());
for
(
size_t
i
=
0
;
i
<
states
.
size
();
i
++
)
shortest_path
->
AddState
();
for
(
StateId
s
=
0
;
static_cast
<
size_t
>
(
s
)
<
states
.
size
();
s
++
)
{
if
(
s
==
0
)
shortest_path
->
SetStart
(
s
);
if
(
static_cast
<
size_t
>
(
s
+
1
)
<
states
.
size
())
{
// transition to next state.
bool
have_arc
=
false
;
Arc
cur_arc
;
for
(
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
states
[
s
]);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
nextstate
==
states
[
s
+
1
])
{
if
(
!
have_arc
||
ConvertToCost
(
arc
.
weight
)
<
ConvertToCost
(
cur_arc
.
weight
))
{
cur_arc
=
arc
;
have_arc
=
true
;
}
}
}
KALDI_ASSERT
(
have_arc
&&
"Code error."
);
shortest_path
->
AddArc
(
s
,
Arc
(
cur_arc
.
ilabel
,
cur_arc
.
olabel
,
cur_arc
.
weight
,
s
+
1
));
}
else
{
// final-prob.
shortest_path
->
SetFinal
(
s
,
clat
.
Final
(
states
[
s
]));
}
}
}
void
ExpandCompactLattice
(
const
CompactLattice
&
clat
,
double
epsilon
,
CompactLattice
*
expand_clat
)
{
using
namespace
fst
;
typedef
CompactLattice
::
Arc
Arc
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
StateId
StateId
;
typedef
std
::
pair
<
StateId
,
StateId
>
StatePair
;
typedef
unordered_map
<
StatePair
,
StateId
,
PairHasher
<
StateId
>
>
MapType
;
typedef
MapType
::
iterator
IterType
;
if
(
clat
.
Start
()
==
kNoStateId
)
return
;
// Make sure the input lattice is topologically sorted.
if
(
clat
.
Properties
(
kTopSorted
,
true
)
==
0
)
{
CompactLattice
clat_copy
(
clat
);
KALDI_LOG
<<
"Topsort this lattice."
;
if
(
!
TopSort
(
&
clat_copy
))
KALDI_ERR
<<
"Was not able to topologically sort lattice (cycles found?)"
;
ExpandCompactLattice
(
clat_copy
,
epsilon
,
expand_clat
);
return
;
}
// Compute backward logprobs betas for the expanded lattice.
// Note: the backward logprobs in the original lattice <clat> and the
// expanded lattice <expand_clat> are the same.
int32
num_states
=
clat
.
NumStates
();
std
::
vector
<
double
>
beta
(
num_states
,
kLogZeroDouble
);
ComputeCompactLatticeBetas
(
clat
,
&
beta
);
double
tot_backward_logprob
=
beta
[
0
];
std
::
vector
<
double
>
alpha
;
alpha
.
push_back
(
0.0
);
expand_clat
->
DeleteStates
();
MapType
state_map
;
// Map from state pair (orig_state, copy_state) to
// copy_state, where orig_state is a state in the original lattice, and
// copy_state is its corresponding one in the expanded lattice.
unordered_map
<
StateId
,
StateId
>
states
;
// Map from orig_state to its
// copy_state for states with incoming arcs' posteriors <= epsilon.
std
::
queue
<
StatePair
>
state_queue
;
// Set start state in the expanded lattice.
StateId
start_state
=
expand_clat
->
AddState
();
expand_clat
->
SetStart
(
start_state
);
StatePair
start_pair
(
clat
.
Start
(),
start_state
);
state_queue
.
push
(
start_pair
);
std
::
pair
<
IterType
,
bool
>
result
=
state_map
.
insert
(
std
::
make_pair
(
start_pair
,
start_state
));
KALDI_ASSERT
(
result
.
second
==
true
);
// Expand <clat> and update forward logprobs alphas in <expand_clat>.
while
(
!
state_queue
.
empty
())
{
StatePair
s
=
state_queue
.
front
();
StateId
s1
=
s
.
first
,
s2
=
s
.
second
;
state_queue
.
pop
();
Weight
f
=
clat
.
Final
(
s1
);
if
(
f
!=
Weight
::
Zero
())
{
KALDI_ASSERT
(
state_map
.
find
(
s
)
!=
state_map
.
end
());
expand_clat
->
SetFinal
(
state_map
[
s
],
f
);
}
for
(
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s1
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
StateId
orig_state
=
arc
.
nextstate
;
double
arc_like
=
-
ConvertToCost
(
arc
.
weight
),
this_alpha
=
alpha
[
s2
]
+
arc_like
,
arc_post
=
Exp
(
this_alpha
+
beta
[
orig_state
]
-
tot_backward_logprob
);
// Generate the expanded lattice.
StateId
copy_state
;
if
(
arc_post
>
epsilon
)
{
copy_state
=
expand_clat
->
AddState
();
StatePair
next_pair
(
orig_state
,
copy_state
);
std
::
pair
<
IterType
,
bool
>
result
=
state_map
.
insert
(
std
::
make_pair
(
next_pair
,
copy_state
));
KALDI_ASSERT
(
result
.
second
==
true
);
state_queue
.
push
(
next_pair
);
}
else
{
unordered_map
<
StateId
,
StateId
>::
iterator
iter
=
states
.
find
(
orig_state
);
if
(
iter
==
states
.
end
()
)
{
// The counterpart state of orig_state
// has not been created in <expand_clat> yet.
copy_state
=
expand_clat
->
AddState
();
StatePair
next_pair
(
orig_state
,
copy_state
);
std
::
pair
<
IterType
,
bool
>
result
=
state_map
.
insert
(
std
::
make_pair
(
next_pair
,
copy_state
));
KALDI_ASSERT
(
result
.
second
==
true
);
state_queue
.
push
(
next_pair
);
states
[
orig_state
]
=
copy_state
;
}
else
{
copy_state
=
iter
->
second
;
}
}
// Create an arc from state_map[s] to copy_state in the expanded lattice.
expand_clat
->
AddArc
(
state_map
[
s
],
Arc
(
arc
.
ilabel
,
arc
.
olabel
,
arc
.
weight
,
copy_state
));
// Compute forward logprobs alpha for the expanded lattice.
if
((
alpha
.
size
()
-
1
)
<
copy_state
)
{
// The first time to compute alpha
// for copy_state in <expand_clat>.
alpha
.
push_back
(
this_alpha
);
}
else
{
// Accumulate alpha.
alpha
[
copy_state
]
=
LogAdd
(
alpha
[
copy_state
],
this_alpha
);
}
}
}
// end while
}
void
CompactLatticeBestCostsAndTracebacks
(
const
CompactLattice
&
clat
,
CostTraceType
*
forward_best_cost_and_pred
,
CostTraceType
*
backward_best_cost_and_pred
)
{
// typedef the arc, weight types
typedef
CompactLatticeArc
Arc
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
StateId
StateId
;
forward_best_cost_and_pred
->
clear
();
backward_best_cost_and_pred
->
clear
();
forward_best_cost_and_pred
->
resize
(
clat
.
NumStates
());
backward_best_cost_and_pred
->
resize
(
clat
.
NumStates
());
// Initialize the cost and predecessor state for each state.
for
(
StateId
s
=
0
;
s
<
clat
.
NumStates
();
s
++
)
{
(
*
forward_best_cost_and_pred
)[
s
].
first
=
std
::
numeric_limits
<
double
>::
infinity
();
(
*
backward_best_cost_and_pred
)[
s
].
first
=
std
::
numeric_limits
<
double
>::
infinity
();
(
*
forward_best_cost_and_pred
)[
s
].
second
=
fst
::
kNoStateId
;
(
*
backward_best_cost_and_pred
)[
s
].
second
=
fst
::
kNoStateId
;
}
StateId
start_state
=
clat
.
Start
();
(
*
forward_best_cost_and_pred
)[
start_state
].
first
=
0
;
// Transverse the lattice forwardly to compute the best cost from the start
// state to each state and the best predecessor state of each state.
for
(
StateId
s
=
0
;
s
<
clat
.
NumStates
();
s
++
)
{
double
cur_cost
=
(
*
forward_best_cost_and_pred
)[
s
].
first
;
for
(
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
next_cost
=
cur_cost
+
ConvertToCost
(
arc
.
weight
);
if
(
next_cost
<
(
*
forward_best_cost_and_pred
)[
arc
.
nextstate
].
first
)
{
(
*
forward_best_cost_and_pred
)[
arc
.
nextstate
].
first
=
next_cost
;
(
*
forward_best_cost_and_pred
)[
arc
.
nextstate
].
second
=
s
;
}
}
}
// Transverse the lattice backwardly to compute the best cost from a final
// state to each state and the best predecessor state of each state.
for
(
StateId
s
=
clat
.
NumStates
()
-
1
;
s
>=
0
;
s
--
)
{
double
this_cost
=
ConvertToCost
(
clat
.
Final
(
s
));
for
(
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
double
next_cost
=
(
*
backward_best_cost_and_pred
)[
arc
.
nextstate
].
first
+
ConvertToCost
(
arc
.
weight
);
if
(
next_cost
<
this_cost
)
{
this_cost
=
next_cost
;
(
*
backward_best_cost_and_pred
)[
s
].
second
=
arc
.
nextstate
;
}
}
(
*
backward_best_cost_and_pred
)[
s
].
first
=
this_cost
;
}
}
void
AddNnlmScoreToCompactLattice
(
const
MapT
&
nnlm_scores
,
CompactLattice
*
clat
)
{
if
(
clat
->
Start
()
==
fst
::
kNoStateId
)
return
;
// Make sure the input lattice is topologically sorted.
if
(
clat
->
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
KALDI_LOG
<<
"Topsort this lattice."
;
if
(
!
TopSort
(
clat
))
KALDI_ERR
<<
"Was not able to topologically sort lattice (cycles found?)"
;
AddNnlmScoreToCompactLattice
(
nnlm_scores
,
clat
);
return
;
}
// typedef the arc, weight types
typedef
CompactLatticeArc
Arc
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
StateId
StateId
;
typedef
std
::
pair
<
int32
,
int32
>
StatePair
;
int32
num_states
=
clat
->
NumStates
();
unordered_map
<
StatePair
,
bool
,
PairHasher
<
int32
>
>
final_state_check
;
for
(
StateId
s
=
0
;
s
<
num_states
;
s
++
)
{
for
(
fst
::
MutableArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
Arc
arc
(
aiter
.
Value
());
StatePair
arc_index
=
std
::
make_pair
(
static_cast
<
int32
>
(
s
),
static_cast
<
int32
>
(
arc
.
nextstate
));
MapT
::
const_iterator
it
=
nnlm_scores
.
find
(
arc_index
);
double
nnlm_score
;
if
(
it
!=
nnlm_scores
.
end
())
nnlm_score
=
it
->
second
;
else
KALDI_ERR
<<
"Some arc does not have neural language model score."
;
if
(
arc
.
ilabel
!=
0
)
{
// if there is a word on this arc
LatticeWeight
weight
=
arc
.
weight
.
Weight
();
// Add associated neural LM score to each arc.
weight
.
SetValue1
(
weight
.
Value1
()
+
nnlm_score
);
arc
.
weight
.
SetWeight
(
weight
);
aiter
.
SetValue
(
arc
);
}
Weight
clat_final
=
clat
->
Final
(
arc
.
nextstate
);
StatePair
final_pair
=
std
::
make_pair
(
arc
.
nextstate
,
arc
.
nextstate
);
// Add neural LM scores to each final state only once.
if
(
clat_final
!=
CompactLatticeWeight
::
Zero
()
&&
final_state_check
.
find
(
final_pair
)
==
final_state_check
.
end
())
{
MapT
::
const_iterator
final_it
=
nnlm_scores
.
find
(
final_pair
);
double
final_nnlm_score
=
0.0
;
if
(
final_it
!=
nnlm_scores
.
end
())
final_nnlm_score
=
final_it
->
second
;
// Add neural LM scores to the final weight.
Weight
final_weight
(
LatticeWeight
(
clat_final
.
Weight
().
Value1
()
+
final_nnlm_score
,
clat_final
.
Weight
().
Value2
()),
clat_final
.
String
());
clat
->
SetFinal
(
arc
.
nextstate
,
final_weight
);
final_state_check
[
final_pair
]
=
true
;
}
}
// end looping over arcs
}
// end looping over states
}
void
AddWordInsPenToCompactLattice
(
BaseFloat
word_ins_penalty
,
CompactLattice
*
clat
)
{
typedef
CompactLatticeArc
Arc
;
int32
num_states
=
clat
->
NumStates
();
//scan the lattice
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
for
(
fst
::
MutableArcIterator
<
CompactLattice
>
aiter
(
clat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
Arc
arc
(
aiter
.
Value
());
if
(
arc
.
ilabel
!=
0
)
{
// if there is a word on this arc
LatticeWeight
weight
=
arc
.
weight
.
Weight
();
// add word insertion penalty to lattice
weight
.
SetValue1
(
weight
.
Value1
()
+
word_ins_penalty
);
arc
.
weight
.
SetWeight
(
weight
);
aiter
.
SetValue
(
arc
);
}
}
// end looping over arcs
}
// end looping over states
}
struct
ClatRescoreTuple
{
ClatRescoreTuple
(
int32
state
,
int32
arc
,
int32
tid
)
:
state_id
(
state
),
arc_id
(
arc
),
tid
(
tid
)
{
}
int32
state_id
;
int32
arc_id
;
int32
tid
;
};
/** RescoreCompactLatticeInternal is the internal code for both
RescoreCompactLattice and RescoreCompatLatticeSpeedup. For
RescoreCompactLattice, "tmodel" will be NULL and speedup_factor will be 1.0.
*/
bool
RescoreCompactLatticeInternal
(
const
TransitionInformation
*
tmodel
,
BaseFloat
speedup_factor
,
DecodableInterface
*
decodable
,
CompactLattice
*
clat
)
{
KALDI_ASSERT
(
speedup_factor
>=
1.0
);
if
(
clat
->
NumStates
()
==
0
)
{
KALDI_WARN
<<
"Rescoring empty lattice"
;
return
false
;
}
if
(
!
clat
->
Properties
(
fst
::
kTopSorted
,
true
))
{
if
(
fst
::
TopSort
(
clat
)
==
false
)
{
KALDI_WARN
<<
"Cycles detected in lattice."
;
return
false
;
}
}
std
::
vector
<
int32
>
state_times
;
int32
utt_len
=
kaldi
::
CompactLatticeStateTimes
(
*
clat
,
&
state_times
);
std
::
vector
<
std
::
vector
<
ClatRescoreTuple
>
>
time_to_state
(
utt_len
);
int32
num_states
=
clat
->
NumStates
();
KALDI_ASSERT
(
num_states
==
state_times
.
size
());
for
(
size_t
state
=
0
;
state
<
num_states
;
state
++
)
{
KALDI_ASSERT
(
state_times
[
state
]
>=
0
);
int32
t
=
state_times
[
state
];
int32
arc_id
=
0
;
for
(
fst
::
MutableArcIterator
<
CompactLattice
>
aiter
(
clat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
(),
arc_id
++
)
{
CompactLatticeArc
arc
=
aiter
.
Value
();
std
::
vector
<
int32
>
arc_string
=
arc
.
weight
.
String
();
for
(
size_t
offset
=
0
;
offset
<
arc_string
.
size
();
offset
++
)
{
if
(
t
<
utt_len
)
{
// end state may be past this..
int32
tid
=
arc_string
[
offset
];
time_to_state
[
t
+
offset
].
push_back
(
ClatRescoreTuple
(
state
,
arc_id
,
tid
));
}
else
{
if
(
t
!=
utt_len
)
{
KALDI_WARN
<<
"There appears to be lattice/feature mismatch, "
<<
"aborting."
;
return
false
;
}
}
}
}
if
(
clat
->
Final
(
state
)
!=
CompactLatticeWeight
::
Zero
())
{
arc_id
=
-
1
;
std
::
vector
<
int32
>
arc_string
=
clat
->
Final
(
state
).
String
();
for
(
size_t
offset
=
0
;
offset
<
arc_string
.
size
();
offset
++
)
{
KALDI_ASSERT
(
t
+
offset
<
utt_len
);
// already checked in
// CompactLatticeStateTimes, so would be code error.
time_to_state
[
t
+
offset
].
push_back
(
ClatRescoreTuple
(
state
,
arc_id
,
arc_string
[
offset
]));
}
}
}
for
(
int32
t
=
0
;
t
<
utt_len
;
t
++
)
{
if
((
t
<
utt_len
-
1
)
&&
decodable
->
IsLastFrame
(
t
))
{
KALDI_WARN
<<
"Features are too short for lattice: utt-len is "
<<
utt_len
<<
", "
<<
t
<<
" is last frame"
;
return
false
;
}
// frame_scale is the scale we put on the computed acoustic probs for this
// frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not doing
// the "speedup" code). For frames with multiple pdf-ids it will be one.
// For frames with only one pdf-id, it will equal speedup_factor (>=1.0)
// with probability 1.0 / speedup_factor, and zero otherwise. If it is zero,
// we can avoid computing the probabilities.
BaseFloat
frame_scale
=
1.0
;
KALDI_ASSERT
(
!
time_to_state
[
t
].
empty
());
if
(
tmodel
!=
NULL
)
{
int32
pdf_id
=
tmodel
->
TransitionIdToPdf
(
time_to_state
[
t
][
0
].
tid
);
bool
frame_has_multiple_pdfs
=
false
;
for
(
size_t
i
=
1
;
i
<
time_to_state
[
t
].
size
();
i
++
)
{
if
(
tmodel
->
TransitionIdToPdf
(
time_to_state
[
t
][
i
].
tid
)
!=
pdf_id
)
{
frame_has_multiple_pdfs
=
true
;
break
;
}
}
if
(
frame_has_multiple_pdfs
)
{
frame_scale
=
1.0
;
}
else
{
if
(
WithProb
(
1.0
/
speedup_factor
))
{
frame_scale
=
speedup_factor
;
}
else
{
frame_scale
=
0.0
;
}
}
if
(
frame_scale
==
0.0
)
continue
;
// the code below would be pointless.
}
for
(
size_t
i
=
0
;
i
<
time_to_state
[
t
].
size
();
i
++
)
{
int32
state
=
time_to_state
[
t
][
i
].
state_id
;
int32
arc_id
=
time_to_state
[
t
][
i
].
arc_id
;
int32
tid
=
time_to_state
[
t
][
i
].
tid
;
if
(
arc_id
==
-
1
)
{
// Final state
// Access the trans_id
CompactLatticeWeight
curr_clat_weight
=
clat
->
Final
(
state
);
// Calculate likelihood
BaseFloat
log_like
=
decodable
->
LogLikelihood
(
t
,
tid
)
*
frame_scale
;
// update weight
CompactLatticeWeight
new_clat_weight
=
curr_clat_weight
;
LatticeWeight
new_lat_weight
=
new_clat_weight
.
Weight
();
new_lat_weight
.
SetValue2
(
-
log_like
+
curr_clat_weight
.
Weight
().
Value2
());
new_clat_weight
.
SetWeight
(
new_lat_weight
);
clat
->
SetFinal
(
state
,
new_clat_weight
);
}
else
{
fst
::
MutableArcIterator
<
CompactLattice
>
aiter
(
clat
,
state
);
aiter
.
Seek
(
arc_id
);
CompactLatticeArc
arc
=
aiter
.
Value
();
// Calculate likelihood
BaseFloat
log_like
=
decodable
->
LogLikelihood
(
t
,
tid
)
*
frame_scale
;
// update weight
LatticeWeight
new_weight
=
arc
.
weight
.
Weight
();
new_weight
.
SetValue2
(
-
log_like
+
arc
.
weight
.
Weight
().
Value2
());
arc
.
weight
.
SetWeight
(
new_weight
);
aiter
.
SetValue
(
arc
);
}
}
}
return
true
;
}
bool
RescoreCompactLatticeSpeedup
(
const
TransitionInformation
&
tmodel
,
BaseFloat
speedup_factor
,
DecodableInterface
*
decodable
,
CompactLattice
*
clat
)
{
return
RescoreCompactLatticeInternal
(
&
tmodel
,
speedup_factor
,
decodable
,
clat
);
}
bool
RescoreCompactLattice
(
DecodableInterface
*
decodable
,
CompactLattice
*
clat
)
{
return
RescoreCompactLatticeInternal
(
NULL
,
1.0
,
decodable
,
clat
);
}
bool
RescoreLattice
(
DecodableInterface
*
decodable
,
Lattice
*
lat
)
{
if
(
lat
->
NumStates
()
==
0
)
{
KALDI_WARN
<<
"Rescoring empty lattice"
;
return
false
;
}
if
(
!
lat
->
Properties
(
fst
::
kTopSorted
,
true
))
{
if
(
fst
::
TopSort
(
lat
)
==
false
)
{
KALDI_WARN
<<
"Cycles detected in lattice."
;
return
false
;
}
}
std
::
vector
<
int32
>
state_times
;
int32
utt_len
=
kaldi
::
LatticeStateTimes
(
*
lat
,
&
state_times
);
std
::
vector
<
std
::
vector
<
int32
>
>
time_to_state
(
utt_len
);
int32
num_states
=
lat
->
NumStates
();
KALDI_ASSERT
(
num_states
==
state_times
.
size
());
for
(
size_t
state
=
0
;
state
<
num_states
;
state
++
)
{
int32
t
=
state_times
[
state
];
// Don't check t >= 0 because non-accessible states could have t = -1.
KALDI_ASSERT
(
t
<=
utt_len
);
if
(
t
>=
0
&&
t
<
utt_len
)
time_to_state
[
t
].
push_back
(
state
);
}
for
(
int32
t
=
0
;
t
<
utt_len
;
t
++
)
{
if
((
t
<
utt_len
-
1
)
&&
decodable
->
IsLastFrame
(
t
))
{
KALDI_WARN
<<
"Features are too short for lattice: utt-len is "
<<
utt_len
<<
", "
<<
t
<<
" is last frame"
;
return
false
;
}
for
(
size_t
i
=
0
;
i
<
time_to_state
[
t
].
size
();
i
++
)
{
int32
state
=
time_to_state
[
t
][
i
];
for
(
fst
::
MutableArcIterator
<
Lattice
>
aiter
(
lat
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
LatticeArc
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
)
{
int32
trans_id
=
arc
.
ilabel
;
// Note: it doesn't necessarily
// have to be a transition-id, just whatever the Decodable
// object is expecting, but it's normally a transition-id.
BaseFloat
log_like
=
decodable
->
LogLikelihood
(
t
,
trans_id
);
arc
.
weight
.
SetValue2
(
-
log_like
+
arc
.
weight
.
Value2
());
aiter
.
SetValue
(
arc
);
}
}
}
}
return
true
;
}
int32
LongestSentenceLength
(
const
Lattice
&
lat
)
{
typedef
Lattice
::
Arc
Arc
;
typedef
Arc
::
Label
Label
;
typedef
Arc
::
StateId
StateId
;
if
(
lat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
Lattice
lat_copy
(
lat
);
if
(
!
TopSort
(
&
lat_copy
))
KALDI_ERR
<<
"Was not able to topologically sort lattice (cycles found?)"
;
return
LongestSentenceLength
(
lat_copy
);
}
std
::
vector
<
int32
>
max_length
(
lat
.
NumStates
(),
0
);
int32
lattice_max_length
=
0
;
for
(
StateId
s
=
0
;
s
<
lat
.
NumStates
();
s
++
)
{
int32
this_max_length
=
max_length
[
s
];
for
(
fst
::
ArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
bool
arc_has_word
=
(
arc
.
olabel
!=
0
);
StateId
nextstate
=
arc
.
nextstate
;
KALDI_ASSERT
(
static_cast
<
size_t
>
(
nextstate
)
<
max_length
.
size
());
if
(
arc_has_word
)
{
// A lattice should ideally not have cycles anyway; a cycle with a word
// on is something very bad.
KALDI_ASSERT
(
nextstate
>
s
&&
"Lattice has cycles with words on."
);
max_length
[
nextstate
]
=
std
::
max
(
max_length
[
nextstate
],
this_max_length
+
1
);
}
else
{
max_length
[
nextstate
]
=
std
::
max
(
max_length
[
nextstate
],
this_max_length
);
}
}
if
(
lat
.
Final
(
s
)
!=
LatticeWeight
::
Zero
())
lattice_max_length
=
std
::
max
(
lattice_max_length
,
max_length
[
s
]);
}
return
lattice_max_length
;
}
int32
LongestSentenceLength
(
const
CompactLattice
&
clat
)
{
typedef
CompactLattice
::
Arc
Arc
;
typedef
Arc
::
Label
Label
;
typedef
Arc
::
StateId
StateId
;
if
(
clat
.
Properties
(
fst
::
kTopSorted
,
true
)
==
0
)
{
CompactLattice
clat_copy
(
clat
);
if
(
!
TopSort
(
&
clat_copy
))
KALDI_ERR
<<
"Was not able to topologically sort lattice (cycles found?)"
;
return
LongestSentenceLength
(
clat_copy
);
}
std
::
vector
<
int32
>
max_length
(
clat
.
NumStates
(),
0
);
int32
lattice_max_length
=
0
;
for
(
StateId
s
=
0
;
s
<
clat
.
NumStates
();
s
++
)
{
int32
this_max_length
=
max_length
[
s
];
for
(
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
bool
arc_has_word
=
(
arc
.
ilabel
!=
0
);
// note: olabel == ilabel.
// also note: for normal CompactLattice, e.g. as produced by
// determinization, all arcs will have nonzero labels, but the user might
// decide to remplace some of the labels with zero for some reason, and we
// want to support this.
StateId
nextstate
=
arc
.
nextstate
;
KALDI_ASSERT
(
static_cast
<
size_t
>
(
nextstate
)
<
max_length
.
size
());
KALDI_ASSERT
(
nextstate
>
s
&&
"CompactLattice has cycles"
);
if
(
arc_has_word
)
max_length
[
nextstate
]
=
std
::
max
(
max_length
[
nextstate
],
this_max_length
+
1
);
else
max_length
[
nextstate
]
=
std
::
max
(
max_length
[
nextstate
],
this_max_length
);
}
if
(
clat
.
Final
(
s
)
!=
CompactLatticeWeight
::
Zero
())
lattice_max_length
=
std
::
max
(
lattice_max_length
,
max_length
[
s
]);
}
return
lattice_max_length
;
}
void
ComposeCompactLatticeDeterministic
(
const
CompactLattice
&
clat
,
fst
::
DeterministicOnDemandFst
<
fst
::
StdArc
>*
det_fst
,
CompactLattice
*
composed_clat
)
{
// StdFst::Arc and CompactLatticeArc has the same StateId type.
typedef
fst
::
StdArc
::
StateId
StateId
;
typedef
fst
::
StdArc
::
Weight
Weight1
;
typedef
CompactLatticeArc
::
Weight
Weight2
;
typedef
std
::
pair
<
StateId
,
StateId
>
StatePair
;
typedef
unordered_map
<
StatePair
,
StateId
,
PairHasher
<
StateId
>
>
MapType
;
typedef
MapType
::
iterator
IterType
;
// Empties the output FST.
KALDI_ASSERT
(
composed_clat
!=
NULL
);
composed_clat
->
DeleteStates
();
MapType
state_map
;
std
::
queue
<
StatePair
>
state_queue
;
// Sets start state in <composed_clat>.
StateId
start_state
=
composed_clat
->
AddState
();
StatePair
start_pair
(
clat
.
Start
(),
det_fst
->
Start
());
composed_clat
->
SetStart
(
start_state
);
state_queue
.
push
(
start_pair
);
std
::
pair
<
IterType
,
bool
>
result
=
state_map
.
insert
(
std
::
make_pair
(
start_pair
,
start_state
));
KALDI_ASSERT
(
result
.
second
==
true
);
// Starts composition here.
while
(
!
state_queue
.
empty
())
{
// Gets the first state in the queue.
StatePair
s
=
state_queue
.
front
();
StateId
s1
=
s
.
first
;
StateId
s2
=
s
.
second
;
state_queue
.
pop
();
Weight2
clat_final
=
clat
.
Final
(
s1
);
if
(
clat_final
.
Weight
().
Value1
()
!=
std
::
numeric_limits
<
BaseFloat
>::
infinity
())
{
// Test for whether the final-prob of state s1 was zero.
Weight1
det_fst_final
=
det_fst
->
Final
(
s2
);
if
(
det_fst_final
.
Value
()
!=
std
::
numeric_limits
<
BaseFloat
>::
infinity
())
{
// Test for whether the final-prob of state s2 was zero. If neither
// source-state final prob was zero, then we should create final state
// in fst_composed. We compute the product manually since this is more
// efficient.
Weight2
final_weight
(
LatticeWeight
(
clat_final
.
Weight
().
Value1
()
+
det_fst_final
.
Value
(),
clat_final
.
Weight
().
Value2
()),
clat_final
.
String
());
// we can assume final_weight is not Zero(), since neither of
// the sources was zero.
KALDI_ASSERT
(
state_map
.
find
(
s
)
!=
state_map
.
end
());
composed_clat
->
SetFinal
(
state_map
[
s
],
final_weight
);
}
}
// Loops over pair of edges at s1 and s2.
for
(
fst
::
ArcIterator
<
CompactLattice
>
aiter
(
clat
,
s1
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
CompactLatticeArc
&
arc1
=
aiter
.
Value
();
fst
::
StdArc
arc2
;
StateId
next_state1
=
arc1
.
nextstate
,
next_state2
;
bool
matched
=
false
;
if
(
arc1
.
olabel
==
0
)
{
// If the symbol on <arc1> is <epsilon>, we transit to the next state
// for <clat>, but keep <det_fst> at the current state.
matched
=
true
;
next_state2
=
s2
;
}
else
{
// Otherwise try to find the matched arc in <det_fst>.
matched
=
det_fst
->
GetArc
(
s2
,
arc1
.
olabel
,
&
arc2
);
if
(
matched
)
{
next_state2
=
arc2
.
nextstate
;
}
}
// If matched arc is found in <det_fst>, then we have to add new arcs to
// <composed_clat>.
if
(
matched
)
{
StatePair
next_state_pair
(
next_state1
,
next_state2
);
IterType
siter
=
state_map
.
find
(
next_state_pair
);
StateId
next_state
;
// Adds composed state to <state_map>.
if
(
siter
==
state_map
.
end
())
{
// If the composed state has not been created yet, create it.
next_state
=
composed_clat
->
AddState
();
std
::
pair
<
const
StatePair
,
StateId
>
next_state_map
(
next_state_pair
,
next_state
);
std
::
pair
<
IterType
,
bool
>
result
=
state_map
.
insert
(
next_state_map
);
KALDI_ASSERT
(
result
.
second
);
state_queue
.
push
(
next_state_pair
);
}
else
{
// If the composed state is already in <state_map>, we can directly
// use that.
next_state
=
siter
->
second
;
}
// Adds arc to <composed_clat>.
if
(
arc1
.
olabel
==
0
)
{
composed_clat
->
AddArc
(
state_map
[
s
],
CompactLatticeArc
(
arc1
.
ilabel
,
0
,
arc1
.
weight
,
next_state
));
}
else
{
Weight2
composed_weight
(
LatticeWeight
(
arc1
.
weight
.
Weight
().
Value1
()
+
arc2
.
weight
.
Value
(),
arc1
.
weight
.
Weight
().
Value2
()),
arc1
.
weight
.
String
());
composed_clat
->
AddArc
(
state_map
[
s
],
CompactLatticeArc
(
arc1
.
ilabel
,
arc2
.
olabel
,
composed_weight
,
next_state
));
}
}
}
}
fst
::
Connect
(
composed_clat
);
}
void
ComputeAcousticScoresMap
(
const
Lattice
&
lat
,
unordered_map
<
std
::
pair
<
int32
,
int32
>
,
std
::
pair
<
BaseFloat
,
int32
>
,
PairHasher
<
int32
>
>
*
acoustic_scores
)
{
// typedef the arc, weight types
typedef
Lattice
::
Arc
Arc
;
typedef
Arc
::
Weight
LatticeWeight
;
typedef
Arc
::
StateId
StateId
;
acoustic_scores
->
clear
();
std
::
vector
<
int32
>
state_times
;
LatticeStateTimes
(
lat
,
&
state_times
);
// Assumes the input is top sorted
KALDI_ASSERT
(
lat
.
Start
()
==
0
);
for
(
StateId
s
=
0
;
s
<
lat
.
NumStates
();
s
++
)
{
int32
t
=
state_times
[
s
];
for
(
fst
::
ArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
const
LatticeWeight
&
weight
=
arc
.
weight
;
int32
tid
=
arc
.
ilabel
;
if
(
tid
!=
0
)
{
unordered_map
<
std
::
pair
<
int32
,
int32
>
,
std
::
pair
<
BaseFloat
,
int32
>
,
PairHasher
<
int32
>
>::
iterator
it
=
acoustic_scores
->
find
(
std
::
make_pair
(
t
,
tid
));
if
(
it
==
acoustic_scores
->
end
())
{
acoustic_scores
->
insert
(
std
::
make_pair
(
std
::
make_pair
(
t
,
tid
),
std
::
make_pair
(
weight
.
Value2
(),
1
)));
}
else
{
if
(
it
->
second
.
second
==
2
&&
it
->
second
.
first
/
it
->
second
.
second
!=
weight
.
Value2
())
{
KALDI_VLOG
(
2
)
<<
"Transitions on the same frame have different "
<<
"acoustic costs for tid "
<<
tid
<<
"; "
<<
it
->
second
.
first
/
it
->
second
.
second
<<
" vs "
<<
weight
.
Value2
();
}
it
->
second
.
first
+=
weight
.
Value2
();
it
->
second
.
second
++
;
}
}
else
{
// Arcs with epsilon input label (tid) must have 0 acoustic cost
KALDI_ASSERT
(
weight
.
Value2
()
==
0
);
}
}
LatticeWeight
f
=
lat
.
Final
(
s
);
if
(
f
!=
LatticeWeight
::
Zero
())
{
// Final acoustic cost must be 0 as we are reading from
// non-determinized, non-compact lattice
KALDI_ASSERT
(
f
.
Value2
()
==
0.0
);
}
}
}
void
ReplaceAcousticScoresFromMap
(
const
unordered_map
<
std
::
pair
<
int32
,
int32
>
,
std
::
pair
<
BaseFloat
,
int32
>
,
PairHasher
<
int32
>
>
&
acoustic_scores
,
Lattice
*
lat
)
{
// typedef the arc, weight types
typedef
Lattice
::
Arc
Arc
;
typedef
Arc
::
Weight
LatticeWeight
;
typedef
Arc
::
StateId
StateId
;
TopSortLatticeIfNeeded
(
lat
);
std
::
vector
<
int32
>
state_times
;
LatticeStateTimes
(
*
lat
,
&
state_times
);
KALDI_ASSERT
(
lat
->
Start
()
==
0
);
for
(
StateId
s
=
0
;
s
<
lat
->
NumStates
();
s
++
)
{
int32
t
=
state_times
[
s
];
for
(
fst
::
MutableArcIterator
<
Lattice
>
aiter
(
lat
,
s
);
!
aiter
.
Done
();
aiter
.
Next
())
{
Arc
arc
(
aiter
.
Value
());
int32
tid
=
arc
.
ilabel
;
if
(
tid
!=
0
)
{
unordered_map
<
std
::
pair
<
int32
,
int32
>
,
std
::
pair
<
BaseFloat
,
int32
>
,
PairHasher
<
int32
>
>::
const_iterator
it
=
acoustic_scores
.
find
(
std
::
make_pair
(
t
,
tid
));
if
(
it
==
acoustic_scores
.
end
())
{
KALDI_ERR
<<
"Could not find tid "
<<
tid
<<
" at time "
<<
t
<<
" in the acoustic scores map."
;
}
else
{
arc
.
weight
.
SetValue2
(
it
->
second
.
first
/
it
->
second
.
second
);
}
}
else
{
// For epsilon arcs, set acoustic cost to 0.0
arc
.
weight
.
SetValue2
(
0.0
);
}
aiter
.
SetValue
(
arc
);
}
LatticeWeight
f
=
lat
->
Final
(
s
);
if
(
f
!=
LatticeWeight
::
Zero
())
{
// Set final acoustic cost to 0.0
f
.
SetValue2
(
0.0
);
lat
->
SetFinal
(
s
,
f
);
}
}
}
}
// namespace kaldi
speechx/speechx/kaldi/lat/lattice-functions.h
0 → 100644
浏览文件 @
d14ee800
// lat/lattice-functions.h
// Copyright 2009-2012 Saarland University (author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey);
// Bagher BabaAli
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_LATTICE_FUNCTIONS_H_
#define KALDI_LAT_LATTICE_FUNCTIONS_H_
#include <vector>
#include <map>
#include "base/kaldi-common.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "itf/transition-information.h"
#include "lat/kaldi-lattice.h"
namespace
kaldi
{
// Redundant with the typedef in hmm/posterior.h. We want functions
// using the Posterior type to be usable without a dependency on the
// hmm library.
typedef
std
::
vector
<
std
::
vector
<
std
::
pair
<
int32
,
BaseFloat
>
>
>
Posterior
;
/**
This function extracts the per-frame log likelihoods from a linear
lattice (which we refer to as an 'nbest' lattice elsewhere in Kaldi code).
The dimension of *per_frame_loglikes will be set to the
number of input symbols in 'nbest'. The elements of
'*per_frame_loglikes' will be set to the .Value2() elements of the lattice
weights, which represent the acoustic costs; you may want to scale this
vector afterward by -1/acoustic_scale to get the original loglikes.
If there are acoustic costs on input-epsilon arcs or the final-prob in 'nbest'
(and this should not normally be the case in situations where it makes
sense to call this function), they will be included to the cost of the
preceding input symbol, or the following input symbol for input-epsilons
encountered prior to any input symbol. If 'nbest' has no input symbols,
'per_frame_loglikes' will be set to the empty vector.
**/
void
GetPerFrameAcousticCosts
(
const
Lattice
&
nbest
,
Vector
<
BaseFloat
>
*
per_frame_loglikes
);
/// This function iterates over the states of a topologically sorted lattice and
/// counts the time instance corresponding to each state. The times are returned
/// in a vector of integers 'times' which is resized to have a size equal to the
/// number of states in the lattice. The function also returns the maximum time
/// in the lattice (this will equal the number of frames in the file).
int32
LatticeStateTimes
(
const
Lattice
&
lat
,
std
::
vector
<
int32
>
*
times
);
/// As LatticeStateTimes, but in the CompactLattice format. Note: must
/// be topologically sorted. Returns length of the utterance in frames, which
/// might not be the same as the maximum time in the lattice, due to frames
/// in the final-prob.
int32
CompactLatticeStateTimes
(
const
CompactLattice
&
clat
,
std
::
vector
<
int32
>
*
times
);
/// This function does the forward-backward over lattices and computes the
/// posterior probabilities of the arcs. It returns the total log-probability
/// of the lattice. The Posterior quantities contain pairs of (transition-id, weight)
/// on each frame.
/// If the pointer "acoustic_like_sum" is provided, this value is set to
/// the sum over the arcs, of the posterior of the arc times the
/// acoustic likelihood [i.e. negated acoustic score] on that link.
/// This is used in combination with other quantities to work out
/// the objective function in MMI discriminative training.
BaseFloat
LatticeForwardBackward
(
const
Lattice
&
lat
,
Posterior
*
arc_post
,
double
*
acoustic_like_sum
=
NULL
);
// This function is something similar to LatticeForwardBackward(), but it is on
// the CompactLattice lattice format. Also we only need the alpha in the forward
// path, not the posteriors.
bool
ComputeCompactLatticeAlphas
(
const
CompactLattice
&
lat
,
std
::
vector
<
double
>
*
alpha
);
// A sibling of the function CompactLatticeAlphas()... We compute the beta from
// the backward path here.
bool
ComputeCompactLatticeBetas
(
const
CompactLattice
&
lat
,
std
::
vector
<
double
>
*
beta
);
// Computes (normal or Viterbi) alphas and betas; returns (total-prob, or
// best-path negated cost) Note: in either case, the alphas and betas are
// negated costs. Requires that lat be topologically sorted. This code
// will work for either CompactLattice or Lattice.
template
<
typename
LatticeType
>
double
ComputeLatticeAlphasAndBetas
(
const
LatticeType
&
lat
,
bool
viterbi
,
std
::
vector
<
double
>
*
alpha
,
std
::
vector
<
double
>
*
beta
);
/// Topologically sort the compact lattice if not already topologically sorted.
/// Will crash if the lattice cannot be topologically sorted.
void
TopSortCompactLatticeIfNeeded
(
CompactLattice
*
clat
);
/// Topologically sort the lattice if not already topologically sorted.
/// Will crash if lattice cannot be topologically sorted.
void
TopSortLatticeIfNeeded
(
Lattice
*
clat
);
/// Returns the depth of the lattice, defined as the average number of arcs (or
/// final-prob strings) crossing any given frame. Returns 1 for empty lattices.
/// Requires that clat is topologically sorted!
BaseFloat
CompactLatticeDepth
(
const
CompactLattice
&
clat
,
int32
*
num_frames
=
NULL
);
/// This function returns, for each frame, the number of arcs crossing that
/// frame.
void
CompactLatticeDepthPerFrame
(
const
CompactLattice
&
clat
,
std
::
vector
<
int32
>
*
depth_per_frame
);
/// This function limits the depth of the lattice, per frame: that means, it
/// does not allow more than a specified number of arcs active on any given
/// frame. This can be used to reduce the size of the "very deep" portions of
/// the lattice.
void
CompactLatticeLimitDepth
(
int32
max_arcs_per_frame
,
CompactLattice
*
clat
);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// outputs for each frame the set of phones active on that frame. If
/// sil_phones (which must be sorted and uniq) is nonempty, it excludes
/// phones in this list.
void
LatticeActivePhones
(
const
Lattice
&
lat
,
const
TransitionInformation
&
trans
,
const
std
::
vector
<
int32
>
&
sil_phones
,
std
::
vector
<
std
::
set
<
int32
>
>
*
active_phones
);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// replace the output symbols (presumably words), with phones; we
/// use the TransitionModel to work out the phone sequence. Note
/// that the phone labels are not exactly aligned with the phone
/// boundaries. We put a phone label to coincide with any transition
/// to the final, nonemitting state of a phone (this state always exists,
/// we ensure this in HmmTopology::Check()). This would be the last
/// transition-id in the phone if reordering is not done (but typically
/// we do reorder).
/// Also see PhoneAlignLattice, in phone-align-lattice.h.
void
ConvertLatticeToPhones
(
const
TransitionInformation
&
trans_model
,
Lattice
*
lat
);
/// Prunes a lattice or compact lattice. Returns true on success, false if
/// there was some kind of failure.
template
<
class
LatticeType
>
bool
PruneLattice
(
BaseFloat
beam
,
LatticeType
*
lat
);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// replace the sequences of transition-ids with sequences of phones.
/// Note that this is different from ConvertLatticeToPhones, in that
/// we replace the transition-ids not the words.
void
ConvertCompactLatticeToPhones
(
const
TransitionInformation
&
trans_model
,
CompactLattice
*
clat
);
/// Boosts LM probabilities by b * [number of frame errors]; equivalently, adds
/// -b*[number of frame errors] to the graph-component of the cost of each arc/path.
/// There is a frame error if a particular transition-id on a particular frame
/// corresponds to a phone not matching transcription's alignment for that frame.
/// This is used in "margin-inspired" discriminative training, esp. Boosted MMI.
/// The TransitionInformation is used to map transition-ids in the lattice
/// input-side to phones; the phones appearing in
/// "silence_phones" are treated specially in that we replace the frame error f
/// (either zero or 1) for a frame, with the minimum of f or max_silence_error.
/// For the normal recipe, max_silence_error would be zero.
/// Returns true on success, false if there was some kind of mismatch.
/// At input, silence_phones must be sorted and unique.
bool
LatticeBoost
(
const
TransitionInformation
&
trans
,
const
std
::
vector
<
int32
>
&
alignment
,
const
std
::
vector
<
int32
>
&
silence_phones
,
BaseFloat
b
,
BaseFloat
max_silence_error
,
Lattice
*
lat
);
/**
This function implements either the MPFE (minimum phone frame error) or SMBR
(state-level minimum bayes risk) forward-backward, depending on whether
"criterion" is "mpfe" or "smbr". It returns the MPFE
criterion of SMBR criterion for this utterance, and outputs the posteriors (which
may be positive or negative) into "post".
@param [in] trans The transition model. Used to map the
transition-ids to phones or pdfs.
@param [in] silence_phones A list of integer ids of silence phones. The
silence frames i.e. the frames where num_ali
corresponds to a silence phones are treated specially.
The behavior is determined by 'one_silence_class'
being false (traditional behavior) or true.
Usually in our setup, several phones including
the silence, vocalized noise, non-spoken noise
and unk are treated as "silence phones"
@param [in] lat The denominator lattice
@param [in] num_ali The numerator alignment
@param [in] criterion The objective function. Must be "mpfe" or "smbr"
for MPFE (minimum phone frame error) or sMBR
(state minimum bayes risk) training.
@param [in] one_silence_class Determines how the silence frames are treated.
Setting this to false gives the old traditional behavior,
where the silence frames (according to num_ali) are
treated as incorrect. However, this means that the
insertions are not penalized by the objective.
Setting this to true gives the new behaviour, where we
treat silence as any other phone, except that all pdfs
of silence phones are collapsed into a single class for
the frame-error computation. This can possible reduce
the insertions in the trained model. This is closer to
the WER metric that we actually care about, since WER is
generally computed after filtering out noises, but
does penalize insertions.
@param [out] post The "MBR posteriors" i.e. derivatives w.r.t to the
pseudo log-likelihoods of states at each frame.
*/
BaseFloat
LatticeForwardBackwardMpeVariants
(
const
TransitionInformation
&
trans
,
const
std
::
vector
<
int32
>
&
silence_phones
,
const
Lattice
&
lat
,
const
std
::
vector
<
int32
>
&
num_ali
,
std
::
string
criterion
,
bool
one_silence_class
,
Posterior
*
post
);
/// This function takes a CompactLattice that should only contain a single
/// linear sequence (e.g. derived from lattice-1best), and that should have been
/// processed so that the arcs in the CompactLattice align correctly with the
/// word boundaries (e.g. by lattice-align-words). It outputs 3 vectors of the
/// same size, which give, for each word in the lattice (in sequence), the word
/// label and the begin time and length in frames. This is done even for zero
/// (epsilon) words, generally corresponding to optional silence-- if you don't
/// want them, just ignore them in the output.
/// This function will print a warning and return false, if the lattice
/// did not have the correct format (e.g. if it is empty or it is not
/// linear).
bool
CompactLatticeToWordAlignment
(
const
CompactLattice
&
clat
,
std
::
vector
<
int32
>
*
words
,
std
::
vector
<
int32
>
*
begin_times
,
std
::
vector
<
int32
>
*
lengths
);
/// A form of the shortest-path/best-path algorithm that's specially coded for
/// CompactLattice. Requires that clat be acyclic.
void
CompactLatticeShortestPath
(
const
CompactLattice
&
clat
,
CompactLattice
*
shortest_path
);
/// This function expands a CompactLattice to ensure high-probability paths
/// have unique histories. Arcs with posteriors larger than epsilon get splitted.
void
ExpandCompactLattice
(
const
CompactLattice
&
clat
,
double
epsilon
,
CompactLattice
*
expand_clat
);
/// For each state, compute forward and backward best (viterbi) costs and its
/// traceback states (for generating best paths later). The forward best cost
/// for a state is the cost of the best path from the start state to the state.
/// The traceback state of this state is its predecessor state in the best path.
/// The backward best cost for a state is the cost of the best path from the
/// state to a final one. Its traceback state is the successor state in the best
/// path in the forward direction.
/// Note: final weights of states are in backward_best_cost_and_pred.
/// Requires the input CompactLattice clat be acyclic.
typedef
std
::
vector
<
std
::
pair
<
double
,
CompactLatticeArc
::
StateId
>
>
CostTraceType
;
void
CompactLatticeBestCostsAndTracebacks
(
const
CompactLattice
&
clat
,
CostTraceType
*
forward_best_cost_and_pred
,
CostTraceType
*
backward_best_cost_and_pred
);
/// This function adds estimated neural language model scores of words in a
/// minimal list of hypotheses that covers a lattice, to the graph scores on the
/// arcs. The list of hypotheses are generated by latbin/lattice-path-cover.
typedef
unordered_map
<
std
::
pair
<
int32
,
int32
>
,
double
,
PairHasher
<
int32
>
>
MapT
;
void
AddNnlmScoreToCompactLattice
(
const
MapT
&
nnlm_scores
,
CompactLattice
*
clat
);
/// This function add the word insertion penalty to graph score of each word
/// in the compact lattice
void
AddWordInsPenToCompactLattice
(
BaseFloat
word_ins_penalty
,
CompactLattice
*
clat
);
/// This function *adds* the negated scores obtained from the Decodable object,
/// to the acoustic scores on the arcs. If you want to replace them, you should
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
/// true on success, false on error (typically some kind of mismatched inputs).
bool
RescoreCompactLattice
(
DecodableInterface
*
decodable
,
CompactLattice
*
clat
);
/// This function returns the number of words in the longest sentence in a
/// CompactLattice (i.e. the the maximum of any path, of the count of
/// olabels on that path).
int32
LongestSentenceLength
(
const
Lattice
&
lat
);
/// This function returns the number of words in the longest sentence in a
/// CompactLattice, i.e. the the maximum of any path, of the count of
/// labels on that path... note, in CompactLattice, the ilabels and olabels
/// are identical because it is an acceptor.
int32
LongestSentenceLength
(
const
CompactLattice
&
lat
);
/// This function is like RescoreCompactLattice, but it is modified to avoid
/// computing probabilities on most frames where all the pdf-ids are the same.
/// (it needs the transition-model to work out whether two transition-ids map to
/// the same pdf-id, and it assumes that the lattice has transition-ids on it).
/// The naive thing would be to just set all probabilities to zero on frames
/// where all the pdf-ids are the same (because this value won't affect the
/// lattice posterior). But this would become confusing when we compute
/// corpus-level diagnostics such as the MMI objective function. Instead,
/// imagine speedup_factor = 100 (it must be >= 1.0)... with probability (1.0 /
/// speedup_factor) we compute those likelihoods and multiply them by
/// speedup_factor; otherwise we set them to zero. This gives the right
/// expected probability so our corpus-level diagnostics will be about right.
bool
RescoreCompactLatticeSpeedup
(
const
TransitionInformation
&
tmodel
,
BaseFloat
speedup_factor
,
DecodableInterface
*
decodable
,
CompactLattice
*
clat
);
/// This function *adds* the negated scores obtained from the Decodable object,
/// to the acoustic scores on the arcs. If you want to replace them, you should
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
/// true on success, false on error (e.g. some kind of mismatched inputs).
/// The input labels, if nonzero, are interpreted as transition-ids or whatever
/// other index the Decodable object expects.
bool
RescoreLattice
(
DecodableInterface
*
decodable
,
Lattice
*
lat
);
/// This function Composes a CompactLattice format lattice with a
/// DeterministicOnDemandFst<fst::StdFst> format fst, and outputs another
/// CompactLattice format lattice. The first element (the one that corresponds
/// to LM weight) in CompactLatticeWeight is used for composition.
///
/// Note that the DeterministicOnDemandFst interface is not "const", therefore
/// we cannot use "const" for <det_fst>.
void
ComposeCompactLatticeDeterministic
(
const
CompactLattice
&
clat
,
fst
::
DeterministicOnDemandFst
<
fst
::
StdArc
>*
det_fst
,
CompactLattice
*
composed_clat
);
/// This function computes the mapping from the pair
/// (frame-index, transition-id) to the pair
/// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
/// transition-id in that frame.
/// frame-index in the lattice.
/// This function is useful for retaining the acoustic scores in a
/// non-compact lattice after a process like determinization where the
/// frame-level acoustic scores are typically lost.
/// The function ReplaceAcousticScoresFromMap is used to restore the
/// acoustic scores computed by this function.
///
/// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
/// function will crash.
/// @param [out] acoustic_scores
/// Pointer to a map from the pair (frame-index,
/// transition-id) to a pair (sum-of-acoustic-scores,
/// num-of-occurences).
/// Usually the acoustic scores for a pdf-id (and hence
/// transition-id) on a frame will be the same for all the
/// occurences of the pdf-id in that frame.
/// But if not, we will take the average of the acoustic
/// scores. Hence, we store both the sum-of-acoustic-scores
/// and the num-of-occurences of the transition-id in that
/// frame.
void
ComputeAcousticScoresMap
(
const
Lattice
&
lat
,
unordered_map
<
std
::
pair
<
int32
,
int32
>
,
std
::
pair
<
BaseFloat
,
int32
>
,
PairHasher
<
int32
>
>
*
acoustic_scores
);
/// This function restores acoustic scores computed using the function
/// ComputeAcousticScoresMap into the lattice.
///
/// @param [in] acoustic_scores
/// A map from the pair (frame-index, transition-id) to a
/// pair (sum-of-acoustic-scores, num-of-occurences) of
/// the occurences of the transition-id in that frame.
/// See the comments for ComputeAcousticScoresMap for
/// details.
/// @param [out] lat Pointer to the output lattice.
void
ReplaceAcousticScoresFromMap
(
const
unordered_map
<
std
::
pair
<
int32
,
int32
>
,
std
::
pair
<
BaseFloat
,
int32
>
,
PairHasher
<
int32
>
>
&
acoustic_scores
,
Lattice
*
lat
);
}
// namespace kaldi
#endif // KALDI_LAT_LATTICE_FUNCTIONS_H_
speechx/speechx/nnet/ctc_decodable.h
0 → 100644
浏览文件 @
d14ee800
speechx/speechx/nnet/decodable-itf.h
0 → 100644
浏览文件 @
d14ee800
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Mirko Hannemann; Go Vivace Inc.;
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_ITF_DECODABLE_ITF_H_
#define KALDI_ITF_DECODABLE_ITF_H_ 1
#include "base/kaldi-common.h"
namespace
kaldi
{
/// @ingroup Interfaces
/// @{
/**
DecodableInterface provides a link between the (acoustic-modeling and
feature-processing) code and the decoder. The idea is to make this
interface as small as possible, and to make it as agnostic as possible about
the form of the acoustic model (e.g. don't assume the probabilities are a
function of just a vector of floats), and about the decoder (e.g. don't
assume it accesses frames in strict left-to-right order). For normal
models, without on-line operation, the "decodable" sub-class will just be a
wrapper around a matrix of features and an acoustic model, and it will
answer the question 'what is the acoustic likelihood for this index and this
frame?'.
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
// Process this frame
}
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
\code{.cc}
while (num_frames_decoded_ < decodable.NumFramesReady()) {
// Decode one more frame [increments num_frames_decoded_]
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
For truly online decoding, the "old" online decodable objects in ../online/
have a "blocking" IsLastFrame() and will crash if you call NumFramesReady().
The "new" online decodable objects in ../online2/ return the number of frames
currently accessible if you call NumFramesReady(). You will likely not need
to call IsLastFrame(), but we implement it to only return true for the last
frame of the file once we've decided to terminate decoding.
*/
class
DecodableInterface
{
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// before calling this.
virtual
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
)
=
0
;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual
bool
IsLastFrame
(
int32
frame
)
const
=
0
;
/// The call NumFramesReady() will return the number of frames currently available
/// for this decodable object. This is for use in setups where you don't want the
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding.
virtual
int32
NumFramesReady
()
const
{
KALDI_ERR
<<
"NumFramesReady() not implemented for this decodable type."
;
return
-
1
;
}
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual
int32
NumIndices
()
const
=
0
;
virtual
~
DecodableInterface
()
{}
};
/// @}
}
// namespace Kaldi
#endif // KALDI_ITF_DECODABLE_ITF_H_
speechx/speechx/nnet/decodable.h
0 → 100644
浏览文件 @
d14ee800
#include "nnet/decodable-itf.h"
#include "base/common.h"
namespace
ppsepeech
{
struct
DecodeableConfig
;
class
Decodeable
:
public
kaldi
::
DecodableInterface
{
public:
virtual
Init
(
Decodeable
config
)
=
0
;
virtual
Acceptlikeihood
()
=
0
;
private:
std
::
share_ptr
<
FeatureExtractorInterface
>
frontend_
;
std
::
share_ptr
<
NnetInterface
>
nnet_
;
//Cache nnet_cache_;
}
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/nnet/dnn_decodable.h
0 → 100644
浏览文件 @
d14ee800
speechx/speechx/nnet/nnet_interface.h
浏览文件 @
d14ee800
...
...
@@ -8,7 +8,7 @@ namespace ppspeech {
class
NnetInterface
{
public:
virtual
~
Nnet
Forward
Interface
()
{}
virtual
~
NnetInterface
()
{}
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
)
const
=
0
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录