Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
2430545d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2430545d
编写于
10月 13, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update vector ctc prefix score
上级
331bd9ea
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
71 addition
and
73 deletion
+71
-73
deepspeech/decoders/scores/ctc.py
deepspeech/decoders/scores/ctc.py
+4
-4
deepspeech/decoders/scores/ctc_prefix_score.py
deepspeech/decoders/scores/ctc_prefix_score.py
+67
-69
未找到文件。
deepspeech/decoders/scores/ctc.py
浏览文件 @
2430545d
...
@@ -4,7 +4,7 @@ import numpy as np
...
@@ -4,7 +4,7 @@ import numpy as np
import
paddle
import
paddle
from
.ctc_prefix_score
import
CTCPrefixScore
from
.ctc_prefix_score
import
CTCPrefixScore
from
.ctc_prefix_score
import
CTCPrefixScore
TH
from
.ctc_prefix_score
import
CTCPrefixScore
PD
from
.scorer_interface
import
BatchPartialScorerInterface
from
.scorer_interface
import
BatchPartialScorerInterface
...
@@ -34,7 +34,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
...
@@ -34,7 +34,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
"""
"""
logp
=
self
.
ctc
.
log_softmax
(
x
.
unsqueeze
(
0
)).
squeeze
(
0
).
numpy
()
logp
=
self
.
ctc
.
log_softmax
(
x
.
unsqueeze
(
0
)).
squeeze
(
0
).
numpy
()
# TODO(karita): use CTCPrefixScore
TH
# TODO(karita): use CTCPrefixScore
PD
self
.
impl
=
CTCPrefixScore
(
logp
,
0
,
self
.
eos
,
np
)
self
.
impl
=
CTCPrefixScore
(
logp
,
0
,
self
.
eos
,
np
)
return
0
,
self
.
impl
.
initial_state
()
return
0
,
self
.
impl
.
initial_state
()
...
@@ -54,7 +54,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
...
@@ -54,7 +54,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
if
len
(
state
)
==
2
:
# for CTCPrefixScore
if
len
(
state
)
==
2
:
# for CTCPrefixScore
sc
,
st
=
state
sc
,
st
=
state
return
sc
[
i
],
st
[
i
]
return
sc
[
i
],
st
[
i
]
else
:
# for CTCPrefixScore
TH
(need new_id > 0)
else
:
# for CTCPrefixScore
PD
(need new_id > 0)
r
,
log_psi
,
f_min
,
f_max
,
scoring_idmap
=
state
r
,
log_psi
,
f_min
,
f_max
,
scoring_idmap
=
state
s
=
log_psi
[
i
,
new_id
].
expand
(
log_psi
.
size
(
1
))
s
=
log_psi
[
i
,
new_id
].
expand
(
log_psi
.
size
(
1
))
if
scoring_idmap
is
not
None
:
if
scoring_idmap
is
not
None
:
...
@@ -96,7 +96,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
...
@@ -96,7 +96,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
"""
"""
logp
=
self
.
ctc
.
log_softmax
(
x
.
unsqueeze
(
0
))
# assuming batch_size = 1
logp
=
self
.
ctc
.
log_softmax
(
x
.
unsqueeze
(
0
))
# assuming batch_size = 1
xlen
=
paddle
.
to_tensor
([
logp
.
size
(
1
)])
xlen
=
paddle
.
to_tensor
([
logp
.
size
(
1
)])
self
.
impl
=
CTCPrefixScore
TH
(
logp
,
xlen
,
0
,
self
.
eos
)
self
.
impl
=
CTCPrefixScore
PD
(
logp
,
xlen
,
0
,
self
.
eos
)
return
None
return
None
def
batch_score_partial
(
self
,
y
,
ids
,
state
,
x
):
def
batch_score_partial
(
self
,
y
,
ids
,
state
,
x
):
...
...
deepspeech/decoders/scores/ctc_prefix_score.py
浏览文件 @
2430545d
...
@@ -3,13 +3,13 @@
...
@@ -3,13 +3,13 @@
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
torch
import
paddle
import
numpy
as
np
import
numpy
as
np
import
six
import
six
class
CTCPrefixScore
TH
():
class
CTCPrefixScore
PD
():
"""Batch processing of CTCPrefixScore
"""Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al.
which is based on Algorithm 2 in WATANABE et al.
...
@@ -23,8 +23,10 @@ class CTCPrefixScoreTH():
...
@@ -23,8 +23,10 @@ class CTCPrefixScoreTH():
def
__init__
(
self
,
x
,
xlens
,
blank
,
eos
,
margin
=
0
):
def
__init__
(
self
,
x
,
xlens
,
blank
,
eos
,
margin
=
0
):
"""Construct CTC prefix scorer
"""Construct CTC prefix scorer
:param torch.Tensor x: input label posterior sequences (B, T, O)
`margin` is M in eq.(22,23)
:param torch.Tensor xlens: input lengths (B,)
:param paddle.Tensor x: input label posterior sequences (B, T, O)
:param paddle.Tensor xlens: input lengths (B,)
:param int blank: blank label id
:param int blank: blank label id
:param int eos: end-of-sequence id
:param int eos: end-of-sequence id
:param int margin: margin parameter for windowing (0 means no windowing)
:param int margin: margin parameter for windowing (0 means no windowing)
...
@@ -38,11 +40,8 @@ class CTCPrefixScoreTH():
...
@@ -38,11 +40,8 @@ class CTCPrefixScoreTH():
self
.
input_length
=
x
.
size
(
1
)
self
.
input_length
=
x
.
size
(
1
)
self
.
odim
=
x
.
size
(
2
)
self
.
odim
=
x
.
size
(
2
)
self
.
dtype
=
x
.
dtype
self
.
dtype
=
x
.
dtype
self
.
device
=
(
self
.
device
=
x
.
place
torch
.
device
(
"cuda:%d"
%
x
.
get_device
())
if
x
.
is_cuda
else
torch
.
device
(
"cpu"
)
)
# Pad the rest of posteriors in the batch
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
# TODO(takaaki-hori): need a better way without for-loops
for
i
,
l
in
enumerate
(
xlens
):
for
i
,
l
in
enumerate
(
xlens
):
...
@@ -50,20 +49,21 @@ class CTCPrefixScoreTH():
...
@@ -50,20 +49,21 @@ class CTCPrefixScoreTH():
x
[
i
,
l
:,
:]
=
self
.
logzero
x
[
i
,
l
:,
:]
=
self
.
logzero
x
[
i
,
l
:,
blank
]
=
0
x
[
i
,
l
:,
blank
]
=
0
# Reshape input x
# Reshape input x
xn
=
x
.
transpose
(
0
,
1
)
# (B, T, O) -> (T, B, O)
xn
=
x
.
transpose
(
[
1
,
0
,
2
]
)
# (B, T, O) -> (T, B, O)
xb
=
xn
[:,
:,
self
.
blank
].
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
self
.
odim
)
xb
=
xn
[:,
:,
self
.
blank
].
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
self
.
odim
)
# (T,B,O)
self
.
x
=
torch
.
stack
([
xn
,
xb
])
# (2, T, B, O)
self
.
x
=
paddle
.
stack
([
xn
,
xb
])
# (2, T, B, O)
self
.
end_frames
=
torch
.
as_tensor
(
xlens
)
-
1
self
.
end_frames
=
paddle
.
to_tensor
(
xlens
)
-
1
# (B,)
# Setup CTC windowing
# Setup CTC windowing
self
.
margin
=
margin
self
.
margin
=
margin
if
margin
>
0
:
if
margin
>
0
:
self
.
frame_ids
=
torch
.
arange
(
self
.
frame_ids
=
paddle
.
arange
(
self
.
input_length
,
dtype
=
self
.
dtype
)
self
.
input_length
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
# Base indices for index conversion
# Base indices for index conversion
self
.
idx_bh
=
None
# B idx, hyp idx. shape (B*W, 1)
self
.
idx_b
=
torch
.
arange
(
self
.
batch
,
device
=
self
.
device
)
self
.
idx_bh
=
None
# B idx. shape (B,)
self
.
idx_b
=
paddle
.
arange
(
self
.
batch
,
place
=
self
.
device
)
# B idx, O idx. shape (B, 1)
self
.
idx_bo
=
(
self
.
idx_b
*
self
.
odim
).
unsqueeze
(
1
)
self
.
idx_bo
=
(
self
.
idx_b
*
self
.
odim
).
unsqueeze
(
1
)
def
__call__
(
self
,
y
,
state
,
scoring_ids
=
None
,
att_w
=
None
):
def
__call__
(
self
,
y
,
state
,
scoring_ids
=
None
,
att_w
=
None
):
...
@@ -71,8 +71,8 @@ class CTCPrefixScoreTH():
...
@@ -71,8 +71,8 @@ class CTCPrefixScoreTH():
:param list y: prefix label sequences
:param list y: prefix label sequences
:param tuple state: previous CTC state
:param tuple state: previous CTC state
:param
torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
:param
paddle.Tensor scoring_ids: selected next ids to score (BW, O'), O' <= O
:param
torch
.Tensor att_w: attention weights to decide CTC window
:param
paddle
.Tensor att_w: attention weights to decide CTC window
:return new_state, ctc_local_scores (BW, O)
:return new_state, ctc_local_scores (BW, O)
"""
"""
output_length
=
len
(
y
[
0
])
-
1
# ignore sos
output_length
=
len
(
y
[
0
])
-
1
# ignore sos
...
@@ -82,56 +82,53 @@ class CTCPrefixScoreTH():
...
@@ -82,56 +82,53 @@ class CTCPrefixScoreTH():
self
.
scoring_num
=
scoring_ids
.
size
(
-
1
)
if
scoring_ids
is
not
None
else
0
self
.
scoring_num
=
scoring_ids
.
size
(
-
1
)
if
scoring_ids
is
not
None
else
0
# prepare state info
# prepare state info
if
state
is
None
:
if
state
is
None
:
r_prev
=
torch
.
full
(
r_prev
=
paddle
.
full
(
(
self
.
input_length
,
2
,
self
.
batch
,
n_hyps
),
(
self
.
input_length
,
2
,
self
.
batch
,
n_hyps
),
self
.
logzero
,
self
.
logzero
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
# (T, 2, B, W)
)
r_prev
[:,
1
]
=
paddle
.
cumsum
(
self
.
x
[
0
,
:,
:,
self
.
blank
],
0
).
unsqueeze
(
2
)
r_prev
[:,
1
]
=
torch
.
cumsum
(
self
.
x
[
0
,
:,
:,
self
.
blank
],
0
).
unsqueeze
(
2
)
r_prev
=
r_prev
.
view
(
-
1
,
2
,
n_bh
)
# (T, 2, BW)
r_prev
=
r_prev
.
view
(
-
1
,
2
,
n_bh
)
s_prev
=
0.0
# score
s_prev
=
0.0
f_min_prev
=
0
# eq. 22-23
f_min_prev
=
0
f_max_prev
=
1
# eq. 22-23
f_max_prev
=
1
else
:
else
:
r_prev
,
s_prev
,
f_min_prev
,
f_max_prev
=
state
r_prev
,
s_prev
,
f_min_prev
,
f_max_prev
=
state
# select input dimensions for scoring
# select input dimensions for scoring
if
self
.
scoring_num
>
0
:
if
self
.
scoring_num
>
0
:
scoring_idmap
=
torch
.
full
(
# (BW, O)
(
n_bh
,
self
.
odim
),
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
device
scoring_idmap
=
paddle
.
full
((
n_bh
,
self
.
odim
),
-
1
,
dtype
=
paddle
.
long
)
)
snum
=
self
.
scoring_num
snum
=
self
.
scoring_num
if
self
.
idx_bh
is
None
or
n_bh
>
len
(
self
.
idx_bh
):
if
self
.
idx_bh
is
None
or
n_bh
>
len
(
self
.
idx_bh
):
self
.
idx_bh
=
torch
.
arange
(
n_bh
,
device
=
self
.
device
).
view
(
-
1
,
1
)
self
.
idx_bh
=
paddle
.
arange
(
n_bh
).
view
(
-
1
,
1
)
# (BW, 1)
scoring_idmap
[
self
.
idx_bh
[:
n_bh
],
scoring_ids
]
=
torch
.
arange
(
scoring_idmap
[
self
.
idx_bh
[:
n_bh
],
scoring_ids
]
=
paddle
.
arange
(
snum
)
snum
,
device
=
self
.
device
)
scoring_idx
=
(
scoring_idx
=
(
scoring_ids
+
self
.
idx_bo
.
repeat
(
1
,
n_hyps
).
view
(
-
1
,
1
)
scoring_ids
+
self
.
idx_bo
.
repeat
(
1
,
n_hyps
).
view
(
-
1
,
1
)
# (BW,1)
).
view
(
-
1
)
).
view
(
-
1
)
# (BWO)
x_
=
torch
.
index_select
(
# x_ shape (2, T, B*W, O)
self
.
x
.
view
(
2
,
-
1
,
self
.
batch
*
self
.
odim
),
2
,
scoring_idx
x_
=
paddle
.
index_select
(
self
.
x
.
view
(
2
,
-
1
,
self
.
batch
*
self
.
odim
),
scoring_idx
,
2
).
view
(
2
,
-
1
,
n_bh
,
snum
)
).
view
(
2
,
-
1
,
n_bh
,
snum
)
else
:
else
:
scoring_ids
=
None
scoring_ids
=
None
scoring_idmap
=
None
scoring_idmap
=
None
snum
=
self
.
odim
snum
=
self
.
odim
# x_ shape (2, T, B*W, O)
x_
=
self
.
x
.
unsqueeze
(
3
).
repeat
(
1
,
1
,
1
,
n_hyps
,
1
).
view
(
2
,
-
1
,
n_bh
,
snum
)
x_
=
self
.
x
.
unsqueeze
(
3
).
repeat
(
1
,
1
,
1
,
n_hyps
,
1
).
view
(
2
,
-
1
,
n_bh
,
snum
)
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
r
=
torch
.
full
(
r
=
paddle
.
full
(
(
self
.
input_length
,
2
,
n_bh
,
snum
),
(
self
.
input_length
,
2
,
n_bh
,
snum
),
self
.
logzero
,
self
.
logzero
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
)
if
output_length
==
0
:
if
output_length
==
0
:
r
[
0
,
0
]
=
x_
[
0
,
0
]
r
[
0
,
0
]
=
x_
[
0
,
0
]
r_sum
=
torch
.
logsumexp
(
r_prev
,
1
)
r_sum
=
paddle
.
logsumexp
(
r_prev
,
1
)
#(T,BW
)
log_phi
=
r_sum
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
snum
)
log_phi
=
r_sum
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
snum
)
# (T, BW, O)
if
scoring_ids
is
not
None
:
if
scoring_ids
is
not
None
:
for
idx
in
range
(
n_bh
):
for
idx
in
range
(
n_bh
):
pos
=
scoring_idmap
[
idx
,
last_ids
[
idx
]]
pos
=
scoring_idmap
[
idx
,
last_ids
[
idx
]]
...
@@ -143,40 +140,39 @@ class CTCPrefixScoreTH():
...
@@ -143,40 +140,39 @@ class CTCPrefixScoreTH():
# decide start and end frames based on attention weights
# decide start and end frames based on attention weights
if
att_w
is
not
None
and
self
.
margin
>
0
:
if
att_w
is
not
None
and
self
.
margin
>
0
:
f_arg
=
torch
.
matmul
(
att_w
,
self
.
frame_ids
)
f_arg
=
paddle
.
matmul
(
att_w
,
self
.
frame_ids
)
f_min
=
max
(
int
(
f_arg
.
min
().
cpu
()),
f_min_prev
)
f_min
=
max
(
int
(
f_arg
.
min
().
cpu
()),
f_min_prev
)
f_max
=
max
(
int
(
f_arg
.
max
().
cpu
()),
f_max_prev
)
f_max
=
max
(
int
(
f_arg
.
max
().
cpu
()),
f_max_prev
)
start
=
min
(
f_max_prev
,
max
(
f_min
-
self
.
margin
,
output_length
,
1
))
start
=
min
(
f_max_prev
,
max
(
f_min
-
self
.
margin
,
output_length
,
1
))
end
=
min
(
f_max
+
self
.
margin
,
self
.
input_length
)
end
=
min
(
f_max
+
self
.
margin
,
self
.
input_length
)
else
:
else
:
f_min
=
f_max
=
0
f_min
=
f_max
=
0
# if one frame one out, the output_length is the eating frame num now.
start
=
max
(
output_length
,
1
)
start
=
max
(
output_length
,
1
)
end
=
self
.
input_length
end
=
self
.
input_length
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for
t
in
range
(
start
,
end
):
for
t
in
range
(
start
,
end
):
rp
=
r
[
t
-
1
]
rp
=
r
[
t
-
1
]
# (2 x BW x O')
rr
=
torch
.
stack
([
rp
[
0
],
log_phi
[
t
-
1
],
rp
[
0
],
rp
[
1
]]).
view
(
rr
=
paddle
.
stack
([
rp
[
0
],
log_phi
[
t
-
1
],
rp
[
0
],
rp
[
1
]]).
view
(
2
,
2
,
n_bh
,
snum
2
,
2
,
n_bh
,
snum
)
)
# (2,2,BW,O')
r
[
t
]
=
torch
.
logsumexp
(
rr
,
1
)
+
x_
[:,
t
]
r
[
t
]
=
paddle
.
logsumexp
(
rr
,
1
)
+
x_
[:,
t
]
# compute log prefix probabilities log(psi)
# compute log prefix probabilities log(psi)
log_phi_x
=
torch
.
cat
((
log_phi
[
0
].
unsqueeze
(
0
),
log_phi
[:
-
1
]),
dim
=
0
)
+
x_
[
0
]
log_phi_x
=
paddle
.
concat
((
log_phi
[
0
].
unsqueeze
(
0
),
log_phi
[:
-
1
]),
axis
=
0
)
+
x_
[
0
]
if
scoring_ids
is
not
None
:
if
scoring_ids
is
not
None
:
log_psi
=
torch
.
full
(
log_psi
=
paddle
.
full
((
n_bh
,
self
.
odim
),
self
.
logzero
,
dtype
=
self
.
dtype
)
(
n_bh
,
self
.
odim
),
self
.
logzero
,
dtype
=
self
.
dtype
,
device
=
self
.
device
log_psi_
=
paddle
.
logsumexp
(
)
paddle
.
concat
((
log_phi_x
[
start
:
end
],
r
[
start
-
1
,
0
].
unsqueeze
(
0
)),
axis
=
0
),
log_psi_
=
torch
.
logsumexp
(
axis
=
0
,
torch
.
cat
((
log_phi_x
[
start
:
end
],
r
[
start
-
1
,
0
].
unsqueeze
(
0
)),
dim
=
0
),
dim
=
0
,
)
)
for
si
in
range
(
n_bh
):
for
si
in
range
(
n_bh
):
log_psi
[
si
,
scoring_ids
[
si
]]
=
log_psi_
[
si
]
log_psi
[
si
,
scoring_ids
[
si
]]
=
log_psi_
[
si
]
else
:
else
:
log_psi
=
torch
.
logsumexp
(
log_psi
=
paddle
.
logsumexp
(
torch
.
cat
((
log_phi_x
[
start
:
end
],
r
[
start
-
1
,
0
].
unsqueeze
(
0
)),
dim
=
0
),
paddle
.
concat
((
log_phi_x
[
start
:
end
],
r
[
start
-
1
,
0
].
unsqueeze
(
0
)),
axis
=
0
),
dim
=
0
,
axis
=
0
,
)
)
for
si
in
range
(
n_bh
):
for
si
in
range
(
n_bh
):
...
@@ -200,7 +196,7 @@ class CTCPrefixScoreTH():
...
@@ -200,7 +196,7 @@ class CTCPrefixScoreTH():
n_hyps
=
n_bh
//
self
.
batch
n_hyps
=
n_bh
//
self
.
batch
vidx
=
(
best_ids
+
(
self
.
idx_b
*
(
n_hyps
*
self
.
odim
)).
view
(
-
1
,
1
)).
view
(
-
1
)
vidx
=
(
best_ids
+
(
self
.
idx_b
*
(
n_hyps
*
self
.
odim
)).
view
(
-
1
,
1
)).
view
(
-
1
)
# select hypothesis scores
# select hypothesis scores
s_new
=
torch
.
index_select
(
s
.
view
(
-
1
),
0
,
vidx
)
s_new
=
paddle
.
index_select
(
s
.
view
(
-
1
),
vidx
,
0
)
s_new
=
s_new
.
view
(
-
1
,
1
).
repeat
(
1
,
self
.
odim
).
view
(
n_bh
,
self
.
odim
)
s_new
=
s_new
.
view
(
-
1
,
1
).
repeat
(
1
,
self
.
odim
).
view
(
n_bh
,
self
.
odim
)
# convert ids to BHS space (S: scoring_num)
# convert ids to BHS space (S: scoring_num)
if
scoring_idmap
is
not
None
:
if
scoring_idmap
is
not
None
:
...
@@ -208,14 +204,14 @@ class CTCPrefixScoreTH():
...
@@ -208,14 +204,14 @@ class CTCPrefixScoreTH():
hyp_idx
=
(
best_ids
//
self
.
odim
+
(
self
.
idx_b
*
n_hyps
).
view
(
-
1
,
1
)).
view
(
hyp_idx
=
(
best_ids
//
self
.
odim
+
(
self
.
idx_b
*
n_hyps
).
view
(
-
1
,
1
)).
view
(
-
1
-
1
)
)
label_ids
=
torch
.
fmod
(
best_ids
,
self
.
odim
).
view
(
-
1
)
label_ids
=
paddle
.
fmod
(
best_ids
,
self
.
odim
).
view
(
-
1
)
score_idx
=
scoring_idmap
[
hyp_idx
,
label_ids
]
score_idx
=
scoring_idmap
[
hyp_idx
,
label_ids
]
score_idx
[
score_idx
==
-
1
]
=
0
score_idx
[
score_idx
==
-
1
]
=
0
vidx
=
score_idx
+
hyp_idx
*
snum
vidx
=
score_idx
+
hyp_idx
*
snum
else
:
else
:
snum
=
self
.
odim
snum
=
self
.
odim
# select forward probabilities
# select forward probabilities
r_new
=
torch
.
index_select
(
r
.
view
(
-
1
,
2
,
n_bh
*
snum
),
2
,
vidx
).
view
(
r_new
=
paddle
.
index_select
(
r
.
view
(
-
1
,
2
,
n_bh
*
snum
),
vidx
,
2
).
view
(
-
1
,
2
,
n_bh
-
1
,
2
,
n_bh
)
)
return
r_new
,
s_new
,
f_min
,
f_max
return
r_new
,
s_new
,
f_min
,
f_max
...
@@ -223,7 +219,7 @@ class CTCPrefixScoreTH():
...
@@ -223,7 +219,7 @@ class CTCPrefixScoreTH():
def
extend_prob
(
self
,
x
):
def
extend_prob
(
self
,
x
):
"""Extend CTC prob.
"""Extend CTC prob.
:param
torch
.Tensor x: input label posterior sequences (B, T, O)
:param
paddle
.Tensor x: input label posterior sequences (B, T, O)
"""
"""
if
self
.
x
.
shape
[
1
]
<
x
.
shape
[
1
]:
# self.x (2,T,B,O); x (B,T,O)
if
self
.
x
.
shape
[
1
]
<
x
.
shape
[
1
]:
# self.x (2,T,B,O); x (B,T,O)
...
@@ -235,12 +231,12 @@ class CTCPrefixScoreTH():
...
@@ -235,12 +231,12 @@ class CTCPrefixScoreTH():
x
[
i
,
l
:,
:]
=
self
.
logzero
x
[
i
,
l
:,
:]
=
self
.
logzero
x
[
i
,
l
:,
self
.
blank
]
=
0
x
[
i
,
l
:,
self
.
blank
]
=
0
tmp_x
=
self
.
x
tmp_x
=
self
.
x
xn
=
x
.
transpose
(
0
,
1
)
# (B, T, O) -> (T, B, O)
xn
=
x
.
transpose
(
[
1
,
0
,
2
]
)
# (B, T, O) -> (T, B, O)
xb
=
xn
[:,
:,
self
.
blank
].
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
self
.
odim
)
xb
=
xn
[:,
:,
self
.
blank
].
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
self
.
odim
)
self
.
x
=
torch
.
stack
([
xn
,
xb
])
# (2, T, B, O)
self
.
x
=
paddle
.
stack
([
xn
,
xb
])
# (2, T, B, O)
self
.
x
[:,
:
tmp_x
.
shape
[
1
],
:,
:]
=
tmp_x
self
.
x
[:,
:
tmp_x
.
shape
[
1
],
:,
:]
=
tmp_x
self
.
input_length
=
x
.
size
(
1
)
self
.
input_length
=
x
.
size
(
1
)
self
.
end_frames
=
torch
.
as
_tensor
(
xlens
)
-
1
self
.
end_frames
=
paddle
.
to
_tensor
(
xlens
)
-
1
def
extend_state
(
self
,
state
):
def
extend_state
(
self
,
state
):
"""Compute CTC prefix state.
"""Compute CTC prefix state.
...
@@ -256,15 +252,14 @@ class CTCPrefixScoreTH():
...
@@ -256,15 +252,14 @@ class CTCPrefixScoreTH():
else
:
else
:
r_prev
,
s_prev
,
f_min_prev
,
f_max_prev
=
state
r_prev
,
s_prev
,
f_min_prev
,
f_max_prev
=
state
r_prev_new
=
torch
.
full
(
r_prev_new
=
paddle
.
full
(
(
self
.
input_length
,
2
),
(
self
.
input_length
,
2
),
self
.
logzero
,
self
.
logzero
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
)
start
=
max
(
r_prev
.
shape
[
0
],
1
)
start
=
max
(
r_prev
.
shape
[
0
],
1
)
r_prev_new
[
0
:
start
]
=
r_prev
r_prev_new
[
0
:
start
]
=
r_prev
for
t
in
six
.
moves
.
range
(
start
,
self
.
input_length
):
for
t
in
range
(
start
,
self
.
input_length
):
r_prev_new
[
t
,
1
]
=
r_prev_new
[
t
-
1
,
1
]
+
self
.
x
[
0
,
t
,
:,
self
.
blank
]
r_prev_new
[
t
,
1
]
=
r_prev_new
[
t
-
1
,
1
]
+
self
.
x
[
0
,
t
,
:,
self
.
blank
]
return
(
r_prev_new
,
s_prev
,
f_min_prev
,
f_max_prev
)
return
(
r_prev_new
,
s_prev
,
f_min_prev
,
f_max_prev
)
...
@@ -285,7 +280,7 @@ class CTCPrefixScore():
...
@@ -285,7 +280,7 @@ class CTCPrefixScore():
self
.
blank
=
blank
self
.
blank
=
blank
self
.
eos
=
eos
self
.
eos
=
eos
self
.
input_length
=
len
(
x
)
self
.
input_length
=
len
(
x
)
self
.
x
=
x
self
.
x
=
x
# (T, O)
def
initial_state
(
self
):
def
initial_state
(
self
):
"""Obtain an initial CTC state
"""Obtain an initial CTC state
...
@@ -295,6 +290,7 @@ class CTCPrefixScore():
...
@@ -295,6 +290,7 @@ class CTCPrefixScore():
# initial CTC state is made of a frame x 2 tensor that corresponds to
# initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively.
# superscripts n and b (non-blank and blank), respectively.
# r shape (T, 2)
r
=
self
.
xp
.
full
((
self
.
input_length
,
2
),
self
.
logzero
,
dtype
=
np
.
float32
)
r
=
self
.
xp
.
full
((
self
.
input_length
,
2
),
self
.
logzero
,
dtype
=
np
.
float32
)
r
[
0
,
1
]
=
self
.
x
[
0
,
self
.
blank
]
r
[
0
,
1
]
=
self
.
x
[
0
,
self
.
blank
]
for
i
in
six
.
moves
.
range
(
1
,
self
.
input_length
):
for
i
in
six
.
moves
.
range
(
1
,
self
.
input_length
):
...
@@ -313,6 +309,7 @@ class CTCPrefixScore():
...
@@ -313,6 +309,7 @@ class CTCPrefixScore():
output_length
=
len
(
y
)
-
1
# ignore sos
output_length
=
len
(
y
)
-
1
# ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h).
# that corresponds to r_t^n(h) and r_t^b(h).
# r shape (T, 2, n_labels)
r
=
self
.
xp
.
ndarray
((
self
.
input_length
,
2
,
len
(
cs
)),
dtype
=
np
.
float32
)
r
=
self
.
xp
.
ndarray
((
self
.
input_length
,
2
,
len
(
cs
)),
dtype
=
np
.
float32
)
xs
=
self
.
x
[:,
cs
]
xs
=
self
.
x
[:,
cs
]
if
output_length
==
0
:
if
output_length
==
0
:
...
@@ -356,4 +353,5 @@ class CTCPrefixScore():
...
@@ -356,4 +353,5 @@ class CTCPrefixScore():
# return the log prefix probability and CTC states, where the label axis
# return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily
# of the CTC states is moved to the first axis to slice it easily
# log_psi shape (n_labels,), state shape (n_labels, T, 2)
return
log_psi
,
self
.
xp
.
rollaxis
(
r
,
2
)
return
log_psi
,
self
.
xp
.
rollaxis
(
r
,
2
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录