Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
2b8c08e3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2b8c08e3
编写于
4月 21, 2022
作者:
H
Hui Zhang
提交者:
GitHub
4月 21, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1673 from Jackwaterveg/CER
[asr] Add new cer tools
上级
f39de8d7
8d1ee826
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
667 addition
and
566 deletion
+667
-566
examples/aishell/asr0/local/test.sh
examples/aishell/asr0/local/test.sh
+39
-10
examples/aishell/asr1/local/test.sh
examples/aishell/asr1/local/test.sh
+81
-42
examples/aishell/asr1/run.sh
examples/aishell/asr1/run.sh
+1
-1
paddlespeech/s2t/exps/deepspeech2/model.py
paddlespeech/s2t/exps/deepspeech2/model.py
+1
-1
utils/compute-wer.py
utils/compute-wer.py
+455
-512
utils/format_rsl.py
utils/format_rsl.py
+90
-0
未找到文件。
examples/aishell/asr0/local/test.sh
浏览文件 @
2b8c08e3
...
@@ -5,6 +5,8 @@ if [ $# != 4 ];then
...
@@ -5,6 +5,8 @@ if [ $# != 4 ];then
exit
-1
exit
-1
fi
fi
stage
=
0
stop_stage
=
100
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
echo
"using
$ngpu
gpus..."
...
@@ -19,18 +21,45 @@ if [ $? -ne 0 ]; then
...
@@ -19,18 +21,45 @@ if [ $? -ne 0 ]; then
exit
1
exit
1
fi
fi
python3
-u
${
BIN_DIR
}
/test.py
\
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
--ngpu
${
ngpu
}
\
# format the reference test file
--config
${
config_path
}
\
python utils/format_rsl.py
\
--decode_cfg
${
decode_config_path
}
\
--origin_ref
data/manifest.test.raw
\
--result_file
${
ckpt_prefix
}
.rsl
\
--trans_ref
data/manifest.test.text
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
python3
-u
${
BIN_DIR
}
/test.py
\
--ngpu
${
ngpu
}
\
--config
${
config_path
}
\
--decode_cfg
${
decode_config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
echo
"Failed in evaluation!"
exit
1
exit
1
fi
# format the hyp file
python utils/format_rsl.py
\
--origin_hyp
${
ckpt_prefix
}
.rsl
\
--trans_hyp
${
ckpt_prefix
}
.rsl.text
python utils/compute-wer.py
--char
=
1
--v
=
1
\
data/manifest.test.text
${
ckpt_prefix
}
.rsl.text
>
${
ckpt_prefix
}
.error
fi
fi
if
[
${
stage
}
-le
101
]
&&
[
${
stop_stage
}
-ge
101
]
;
then
python utils/format_rsl.py
\
--origin_ref
data/manifest.test.raw
\
--trans_ref_sclite
data/manifest.test.text.sclite
python utils/format_rsl.py
\
--origin_hyp
${
ckpt_prefix
}
.rsl
\
--trans_hyp_sclite
${
ckpt_prefix
}
.rsl.text.sclite
mkdir
-p
${
ckpt_prefix
}
_sclite
sclite
-i
wsj
-r
data/manifest.test.text.sclite
-h
${
ckpt_prefix
}
.rsl.text.sclite
-e
utf-8
-o
all
-O
${
ckpt_prefix
}
_sclite
-c
NOASCII
fi
exit
0
exit
0
examples/aishell/asr1/local/test.sh
浏览文件 @
2b8c08e3
...
@@ -5,6 +5,8 @@ if [ $# != 3 ];then
...
@@ -5,6 +5,8 @@ if [ $# != 3 ];then
exit
-1
exit
-1
fi
fi
stage
=
0
stop_stage
=
100
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
echo
"using
$ngpu
gpus..."
...
@@ -24,7 +26,13 @@ fi
...
@@ -24,7 +26,13 @@ fi
#fi
#fi
for
type
in
attention ctc_greedy_search
;
do
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# format the reference test file
python utils/format_rsl.py
\
--origin_ref
data/manifest.test.raw
\
--trans_ref
data/manifest.test.text
for
type
in
attention ctc_greedy_search
;
do
echo
"decoding
${
type
}
"
echo
"decoding
${
type
}
"
if
[
${
chunk_mode
}
==
true
]
;
then
if
[
${
chunk_mode
}
==
true
]
;
then
# stream decoding only support batchsize=1
# stream decoding only support batchsize=1
...
@@ -46,10 +54,18 @@ for type in attention ctc_greedy_search; do
...
@@ -46,10 +54,18 @@ for type in attention ctc_greedy_search; do
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
echo
"Failed in evaluation!"
exit
1
exit
1
fi
fi
done
# format the hyp file
python utils/format_rsl.py
\
--origin_hyp
${
output_dir
}
/
${
type
}
.rsl
\
--trans_hyp
${
output_dir
}
/
${
type
}
.rsl.text
python utils/compute-wer.py
--char
=
1
--v
=
1
\
data/manifest.test.text
${
output_dir
}
/
${
type
}
.rsl.text
>
${
output_dir
}
/
${
type
}
.error
done
for
type
in
ctc_prefix_beam_search attention_rescoring
;
do
for
type
in
ctc_prefix_beam_search attention_rescoring
;
do
echo
"decoding
${
type
}
"
echo
"decoding
${
type
}
"
batch_size
=
1
batch_size
=
1
output_dir
=
${
ckpt_prefix
}
output_dir
=
${
ckpt_prefix
}
...
@@ -67,6 +83,29 @@ for type in ctc_prefix_beam_search attention_rescoring; do
...
@@ -67,6 +83,29 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo
"Failed in evaluation!"
echo
"Failed in evaluation!"
exit
1
exit
1
fi
fi
done
python utils/format_rsl.py
\
--origin_hyp
${
output_dir
}
/
${
type
}
.rsl
--trans_hyp
${
output_dir
}
/
${
type
}
.rsl.text
python utils/compute-wer.py
--char
=
1
--v
=
1
\
data/manifest.test.text
${
output_dir
}
/
${
type
}
.rsl.text
>
${
output_dir
}
/
${
type
}
.error
done
fi
if
[
${
stage
}
-le
101
]
&&
[
${
stop_stage
}
-ge
101
]
;
then
# format the reference test file for sclite
python utils/format_rsl.py
\
--origin_ref
data/manifest.test.raw
\
--trans_ref_sclite
data/manifest.test.text.sclite
output_dir
=
${
ckpt_prefix
}
for
type
in
attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring
;
do
python utils/format_rsl.py
\
--origin_hyp
${
output_dir
}
/
${
type
}
.rsl
--trans_hyp_sclite
${
output_dir
}
/
${
type
}
.rsl.text.sclite
mkdir
-p
${
output_dir
}
/
${
type
}
_sclite
sclite
-i
wsj
-r
data/manifest.test.text.sclite
-h
${
output_dir
}
/
${
type
}
.rsl.text.sclite
-e
utf-8
-o
all
-O
${
output_dir
}
/
${
type
}
_sclite
-c
NOASCII
done
fi
exit
0
exit
0
examples/aishell/asr1/run.sh
浏览文件 @
2b8c08e3
...
@@ -7,7 +7,7 @@ stage=0
...
@@ -7,7 +7,7 @@ stage=0
stop_stage
=
50
stop_stage
=
50
conf_path
=
conf/conformer.yaml
conf_path
=
conf/conformer.yaml
decode_conf_path
=
conf/tuning/decode.yaml
decode_conf_path
=
conf/tuning/decode.yaml
avg_num
=
2
0
avg_num
=
3
0
audio_file
=
data/demo_01_03.wav
audio_file
=
data/demo_01_03.wav
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
...
paddlespeech/s2t/exps/deepspeech2/model.py
浏览文件 @
2b8c08e3
...
@@ -278,7 +278,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -278,7 +278,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs
+=
len_ref
len_refs
+=
len_ref
num_ins
+=
1
num_ins
+=
1
if
fout
:
if
fout
:
fout
.
write
({
"utt"
:
utt
,
"ref
"
:
target
,
"hyp"
:
result
})
fout
.
write
({
"utt"
:
utt
,
"ref
s"
:
[
target
],
"hyps"
:
[
result
]
})
logger
.
info
(
f
"Utt:
{
utt
}
"
)
logger
.
info
(
f
"Utt:
{
utt
}
"
)
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
...
...
utils/compute-wer.py
浏览文件 @
2b8c08e3
#!/usr/bin/env python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License
# CopyRight WeNet Apache-2.0 License
import
re
,
sys
,
unicodedata
import
codecs
import
codecs
import
sys
import
unicodedata
remove_tag
=
True
remove_tag
=
True
spacelist
=
[
' '
,
'
\t
'
,
'
\r
'
,
'
\n
'
]
spacelist
=
[
' '
,
'
\t
'
,
'
\r
'
,
'
\n
'
]
puncts
=
[
puncts
=
[
'!'
,
','
,
'?'
,
'!'
,
','
,
'?'
,
'、'
,
'。'
,
'!'
,
','
,
';'
,
'?'
,
':'
,
'「'
,
'」'
,
'︰'
,
'『'
,
'』'
,
'、'
,
'。'
,
'!'
,
','
,
';'
,
'?'
,
'《'
,
'》'
':'
,
'「'
,
'」'
,
'︰'
,
'『'
,
'』'
,
'《'
,
'》'
]
]
def
characterize
(
string
):
def
characterize
(
string
)
:
res
=
[]
res
=
[]
i
=
0
i
=
0
while
i
<
len
(
string
):
while
i
<
len
(
string
):
...
@@ -32,12 +30,11 @@ def characterize(string):
...
@@ -32,12 +30,11 @@ def characterize(string):
else
:
else
:
# some input looks like: <unk><noise>, we want to separate it to two words.
# some input looks like: <unk><noise>, we want to separate it to two words.
sep
=
' '
sep
=
' '
if
char
==
'<'
:
if
char
==
'<'
:
sep
=
'>'
sep
=
'>'
j
=
i
+
1
j
=
i
+
1
while
j
<
len
(
string
):
while
j
<
len
(
string
):
c
=
string
[
j
]
c
=
string
[
j
]
if
ord
(
c
)
>=
128
or
(
c
in
spacelist
)
or
(
c
==
sep
):
if
ord
(
c
)
>=
128
or
(
c
in
spacelist
)
or
(
c
==
sep
):
break
break
j
+=
1
j
+=
1
if
j
<
len
(
string
)
and
string
[
j
]
==
'>'
:
if
j
<
len
(
string
)
and
string
[
j
]
==
'>'
:
...
@@ -46,13 +43,10 @@ def characterize(string):
...
@@ -46,13 +43,10 @@ def characterize(string):
i
=
j
i
=
j
return
res
return
res
def
stripoff_tags
(
x
):
def
stripoff_tags
(
x
):
if
not
x
:
if
not
x
:
return
''
return
''
chars
=
[]
chars
=
[]
i
=
0
i
=
0
;
T
=
len
(
x
)
T
=
len
(
x
)
while
i
<
T
:
while
i
<
T
:
if
x
[
i
]
==
'<'
:
if
x
[
i
]
==
'<'
:
while
i
<
T
and
x
[
i
]
!=
'>'
:
while
i
<
T
and
x
[
i
]
!=
'>'
:
...
@@ -84,9 +78,8 @@ def normalize(sentence, ignore_words, cs, split=None):
...
@@ -84,9 +78,8 @@ def normalize(sentence, ignore_words, cs, split=None):
new_sentence
.
append
(
x
)
new_sentence
.
append
(
x
)
return
new_sentence
return
new_sentence
class
Calculator
:
class
Calculator
:
def
__init__
(
self
)
:
def
__init__
(
self
):
self
.
data
=
{}
self
.
data
=
{}
self
.
space
=
[]
self
.
space
=
[]
self
.
cost
=
{}
self
.
cost
=
{}
...
@@ -94,87 +87,66 @@ class Calculator:
...
@@ -94,87 +87,66 @@ class Calculator:
self
.
cost
[
'sub'
]
=
1
self
.
cost
[
'sub'
]
=
1
self
.
cost
[
'del'
]
=
1
self
.
cost
[
'del'
]
=
1
self
.
cost
[
'ins'
]
=
1
self
.
cost
[
'ins'
]
=
1
def
calculate
(
self
,
lab
,
rec
)
:
def
calculate
(
self
,
lab
,
rec
):
# Initialization
# Initialization
lab
.
insert
(
0
,
''
)
lab
.
insert
(
0
,
''
)
rec
.
insert
(
0
,
''
)
rec
.
insert
(
0
,
''
)
while
len
(
self
.
space
)
<
len
(
lab
)
:
while
len
(
self
.
space
)
<
len
(
lab
)
:
self
.
space
.
append
([])
self
.
space
.
append
([])
for
row
in
self
.
space
:
for
row
in
self
.
space
:
for
element
in
row
:
for
element
in
row
:
element
[
'dist'
]
=
0
element
[
'dist'
]
=
0
element
[
'error'
]
=
'non'
element
[
'error'
]
=
'non'
while
len
(
row
)
<
len
(
rec
)
:
while
len
(
row
)
<
len
(
rec
)
:
row
.
append
({
'dist'
:
0
,
'error'
:
'non'
})
row
.
append
({
'dist'
:
0
,
'error'
:
'non'
})
for
i
in
range
(
len
(
lab
))
:
for
i
in
range
(
len
(
lab
))
:
self
.
space
[
i
][
0
][
'dist'
]
=
i
self
.
space
[
i
][
0
][
'dist'
]
=
i
self
.
space
[
i
][
0
][
'error'
]
=
'del'
self
.
space
[
i
][
0
][
'error'
]
=
'del'
for
j
in
range
(
len
(
rec
))
:
for
j
in
range
(
len
(
rec
))
:
self
.
space
[
0
][
j
][
'dist'
]
=
j
self
.
space
[
0
][
j
][
'dist'
]
=
j
self
.
space
[
0
][
j
][
'error'
]
=
'ins'
self
.
space
[
0
][
j
][
'error'
]
=
'ins'
self
.
space
[
0
][
0
][
'error'
]
=
'non'
self
.
space
[
0
][
0
][
'error'
]
=
'non'
for
token
in
lab
:
for
token
in
lab
:
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
self
.
data
[
token
]
=
{
self
.
data
[
token
]
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
'all'
:
0
,
for
token
in
rec
:
'cor'
:
0
,
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
'sub'
:
0
,
self
.
data
[
token
]
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
'ins'
:
0
,
'del'
:
0
}
for
token
in
rec
:
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
self
.
data
[
token
]
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
# Computing edit distance
# Computing edit distance
for
i
,
lab_token
in
enumerate
(
lab
)
:
for
i
,
lab_token
in
enumerate
(
lab
)
:
for
j
,
rec_token
in
enumerate
(
rec
)
:
for
j
,
rec_token
in
enumerate
(
rec
)
:
if
i
==
0
or
j
==
0
:
if
i
==
0
or
j
==
0
:
continue
continue
min_dist
=
sys
.
maxsize
min_dist
=
sys
.
maxsize
min_error
=
'none'
min_error
=
'none'
dist
=
self
.
space
[
i
-
1
][
j
][
'dist'
]
+
self
.
cost
[
'del'
]
dist
=
self
.
space
[
i
-
1
][
j
][
'dist'
]
+
self
.
cost
[
'del'
]
error
=
'del'
error
=
'del'
if
dist
<
min_dist
:
if
dist
<
min_dist
:
min_dist
=
dist
min_dist
=
dist
min_error
=
error
min_error
=
error
dist
=
self
.
space
[
i
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'ins'
]
dist
=
self
.
space
[
i
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'ins'
]
error
=
'ins'
error
=
'ins'
if
dist
<
min_dist
:
if
dist
<
min_dist
:
min_dist
=
dist
min_dist
=
dist
min_error
=
error
min_error
=
error
if
lab_token
==
rec_token
:
if
lab_token
==
rec_token
:
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'cor'
]
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'cor'
]
error
=
'cor'
error
=
'cor'
else
:
else
:
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'sub'
]
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'sub'
]
error
=
'sub'
error
=
'sub'
if
dist
<
min_dist
:
if
dist
<
min_dist
:
min_dist
=
dist
min_dist
=
dist
min_error
=
error
min_error
=
error
self
.
space
[
i
][
j
][
'dist'
]
=
min_dist
self
.
space
[
i
][
j
][
'dist'
]
=
min_dist
self
.
space
[
i
][
j
][
'error'
]
=
min_error
self
.
space
[
i
][
j
][
'error'
]
=
min_error
# Tracing back
# Tracing back
result
=
{
result
=
{
'lab'
:[],
'rec'
:[],
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
'lab'
:
[],
'rec'
:
[],
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
i
=
len
(
lab
)
-
1
i
=
len
(
lab
)
-
1
j
=
len
(
rec
)
-
1
j
=
len
(
rec
)
-
1
while
True
:
while
True
:
if
self
.
space
[
i
][
j
][
'error'
]
==
'cor'
:
# correct
if
self
.
space
[
i
][
j
][
'error'
]
==
'cor'
:
# correct
if
len
(
lab
[
i
])
>
0
:
if
len
(
lab
[
i
])
>
0
:
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'cor'
]
=
self
.
data
[
lab
[
i
]][
'cor'
]
+
1
self
.
data
[
lab
[
i
]][
'cor'
]
=
self
.
data
[
lab
[
i
]][
'cor'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
...
@@ -183,8 +155,8 @@ class Calculator:
...
@@ -183,8 +155,8 @@ class Calculator:
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
i
=
i
-
1
i
=
i
-
1
j
=
j
-
1
j
=
j
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'sub'
:
# substitution
elif
self
.
space
[
i
][
j
][
'error'
]
==
'sub'
:
# substitution
if
len
(
lab
[
i
])
>
0
:
if
len
(
lab
[
i
])
>
0
:
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'sub'
]
=
self
.
data
[
lab
[
i
]][
'sub'
]
+
1
self
.
data
[
lab
[
i
]][
'sub'
]
=
self
.
data
[
lab
[
i
]][
'sub'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
...
@@ -193,8 +165,8 @@ class Calculator:
...
@@ -193,8 +165,8 @@ class Calculator:
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
i
=
i
-
1
i
=
i
-
1
j
=
j
-
1
j
=
j
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'del'
:
# deletion
elif
self
.
space
[
i
][
j
][
'error'
]
==
'del'
:
# deletion
if
len
(
lab
[
i
])
>
0
:
if
len
(
lab
[
i
])
>
0
:
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'del'
]
=
self
.
data
[
lab
[
i
]][
'del'
]
+
1
self
.
data
[
lab
[
i
]][
'del'
]
=
self
.
data
[
lab
[
i
]][
'del'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
...
@@ -202,64 +174,57 @@ class Calculator:
...
@@ -202,64 +174,57 @@ class Calculator:
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
result
[
'rec'
].
insert
(
0
,
""
)
result
[
'rec'
].
insert
(
0
,
""
)
i
=
i
-
1
i
=
i
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'ins'
:
# insertion
elif
self
.
space
[
i
][
j
][
'error'
]
==
'ins'
:
# insertion
if
len
(
rec
[
j
])
>
0
:
if
len
(
rec
[
j
])
>
0
:
self
.
data
[
rec
[
j
]][
'ins'
]
=
self
.
data
[
rec
[
j
]][
'ins'
]
+
1
self
.
data
[
rec
[
j
]][
'ins'
]
=
self
.
data
[
rec
[
j
]][
'ins'
]
+
1
result
[
'ins'
]
=
result
[
'ins'
]
+
1
result
[
'ins'
]
=
result
[
'ins'
]
+
1
result
[
'lab'
].
insert
(
0
,
""
)
result
[
'lab'
].
insert
(
0
,
""
)
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
j
=
j
-
1
j
=
j
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'non'
:
# starting point
elif
self
.
space
[
i
][
j
][
'error'
]
==
'non'
:
# starting point
break
break
else
:
# shouldn't reach here
else
:
# shouldn't reach here
print
(
print
(
'this should not happen , i = {i} , j = {j} , error = {error}'
.
format
(
i
=
i
,
j
=
j
,
error
=
self
.
space
[
i
][
j
][
'error'
]))
'this should not happen , i = {i} , j = {j} , error = {error}'
.
format
(
i
=
i
,
j
=
j
,
error
=
self
.
space
[
i
][
j
][
'error'
]))
return
result
return
result
def
overall
(
self
)
:
def
overall
(
self
):
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
for
token
in
self
.
data
:
for
token
in
self
.
data
:
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
return
result
return
result
def
cluster
(
self
,
data
)
:
def
cluster
(
self
,
data
):
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
for
token
in
data
:
for
token
in
data
:
if
token
in
self
.
data
:
if
token
in
self
.
data
:
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
return
result
return
result
def
keys
(
self
)
:
def
keys
(
self
):
return
list
(
self
.
data
.
keys
())
return
list
(
self
.
data
.
keys
())
def
width
(
string
):
def
width
(
string
):
return
sum
(
1
+
(
unicodedata
.
east_asian_width
(
c
)
in
"AFW"
)
for
c
in
string
)
return
sum
(
1
+
(
unicodedata
.
east_asian_width
(
c
)
in
"AFW"
)
for
c
in
string
)
def
default_cluster
(
word
)
:
def
default_cluster
(
word
):
unicode_names
=
[
unicodedata
.
name
(
char
)
for
char
in
word
]
unicode_names
=
[
unicodedata
.
name
(
char
)
for
char
in
word
]
for
i
in
reversed
(
range
(
len
(
unicode_names
)))
:
for
i
in
reversed
(
range
(
len
(
unicode_names
))):
if
unicode_names
[
i
].
startswith
(
'DIGIT'
)
:
# 1
if
unicode_names
[
i
].
startswith
(
'DIGIT'
):
# 1
unicode_names
[
i
]
=
'Number'
# 'DIGIT'
unicode_names
[
i
]
=
'Number'
# 'DIGIT'
elif
(
unicode_names
[
i
].
startswith
(
'CJK UNIFIED IDEOGRAPH'
)
or
elif
(
unicode_names
[
i
].
startswith
(
'CJK UNIFIED IDEOGRAPH'
)
or
unicode_names
[
i
].
startswith
(
'CJK COMPATIBILITY IDEOGRAPH'
))
:
unicode_names
[
i
].
startswith
(
'CJK COMPATIBILITY IDEOGRAPH'
))
:
# 明 / 郎
# 明 / 郎
unicode_names
[
i
]
=
'Mandarin'
# 'CJK IDEOGRAPH'
unicode_names
[
i
]
=
'Mandarin'
# 'CJK IDEOGRAPH'
elif
(
unicode_names
[
i
].
startswith
(
'LATIN CAPITAL LETTER'
)
or
elif
(
unicode_names
[
i
].
startswith
(
'LATIN CAPITAL LETTER'
)
or
unicode_names
[
i
].
startswith
(
'LATIN SMALL LETTER'
))
:
unicode_names
[
i
].
startswith
(
'LATIN SMALL LETTER'
))
:
# A / a
# A / a
unicode_names
[
i
]
=
'English'
# 'LATIN LETTER'
unicode_names
[
i
]
=
'English'
# 'LATIN LETTER'
elif
unicode_names
[
i
].
startswith
(
'HIRAGANA LETTER'
)
:
# は こ め
elif
unicode_names
[
i
].
startswith
(
'HIRAGANA LETTER'
)
:
# は こ め
unicode_names
[
i
]
=
'Japanese'
# 'GANA LETTER'
unicode_names
[
i
]
=
'Japanese'
# 'GANA LETTER'
elif
(
unicode_names
[
i
].
startswith
(
'AMPERSAND'
)
or
elif
(
unicode_names
[
i
].
startswith
(
'AMPERSAND'
)
or
unicode_names
[
i
].
startswith
(
'APOSTROPHE'
)
or
unicode_names
[
i
].
startswith
(
'APOSTROPHE'
)
or
...
@@ -271,40 +236,34 @@ def default_cluster(word):
...
@@ -271,40 +236,34 @@ def default_cluster(word):
unicode_names
[
i
].
startswith
(
'LOW LINE'
)
or
unicode_names
[
i
].
startswith
(
'LOW LINE'
)
or
unicode_names
[
i
].
startswith
(
'NUMBER SIGN'
)
or
unicode_names
[
i
].
startswith
(
'NUMBER SIGN'
)
or
unicode_names
[
i
].
startswith
(
'PLUS SIGN'
)
or
unicode_names
[
i
].
startswith
(
'PLUS SIGN'
)
or
unicode_names
[
i
].
startswith
(
'SEMICOLON'
))
:
unicode_names
[
i
].
startswith
(
'SEMICOLON'
))
:
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del
unicode_names
[
i
]
del
unicode_names
[
i
]
else
:
else
:
return
'Other'
return
'Other'
if
len
(
unicode_names
)
==
0
:
if
len
(
unicode_names
)
==
0
:
return
'Other'
return
'Other'
if
len
(
unicode_names
)
==
1
:
if
len
(
unicode_names
)
==
1
:
return
unicode_names
[
0
]
return
unicode_names
[
0
]
for
i
in
range
(
len
(
unicode_names
)
-
1
)
:
for
i
in
range
(
len
(
unicode_names
)
-
1
)
:
if
unicode_names
[
i
]
!=
unicode_names
[
i
+
1
]
:
if
unicode_names
[
i
]
!=
unicode_names
[
i
+
1
]
:
return
'Other'
return
'Other'
return
unicode_names
[
0
]
return
unicode_names
[
0
]
def
usage
()
:
def
usage
():
print
(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print
(
print
(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print
(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
==
1
:
if
len
(
sys
.
argv
)
==
1
:
usage
()
usage
()
sys
.
exit
(
0
)
sys
.
exit
(
0
)
calculator
=
Calculator
()
calculator
=
Calculator
()
cluster_file
=
''
cluster_file
=
''
ignore_words
=
set
()
ignore_words
=
set
()
tochar
=
False
tochar
=
False
verbose
=
1
verbose
=
1
padding_symbol
=
' '
padding_symbol
=
' '
case_sensitive
=
False
case_sensitive
=
False
max_words_per_line
=
sys
.
maxsize
max_words_per_line
=
sys
.
maxsize
split
=
None
split
=
None
...
@@ -363,10 +322,10 @@ if __name__ == '__main__':
...
@@ -363,10 +322,10 @@ if __name__ == '__main__':
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
verbose
=
0
verbose
=
0
try
:
try
:
verbose
=
int
(
b
)
verbose
=
int
(
b
)
except
Exception
as
e
:
except
:
if
b
==
'true'
or
b
!=
'0'
:
if
b
==
'true'
or
b
!=
'0'
:
verbose
=
1
verbose
=
1
continue
continue
...
@@ -375,9 +334,9 @@ if __name__ == '__main__':
...
@@ -375,9 +334,9 @@ if __name__ == '__main__':
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
if
b
==
'space'
:
if
b
==
'space'
:
padding_symbol
=
' '
padding_symbol
=
' '
elif
b
==
'underline'
:
elif
b
==
'underline'
:
padding_symbol
=
'_'
padding_symbol
=
'_'
continue
continue
if
True
or
sys
.
argv
[
1
].
startswith
(
'-'
):
if
True
or
sys
.
argv
[
1
].
startswith
(
'-'
):
#ignore invalid switch
#ignore invalid switch
...
@@ -385,7 +344,7 @@ if __name__ == '__main__':
...
@@ -385,7 +344,7 @@ if __name__ == '__main__':
continue
continue
if
not
case_sensitive
:
if
not
case_sensitive
:
ig
=
set
([
w
.
upper
()
for
w
in
ignore_words
])
ig
=
set
([
w
.
upper
()
for
w
in
ignore_words
])
ignore_words
=
ig
ignore_words
=
ig
default_clusters
=
{}
default_clusters
=
{}
...
@@ -409,20 +368,17 @@ if __name__ == '__main__':
...
@@ -409,20 +368,17 @@ if __name__ == '__main__':
array
=
characterize
(
line
)
array
=
characterize
(
line
)
else
:
else
:
array
=
line
.
strip
().
split
()
array
=
line
.
strip
().
split
()
if
len
(
array
)
==
0
:
if
len
(
array
)
==
0
:
continue
continue
fid
=
array
[
0
]
fid
=
array
[
0
]
rec_set
[
fid
]
=
normalize
(
array
[
1
:],
ignore_words
,
case_sensitive
,
rec_set
[
fid
]
=
normalize
(
array
[
1
:],
ignore_words
,
case_sensitive
,
split
)
split
)
# compute error rate on the interaction of reference file and hyp file
# compute error rate on the interaction of reference file and hyp file
for
line
in
open
(
ref_file
,
'r'
,
encoding
=
'utf-8'
)
:
for
line
in
open
(
ref_file
,
'r'
,
encoding
=
'utf-8'
)
:
if
tochar
:
if
tochar
:
array
=
characterize
(
line
)
array
=
characterize
(
line
)
else
:
else
:
array
=
line
.
rstrip
(
'
\n
'
).
split
()
array
=
line
.
rstrip
(
'
\n
'
).
split
()
if
len
(
array
)
==
0
:
if
len
(
array
)
==
0
:
continue
continue
fid
=
array
[
0
]
fid
=
array
[
0
]
if
fid
not
in
rec_set
:
if
fid
not
in
rec_set
:
continue
continue
...
@@ -431,127 +387,114 @@ if __name__ == '__main__':
...
@@ -431,127 +387,114 @@ if __name__ == '__main__':
if
verbose
:
if
verbose
:
print
(
'
\n
utt: %s'
%
fid
)
print
(
'
\n
utt: %s'
%
fid
)
for
word
in
rec
+
lab
:
for
word
in
rec
+
lab
:
if
word
not
in
default_words
:
if
word
not
in
default_words
:
default_cluster_name
=
default_cluster
(
word
)
default_cluster_name
=
default_cluster
(
word
)
if
default_cluster_name
not
in
default_clusters
:
if
default_cluster_name
not
in
default_clusters
:
default_clusters
[
default_cluster_name
]
=
{}
default_clusters
[
default_cluster_name
]
=
{}
if
word
not
in
default_clusters
[
default_cluster_name
]
:
if
word
not
in
default_clusters
[
default_cluster_name
]
:
default_clusters
[
default_cluster_name
][
word
]
=
1
default_clusters
[
default_cluster_name
][
word
]
=
1
default_words
[
word
]
=
default_cluster_name
default_words
[
word
]
=
default_cluster_name
result
=
calculator
.
calculate
(
lab
,
rec
)
result
=
calculator
.
calculate
(
lab
,
rec
)
if
verbose
:
if
verbose
:
if
result
[
'all'
]
!=
0
:
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
'del'
])
*
100.0
/
result
[
'all'
]
else
:
else
:
wer
=
0.0
wer
=
0.0
print
(
'WER: %4.2f %%'
%
wer
,
end
=
' '
)
print
(
'WER: %4.2f %%'
%
wer
,
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
result
[
'ins'
]))
space
=
{}
space
=
{}
space
[
'lab'
]
=
[]
space
[
'lab'
]
=
[]
space
[
'rec'
]
=
[]
space
[
'rec'
]
=
[]
for
idx
in
range
(
len
(
result
[
'lab'
]))
:
for
idx
in
range
(
len
(
result
[
'lab'
]))
:
len_lab
=
width
(
result
[
'lab'
][
idx
])
len_lab
=
width
(
result
[
'lab'
][
idx
])
len_rec
=
width
(
result
[
'rec'
][
idx
])
len_rec
=
width
(
result
[
'rec'
][
idx
])
length
=
max
(
len_lab
,
len_rec
)
length
=
max
(
len_lab
,
len_rec
)
space
[
'lab'
].
append
(
length
-
len_lab
)
space
[
'lab'
].
append
(
length
-
len_lab
)
space
[
'rec'
].
append
(
length
-
len_rec
)
space
[
'rec'
].
append
(
length
-
len_rec
)
upper_lab
=
len
(
result
[
'lab'
])
upper_lab
=
len
(
result
[
'lab'
])
upper_rec
=
len
(
result
[
'rec'
])
upper_rec
=
len
(
result
[
'rec'
])
lab1
,
rec1
=
0
,
0
lab1
,
rec1
=
0
,
0
while
lab1
<
upper_lab
or
rec1
<
upper_rec
:
while
lab1
<
upper_lab
or
rec1
<
upper_rec
:
if
verbose
>
1
:
if
verbose
>
1
:
print
(
'lab(%s):'
%
fid
.
encode
(
'utf-8'
),
end
=
' '
)
print
(
'lab(%s):'
%
fid
.
encode
(
'utf-8'
),
end
=
' '
)
else
:
else
:
print
(
'lab:'
,
end
=
' '
)
print
(
'lab:'
,
end
=
' '
)
lab2
=
min
(
upper_lab
,
lab1
+
max_words_per_line
)
lab2
=
min
(
upper_lab
,
lab1
+
max_words_per_line
)
for
idx
in
range
(
lab1
,
lab2
):
for
idx
in
range
(
lab1
,
lab2
):
token
=
result
[
'lab'
][
idx
]
token
=
result
[
'lab'
][
idx
]
print
(
'{token}'
.
format
(
token
=
token
),
end
=
''
)
print
(
'{token}'
.
format
(
token
=
token
),
end
=
''
)
for
n
in
range
(
space
[
'lab'
][
idx
])
:
for
n
in
range
(
space
[
'lab'
][
idx
])
:
print
(
padding_symbol
,
end
=
''
)
print
(
padding_symbol
,
end
=
''
)
print
(
' '
,
end
=
''
)
print
(
' '
,
end
=
''
)
print
()
print
()
if
verbose
>
1
:
if
verbose
>
1
:
print
(
'rec(%s):'
%
fid
.
encode
(
'utf-8'
),
end
=
' '
)
print
(
'rec(%s):'
%
fid
.
encode
(
'utf-8'
),
end
=
' '
)
else
:
else
:
print
(
'rec:'
,
end
=
' '
)
print
(
'rec:'
,
end
=
' '
)
rec2
=
min
(
upper_rec
,
rec1
+
max_words_per_line
)
rec2
=
min
(
upper_rec
,
rec1
+
max_words_per_line
)
for
idx
in
range
(
rec1
,
rec2
):
for
idx
in
range
(
rec1
,
rec2
):
token
=
result
[
'rec'
][
idx
]
token
=
result
[
'rec'
][
idx
]
print
(
'{token}'
.
format
(
token
=
token
),
end
=
''
)
print
(
'{token}'
.
format
(
token
=
token
),
end
=
''
)
for
n
in
range
(
space
[
'rec'
][
idx
])
:
for
n
in
range
(
space
[
'rec'
][
idx
])
:
print
(
padding_symbol
,
end
=
''
)
print
(
padding_symbol
,
end
=
''
)
print
(
' '
,
end
=
''
)
print
(
' '
,
end
=
''
)
print
(
'
\n
'
,
end
=
'
\n
'
)
print
(
'
\n
'
,
end
=
'
\n
'
)
lab1
=
lab2
lab1
=
lab2
rec1
=
rec2
rec1
=
rec2
if
verbose
:
if
verbose
:
print
(
print
(
'==========================================================================='
)
'==========================================================================='
)
print
()
print
()
result
=
calculator
.
overall
()
result
=
calculator
.
overall
()
if
result
[
'all'
]
!=
0
:
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
'del'
])
*
100.0
/
result
[
'all'
]
else
:
else
:
wer
=
0.0
wer
=
0.0
print
(
'Overall -> %4.2f %%'
%
wer
,
end
=
' '
)
print
(
'Overall -> %4.2f %%'
%
wer
,
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
result
[
'ins'
]))
if
not
verbose
:
if
not
verbose
:
print
()
print
()
if
verbose
:
if
verbose
:
for
cluster_id
in
default_clusters
:
for
cluster_id
in
default_clusters
:
result
=
calculator
.
cluster
(
result
=
calculator
.
cluster
([
k
for
k
in
default_clusters
[
cluster_id
]
])
[
k
for
k
in
default_clusters
[
cluster_id
]])
if
result
[
'all'
]
!=
0
:
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
else
:
'del'
])
*
100.0
/
result
[
'all'
]
else
:
wer
=
0.0
wer
=
0.0
print
(
'%s -> %4.2f %%'
%
(
cluster_id
,
wer
),
end
=
' '
)
print
(
'%s -> %4.2f %%'
%
(
cluster_id
,
wer
),
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
result
[
'ins'
]))
if
len
(
cluster_file
)
>
0
:
# compute separated WERs for word clusters
if
len
(
cluster_file
)
>
0
:
# compute separated WERs for word clusters
cluster_id
=
''
cluster_id
=
''
cluster
=
[]
cluster
=
[]
for
line
in
open
(
cluster_file
,
'r'
,
encoding
=
'utf-8'
)
:
for
line
in
open
(
cluster_file
,
'r'
,
encoding
=
'utf-8'
)
:
for
token
in
line
.
decode
(
'utf-8'
).
rstrip
(
'
\n
'
).
split
()
:
for
token
in
line
.
decode
(
'utf-8'
).
rstrip
(
'
\n
'
).
split
()
:
# end of cluster reached, like </Keyword>
# end of cluster reached, like </Keyword>
if
token
[
0
:
2
]
==
'</'
and
token
[
len
(
token
)
-
1
]
==
'>'
and
\
if
token
[
0
:
2
]
==
'</'
and
token
[
len
(
token
)
-
1
]
==
'>'
and
\
token
.
lstrip
(
'</'
).
rstrip
(
'>'
)
==
cluster_id
:
token
.
lstrip
(
'</'
).
rstrip
(
'>'
)
==
cluster_id
:
result
=
calculator
.
cluster
(
cluster
)
result
=
calculator
.
cluster
(
cluster
)
if
result
[
'all'
]
!=
0
:
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
'del'
])
*
100.0
/
result
[
'all'
]
else
:
else
:
wer
=
0.0
wer
=
0.0
print
(
'%s -> %4.2f %%'
%
(
cluster_id
,
wer
),
end
=
' '
)
print
(
'%s -> %4.2f %%'
%
(
cluster_id
,
wer
),
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
result
[
'del'
],
result
[
'ins'
]))
cluster_id
=
''
cluster_id
=
''
cluster
=
[]
cluster
=
[]
# begin of cluster reached, like <Keyword>
# begin of cluster reached, like <Keyword>
elif
token
[
0
]
==
'<'
and
token
[
len
(
token
)
-
1
]
==
'>'
and
\
elif
token
[
0
]
==
'<'
and
token
[
len
(
token
)
-
1
]
==
'>'
and
\
cluster_id
==
''
:
cluster_id
==
''
:
cluster_id
=
token
.
lstrip
(
'<'
).
rstrip
(
'>'
)
cluster_id
=
token
.
lstrip
(
'<'
).
rstrip
(
'>'
)
cluster
=
[]
cluster
=
[]
# general terms, like WEATHER / CAR / ...
# general terms, like WEATHER / CAR / ...
else
:
else
:
cluster
.
append
(
token
)
cluster
.
append
(
token
)
print
()
print
()
print
(
print
(
'==========================================================================='
)
'==========================================================================='
\ No newline at end of file
)
utils/format_rsl.py
0 → 100644
浏览文件 @
2b8c08e3
import
os
import
argparse
import
jsonlines
def
trans_hyp
(
origin_hyp
,
trans_hyp
=
None
,
trans_hyp_sclite
=
None
):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict
=
{}
with
open
(
origin_hyp
,
"r+"
,
encoding
=
"utf8"
)
as
f
:
for
item
in
jsonlines
.
Reader
(
f
):
input_dict
[
item
[
"utt"
]]
=
item
[
"hyps"
][
0
]
if
trans_hyp
is
not
None
:
with
open
(
trans_hyp
,
"w+"
,
encoding
=
"utf8"
)
as
f
:
for
key
in
input_dict
.
keys
():
f
.
write
(
key
+
" "
+
input_dict
[
key
]
+
"
\n
"
)
if
trans_hyp_sclite
is
not
None
:
with
open
(
trans_hyp_sclite
,
"w+"
)
as
f
:
for
key
in
input_dict
.
keys
():
line
=
input_dict
[
key
]
+
"("
+
key
+
".wav"
+
")"
+
"
\n
"
f
.
write
(
line
)
def
trans_ref
(
origin_ref
,
trans_ref
=
None
,
trans_ref_sclite
=
None
):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict
=
{}
with
open
(
origin_ref
,
"r"
,
encoding
=
"utf8"
)
as
f
:
for
item
in
jsonlines
.
Reader
(
f
):
input_dict
[
item
[
"utt"
]]
=
item
[
"text"
]
if
trans_ref
is
not
None
:
with
open
(
trans_ref
,
"w"
,
encoding
=
"utf8"
)
as
f
:
for
key
in
input_dict
.
keys
():
f
.
write
(
key
+
" "
+
input_dict
[
key
]
+
"
\n
"
)
if
trans_ref_sclite
is
not
None
:
with
open
(
trans_ref_sclite
,
"w"
)
as
f
:
for
key
in
input_dict
.
keys
():
line
=
input_dict
[
key
]
+
"("
+
key
+
".wav"
+
")"
+
"
\n
"
f
.
write
(
line
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
'format hyp file for compute CER/WER'
,
add_help
=
True
)
parser
.
add_argument
(
'--origin_hyp'
,
type
=
str
,
default
=
None
,
help
=
'origin hyp file'
)
parser
.
add_argument
(
'--trans_hyp'
,
type
=
str
,
default
=
None
,
help
=
'hyp file for caculating CER/WER'
)
parser
.
add_argument
(
'--trans_hyp_sclite'
,
type
=
str
,
default
=
None
,
help
=
'hyp file for caculating CER/WER by sclite'
)
parser
.
add_argument
(
'--origin_ref'
,
type
=
str
,
default
=
None
,
help
=
'origin ref file'
)
parser
.
add_argument
(
'--trans_ref'
,
type
=
str
,
default
=
None
,
help
=
'ref file for caculating CER/WER'
)
parser
.
add_argument
(
'--trans_ref_sclite'
,
type
=
str
,
default
=
None
,
help
=
'ref file for caculating CER/WER by sclite'
)
parser_args
=
parser
.
parse_args
()
if
parser_args
.
origin_hyp
is
not
None
:
trans_hyp
(
origin_hyp
=
parser_args
.
origin_hyp
,
trans_hyp
=
parser_args
.
trans_hyp
,
trans_hyp_sclite
=
parser_args
.
trans_hyp_sclite
,
)
if
parser_args
.
origin_ref
is
not
None
:
trans_ref
(
origin_ref
=
parser_args
.
origin_ref
,
trans_ref
=
parser_args
.
trans_ref
,
trans_ref_sclite
=
parser_args
.
trans_ref_sclite
,
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录