Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
c7d9b115
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c7d9b115
编写于
4月 21, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format
上级
caf72258
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
910 addition
and
823 deletion
+910
-823
.flake8
.flake8
+2
-0
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+4
-2
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+1
-1
paddlespeech/s2t/modules/ctc.py
paddlespeech/s2t/modules/ctc.py
+1
-1
paddlespeech/server/README.md
paddlespeech/server/README.md
+1
-1
paddlespeech/server/README_cn.md
paddlespeech/server/README_cn.md
+1
-1
paddlespeech/server/bin/paddlespeech_client.py
paddlespeech/server/bin/paddlespeech_client.py
+1
-0
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+2
-0
paddlespeech/server/tests/asr/online/websocket_client.py
paddlespeech/server/tests/asr/online/websocket_client.py
+2
-2
paddlespeech/t2s/exps/synthesize.py
paddlespeech/t2s/exps/synthesize.py
+1
-1
paddlespeech/vector/cluster/diarization.py
paddlespeech/vector/cluster/diarization.py
+1
-1
speechx/examples/ngram/zh/local/text_to_lexicon.py
speechx/examples/ngram/zh/local/text_to_lexicon.py
+6
-10
speechx/examples/text_lm/local/mmseg.py
speechx/examples/text_lm/local/mmseg.py
+325
-313
speechx/examples/wfst/README.md
speechx/examples/wfst/README.md
+1
-1
utils/DER.py
utils/DER.py
+1
-1
utils/compute-wer.py
utils/compute-wer.py
+509
-455
utils/format_rsl.py
utils/format_rsl.py
+46
-31
utils/fst/prepare_dict.py
utils/fst/prepare_dict.py
+5
-2
未找到文件。
.flake8
浏览文件 @
c7d9b115
...
@@ -12,6 +12,8 @@ exclude =
...
@@ -12,6 +12,8 @@ exclude =
.git,
.git,
# python cache
# python cache
__pycache__,
__pycache__,
# third party
utils/compute-wer.py,
third_party/,
third_party/,
# Provide a comma-separate list of glob patterns to include for checks.
# Provide a comma-separate list of glob patterns to include for checks.
filename =
filename =
...
...
paddlespeech/cli/asr/infer.py
浏览文件 @
c7d9b115
...
@@ -40,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
...
@@ -40,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__
=
[
'ASRExecutor'
]
__all__
=
[
'ASRExecutor'
]
@
cli_register
(
@
cli_register
(
name
=
'paddlespeech.asr'
,
description
=
'Speech to text infer command.'
)
name
=
'paddlespeech.asr'
,
description
=
'Speech to text infer command.'
)
class
ASRExecutor
(
BaseExecutor
):
class
ASRExecutor
(
BaseExecutor
):
...
@@ -148,7 +149,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -148,7 +149,7 @@ class ASRExecutor(BaseExecutor):
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
logger
.
info
(
self
.
ckpt_path
)
#Init body.
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
...
@@ -278,7 +279,8 @@ class ASRExecutor(BaseExecutor):
...
@@ -278,7 +279,8 @@ class ASRExecutor(BaseExecutor):
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
logger
.
info
(
f
"we will use the transformer like model :
{
model_type
}
"
)
logger
.
info
(
f
"we will use the transformer like model :
{
model_type
}
"
)
try
:
try
:
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio
,
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
c7d9b115
...
@@ -279,7 +279,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -279,7 +279,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size:
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if
end_flag
.
cast
(
paddle
.
int64
).
sum
()
==
running_size
:
if
end_flag
.
cast
(
paddle
.
int64
).
sum
()
==
running_size
:
break
break
# 2.1 Forward decoder step
# 2.1 Forward decoder step
hyps_mask
=
subsequent_mask
(
i
).
unsqueeze
(
0
).
repeat
(
hyps_mask
=
subsequent_mask
(
i
).
unsqueeze
(
0
).
repeat
(
running_size
,
1
,
1
).
to
(
device
)
# (B*N, i, i)
running_size
,
1
,
1
).
to
(
device
)
# (B*N, i, i)
...
...
paddlespeech/s2t/modules/ctc.py
浏览文件 @
c7d9b115
...
@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
...
@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once
# init once
if
self
.
_ext_scorer
is
not
None
:
if
self
.
_ext_scorer
is
not
None
:
return
return
if
language_model_path
!=
''
:
if
language_model_path
!=
''
:
logger
.
info
(
"begin to initialize the external scorer "
logger
.
info
(
"begin to initialize the external scorer "
"for decoding"
)
"for decoding"
)
...
...
paddlespeech/server/README.md
浏览文件 @
c7d9b115
...
@@ -47,4 +47,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml
...
@@ -47,4 +47,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```
```
\ No newline at end of file
paddlespeech/server/README_cn.md
浏览文件 @
c7d9b115
...
@@ -48,4 +48,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml
...
@@ -48,4 +48,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```
```
\ No newline at end of file
paddlespeech/server/bin/paddlespeech_client.py
浏览文件 @
c7d9b115
...
@@ -305,6 +305,7 @@ class ASRClientExecutor(BaseExecutor):
...
@@ -305,6 +305,7 @@ class ASRClientExecutor(BaseExecutor):
return
res
[
'asr_results'
]
return
res
[
'asr_results'
]
@
cli_client_register
(
@
cli_client_register
(
name
=
'paddlespeech_client.cls'
,
description
=
'visit cls service'
)
name
=
'paddlespeech_client.cls'
,
description
=
'visit cls service'
)
class
CLSClientExecutor
(
BaseExecutor
):
class
CLSClientExecutor
(
BaseExecutor
):
...
...
paddlespeech/server/engine/asr/online/ctc_search.py
浏览文件 @
c7d9b115
...
@@ -12,7 +12,9 @@
...
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
collections
import
defaultdict
from
collections
import
defaultdict
import
paddle
import
paddle
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.s2t.utils.utility
import
log_add
from
paddlespeech.s2t.utils.utility
import
log_add
...
...
paddlespeech/server/tests/asr/online/websocket_client.py
浏览文件 @
c7d9b115
...
@@ -36,7 +36,7 @@ class ASRAudioHandler:
...
@@ -36,7 +36,7 @@ class ASRAudioHandler:
x_len
=
len
(
samples
)
x_len
=
len
(
samples
)
chunk_size
=
85
*
16
#80ms, sample_rate = 16kHz
chunk_size
=
85
*
16
#80ms, sample_rate = 16kHz
if
x_len
%
chunk_size
!=
0
:
if
x_len
%
chunk_size
!=
0
:
padding_len_x
=
chunk_size
-
x_len
%
chunk_size
padding_len_x
=
chunk_size
-
x_len
%
chunk_size
else
:
else
:
padding_len_x
=
0
padding_len_x
=
0
...
@@ -92,7 +92,7 @@ class ASRAudioHandler:
...
@@ -92,7 +92,7 @@ class ASRAudioHandler:
separators
=
(
','
,
': '
))
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
msg
=
await
ws
.
recv
()
# decode the bytes to str
# decode the bytes to str
msg
=
json
.
loads
(
msg
)
msg
=
json
.
loads
(
msg
)
logging
.
info
(
"final receive msg={}"
.
format
(
msg
))
logging
.
info
(
"final receive msg={}"
.
format
(
msg
))
...
...
paddlespeech/t2s/exps/synthesize.py
浏览文件 @
c7d9b115
...
@@ -52,7 +52,7 @@ def evaluate(args):
...
@@ -52,7 +52,7 @@ def evaluate(args):
# acoustic model
# acoustic model
am_name
=
args
.
am
[:
args
.
am
.
rindex
(
'_'
)]
am_name
=
args
.
am
[:
args
.
am
.
rindex
(
'_'
)]
am_dataset
=
args
.
am
[
args
.
am
.
rindex
(
'_'
)
+
1
:]
am_dataset
=
args
.
am
[
args
.
am
.
rindex
(
'_'
)
+
1
:]
am_inference
=
get_am_inference
(
am_inference
=
get_am_inference
(
am
=
args
.
am
,
am
=
args
.
am
,
am_config
=
am_config
,
am_config
=
am_config
,
...
...
paddlespeech/vector/cluster/diarization.py
浏览文件 @
c7d9b115
...
@@ -20,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement.
...
@@ -20,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement.
import
argparse
import
argparse
import
copy
import
copy
import
warnings
import
warnings
from
distutils.util
import
strtobool
import
numpy
as
np
import
numpy
as
np
import
scipy
import
scipy
import
sklearn
import
sklearn
from
distutils.util
import
strtobool
from
scipy
import
linalg
from
scipy
import
linalg
from
scipy
import
sparse
from
scipy
import
sparse
from
scipy.sparse.csgraph
import
connected_components
from
scipy.sparse.csgraph
import
connected_components
...
...
speechx/examples/ngram/zh/local/text_to_lexicon.py
浏览文件 @
c7d9b115
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
argparse
import
argparse
from
collections
import
Counter
from
collections
import
Counter
def
main
(
args
):
def
main
(
args
):
counter
=
Counter
()
counter
=
Counter
()
with
open
(
args
.
text
,
'r'
)
as
fin
,
open
(
args
.
lexicon
,
'w'
)
as
fout
:
with
open
(
args
.
text
,
'r'
)
as
fin
,
open
(
args
.
lexicon
,
'w'
)
as
fout
:
...
@@ -12,7 +13,7 @@ def main(args):
...
@@ -12,7 +13,7 @@ def main(args):
words
=
text
.
split
()
words
=
text
.
split
()
else
:
else
:
words
=
line
.
split
()
words
=
line
.
split
()
counter
.
update
(
words
)
counter
.
update
(
words
)
for
word
in
counter
:
for
word
in
counter
:
...
@@ -20,21 +21,16 @@ def main(args):
...
@@ -20,21 +21,16 @@ def main(args):
fout
.
write
(
f
"
{
word
}
\t
{
val
}
\n
"
)
fout
.
write
(
f
"
{
word
}
\t
{
val
}
\n
"
)
fout
.
flush
()
fout
.
flush
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'text(line:utt1 中国 人) to lexicon(line:中国 中 国).'
)
description
=
'text(line:utt1 中国 人) to lexicon(line:中国 中 国).'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--has_key'
,
'--has_key'
,
default
=
True
,
help
=
'text path, with utt or not'
)
default
=
True
,
help
=
'text path, with utt or not'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--text'
,
'--text'
,
required
=
True
,
help
=
'text path. line: utt1 中国 人 or 中国 人'
)
required
=
True
,
help
=
'text path. line: utt1 中国 人 or 中国 人'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--lexicon'
,
'--lexicon'
,
required
=
True
,
help
=
'lexicon path. line:中国 中 国'
)
required
=
True
,
help
=
'lexicon path. line:中国 中 国'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
...
speechx/examples/text_lm/local/mmseg.py
浏览文件 @
c7d9b115
#!/usr/bin/env python3
#!/usr/bin/env python3
# modify from https://sites.google.com/site/homepageoffuyanwei/Home/remarksandexcellentdiscussion/page-2
# modify from https://sites.google.com/site/homepageoffuyanwei/Home/remarksandexcellentdiscussion/page-2
class
Word
:
def
__init__
(
self
,
text
=
''
,
freq
=
0
):
class
Word
:
self
.
text
=
text
def
__init__
(
self
,
text
=
''
,
freq
=
0
):
self
.
freq
=
freq
self
.
text
=
text
self
.
length
=
len
(
text
)
self
.
freq
=
freq
self
.
length
=
len
(
text
)
class
Chunk
:
class
Chunk
:
def
__init__
(
self
,
w1
,
w2
=
None
,
w3
=
None
):
def
__init__
(
self
,
w1
,
w2
=
None
,
w3
=
None
):
self
.
words
=
[]
self
.
words
=
[]
self
.
words
.
append
(
w1
)
self
.
words
.
append
(
w1
)
if
w2
:
if
w2
:
self
.
words
.
append
(
w2
)
self
.
words
.
append
(
w2
)
if
w3
:
if
w3
:
self
.
words
.
append
(
w3
)
self
.
words
.
append
(
w3
)
#计算chunk的总长度
#计算chunk的总长度
def
totalWordLength
(
self
):
def
totalWordLength
(
self
):
length
=
0
length
=
0
for
word
in
self
.
words
:
for
word
in
self
.
words
:
length
+=
len
(
word
.
text
)
length
+=
len
(
word
.
text
)
return
length
return
length
#计算平均长度
#计算平均长度
def
averageWordLength
(
self
):
def
averageWordLength
(
self
):
return
float
(
self
.
totalWordLength
())
/
float
(
len
(
self
.
words
))
return
float
(
self
.
totalWordLength
())
/
float
(
len
(
self
.
words
))
#计算标准差
#计算标准差
def
standardDeviation
(
self
):
def
standardDeviation
(
self
):
average
=
self
.
averageWordLength
()
average
=
self
.
averageWordLength
()
sum
=
0.0
sum
=
0.0
for
word
in
self
.
words
:
for
word
in
self
.
words
:
tmp
=
(
len
(
word
.
text
)
-
average
)
tmp
=
(
len
(
word
.
text
)
-
average
)
sum
+=
float
(
tmp
)
*
float
(
tmp
)
sum
+=
float
(
tmp
)
*
float
(
tmp
)
return
sum
return
sum
#自由语素度
#自由语素度
def
wordFrequency
(
self
):
def
wordFrequency
(
self
):
sum
=
0
sum
=
0
for
word
in
self
.
words
:
for
word
in
self
.
words
:
sum
+=
word
.
freq
sum
+=
word
.
freq
return
sum
return
sum
class
ComplexCompare
:
class
ComplexCompare
:
def
takeHightest
(
self
,
chunks
,
comparator
):
def
takeHightest
(
self
,
chunks
,
comparator
):
i
=
1
i
=
1
for
j
in
range
(
1
,
len
(
chunks
)):
for
j
in
range
(
1
,
len
(
chunks
)):
rlt
=
comparator
(
chunks
[
j
],
chunks
[
0
])
rlt
=
comparator
(
chunks
[
j
],
chunks
[
0
])
if
rlt
>
0
:
if
rlt
>
0
:
i
=
0
i
=
0
if
rlt
>=
0
:
if
rlt
>=
0
:
chunks
[
i
],
chunks
[
j
]
=
chunks
[
j
],
chunks
[
i
]
chunks
[
i
],
chunks
[
j
]
=
chunks
[
j
],
chunks
[
i
]
i
+=
1
i
+=
1
return
chunks
[
0
:
i
]
return
chunks
[
0
:
i
]
#以下四个函数是mmseg算法的四种过滤原则,核心算法
#以下四个函数是mmseg算法的四种过滤原则,核心算法
def
mmFilter
(
self
,
chunks
):
def
mmFilter
(
self
,
chunks
):
def
comparator
(
a
,
b
):
def
comparator
(
a
,
b
):
return
a
.
totalWordLength
()
-
b
.
totalWordLength
()
return
a
.
totalWordLength
()
-
b
.
totalWordLength
()
return
self
.
takeHightest
(
chunks
,
comparator
)
return
self
.
takeHightest
(
chunks
,
comparator
)
def
lawlFilter
(
self
,
chunks
):
def
comparator
(
a
,
b
):
def
lawlFilter
(
self
,
chunks
):
return
a
.
averageWordLength
()
-
b
.
averageWordLength
()
def
comparator
(
a
,
b
):
return
self
.
takeHightest
(
chunks
,
comparator
)
return
a
.
averageWordLength
()
-
b
.
averageWordLength
()
def
svmlFilter
(
self
,
chunks
):
return
self
.
takeHightest
(
chunks
,
comparator
)
def
comparator
(
a
,
b
):
return
b
.
standardDeviation
()
-
a
.
standardDeviation
()
def
svmlFilter
(
self
,
chunks
):
return
self
.
takeHightest
(
chunks
,
comparator
)
def
comparator
(
a
,
b
):
return
b
.
standardDeviation
()
-
a
.
standardDeviation
()
def
logFreqFilter
(
self
,
chunks
):
def
comparator
(
a
,
b
):
return
self
.
takeHightest
(
chunks
,
comparator
)
return
a
.
wordFrequency
()
-
b
.
wordFrequency
()
return
self
.
takeHightest
(
chunks
,
comparator
)
def
logFreqFilter
(
self
,
chunks
):
def
comparator
(
a
,
b
):
return
a
.
wordFrequency
()
-
b
.
wordFrequency
()
return
self
.
takeHightest
(
chunks
,
comparator
)
#加载词组字典和字符字典
#加载词组字典和字符字典
dictWord
=
{}
dictWord
=
{}
maxWordLength
=
0
maxWordLength
=
0
def
loadDictChars
(
filepath
):
global
maxWordLength
def
loadDictChars
(
filepath
):
global
maxWordLength
fsock
=
open
(
filepath
)
fsock
=
open
(
filepath
)
for
line
in
fsock
:
for
line
in
fsock
:
freq
,
word
=
line
.
split
()
freq
,
word
=
line
.
split
()
word
=
word
.
strip
()
word
=
word
.
strip
()
dictWord
[
word
]
=
(
len
(
word
),
int
(
freq
))
dictWord
[
word
]
=
(
len
(
word
),
int
(
freq
))
maxWordLength
=
len
(
word
)
if
maxWordLength
<
len
(
word
)
else
maxWordLength
maxWordLength
=
len
(
word
)
if
maxWordLength
<
len
(
fsock
.
close
()
word
)
else
maxWordLength
fsock
.
close
()
def
loadDictWords
(
filepath
):
global
maxWordLength
fsock
=
open
(
filepath
)
def
loadDictWords
(
filepath
):
for
line
in
fsock
.
readlines
():
global
maxWordLength
fsock
=
open
(
filepath
)
for
line
in
fsock
.
readlines
():
word
=
line
.
strip
()
word
=
line
.
strip
()
dictWord
[
word
]
=
(
len
(
word
),
0
)
dictWord
[
word
]
=
(
len
(
word
),
0
)
maxWordLength
=
len
(
word
)
if
maxWordLength
<
len
(
word
)
else
maxWordLength
maxWordLength
=
len
(
word
)
if
maxWordLength
<
len
(
fsock
.
close
()
word
)
else
maxWordLength
fsock
.
close
()
#判断该词word是否在字典dictWord中
#判断该词word是否在字典dictWord中
def
getDictWord
(
word
):
def
getDictWord
(
word
):
result
=
dictWord
.
get
(
word
)
result
=
dictWord
.
get
(
word
)
if
result
:
if
result
:
return
Word
(
word
,
result
[
1
])
return
Word
(
word
,
result
[
1
])
return
None
return
None
#开始加载字典
#开始加载字典
def
run
():
def
run
():
from
os.path
import
join
,
dirname
from
os.path
import
join
,
dirname
loadDictChars
(
join
(
dirname
(
__file__
),
'data'
,
'chars.dic'
))
loadDictChars
(
join
(
dirname
(
__file__
),
'data'
,
'chars.dic'
))
loadDictWords
(
join
(
dirname
(
__file__
),
'data'
,
'words.dic'
))
loadDictWords
(
join
(
dirname
(
__file__
),
'data'
,
'words.dic'
))
class
Analysis
:
class
Analysis
:
def
__init__
(
self
,
text
):
def
__init__
(
self
,
text
):
self
.
text
=
text
self
.
text
=
text
self
.
cacheSize
=
3
self
.
cacheSize
=
3
self
.
pos
=
0
self
.
pos
=
0
self
.
textLength
=
len
(
self
.
text
)
self
.
textLength
=
len
(
self
.
text
)
self
.
cache
=
[]
self
.
cache
=
[]
self
.
cacheIndex
=
0
self
.
cacheIndex
=
0
self
.
complexCompare
=
ComplexCompare
()
self
.
complexCompare
=
ComplexCompare
()
#简单小技巧,用到个缓存,不知道具体有没有用处
#简单小技巧,用到个缓存,不知道具体有没有用处
for
i
in
range
(
self
.
cacheSize
):
for
i
in
range
(
self
.
cacheSize
):
self
.
cache
.
append
([
-
1
,
Word
()])
self
.
cache
.
append
([
-
1
,
Word
()])
#控制字典只加载一次
#控制字典只加载一次
if
not
dictWord
:
if
not
dictWord
:
run
()
run
()
def
__iter__
(
self
):
while
True
:
token
=
self
.
getNextToken
()
if
token
is
None
:
raise
StopIteration
yield
token
def
getNextChar
(
self
):
return
self
.
text
[
self
.
pos
]
def
__iter__
(
self
):
while
True
:
token
=
self
.
getNextToken
()
if
token
==
None
:
raise
StopIteration
yield
token
def
getNextChar
(
self
):
return
self
.
text
[
self
.
pos
]
#判断该字符是否是中文字符(不包括中文标点)
#判断该字符是否是中文字符(不包括中文标点)
def
isChineseChar
(
self
,
charater
):
def
isChineseChar
(
self
,
charater
):
return
0x4e00
<=
ord
(
charater
)
<
0x9fa6
return
0x4e00
<=
ord
(
charater
)
<
0x9fa6
#判断是否是ASCII码
#判断是否是ASCII码
def
isASCIIChar
(
self
,
ch
):
def
isASCIIChar
(
self
,
ch
):
import
string
import
string
if
ch
in
string
.
whitespace
:
if
ch
in
string
.
whitespace
:
return
False
return
False
if
ch
in
string
.
punctuation
:
if
ch
in
string
.
punctuation
:
return
False
return
False
return
ch
in
string
.
printable
return
ch
in
string
.
printable
#得到下一个切割结果
#得到下一个切割结果
def
getNextToken
(
self
):
def
getNextToken
(
self
):
while
self
.
pos
<
self
.
textLength
:
while
self
.
pos
<
self
.
textLength
:
if
self
.
isChineseChar
(
self
.
getNextChar
()):
if
self
.
isChineseChar
(
self
.
getNextChar
()):
token
=
self
.
getChineseWords
()
token
=
self
.
getChineseWords
()
else
:
else
:
token
=
self
.
getASCIIWords
()
+
'/'
token
=
self
.
getASCIIWords
()
+
'/'
if
len
(
token
)
>
0
:
if
len
(
token
)
>
0
:
return
token
return
token
return
None
return
None
#切割出非中文词
#切割出非中文词
def
getASCIIWords
(
self
):
def
getASCIIWords
(
self
):
# Skip pre-word whitespaces and punctuations
# Skip pre-word whitespaces and punctuations
#跳过中英文标点和空格
#跳过中英文标点和空格
while
self
.
pos
<
self
.
textLength
:
while
self
.
pos
<
self
.
textLength
:
ch
=
self
.
getNextChar
()
ch
=
self
.
getNextChar
()
if
self
.
isASCIIChar
(
ch
)
or
self
.
isChineseChar
(
ch
):
if
self
.
isASCIIChar
(
ch
)
or
self
.
isChineseChar
(
ch
):
break
break
self
.
pos
+=
1
self
.
pos
+=
1
#得到英文单词的起始位置
#得到英文单词的起始位置
start
=
self
.
pos
start
=
self
.
pos
#找出英文单词的结束位置
#找出英文单词的结束位置
while
self
.
pos
<
self
.
textLength
:
while
self
.
pos
<
self
.
textLength
:
ch
=
self
.
getNextChar
()
ch
=
self
.
getNextChar
()
if
not
self
.
isASCIIChar
(
ch
):
if
not
self
.
isASCIIChar
(
ch
):
break
break
self
.
pos
+=
1
self
.
pos
+=
1
end
=
self
.
pos
end
=
self
.
pos
#Skip chinese word whitespaces and punctuations
#Skip chinese word whitespaces and punctuations
#跳过中英文标点和空格
#跳过中英文标点和空格
while
self
.
pos
<
self
.
textLength
:
while
self
.
pos
<
self
.
textLength
:
ch
=
self
.
getNextChar
()
ch
=
self
.
getNextChar
()
if
self
.
isASCIIChar
(
ch
)
or
self
.
isChineseChar
(
ch
):
if
self
.
isASCIIChar
(
ch
)
or
self
.
isChineseChar
(
ch
):
break
break
self
.
pos
+=
1
self
.
pos
+=
1
#返回英文单词
#返回英文单词
return
self
.
text
[
start
:
end
]
return
self
.
text
[
start
:
end
]
#切割出中文词,并且做处理,用上述4种方法
#切割出中文词,并且做处理,用上述4种方法
def
getChineseWords
(
self
):
def
getChineseWords
(
self
):
chunks
=
self
.
createChunks
()
chunks
=
self
.
createChunks
()
if
len
(
chunks
)
>
1
:
if
len
(
chunks
)
>
1
:
chunks
=
self
.
complexCompare
.
mmFilter
(
chunks
)
chunks
=
self
.
complexCompare
.
mmFilter
(
chunks
)
if
len
(
chunks
)
>
1
:
if
len
(
chunks
)
>
1
:
chunks
=
self
.
complexCompare
.
lawlFilter
(
chunks
)
chunks
=
self
.
complexCompare
.
lawlFilter
(
chunks
)
if
len
(
chunks
)
>
1
:
if
len
(
chunks
)
>
1
:
chunks
=
self
.
complexCompare
.
svmlFilter
(
chunks
)
chunks
=
self
.
complexCompare
.
svmlFilter
(
chunks
)
if
len
(
chunks
)
>
1
:
if
len
(
chunks
)
>
1
:
chunks
=
self
.
complexCompare
.
logFreqFilter
(
chunks
)
chunks
=
self
.
complexCompare
.
logFreqFilter
(
chunks
)
if
len
(
chunks
)
==
0
:
if
len
(
chunks
)
==
0
:
return
''
return
''
#最后只有一种切割方法
#最后只有一种切割方法
word
=
chunks
[
0
].
words
word
=
chunks
[
0
].
words
token
=
""
token
=
""
length
=
0
length
=
0
for
x
in
word
:
for
x
in
word
:
if
x
.
length
!=
-
1
:
if
x
.
length
!=
-
1
:
token
+=
x
.
text
+
"/"
token
+=
x
.
text
+
"/"
length
+=
len
(
x
.
text
)
length
+=
len
(
x
.
text
)
self
.
pos
+=
length
self
.
pos
+=
length
return
token
return
token
#三重循环来枚举切割方法,这里也可以运用递归来实现
#三重循环来枚举切割方法,这里也可以运用递归来实现
def
createChunks
(
self
):
def
createChunks
(
self
):
chunks
=
[]
chunks
=
[]
originalPos
=
self
.
pos
originalPos
=
self
.
pos
words1
=
self
.
getMatchChineseWords
()
words1
=
self
.
getMatchChineseWords
()
for
word1
in
words1
:
for
word1
in
words1
:
self
.
pos
+=
len
(
word1
.
text
)
self
.
pos
+=
len
(
word1
.
text
)
if
self
.
pos
<
self
.
textLength
:
if
self
.
pos
<
self
.
textLength
:
words2
=
self
.
getMatchChineseWords
()
words2
=
self
.
getMatchChineseWords
()
for
word2
in
words2
:
for
word2
in
words2
:
self
.
pos
+=
len
(
word2
.
text
)
self
.
pos
+=
len
(
word2
.
text
)
if
self
.
pos
<
self
.
textLength
:
if
self
.
pos
<
self
.
textLength
:
words3
=
self
.
getMatchChineseWords
()
words3
=
self
.
getMatchChineseWords
()
for
word3
in
words3
:
for
word3
in
words3
:
# print(word3.length, word3.text)
# print(word3.length, word3.text)
if
word3
.
length
==
-
1
:
if
word3
.
length
==
-
1
:
chunk
=
Chunk
(
word1
,
word2
)
chunk
=
Chunk
(
word1
,
word2
)
# print("Ture")
# print("Ture")
else
:
else
:
chunk
=
Chunk
(
word1
,
word2
,
word3
)
chunk
=
Chunk
(
word1
,
word2
,
word3
)
chunks
.
append
(
chunk
)
chunks
.
append
(
chunk
)
elif
self
.
pos
==
self
.
textLength
:
elif
self
.
pos
==
self
.
textLength
:
chunks
.
append
(
Chunk
(
word1
,
word2
))
chunks
.
append
(
Chunk
(
word1
,
word2
))
self
.
pos
-=
len
(
word2
.
text
)
self
.
pos
-=
len
(
word2
.
text
)
elif
self
.
pos
==
self
.
textLength
:
elif
self
.
pos
==
self
.
textLength
:
chunks
.
append
(
Chunk
(
word1
))
chunks
.
append
(
Chunk
(
word1
))
self
.
pos
-=
len
(
word1
.
text
)
self
.
pos
-=
len
(
word1
.
text
)
self
.
pos
=
originalPos
self
.
pos
=
originalPos
return
chunks
return
chunks
#运用正向最大匹配算法结合字典来切割中文文本
#运用正向最大匹配算法结合字典来切割中文文本
def
getMatchChineseWords
(
self
):
def
getMatchChineseWords
(
self
):
#use cache,check it
#use cache,check it
for
i
in
range
(
self
.
cacheSize
):
for
i
in
range
(
self
.
cacheSize
):
if
self
.
cache
[
i
][
0
]
==
self
.
pos
:
if
self
.
cache
[
i
][
0
]
==
self
.
pos
:
return
self
.
cache
[
i
][
1
]
return
self
.
cache
[
i
][
1
]
originalPos
=
self
.
pos
originalPos
=
self
.
pos
words
=
[]
words
=
[]
index
=
0
index
=
0
while
self
.
pos
<
self
.
textLength
:
while
self
.
pos
<
self
.
textLength
:
if
index
>=
maxWordLength
:
if
index
>=
maxWordLength
:
break
break
if
not
self
.
isChineseChar
(
self
.
getNextChar
()):
if
not
self
.
isChineseChar
(
self
.
getNextChar
()):
break
break
self
.
pos
+=
1
self
.
pos
+=
1
index
+=
1
index
+=
1
text
=
self
.
text
[
originalPos
:
self
.
pos
]
text
=
self
.
text
[
originalPos
:
self
.
pos
]
word
=
getDictWord
(
text
)
word
=
getDictWord
(
text
)
if
word
:
if
word
:
words
.
append
(
word
)
words
.
append
(
word
)
self
.
pos
=
originalPos
self
.
pos
=
originalPos
#没有词则放置个‘X’,将文本长度标记为-1
#没有词则放置个‘X’,将文本长度标记为-1
if
not
words
:
if
not
words
:
word
=
Word
()
word
=
Word
()
word
.
length
=
-
1
word
.
length
=
-
1
word
.
text
=
'X'
word
.
text
=
'X'
words
.
append
(
word
)
words
.
append
(
word
)
self
.
cache
[
self
.
cacheIndex
]
=
(
self
.
pos
,
words
)
self
.
cache
[
self
.
cacheIndex
]
=
(
self
.
pos
,
words
)
self
.
cacheIndex
+=
1
self
.
cacheIndex
+=
1
if
self
.
cacheIndex
>=
self
.
cacheSize
:
if
self
.
cacheIndex
>=
self
.
cacheSize
:
self
.
cacheIndex
=
0
self
.
cacheIndex
=
0
return
words
return
words
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
def
cuttest
(
text
):
def
cuttest
(
text
):
#cut = Analysis(text)
#cut = Analysis(text)
tmp
=
""
tmp
=
""
try
:
try
:
for
word
in
iter
(
Analysis
(
text
)):
for
word
in
iter
(
Analysis
(
text
)):
tmp
+=
word
tmp
+=
word
...
@@ -310,71 +320,73 @@ if __name__=="__main__":
...
@@ -310,71 +320,73 @@ if __name__=="__main__":
print
(
"================================"
)
print
(
"================================"
)
cuttest
(
u
"研究生命来源"
)
cuttest
(
u
"研究生命来源"
)
cuttest
(
u
"南京市长江大桥欢迎您"
)
cuttest
(
u
"南京市长江大桥欢迎您"
)
cuttest
(
u
"请把手抬高一点儿"
)
cuttest
(
u
"请把手抬高一点儿"
)
cuttest
(
u
"长春市长春节致词。"
)
cuttest
(
u
"长春市长春节致词。"
)
cuttest
(
u
"长春市长春药店。"
)
cuttest
(
u
"长春市长春药店。"
)
cuttest
(
u
"我的和服务必在明天做好。"
)
cuttest
(
u
"我的和服务必在明天做好。"
)
cuttest
(
u
"我发现有很多人喜欢他。"
)
cuttest
(
u
"我发现有很多人喜欢他。"
)
cuttest
(
u
"我喜欢看电视剧大长今。"
)
cuttest
(
u
"我喜欢看电视剧大长今。"
)
cuttest
(
u
"半夜给拎起来陪看欧洲杯糊着两眼半晌没搞明白谁和谁踢。"
)
cuttest
(
u
"半夜给拎起来陪看欧洲杯糊着两眼半晌没搞明白谁和谁踢。"
)
cuttest
(
u
"李智伟高高兴兴以及王晓薇出去玩,后来智伟和晓薇又单独去玩了。"
)
cuttest
(
u
"李智伟高高兴兴以及王晓薇出去玩,后来智伟和晓薇又单独去玩了。"
)
cuttest
(
u
"一次性交出去很多钱。 "
)
cuttest
(
u
"一次性交出去很多钱。 "
)
cuttest
(
u
"这是一个伸手不见五指的黑夜。我叫孙悟空,我爱北京,我爱Python和C++。"
)
cuttest
(
u
"这是一个伸手不见五指的黑夜。我叫孙悟空,我爱北京,我爱Python和C++。"
)
cuttest
(
u
"我不喜欢日本和服。"
)
cuttest
(
u
"我不喜欢日本和服。"
)
cuttest
(
u
"雷猴回归人间。"
)
cuttest
(
u
"雷猴回归人间。"
)
cuttest
(
u
"工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作"
)
cuttest
(
u
"工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作"
)
cuttest
(
u
"我需要廉租房"
)
cuttest
(
u
"我需要廉租房"
)
cuttest
(
u
"永和服装饰品有限公司"
)
cuttest
(
u
"永和服装饰品有限公司"
)
cuttest
(
u
"我爱北京天安门"
)
cuttest
(
u
"我爱北京天安门"
)
cuttest
(
u
"abc"
)
cuttest
(
u
"abc"
)
cuttest
(
u
"隐马尔可夫"
)
cuttest
(
u
"隐马尔可夫"
)
cuttest
(
u
"雷猴是个好网站"
)
cuttest
(
u
"雷猴是个好网站"
)
cuttest
(
u
"“Microsoft”一词由“MICROcomputer(微型计算机)”和“SOFTware(软件)”两部分组成"
)
cuttest
(
u
"“Microsoft”一词由“MICROcomputer(微型计算机)”和“SOFTware(软件)”两部分组成"
)
cuttest
(
u
"草泥马和欺实马是今年的流行词汇"
)
cuttest
(
u
"草泥马和欺实马是今年的流行词汇"
)
cuttest
(
u
"伊藤洋华堂总府店"
)
cuttest
(
u
"伊藤洋华堂总府店"
)
cuttest
(
u
"中国科学院计算技术研究所"
)
cuttest
(
u
"中国科学院计算技术研究所"
)
cuttest
(
u
"罗密欧与朱丽叶"
)
cuttest
(
u
"罗密欧与朱丽叶"
)
cuttest
(
u
"我购买了道具和服装"
)
cuttest
(
u
"我购买了道具和服装"
)
cuttest
(
u
"PS: 我觉得开源有一个好处,就是能够敦促自己不断改进,避免敞帚自珍"
)
cuttest
(
u
"PS: 我觉得开源有一个好处,就是能够敦促自己不断改进,避免敞帚自珍"
)
cuttest
(
u
"湖北省石首市"
)
cuttest
(
u
"湖北省石首市"
)
cuttest
(
u
"总经理完成了这件事情"
)
cuttest
(
u
"总经理完成了这件事情"
)
cuttest
(
u
"电脑修好了"
)
cuttest
(
u
"电脑修好了"
)
cuttest
(
u
"做好了这件事情就一了百了了"
)
cuttest
(
u
"做好了这件事情就一了百了了"
)
cuttest
(
u
"人们审美的观点是不同的"
)
cuttest
(
u
"人们审美的观点是不同的"
)
cuttest
(
u
"我们买了一个美的空调"
)
cuttest
(
u
"我们买了一个美的空调"
)
cuttest
(
u
"线程初始化时我们要注意"
)
cuttest
(
u
"线程初始化时我们要注意"
)
cuttest
(
u
"一个分子是由好多原子组织成的"
)
cuttest
(
u
"一个分子是由好多原子组织成的"
)
cuttest
(
u
"祝你马到功成"
)
cuttest
(
u
"祝你马到功成"
)
cuttest
(
u
"他掉进了无底洞里"
)
cuttest
(
u
"他掉进了无底洞里"
)
cuttest
(
u
"中国的首都是北京"
)
cuttest
(
u
"中国的首都是北京"
)
cuttest
(
u
"孙君意"
)
cuttest
(
u
"孙君意"
)
cuttest
(
u
"外交部发言人马朝旭"
)
cuttest
(
u
"外交部发言人马朝旭"
)
cuttest
(
u
"领导人会议和第四届东亚峰会"
)
cuttest
(
u
"领导人会议和第四届东亚峰会"
)
cuttest
(
u
"在过去的这五年"
)
cuttest
(
u
"在过去的这五年"
)
cuttest
(
u
"还需要很长的路要走"
)
cuttest
(
u
"还需要很长的路要走"
)
cuttest
(
u
"60周年首都阅兵"
)
cuttest
(
u
"60周年首都阅兵"
)
cuttest
(
u
"你好人们审美的观点是不同的"
)
cuttest
(
u
"你好人们审美的观点是不同的"
)
cuttest
(
u
"买水果然后来世博园"
)
cuttest
(
u
"买水果然后来世博园"
)
cuttest
(
u
"买水果然后去世博园"
)
cuttest
(
u
"买水果然后去世博园"
)
cuttest
(
u
"但是后来我才知道你是对的"
)
cuttest
(
u
"但是后来我才知道你是对的"
)
cuttest
(
u
"存在即合理"
)
cuttest
(
u
"存在即合理"
)
cuttest
(
u
"的的的的的在的的的的就以和和和"
)
cuttest
(
u
"的的的的的在的的的的就以和和和"
)
cuttest
(
u
"I love你,不以为耻,反以为rong"
)
cuttest
(
u
"I love你,不以为耻,反以为rong"
)
cuttest
(
u
" "
)
cuttest
(
u
" "
)
cuttest
(
u
""
)
cuttest
(
u
""
)
cuttest
(
u
"hello你好人们审美的观点是不同的"
)
cuttest
(
u
"hello你好人们审美的观点是不同的"
)
cuttest
(
u
"很好但主要是基于网页形式"
)
cuttest
(
u
"很好但主要是基于网页形式"
)
cuttest
(
u
"hello你好人们审美的观点是不同的"
)
cuttest
(
u
"hello你好人们审美的观点是不同的"
)
cuttest
(
u
"为什么我不能拥有想要的生活"
)
cuttest
(
u
"为什么我不能拥有想要的生活"
)
cuttest
(
u
"后来我才"
)
cuttest
(
u
"后来我才"
)
cuttest
(
u
"此次来中国是为了"
)
cuttest
(
u
"此次来中国是为了"
)
cuttest
(
u
"使用了它就可以解决一些问题"
)
cuttest
(
u
"使用了它就可以解决一些问题"
)
cuttest
(
u
",使用了它就可以解决一些问题"
)
cuttest
(
u
",使用了它就可以解决一些问题"
)
cuttest
(
u
"其实使用了它就可以解决一些问题"
)
cuttest
(
u
"其实使用了它就可以解决一些问题"
)
cuttest
(
u
"好人使用了它就可以解决一些问题"
)
cuttest
(
u
"好人使用了它就可以解决一些问题"
)
cuttest
(
u
"是因为和国家"
)
cuttest
(
u
"是因为和国家"
)
cuttest
(
u
"老年搜索还支持"
)
cuttest
(
u
"老年搜索还支持"
)
cuttest
(
u
"干脆就把那部蒙人的闲法给废了拉倒!RT @laoshipukong : 27日,全国人大常委会第三次审议侵权责任法草案,删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 "
)
cuttest
(
u
"干脆就把那部蒙人的闲法给废了拉倒!RT @laoshipukong : 27日,全国人大常委会第三次审议侵权责任法草案,删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 "
)
cuttest
(
"2022年12月30日是星期几?"
)
cuttest
(
"2022年12月30日是星期几?"
)
cuttest
(
"二零二二年十二月三十日是星期几?"
)
cuttest
(
"二零二二年十二月三十日是星期几?"
)
\ No newline at end of file
speechx/examples/wfst/README.md
浏览文件 @
c7d9b115
...
@@ -183,4 +183,4 @@ data/
...
@@ -183,4 +183,4 @@ data/
├── lexiconp_disambig.txt
├── lexiconp_disambig.txt
├── lexiconp.txt
├── lexiconp.txt
└── units.list
└── units.list
```
```
\ No newline at end of file
utils/DER.py
浏览文件 @
c7d9b115
...
@@ -26,9 +26,9 @@ import argparse
...
@@ -26,9 +26,9 @@ import argparse
import
os
import
os
import
re
import
re
import
subprocess
import
subprocess
from
distutils.util
import
strtobool
import
numpy
as
np
import
numpy
as
np
from
distutils.util
import
strtobool
FILE_IDS
=
re
.
compile
(
r
"(?<=Speaker Diarization for).+(?=\*\*\*)"
)
FILE_IDS
=
re
.
compile
(
r
"(?<=Speaker Diarization for).+(?=\*\*\*)"
)
SCORED_SPEAKER_TIME
=
re
.
compile
(
r
"(?<=SCORED SPEAKER TIME =)[\d.]+"
)
SCORED_SPEAKER_TIME
=
re
.
compile
(
r
"(?<=SCORED SPEAKER TIME =)[\d.]+"
)
...
...
utils/compute-wer.py
浏览文件 @
c7d9b115
#!/usr/bin/env python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License
# CopyRight WeNet Apache-2.0 License
import
re
,
sys
,
unicodedata
import
codecs
import
codecs
import
re
import
sys
import
unicodedata
remove_tag
=
True
remove_tag
=
True
spacelist
=
[
' '
,
'
\t
'
,
'
\r
'
,
'
\n
'
]
spacelist
=
[
' '
,
'
\t
'
,
'
\r
'
,
'
\n
'
]
puncts
=
[
'!'
,
','
,
'?'
,
puncts
=
[
'、'
,
'。'
,
'!'
,
','
,
';'
,
'?'
,
'!'
,
','
,
'?'
,
'、'
,
'。'
,
'!'
,
','
,
';'
,
'?'
,
':'
,
'「'
,
'」'
,
'︰'
,
'『'
,
'』'
,
':'
,
'「'
,
'」'
,
'︰'
,
'『'
,
'』'
,
'《'
,
'》'
]
'《'
,
'》'
]
def
characterize
(
string
):
res
=
[]
i
=
0
while
i
<
len
(
string
):
char
=
string
[
i
]
if
char
in
puncts
:
i
+=
1
continue
cat1
=
unicodedata
.
category
(
char
)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if
cat1
==
'Zs'
or
cat1
==
'Cn'
or
char
in
spacelist
:
# space or not assigned
i
+=
1
continue
if
cat1
==
'Lo'
:
# letter-other
res
.
append
(
char
)
i
+=
1
else
:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep
=
' '
if
char
==
'<'
:
sep
=
'>'
j
=
i
+
1
while
j
<
len
(
string
):
c
=
string
[
j
]
if
ord
(
c
)
>=
128
or
(
c
in
spacelist
)
or
(
c
==
sep
):
break
j
+=
1
if
j
<
len
(
string
)
and
string
[
j
]
==
'>'
:
j
+=
1
res
.
append
(
string
[
i
:
j
])
i
=
j
return
res
def
characterize
(
string
)
:
res
=
[]
i
=
0
while
i
<
len
(
string
):
char
=
string
[
i
]
if
char
in
puncts
:
i
+=
1
continue
cat1
=
unicodedata
.
category
(
char
)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if
cat1
==
'Zs'
or
cat1
==
'Cn'
or
char
in
spacelist
:
# space or not assigned
i
+=
1
continue
if
cat1
==
'Lo'
:
# letter-other
res
.
append
(
char
)
i
+=
1
else
:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep
=
' '
if
char
==
'<'
:
sep
=
'>'
j
=
i
+
1
while
j
<
len
(
string
):
c
=
string
[
j
]
if
ord
(
c
)
>=
128
or
(
c
in
spacelist
)
or
(
c
==
sep
):
break
j
+=
1
if
j
<
len
(
string
)
and
string
[
j
]
==
'>'
:
j
+=
1
res
.
append
(
string
[
i
:
j
])
i
=
j
return
res
def
stripoff_tags
(
x
):
def
stripoff_tags
(
x
):
if
not
x
:
return
''
if
not
x
:
return
''
chars
=
[]
chars
=
[]
i
=
0
;
T
=
len
(
x
)
i
=
0
while
i
<
T
:
T
=
len
(
x
)
if
x
[
i
]
==
'<'
:
while
i
<
T
:
while
i
<
T
and
x
[
i
]
!=
'>'
:
if
x
[
i
]
==
'<'
:
i
+=
1
while
i
<
T
and
x
[
i
]
!=
'>'
:
i
+=
1
i
+=
1
else
:
i
+=
1
chars
.
append
(
x
[
i
])
else
:
i
+=
1
chars
.
append
(
x
[
i
])
return
''
.
join
(
chars
)
i
+=
1
return
''
.
join
(
chars
)
def
normalize
(
sentence
,
ignore_words
,
cs
,
split
=
None
):
def
normalize
(
sentence
,
ignore_words
,
cs
,
split
=
None
):
...
@@ -65,436 +70,485 @@ def normalize(sentence, ignore_words, cs, split=None):
...
@@ -65,436 +70,485 @@ def normalize(sentence, ignore_words, cs, split=None):
for
token
in
sentence
:
for
token
in
sentence
:
x
=
token
x
=
token
if
not
cs
:
if
not
cs
:
x
=
x
.
upper
()
x
=
x
.
upper
()
if
x
in
ignore_words
:
if
x
in
ignore_words
:
continue
continue
if
remove_tag
:
if
remove_tag
:
x
=
stripoff_tags
(
x
)
x
=
stripoff_tags
(
x
)
if
not
x
:
if
not
x
:
continue
continue
if
split
and
x
in
split
:
if
split
and
x
in
split
:
new_sentence
+=
split
[
x
]
new_sentence
+=
split
[
x
]
else
:
else
:
new_sentence
.
append
(
x
)
new_sentence
.
append
(
x
)
return
new_sentence
return
new_sentence
class
Calculator
:
def
__init__
(
self
)
:
class
Calculator
:
self
.
data
=
{}
def
__init__
(
self
):
self
.
space
=
[]
self
.
data
=
{}
self
.
cost
=
{}
self
.
space
=
[]
self
.
cost
[
'cor'
]
=
0
self
.
cost
=
{}
self
.
cost
[
'sub'
]
=
1
self
.
cost
[
'cor'
]
=
0
self
.
cost
[
'del'
]
=
1
self
.
cost
[
'sub'
]
=
1
self
.
cost
[
'ins'
]
=
1
self
.
cost
[
'del'
]
=
1
def
calculate
(
self
,
lab
,
rec
)
:
self
.
cost
[
'ins'
]
=
1
# Initialization
lab
.
insert
(
0
,
''
)
def
calculate
(
self
,
lab
,
rec
):
rec
.
insert
(
0
,
''
)
# Initialization
while
len
(
self
.
space
)
<
len
(
lab
)
:
lab
.
insert
(
0
,
''
)
self
.
space
.
append
([])
rec
.
insert
(
0
,
''
)
for
row
in
self
.
space
:
while
len
(
self
.
space
)
<
len
(
lab
):
for
element
in
row
:
self
.
space
.
append
([])
element
[
'dist'
]
=
0
for
row
in
self
.
space
:
element
[
'error'
]
=
'non'
for
element
in
row
:
while
len
(
row
)
<
len
(
rec
)
:
element
[
'dist'
]
=
0
row
.
append
({
'dist'
:
0
,
'error'
:
'non'
})
element
[
'error'
]
=
'non'
for
i
in
range
(
len
(
lab
))
:
while
len
(
row
)
<
len
(
rec
):
self
.
space
[
i
][
0
][
'dist'
]
=
i
row
.
append
({
'dist'
:
0
,
'error'
:
'non'
})
self
.
space
[
i
][
0
][
'error'
]
=
'del'
for
i
in
range
(
len
(
lab
)):
for
j
in
range
(
len
(
rec
))
:
self
.
space
[
i
][
0
][
'dist'
]
=
i
self
.
space
[
0
][
j
][
'dist'
]
=
j
self
.
space
[
i
][
0
][
'error'
]
=
'del'
self
.
space
[
0
][
j
][
'error'
]
=
'ins'
for
j
in
range
(
len
(
rec
)):
self
.
space
[
0
][
0
][
'error'
]
=
'non'
self
.
space
[
0
][
j
][
'dist'
]
=
j
for
token
in
lab
:
self
.
space
[
0
][
j
][
'error'
]
=
'ins'
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
self
.
space
[
0
][
0
][
'error'
]
=
'non'
self
.
data
[
token
]
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
for
token
in
lab
:
for
token
in
rec
:
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
self
.
data
[
token
]
=
{
self
.
data
[
token
]
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
'all'
:
0
,
# Computing edit distance
'cor'
:
0
,
for
i
,
lab_token
in
enumerate
(
lab
)
:
'sub'
:
0
,
for
j
,
rec_token
in
enumerate
(
rec
)
:
'ins'
:
0
,
if
i
==
0
or
j
==
0
:
'del'
:
0
continue
}
min_dist
=
sys
.
maxsize
for
token
in
rec
:
min_error
=
'none'
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
dist
=
self
.
space
[
i
-
1
][
j
][
'dist'
]
+
self
.
cost
[
'del'
]
self
.
data
[
token
]
=
{
error
=
'del'
'all'
:
0
,
if
dist
<
min_dist
:
'cor'
:
0
,
min_dist
=
dist
'sub'
:
0
,
min_error
=
error
'ins'
:
0
,
dist
=
self
.
space
[
i
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'ins'
]
'del'
:
0
error
=
'ins'
}
if
dist
<
min_dist
:
# Computing edit distance
min_dist
=
dist
for
i
,
lab_token
in
enumerate
(
lab
):
min_error
=
error
for
j
,
rec_token
in
enumerate
(
rec
):
if
lab_token
==
rec_token
:
if
i
==
0
or
j
==
0
:
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'cor'
]
continue
error
=
'cor'
min_dist
=
sys
.
maxsize
else
:
min_error
=
'none'
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'sub'
]
dist
=
self
.
space
[
i
-
1
][
j
][
'dist'
]
+
self
.
cost
[
'del'
]
error
=
'sub'
error
=
'del'
if
dist
<
min_dist
:
if
dist
<
min_dist
:
min_dist
=
dist
min_dist
=
dist
min_error
=
error
min_error
=
error
self
.
space
[
i
][
j
][
'dist'
]
=
min_dist
dist
=
self
.
space
[
i
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'ins'
]
self
.
space
[
i
][
j
][
'error'
]
=
min_error
error
=
'ins'
# Tracing back
if
dist
<
min_dist
:
result
=
{
'lab'
:[],
'rec'
:[],
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
min_dist
=
dist
i
=
len
(
lab
)
-
1
min_error
=
error
j
=
len
(
rec
)
-
1
if
lab_token
==
rec_token
:
while
True
:
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'cor'
]
if
self
.
space
[
i
][
j
][
'error'
]
==
'cor'
:
# correct
error
=
'cor'
if
len
(
lab
[
i
])
>
0
:
else
:
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'sub'
]
self
.
data
[
lab
[
i
]][
'cor'
]
=
self
.
data
[
lab
[
i
]][
'cor'
]
+
1
error
=
'sub'
result
[
'all'
]
=
result
[
'all'
]
+
1
if
dist
<
min_dist
:
result
[
'cor'
]
=
result
[
'cor'
]
+
1
min_dist
=
dist
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
min_error
=
error
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
self
.
space
[
i
][
j
][
'dist'
]
=
min_dist
i
=
i
-
1
self
.
space
[
i
][
j
][
'error'
]
=
min_error
j
=
j
-
1
# Tracing back
elif
self
.
space
[
i
][
j
][
'error'
]
==
'sub'
:
# substitution
result
=
{
if
len
(
lab
[
i
])
>
0
:
'lab'
:
[],
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
'rec'
:
[],
self
.
data
[
lab
[
i
]][
'sub'
]
=
self
.
data
[
lab
[
i
]][
'sub'
]
+
1
'all'
:
0
,
result
[
'all'
]
=
result
[
'all'
]
+
1
'cor'
:
0
,
result
[
'sub'
]
=
result
[
'sub'
]
+
1
'sub'
:
0
,
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
'ins'
:
0
,
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
'del'
:
0
i
=
i
-
1
}
j
=
j
-
1
i
=
len
(
lab
)
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'del'
:
# deletion
j
=
len
(
rec
)
-
1
if
len
(
lab
[
i
])
>
0
:
while
True
:
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
if
self
.
space
[
i
][
j
][
'error'
]
==
'cor'
:
# correct
self
.
data
[
lab
[
i
]][
'del'
]
=
self
.
data
[
lab
[
i
]][
'del'
]
+
1
if
len
(
lab
[
i
])
>
0
:
result
[
'all'
]
=
result
[
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
result
[
'del'
]
=
result
[
'del'
]
+
1
self
.
data
[
lab
[
i
]][
'cor'
]
=
self
.
data
[
lab
[
i
]][
'cor'
]
+
1
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
result
[
'all'
]
=
result
[
'all'
]
+
1
result
[
'rec'
].
insert
(
0
,
""
)
result
[
'cor'
]
=
result
[
'cor'
]
+
1
i
=
i
-
1
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
elif
self
.
space
[
i
][
j
][
'error'
]
==
'ins'
:
# insertion
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
if
len
(
rec
[
j
])
>
0
:
i
=
i
-
1
self
.
data
[
rec
[
j
]][
'ins'
]
=
self
.
data
[
rec
[
j
]][
'ins'
]
+
1
j
=
j
-
1
result
[
'ins'
]
=
result
[
'ins'
]
+
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'sub'
:
# substitution
result
[
'lab'
].
insert
(
0
,
""
)
if
len
(
lab
[
i
])
>
0
:
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
j
=
j
-
1
self
.
data
[
lab
[
i
]][
'sub'
]
=
self
.
data
[
lab
[
i
]][
'sub'
]
+
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'non'
:
# starting point
result
[
'all'
]
=
result
[
'all'
]
+
1
break
result
[
'sub'
]
=
result
[
'sub'
]
+
1
else
:
# shouldn't reach here
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
print
(
'this should not happen , i = {i} , j = {j} , error = {error}'
.
format
(
i
=
i
,
j
=
j
,
error
=
self
.
space
[
i
][
j
][
'error'
]))
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
return
result
i
=
i
-
1
def
overall
(
self
)
:
j
=
j
-
1
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
elif
self
.
space
[
i
][
j
][
'error'
]
==
'del'
:
# deletion
for
token
in
self
.
data
:
if
len
(
lab
[
i
])
>
0
:
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
self
.
data
[
lab
[
i
]][
'del'
]
=
self
.
data
[
lab
[
i
]][
'del'
]
+
1
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'all'
]
=
result
[
'all'
]
+
1
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'del'
]
=
result
[
'del'
]
+
1
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
return
result
result
[
'rec'
].
insert
(
0
,
""
)
def
cluster
(
self
,
data
)
:
i
=
i
-
1
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
elif
self
.
space
[
i
][
j
][
'error'
]
==
'ins'
:
# insertion
for
token
in
data
:
if
len
(
rec
[
j
])
>
0
:
if
token
in
self
.
data
:
self
.
data
[
rec
[
j
]][
'ins'
]
=
self
.
data
[
rec
[
j
]][
'ins'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
1
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'lab'
].
insert
(
0
,
""
)
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
j
=
j
-
1
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
elif
self
.
space
[
i
][
j
][
'error'
]
==
'non'
:
# starting point
return
result
break
def
keys
(
self
)
:
else
:
# shouldn't reach here
return
list
(
self
.
data
.
keys
())
print
(
'this should not happen , i = {i} , j = {j} , error = {error}'
.
format
(
i
=
i
,
j
=
j
,
error
=
self
.
space
[
i
][
j
][
'error'
]))
return
result
def
overall
(
self
):
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
for
token
in
self
.
data
:
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
return
result
def
cluster
(
self
,
data
):
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
for
token
in
data
:
if
token
in
self
.
data
:
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
return
result
def
keys
(
self
):
return
list
(
self
.
data
.
keys
())
def
width
(
string
):
def
width
(
string
):
return
sum
(
1
+
(
unicodedata
.
east_asian_width
(
c
)
in
"AFW"
)
for
c
in
string
)
return
sum
(
1
+
(
unicodedata
.
east_asian_width
(
c
)
in
"AFW"
)
for
c
in
string
)
def
default_cluster
(
word
)
:
unicode_names
=
[
unicodedata
.
name
(
char
)
for
char
in
word
]
for
i
in
reversed
(
range
(
len
(
unicode_names
)))
:
if
unicode_names
[
i
].
startswith
(
'DIGIT'
)
:
# 1
unicode_names
[
i
]
=
'Number'
# 'DIGIT'
elif
(
unicode_names
[
i
].
startswith
(
'CJK UNIFIED IDEOGRAPH'
)
or
unicode_names
[
i
].
startswith
(
'CJK COMPATIBILITY IDEOGRAPH'
))
:
# 明 / 郎
unicode_names
[
i
]
=
'Mandarin'
# 'CJK IDEOGRAPH'
elif
(
unicode_names
[
i
].
startswith
(
'LATIN CAPITAL LETTER'
)
or
unicode_names
[
i
].
startswith
(
'LATIN SMALL LETTER'
))
:
# A / a
unicode_names
[
i
]
=
'English'
# 'LATIN LETTER'
elif
unicode_names
[
i
].
startswith
(
'HIRAGANA LETTER'
)
:
# は こ め
unicode_names
[
i
]
=
'Japanese'
# 'GANA LETTER'
elif
(
unicode_names
[
i
].
startswith
(
'AMPERSAND'
)
or
unicode_names
[
i
].
startswith
(
'APOSTROPHE'
)
or
unicode_names
[
i
].
startswith
(
'COMMERCIAL AT'
)
or
unicode_names
[
i
].
startswith
(
'DEGREE CELSIUS'
)
or
unicode_names
[
i
].
startswith
(
'EQUALS SIGN'
)
or
unicode_names
[
i
].
startswith
(
'FULL STOP'
)
or
unicode_names
[
i
].
startswith
(
'HYPHEN-MINUS'
)
or
unicode_names
[
i
].
startswith
(
'LOW LINE'
)
or
unicode_names
[
i
].
startswith
(
'NUMBER SIGN'
)
or
unicode_names
[
i
].
startswith
(
'PLUS SIGN'
)
or
unicode_names
[
i
].
startswith
(
'SEMICOLON'
))
:
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del
unicode_names
[
i
]
else
:
return
'Other'
if
len
(
unicode_names
)
==
0
:
return
'Other'
if
len
(
unicode_names
)
==
1
:
return
unicode_names
[
0
]
for
i
in
range
(
len
(
unicode_names
)
-
1
)
:
if
unicode_names
[
i
]
!=
unicode_names
[
i
+
1
]
:
return
'Other'
return
unicode_names
[
0
]
def
usage
()
:
def
default_cluster
(
word
):
print
(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
unicode_names
=
[
unicodedata
.
name
(
char
)
for
char
in
word
]
print
(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
for
i
in
reversed
(
range
(
len
(
unicode_names
))):
if
unicode_names
[
i
].
startswith
(
'DIGIT'
):
# 1
unicode_names
[
i
]
=
'Number'
# 'DIGIT'
elif
(
unicode_names
[
i
].
startswith
(
'CJK UNIFIED IDEOGRAPH'
)
or
unicode_names
[
i
].
startswith
(
'CJK COMPATIBILITY IDEOGRAPH'
)):
# 明 / 郎
unicode_names
[
i
]
=
'Mandarin'
# 'CJK IDEOGRAPH'
elif
(
unicode_names
[
i
].
startswith
(
'LATIN CAPITAL LETTER'
)
or
unicode_names
[
i
].
startswith
(
'LATIN SMALL LETTER'
)):
# A / a
unicode_names
[
i
]
=
'English'
# 'LATIN LETTER'
elif
unicode_names
[
i
].
startswith
(
'HIRAGANA LETTER'
):
# は こ め
unicode_names
[
i
]
=
'Japanese'
# 'GANA LETTER'
elif
(
unicode_names
[
i
].
startswith
(
'AMPERSAND'
)
or
unicode_names
[
i
].
startswith
(
'APOSTROPHE'
)
or
unicode_names
[
i
].
startswith
(
'COMMERCIAL AT'
)
or
unicode_names
[
i
].
startswith
(
'DEGREE CELSIUS'
)
or
unicode_names
[
i
].
startswith
(
'EQUALS SIGN'
)
or
unicode_names
[
i
].
startswith
(
'FULL STOP'
)
or
unicode_names
[
i
].
startswith
(
'HYPHEN-MINUS'
)
or
unicode_names
[
i
].
startswith
(
'LOW LINE'
)
or
unicode_names
[
i
].
startswith
(
'NUMBER SIGN'
)
or
unicode_names
[
i
].
startswith
(
'PLUS SIGN'
)
or
unicode_names
[
i
].
startswith
(
'SEMICOLON'
)):
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del
unicode_names
[
i
]
else
:
return
'Other'
if
len
(
unicode_names
)
==
0
:
return
'Other'
if
len
(
unicode_names
)
==
1
:
return
unicode_names
[
0
]
for
i
in
range
(
len
(
unicode_names
)
-
1
):
if
unicode_names
[
i
]
!=
unicode_names
[
i
+
1
]:
return
'Other'
return
unicode_names
[
0
]
def
usage
():
print
(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print
(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
==
1
:
if
len
(
sys
.
argv
)
==
1
:
usage
()
usage
()
sys
.
exit
(
0
)
sys
.
exit
(
0
)
calculator
=
Calculator
()
calculator
=
Calculator
()
cluster_file
=
''
cluster_file
=
''
ignore_words
=
set
()
ignore_words
=
set
()
tochar
=
False
tochar
=
False
verbose
=
1
verbose
=
1
padding_symbol
=
' '
padding_symbol
=
' '
case_sensitive
=
False
case_sensitive
=
False
max_words_per_line
=
sys
.
maxsize
max_words_per_line
=
sys
.
maxsize
split
=
None
split
=
None
while
len
(
sys
.
argv
)
>
3
:
while
len
(
sys
.
argv
)
>
3
:
a
=
'--maxw='
a
=
'--maxw='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
b
=
sys
.
argv
[
1
][
len
(
a
):]
b
=
sys
.
argv
[
1
][
len
(
a
):]
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
max_words_per_line
=
int
(
b
)
max_words_per_line
=
int
(
b
)
continue
continue
a
=
'--rt='
a
=
'--rt='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
remove_tag
=
(
b
==
'true'
)
or
(
b
!=
'0'
)
remove_tag
=
(
b
==
'true'
)
or
(
b
!=
'0'
)
continue
continue
a
=
'--cs='
a
=
'--cs='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
case_sensitive
=
(
b
==
'true'
)
or
(
b
!=
'0'
)
case_sensitive
=
(
b
==
'true'
)
or
(
b
!=
'0'
)
continue
continue
a
=
'--cluster='
a
=
'--cluster='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
cluster_file
=
sys
.
argv
[
1
][
len
(
a
):]
cluster_file
=
sys
.
argv
[
1
][
len
(
a
):]
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
continue
continue
a
=
'--splitfile='
a
=
'--splitfile='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
split_file
=
sys
.
argv
[
1
][
len
(
a
):]
split_file
=
sys
.
argv
[
1
][
len
(
a
):]
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
split
=
dict
()
split
=
dict
()
with
codecs
.
open
(
split_file
,
'r'
,
'utf-8'
)
as
fh
:
with
codecs
.
open
(
split_file
,
'r'
,
'utf-8'
)
as
fh
:
for
line
in
fh
:
# line in unicode
for
line
in
fh
:
# line in unicode
words
=
line
.
strip
().
split
()
words
=
line
.
strip
().
split
()
if
len
(
words
)
>=
2
:
if
len
(
words
)
>=
2
:
split
[
words
[
0
]]
=
words
[
1
:]
split
[
words
[
0
]]
=
words
[
1
:]
continue
continue
a
=
'--ig='
a
=
'--ig='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
ignore_file
=
sys
.
argv
[
1
][
len
(
a
):]
ignore_file
=
sys
.
argv
[
1
][
len
(
a
):]
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
with
codecs
.
open
(
ignore_file
,
'r'
,
'utf-8'
)
as
fh
:
with
codecs
.
open
(
ignore_file
,
'r'
,
'utf-8'
)
as
fh
:
for
line
in
fh
:
# line in unicode
for
line
in
fh
:
# line in unicode
line
=
line
.
strip
()
line
=
line
.
strip
()
if
len
(
line
)
>
0
:
if
len
(
line
)
>
0
:
ignore_words
.
add
(
line
)
ignore_words
.
add
(
line
)
continue
continue
a
=
'--char='
a
=
'--char='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
tochar
=
(
b
==
'true'
)
or
(
b
!=
'0'
)
tochar
=
(
b
==
'true'
)
or
(
b
!=
'0'
)
continue
continue
a
=
'--v='
a
=
'--v='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
verbose
=
0
verbose
=
0
try
:
try
:
verbose
=
int
(
b
)
verbose
=
int
(
b
)
except
:
except
:
if
b
==
'true'
or
b
!=
'0'
:
if
b
==
'true'
or
b
!=
'0'
:
verbose
=
1
verbose
=
1
continue
continue
a
=
'--padding-symbol='
a
=
'--padding-symbol='
if
sys
.
argv
[
1
].
startswith
(
a
):
if
sys
.
argv
[
1
].
startswith
(
a
):
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
b
=
sys
.
argv
[
1
][
len
(
a
):].
lower
()
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
if
b
==
'space'
:
if
b
==
'space'
:
padding_symbol
=
' '
padding_symbol
=
' '
elif
b
==
'underline'
:
elif
b
==
'underline'
:
padding_symbol
=
'_'
padding_symbol
=
'_'
continue
continue
if
True
or
sys
.
argv
[
1
].
startswith
(
'-'
):
if
True
or
sys
.
argv
[
1
].
startswith
(
'-'
):
#ignore invalid switch
#ignore invalid switch
del
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
continue
continue
if
not
case_sensitive
:
if
not
case_sensitive
:
ig
=
set
([
w
.
upper
()
for
w
in
ignore_words
])
ig
=
set
([
w
.
upper
()
for
w
in
ignore_words
])
ignore_words
=
ig
ignore_words
=
ig
default_clusters
=
{}
default_clusters
=
{}
default_words
=
{}
default_words
=
{}
ref_file
=
sys
.
argv
[
1
]
ref_file
=
sys
.
argv
[
1
]
hyp_file
=
sys
.
argv
[
2
]
hyp_file
=
sys
.
argv
[
2
]
rec_set
=
{}
rec_set
=
{}
if
split
and
not
case_sensitive
:
if
split
and
not
case_sensitive
:
newsplit
=
dict
()
newsplit
=
dict
()
for
w
in
split
:
for
w
in
split
:
words
=
split
[
w
]
words
=
split
[
w
]
for
i
in
range
(
len
(
words
)):
for
i
in
range
(
len
(
words
)):
words
[
i
]
=
words
[
i
].
upper
()
words
[
i
]
=
words
[
i
].
upper
()
newsplit
[
w
.
upper
()]
=
words
newsplit
[
w
.
upper
()]
=
words
split
=
newsplit
split
=
newsplit
with
codecs
.
open
(
hyp_file
,
'r'
,
'utf-8'
)
as
fh
:
with
codecs
.
open
(
hyp_file
,
'r'
,
'utf-8'
)
as
fh
:
for
line
in
fh
:
for
line
in
fh
:
if
tochar
:
array
=
characterize
(
line
)
else
:
array
=
line
.
strip
().
split
()
if
len
(
array
)
==
0
:
continue
fid
=
array
[
0
]
rec_set
[
fid
]
=
normalize
(
array
[
1
:],
ignore_words
,
case_sensitive
,
split
)
# compute error rate on the interaction of reference file and hyp file
for
line
in
open
(
ref_file
,
'r'
,
encoding
=
'utf-8'
):
if
tochar
:
if
tochar
:
array
=
characterize
(
line
)
array
=
characterize
(
line
)
else
:
else
:
array
=
line
.
strip
(
).
split
()
array
=
line
.
rstrip
(
'
\n
'
).
split
()
if
len
(
array
)
==
0
:
continue
if
len
(
array
)
==
0
:
continue
fid
=
array
[
0
]
fid
=
array
[
0
]
rec_set
[
fid
]
=
normalize
(
array
[
1
:],
ignore_words
,
case_sensitive
,
split
)
if
fid
not
in
rec_set
:
continue
lab
=
normalize
(
array
[
1
:],
ignore_words
,
case_sensitive
,
split
)
rec
=
rec_set
[
fid
]
if
verbose
:
print
(
'
\n
utt: %s'
%
fid
)
# compute error rate on the interaction of reference file and hyp file
for
word
in
rec
+
lab
:
for
line
in
open
(
ref_file
,
'r'
,
encoding
=
'utf-8'
)
:
if
word
not
in
default_words
:
if
tochar
:
default_cluster_name
=
default_cluster
(
word
)
array
=
characterize
(
line
)
if
default_cluster_name
not
in
default_clusters
:
else
:
default_clusters
[
default_cluster_name
]
=
{}
array
=
line
.
rstrip
(
'
\n
'
).
split
()
if
word
not
in
default_clusters
[
default_cluster_name
]:
if
len
(
array
)
==
0
:
continue
default_clusters
[
default_cluster_name
][
word
]
=
1
fid
=
array
[
0
]
default_words
[
word
]
=
default_cluster_name
if
fid
not
in
rec_set
:
continue
lab
=
normalize
(
array
[
1
:],
ignore_words
,
case_sensitive
,
split
)
rec
=
rec_set
[
fid
]
if
verbose
:
print
(
'
\n
utt: %s'
%
fid
)
for
word
in
rec
+
lab
:
result
=
calculator
.
calculate
(
lab
,
rec
)
if
word
not
in
default_words
:
if
verbose
:
default_cluster_name
=
default_cluster
(
word
)
if
result
[
'all'
]
!=
0
:
if
default_cluster_name
not
in
default_clusters
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
default_clusters
[
default_cluster_name
]
=
{}
'del'
])
*
100.0
/
result
[
'all'
]
if
word
not
in
default_clusters
[
default_cluster_name
]
:
else
:
default_clusters
[
default_cluster_name
][
word
]
=
1
wer
=
0.0
default_words
[
word
]
=
default_cluster_name
print
(
'WER: %4.2f %%'
%
wer
,
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
space
=
{}
space
[
'lab'
]
=
[]
space
[
'rec'
]
=
[]
for
idx
in
range
(
len
(
result
[
'lab'
])):
len_lab
=
width
(
result
[
'lab'
][
idx
])
len_rec
=
width
(
result
[
'rec'
][
idx
])
length
=
max
(
len_lab
,
len_rec
)
space
[
'lab'
].
append
(
length
-
len_lab
)
space
[
'rec'
].
append
(
length
-
len_rec
)
upper_lab
=
len
(
result
[
'lab'
])
upper_rec
=
len
(
result
[
'rec'
])
lab1
,
rec1
=
0
,
0
while
lab1
<
upper_lab
or
rec1
<
upper_rec
:
if
verbose
>
1
:
print
(
'lab(%s):'
%
fid
.
encode
(
'utf-8'
),
end
=
' '
)
else
:
print
(
'lab:'
,
end
=
' '
)
lab2
=
min
(
upper_lab
,
lab1
+
max_words_per_line
)
for
idx
in
range
(
lab1
,
lab2
):
token
=
result
[
'lab'
][
idx
]
print
(
'{token}'
.
format
(
token
=
token
),
end
=
''
)
for
n
in
range
(
space
[
'lab'
][
idx
]):
print
(
padding_symbol
,
end
=
''
)
print
(
' '
,
end
=
''
)
print
()
if
verbose
>
1
:
print
(
'rec(%s):'
%
fid
.
encode
(
'utf-8'
),
end
=
' '
)
else
:
print
(
'rec:'
,
end
=
' '
)
rec2
=
min
(
upper_rec
,
rec1
+
max_words_per_line
)
for
idx
in
range
(
rec1
,
rec2
):
token
=
result
[
'rec'
][
idx
]
print
(
'{token}'
.
format
(
token
=
token
),
end
=
''
)
for
n
in
range
(
space
[
'rec'
][
idx
]):
print
(
padding_symbol
,
end
=
''
)
print
(
' '
,
end
=
''
)
print
(
'
\n
'
,
end
=
'
\n
'
)
lab1
=
lab2
rec1
=
rec2
result
=
calculator
.
calculate
(
lab
,
rec
)
if
verbose
:
if
verbose
:
if
result
[
'all'
]
!=
0
:
print
(
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
'==========================================================================='
else
:
)
wer
=
0.0
print
()
print
(
'WER: %4.2f %%'
%
wer
,
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
space
=
{}
space
[
'lab'
]
=
[]
space
[
'rec'
]
=
[]
for
idx
in
range
(
len
(
result
[
'lab'
]))
:
len_lab
=
width
(
result
[
'lab'
][
idx
])
len_rec
=
width
(
result
[
'rec'
][
idx
])
length
=
max
(
len_lab
,
len_rec
)
space
[
'lab'
].
append
(
length
-
len_lab
)
space
[
'rec'
].
append
(
length
-
len_rec
)
upper_lab
=
len
(
result
[
'lab'
])
upper_rec
=
len
(
result
[
'rec'
])
lab1
,
rec1
=
0
,
0
while
lab1
<
upper_lab
or
rec1
<
upper_rec
:
if
verbose
>
1
:
print
(
'lab(%s):'
%
fid
.
encode
(
'utf-8'
),
end
=
' '
)
else
:
print
(
'lab:'
,
end
=
' '
)
lab2
=
min
(
upper_lab
,
lab1
+
max_words_per_line
)
for
idx
in
range
(
lab1
,
lab2
):
token
=
result
[
'lab'
][
idx
]
print
(
'{token}'
.
format
(
token
=
token
),
end
=
''
)
for
n
in
range
(
space
[
'lab'
][
idx
])
:
print
(
padding_symbol
,
end
=
''
)
print
(
' '
,
end
=
''
)
print
()
if
verbose
>
1
:
print
(
'rec(%s):'
%
fid
.
encode
(
'utf-8'
),
end
=
' '
)
else
:
print
(
'rec:'
,
end
=
' '
)
rec2
=
min
(
upper_rec
,
rec1
+
max_words_per_line
)
for
idx
in
range
(
rec1
,
rec2
):
token
=
result
[
'rec'
][
idx
]
print
(
'{token}'
.
format
(
token
=
token
),
end
=
''
)
for
n
in
range
(
space
[
'rec'
][
idx
])
:
print
(
padding_symbol
,
end
=
''
)
print
(
' '
,
end
=
''
)
print
(
'
\n
'
,
end
=
'
\n
'
)
lab1
=
lab2
rec1
=
rec2
if
verbose
:
print
(
'==========================================================================='
)
print
()
result
=
calculator
.
overall
()
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
else
:
wer
=
0.0
print
(
'Overall -> %4.2f %%'
%
wer
,
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
if
not
verbose
:
print
()
if
verbose
:
result
=
calculator
.
overall
()
for
cluster_id
in
default_clusters
:
if
result
[
'all'
]
!=
0
:
result
=
calculator
.
cluster
([
k
for
k
in
default_clusters
[
cluster_id
]
])
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
if
result
[
'all'
]
!=
0
:
'del'
])
*
100.0
/
result
[
'all'
]
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
else
:
else
:
wer
=
0.0
wer
=
0.0
print
(
'%s -> %4.2f %%'
%
(
cluster_id
,
wer
),
end
=
' '
)
print
(
'Overall -> %4.2f %%'
%
wer
,
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
if
len
(
cluster_file
)
>
0
:
# compute separated WERs for word clusters
result
[
'ins'
]))
cluster_id
=
''
if
not
verbose
:
cluster
=
[]
print
()
for
line
in
open
(
cluster_file
,
'r'
,
encoding
=
'utf-8'
)
:
for
token
in
line
.
decode
(
'utf-8'
).
rstrip
(
'
\n
'
).
split
()
:
if
verbose
:
# end of cluster reached, like </Keyword>
for
cluster_id
in
default_clusters
:
if
token
[
0
:
2
]
==
'</'
and
token
[
len
(
token
)
-
1
]
==
'>'
and
\
result
=
calculator
.
cluster
(
token
.
lstrip
(
'</'
).
rstrip
(
'>'
)
==
cluster_id
:
[
k
for
k
in
default_clusters
[
cluster_id
]])
result
=
calculator
.
cluster
(
cluster
)
if
result
[
'all'
]
!=
0
:
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
'del'
])
*
100.0
/
result
[
'all'
]
else
:
else
:
wer
=
0.0
wer
=
0.0
print
(
'%s -> %4.2f %%'
%
(
cluster_id
,
wer
),
end
=
' '
)
print
(
'%s -> %4.2f %%'
%
(
cluster_id
,
wer
),
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
cluster_id
=
''
result
[
'ins'
]))
cluster
=
[]
if
len
(
cluster_file
)
>
0
:
# compute separated WERs for word clusters
# begin of cluster reached, like <Keyword>
cluster_id
=
''
elif
token
[
0
]
==
'<'
and
token
[
len
(
token
)
-
1
]
==
'>'
and
\
cluster
=
[]
cluster_id
==
''
:
for
line
in
open
(
cluster_file
,
'r'
,
encoding
=
'utf-8'
):
cluster_id
=
token
.
lstrip
(
'<'
).
rstrip
(
'>'
)
for
token
in
line
.
decode
(
'utf-8'
).
rstrip
(
'
\n
'
).
split
():
cluster
=
[]
# end of cluster reached, like </Keyword>
# general terms, like WEATHER / CAR / ...
if
token
[
0
:
2
]
==
'</'
and
token
[
len
(
token
)
-
1
]
==
'>'
and
\
else
:
token
.
lstrip
(
'</'
).
rstrip
(
'>'
)
==
cluster_id
:
cluster
.
append
(
token
)
result
=
calculator
.
cluster
(
cluster
)
print
()
if
result
[
'all'
]
!=
0
:
print
(
'==========================================================================='
)
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
\ No newline at end of file
'del'
])
*
100.0
/
result
[
'all'
]
else
:
wer
=
0.0
print
(
'%s -> %4.2f %%'
%
(
cluster_id
,
wer
),
end
=
' '
)
print
(
'N=%d C=%d S=%d D=%d I=%d'
%
(
result
[
'all'
],
result
[
'cor'
],
result
[
'sub'
],
result
[
'del'
],
result
[
'ins'
]))
cluster_id
=
''
cluster
=
[]
# begin of cluster reached, like <Keyword>
elif
token
[
0
]
==
'<'
and
token
[
len
(
token
)
-
1
]
==
'>'
and
\
cluster_id
==
''
:
cluster_id
=
token
.
lstrip
(
'<'
).
rstrip
(
'>'
)
cluster
=
[]
# general terms, like WEATHER / CAR / ...
else
:
cluster
.
append
(
token
)
print
()
print
(
'==========================================================================='
)
utils/format_rsl.py
浏览文件 @
c7d9b115
import
os
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
import
jsonlines
import
jsonlines
def
trans_hyp
(
origin_hyp
,
def
trans_hyp
(
origin_hyp
,
trans_hyp
=
None
,
trans_hyp_sclite
=
None
):
trans_hyp
=
None
,
trans_hyp_sclite
=
None
):
"""
"""
Args:
Args:
origin_hyp: The input json file which contains the model output
origin_hyp: The input json file which contains the model output
...
@@ -17,19 +27,18 @@ def trans_hyp(origin_hyp,
...
@@ -17,19 +27,18 @@ def trans_hyp(origin_hyp,
with
open
(
origin_hyp
,
"r+"
,
encoding
=
"utf8"
)
as
f
:
with
open
(
origin_hyp
,
"r+"
,
encoding
=
"utf8"
)
as
f
:
for
item
in
jsonlines
.
Reader
(
f
):
for
item
in
jsonlines
.
Reader
(
f
):
input_dict
[
item
[
"utt"
]]
=
item
[
"hyps"
][
0
]
input_dict
[
item
[
"utt"
]]
=
item
[
"hyps"
][
0
]
if
trans_hyp
is
not
None
:
if
trans_hyp
is
not
None
:
with
open
(
trans_hyp
,
"w+"
,
encoding
=
"utf8"
)
as
f
:
with
open
(
trans_hyp
,
"w+"
,
encoding
=
"utf8"
)
as
f
:
for
key
in
input_dict
.
keys
():
for
key
in
input_dict
.
keys
():
f
.
write
(
key
+
" "
+
input_dict
[
key
]
+
"
\n
"
)
f
.
write
(
key
+
" "
+
input_dict
[
key
]
+
"
\n
"
)
if
trans_hyp_sclite
is
not
None
:
if
trans_hyp_sclite
is
not
None
:
with
open
(
trans_hyp_sclite
,
"w+"
)
as
f
:
with
open
(
trans_hyp_sclite
,
"w+"
)
as
f
:
for
key
in
input_dict
.
keys
():
for
key
in
input_dict
.
keys
():
line
=
input_dict
[
key
]
+
"("
+
key
+
".wav"
+
")"
+
"
\n
"
line
=
input_dict
[
key
]
+
"("
+
key
+
".wav"
+
")"
+
"
\n
"
f
.
write
(
line
)
f
.
write
(
line
)
def
trans_ref
(
origin_ref
,
trans_ref
=
None
,
def
trans_ref
(
origin_ref
,
trans_ref
=
None
,
trans_ref_sclite
=
None
):
trans_ref_sclite
=
None
):
"""
"""
Args:
Args:
origin_hyp: The input json file which contains the model output
origin_hyp: The input json file which contains the model output
...
@@ -49,42 +58,48 @@ def trans_ref(origin_ref,
...
@@ -49,42 +58,48 @@ def trans_ref(origin_ref,
if
trans_ref_sclite
is
not
None
:
if
trans_ref_sclite
is
not
None
:
with
open
(
trans_ref_sclite
,
"w"
)
as
f
:
with
open
(
trans_ref_sclite
,
"w"
)
as
f
:
for
key
in
input_dict
.
keys
():
for
key
in
input_dict
.
keys
():
line
=
input_dict
[
key
]
+
"("
+
key
+
".wav"
+
")"
+
"
\n
"
line
=
input_dict
[
key
]
+
"("
+
key
+
".wav"
+
")"
+
"
\n
"
f
.
write
(
line
)
f
.
write
(
line
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
'format hyp file for compute CER/WER'
,
add_help
=
True
)
parser
=
argparse
.
ArgumentParser
(
prog
=
'format hyp file for compute CER/WER'
,
add_help
=
True
)
parser
.
add_argument
(
parser
.
add_argument
(
'--origin_hyp'
,
'--origin_hyp'
,
type
=
str
,
default
=
None
,
help
=
'origin hyp file'
)
type
=
str
,
default
=
None
,
help
=
'origin hyp file'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--trans_hyp'
,
type
=
str
,
default
=
None
,
help
=
'hyp file for caculating CER/WER'
)
'--trans_hyp'
,
type
=
str
,
default
=
None
,
help
=
'hyp file for caculating CER/WER'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--trans_hyp_sclite'
,
type
=
str
,
default
=
None
,
help
=
'hyp file for caculating CER/WER by sclite'
)
'--trans_hyp_sclite'
,
type
=
str
,
default
=
None
,
help
=
'hyp file for caculating CER/WER by sclite'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--origin_ref'
,
'--origin_ref'
,
type
=
str
,
default
=
None
,
help
=
'origin ref file'
)
type
=
str
,
default
=
None
,
help
=
'origin ref file'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--trans_ref'
,
type
=
str
,
default
=
None
,
help
=
'ref file for caculating CER/WER'
)
'--trans_ref'
,
type
=
str
,
default
=
None
,
help
=
'ref file for caculating CER/WER'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--trans_ref_sclite'
,
type
=
str
,
default
=
None
,
help
=
'ref file for caculating CER/WER by sclite'
)
'--trans_ref_sclite'
,
type
=
str
,
default
=
None
,
help
=
'ref file for caculating CER/WER by sclite'
)
parser_args
=
parser
.
parse_args
()
parser_args
=
parser
.
parse_args
()
if
parser_args
.
origin_hyp
is
not
None
:
if
parser_args
.
origin_hyp
is
not
None
:
trans_hyp
(
trans_hyp
(
origin_hyp
=
parser_args
.
origin_hyp
,
origin_hyp
=
parser_args
.
origin_hyp
,
trans_hyp
=
parser_args
.
trans_hyp
,
trans_hyp
=
parser_args
.
trans_hyp
,
trans_hyp_sclite
=
parser_args
.
trans_hyp_sclite
,
)
trans_hyp_sclite
=
parser_args
.
trans_hyp_sclite
,
)
if
parser_args
.
origin_ref
is
not
None
:
if
parser_args
.
origin_ref
is
not
None
:
trans_ref
(
trans_ref
(
origin_ref
=
parser_args
.
origin_ref
,
origin_ref
=
parser_args
.
origin_ref
,
trans_ref
=
parser_args
.
trans_ref
,
trans_ref
=
parser_args
.
trans_ref
,
trans_ref_sclite
=
parser_args
.
trans_ref_sclite
,
)
trans_ref_sclite
=
parser_args
.
trans_ref_sclite
,
)
utils/fst/prepare_dict.py
浏览文件 @
c7d9b115
...
@@ -35,7 +35,7 @@ def main(args):
...
@@ -35,7 +35,7 @@ def main(args):
# used to filter polyphone and invalid word
# used to filter polyphone and invalid word
lexicon_table
=
set
()
lexicon_table
=
set
()
in_n
=
0
# in lexicon word count
in_n
=
0
# in lexicon word count
out_n
=
0
# out lexicon word cout
out_n
=
0
# out lexicon word cout
with
open
(
args
.
in_lexicon
,
'r'
)
as
fin
,
\
with
open
(
args
.
in_lexicon
,
'r'
)
as
fin
,
\
open
(
args
.
out_lexicon
,
'w'
)
as
fout
:
open
(
args
.
out_lexicon
,
'w'
)
as
fout
:
for
line
in
fin
:
for
line
in
fin
:
...
@@ -82,7 +82,10 @@ def main(args):
...
@@ -82,7 +82,10 @@ def main(args):
lexicon_table
.
add
(
word
)
lexicon_table
.
add
(
word
)
out_n
+=
1
out_n
+=
1
print
(
f
"Filter lexicon by unit table: filter out
{
in_n
-
out_n
}
,
{
out_n
}
/
{
in_n
}
"
)
print
(
f
"Filter lexicon by unit table: filter out
{
in_n
-
out_n
}
,
{
out_n
}
/
{
in_n
}
"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录