Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
eac36205
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eac36205
编写于
9月 16, 2022
作者:
小湉湉
提交者:
GitHub
9月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add typehint for g2pw (#2390)
上级
68c2ec75
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
71 addition
and
58 deletion
+71
-58
paddlespeech/t2s/frontend/g2pw/__init__.py
paddlespeech/t2s/frontend/g2pw/__init__.py
+1
-1
paddlespeech/t2s/frontend/g2pw/dataset.py
paddlespeech/t2s/frontend/g2pw/dataset.py
+34
-32
paddlespeech/t2s/frontend/g2pw/onnx_api.py
paddlespeech/t2s/frontend/g2pw/onnx_api.py
+30
-20
paddlespeech/t2s/frontend/g2pw/utils.py
paddlespeech/t2s/frontend/g2pw/utils.py
+6
-5
未找到文件。
paddlespeech/t2s/frontend/g2pw/__init__.py
浏览文件 @
eac36205
from
paddlespeech.t2s.frontend.g2pw
.onnx_api
import
G2PWOnnxConverter
from
.onnx_api
import
G2PWOnnxConverter
paddlespeech/t2s/frontend/g2pw/dataset.py
浏览文件 @
eac36205
...
@@ -15,6 +15,10 @@
...
@@ -15,6 +15,10 @@
Credits
Credits
This code is modified from https://github.com/GitYCC/g2pW
This code is modified from https://github.com/GitYCC/g2pW
"""
"""
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
import
numpy
as
np
import
numpy
as
np
from
paddlespeech.t2s.frontend.g2pw.utils
import
tokenize_and_map
from
paddlespeech.t2s.frontend.g2pw.utils
import
tokenize_and_map
...
@@ -23,22 +27,17 @@ ANCHOR_CHAR = '▁'
...
@@ -23,22 +27,17 @@ ANCHOR_CHAR = '▁'
def
prepare_onnx_input
(
tokenizer
,
def
prepare_onnx_input
(
tokenizer
,
labels
,
labels
:
List
[
str
],
char2phonemes
,
char2phonemes
:
Dict
[
str
,
List
[
int
]],
chars
,
chars
:
List
[
str
],
texts
,
texts
:
List
[
str
],
query_ids
,
query_ids
:
List
[
int
],
phonemes
=
None
,
use_mask
:
bool
=
False
,
pos_tags
=
None
,
window_size
:
int
=
None
,
use_mask
=
False
,
max_len
:
int
=
512
)
->
Dict
[
str
,
np
.
array
]:
use_char_phoneme
=
False
,
use_pos
=
False
,
window_size
=
None
,
max_len
=
512
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
truncated_texts
,
truncated_query_ids
=
_truncate_texts
(
window_size
,
truncated_texts
,
truncated_query_ids
=
_truncate_texts
(
texts
,
query_ids
)
window_size
=
window_size
,
texts
=
texts
,
query_ids
=
query_ids
)
input_ids
=
[]
input_ids
=
[]
token_type_ids
=
[]
token_type_ids
=
[]
attention_masks
=
[]
attention_masks
=
[]
...
@@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer,
...
@@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer,
query_id
=
(
truncated_query_ids
if
window_size
else
query_ids
)[
idx
]
query_id
=
(
truncated_query_ids
if
window_size
else
query_ids
)[
idx
]
try
:
try
:
tokens
,
text2token
,
token2text
=
tokenize_and_map
(
tokenizer
,
text
)
tokens
,
text2token
,
token2text
=
tokenize_and_map
(
tokenizer
=
tokenizer
,
text
=
text
)
except
Exception
:
except
Exception
:
print
(
f
'warning: text "
{
text
}
" is invalid'
)
print
(
f
'warning: text "
{
text
}
" is invalid'
)
return
{}
return
{}
text
,
query_id
,
tokens
,
text2token
,
token2text
=
_truncate
(
text
,
query_id
,
tokens
,
text2token
,
token2text
=
_truncate
(
max_len
,
text
,
query_id
,
tokens
,
text2token
,
token2text
)
max_len
=
max_len
,
text
=
text
,
query_id
=
query_id
,
tokens
=
tokens
,
text2token
=
text2token
,
token2text
=
token2text
)
processed_tokens
=
[
'[CLS]'
]
+
tokens
+
[
'[SEP]'
]
processed_tokens
=
[
'[CLS]'
]
+
tokens
+
[
'[SEP]'
]
...
@@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer,
...
@@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer,
return
outputs
return
outputs
def
_truncate_texts
(
window_size
,
texts
,
query_ids
):
def
_truncate_texts
(
window_size
:
int
,
texts
:
List
[
str
],
query_ids
:
List
[
int
])
->
Tuple
[
List
[
str
],
List
[
int
]]:
truncated_texts
=
[]
truncated_texts
=
[]
truncated_query_ids
=
[]
truncated_query_ids
=
[]
for
text
,
query_id
in
zip
(
texts
,
query_ids
):
for
text
,
query_id
in
zip
(
texts
,
query_ids
):
...
@@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids):
...
@@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids):
return
truncated_texts
,
truncated_query_ids
return
truncated_texts
,
truncated_query_ids
def
_truncate
(
max_len
,
text
,
query_id
,
tokens
,
text2token
,
token2text
):
def
_truncate
(
max_len
:
int
,
text
:
str
,
query_id
:
int
,
tokens
:
List
[
str
],
text2token
:
List
[
int
],
token2text
:
List
[
Tuple
[
int
]]):
truncate_len
=
max_len
-
2
truncate_len
=
max_len
-
2
if
len
(
tokens
)
<=
truncate_len
:
if
len
(
tokens
)
<=
truncate_len
:
return
(
text
,
query_id
,
tokens
,
text2token
,
token2text
)
return
(
text
,
query_id
,
tokens
,
text2token
,
token2text
)
...
@@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
...
@@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
],
[(
s
-
start
,
e
-
start
)
for
s
,
e
in
token2text
[
token_start
:
token_end
]])
],
[(
s
-
start
,
e
-
start
)
for
s
,
e
in
token2text
[
token_start
:
token_end
]])
def
prepare_data
(
sent_path
,
lb_path
=
None
):
def
get_phoneme_labels
(
polyphonic_chars
:
List
[
List
[
str
]]
raw_texts
=
open
(
sent_path
).
read
().
rstrip
().
split
(
'
\n
'
)
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
List
[
int
]]]:
query_ids
=
[
raw
.
index
(
ANCHOR_CHAR
)
for
raw
in
raw_texts
]
texts
=
[
raw
.
replace
(
ANCHOR_CHAR
,
''
)
for
raw
in
raw_texts
]
if
lb_path
is
None
:
return
texts
,
query_ids
else
:
phonemes
=
open
(
lb_path
).
read
().
rstrip
().
split
(
'
\n
'
)
return
texts
,
query_ids
,
phonemes
def
get_phoneme_labels
(
polyphonic_chars
):
labels
=
sorted
(
list
(
set
([
phoneme
for
char
,
phoneme
in
polyphonic_chars
])))
labels
=
sorted
(
list
(
set
([
phoneme
for
char
,
phoneme
in
polyphonic_chars
])))
char2phonemes
=
{}
char2phonemes
=
{}
for
char
,
phoneme
in
polyphonic_chars
:
for
char
,
phoneme
in
polyphonic_chars
:
...
@@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
...
@@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
return
labels
,
char2phonemes
return
labels
,
char2phonemes
def
get_char_phoneme_labels
(
polyphonic_chars
):
def
get_char_phoneme_labels
(
polyphonic_chars
:
List
[
List
[
str
]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
List
[
int
]]]:
labels
=
sorted
(
labels
=
sorted
(
list
(
set
([
f
'
{
char
}
{
phoneme
}
'
for
char
,
phoneme
in
polyphonic_chars
])))
list
(
set
([
f
'
{
char
}
{
phoneme
}
'
for
char
,
phoneme
in
polyphonic_chars
])))
char2phonemes
=
{}
char2phonemes
=
{}
...
...
paddlespeech/t2s/frontend/g2pw/onnx_api.py
浏览文件 @
eac36205
...
@@ -17,6 +17,10 @@ Credits
...
@@ -17,6 +17,10 @@ Credits
"""
"""
import
json
import
json
import
os
import
os
from
typing
import
Any
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
import
numpy
as
np
import
numpy
as
np
import
onnxruntime
import
onnxruntime
...
@@ -37,7 +41,8 @@ from paddlespeech.utils.env import MODEL_HOME
...
@@ -37,7 +41,8 @@ from paddlespeech.utils.env import MODEL_HOME
model_version
=
'1.1'
model_version
=
'1.1'
def
predict
(
session
,
onnx_input
,
labels
):
def
predict
(
session
,
onnx_input
:
Dict
[
str
,
Any
],
labels
:
List
[
str
])
->
Tuple
[
List
[
str
],
List
[
float
]]:
all_preds
=
[]
all_preds
=
[]
all_confidences
=
[]
all_confidences
=
[]
probs
=
session
.
run
([],
{
probs
=
session
.
run
([],
{
...
@@ -61,10 +66,10 @@ def predict(session, onnx_input, labels):
...
@@ -61,10 +66,10 @@ def predict(session, onnx_input, labels):
class
G2PWOnnxConverter
:
class
G2PWOnnxConverter
:
def
__init__
(
self
,
def
__init__
(
self
,
model_dir
=
MODEL_HOME
,
model_dir
:
os
.
PathLike
=
MODEL_HOME
,
style
=
'bopomofo'
,
style
:
str
=
'bopomofo'
,
model_source
=
None
,
model_source
:
str
=
None
,
enable_non_tradional_chinese
=
False
):
enable_non_tradional_chinese
:
bool
=
False
):
uncompress_path
=
download_and_decompress
(
uncompress_path
=
download_and_decompress
(
g2pw_onnx_models
[
'G2PWModel'
][
model_version
],
model_dir
)
g2pw_onnx_models
[
'G2PWModel'
][
model_version
],
model_dir
)
...
@@ -76,7 +81,8 @@ class G2PWOnnxConverter:
...
@@ -76,7 +81,8 @@ class G2PWOnnxConverter:
os
.
path
.
join
(
uncompress_path
,
'g2pW.onnx'
),
os
.
path
.
join
(
uncompress_path
,
'g2pW.onnx'
),
sess_options
=
sess_options
)
sess_options
=
sess_options
)
self
.
config
=
load_config
(
self
.
config
=
load_config
(
os
.
path
.
join
(
uncompress_path
,
'config.py'
),
use_default
=
True
)
config_path
=
os
.
path
.
join
(
uncompress_path
,
'config.py'
),
use_default
=
True
)
self
.
model_source
=
model_source
if
model_source
else
self
.
config
.
model_source
self
.
model_source
=
model_source
if
model_source
else
self
.
config
.
model_source
self
.
enable_opencc
=
enable_non_tradional_chinese
self
.
enable_opencc
=
enable_non_tradional_chinese
...
@@ -103,9 +109,9 @@ class G2PWOnnxConverter:
...
@@ -103,9 +109,9 @@ class G2PWOnnxConverter:
.
strip
().
split
(
'
\n
'
)
.
strip
().
split
(
'
\n
'
)
]
]
self
.
labels
,
self
.
char2phonemes
=
get_char_phoneme_labels
(
self
.
labels
,
self
.
char2phonemes
=
get_char_phoneme_labels
(
self
.
polyphonic_chars
polyphonic_chars
=
self
.
polyphonic_chars
)
if
self
.
config
.
use_char_phoneme
else
get_phoneme_labels
(
)
if
self
.
config
.
use_char_phoneme
else
get_phoneme_labels
(
self
.
polyphonic_chars
)
polyphonic_chars
=
self
.
polyphonic_chars
)
self
.
chars
=
sorted
(
list
(
self
.
char2phonemes
.
keys
()))
self
.
chars
=
sorted
(
list
(
self
.
char2phonemes
.
keys
()))
...
@@ -146,7 +152,7 @@ class G2PWOnnxConverter:
...
@@ -146,7 +152,7 @@ class G2PWOnnxConverter:
if
self
.
enable_opencc
:
if
self
.
enable_opencc
:
self
.
cc
=
OpenCC
(
's2tw'
)
self
.
cc
=
OpenCC
(
's2tw'
)
def
_convert_bopomofo_to_pinyin
(
self
,
bopomofo
)
:
def
_convert_bopomofo_to_pinyin
(
self
,
bopomofo
:
str
)
->
str
:
tone
=
bopomofo
[
-
1
]
tone
=
bopomofo
[
-
1
]
assert
tone
in
'12345'
assert
tone
in
'12345'
component
=
self
.
bopomofo_convert_dict
.
get
(
bopomofo
[:
-
1
])
component
=
self
.
bopomofo_convert_dict
.
get
(
bopomofo
[:
-
1
])
...
@@ -156,7 +162,7 @@ class G2PWOnnxConverter:
...
@@ -156,7 +162,7 @@ class G2PWOnnxConverter:
print
(
f
'Warning: "
{
bopomofo
}
" cannot convert to pinyin'
)
print
(
f
'Warning: "
{
bopomofo
}
" cannot convert to pinyin'
)
return
None
return
None
def
__call__
(
self
,
sentences
)
:
def
__call__
(
self
,
sentences
:
List
[
str
])
->
List
[
List
[
str
]]
:
if
isinstance
(
sentences
,
str
):
if
isinstance
(
sentences
,
str
):
sentences
=
[
sentences
]
sentences
=
[
sentences
]
...
@@ -169,23 +175,25 @@ class G2PWOnnxConverter:
...
@@ -169,23 +175,25 @@ class G2PWOnnxConverter:
sentences
=
translated_sentences
sentences
=
translated_sentences
texts
,
query_ids
,
sent_ids
,
partial_results
=
self
.
_prepare_data
(
texts
,
query_ids
,
sent_ids
,
partial_results
=
self
.
_prepare_data
(
sentences
)
sentences
=
sentences
)
if
len
(
texts
)
==
0
:
if
len
(
texts
)
==
0
:
# sentences no polyphonic words
# sentences no polyphonic words
return
partial_results
return
partial_results
onnx_input
=
prepare_onnx_input
(
onnx_input
=
prepare_onnx_input
(
self
.
tokenizer
,
tokenizer
=
self
.
tokenizer
,
self
.
labels
,
labels
=
self
.
labels
,
self
.
char2phonemes
,
char2phonemes
=
self
.
char2phonemes
,
self
.
chars
,
chars
=
self
.
chars
,
texts
,
texts
=
texts
,
query_ids
,
query_ids
=
query_ids
,
use_mask
=
self
.
config
.
use_mask
,
use_mask
=
self
.
config
.
use_mask
,
use_char_phoneme
=
self
.
config
.
use_char_phoneme
,
window_size
=
None
)
window_size
=
None
)
preds
,
confidences
=
predict
(
self
.
session_g2pW
,
onnx_input
,
self
.
labels
)
preds
,
confidences
=
predict
(
session
=
self
.
session_g2pW
,
onnx_input
=
onnx_input
,
labels
=
self
.
labels
)
if
self
.
config
.
use_char_phoneme
:
if
self
.
config
.
use_char_phoneme
:
preds
=
[
pred
.
split
(
' '
)[
1
]
for
pred
in
preds
]
preds
=
[
pred
.
split
(
' '
)[
1
]
for
pred
in
preds
]
...
@@ -195,7 +203,9 @@ class G2PWOnnxConverter:
...
@@ -195,7 +203,9 @@ class G2PWOnnxConverter:
return
results
return
results
def
_prepare_data
(
self
,
sentences
):
def
_prepare_data
(
self
,
sentences
:
List
[
str
]
)
->
Tuple
[
List
[
str
],
List
[
int
],
List
[
int
],
List
[
List
[
str
]]]:
texts
,
query_ids
,
sent_ids
,
partial_results
=
[],
[],
[],
[]
texts
,
query_ids
,
sent_ids
,
partial_results
=
[],
[],
[],
[]
for
sent_id
,
sent
in
enumerate
(
sentences
):
for
sent_id
,
sent
in
enumerate
(
sentences
):
# pypinyin works well for Simplified Chinese than Traditional Chinese
# pypinyin works well for Simplified Chinese than Traditional Chinese
...
...
paddlespeech/t2s/frontend/g2pw/utils.py
浏览文件 @
eac36205
...
@@ -15,10 +15,11 @@
...
@@ -15,10 +15,11 @@
Credits
Credits
This code is modified from https://github.com/GitYCC/g2pW
This code is modified from https://github.com/GitYCC/g2pW
"""
"""
import
os
import
re
import
re
def
wordize_and_map
(
text
):
def
wordize_and_map
(
text
:
str
):
words
=
[]
words
=
[]
index_map_from_text_to_word
=
[]
index_map_from_text_to_word
=
[]
index_map_from_word_to_text
=
[]
index_map_from_word_to_text
=
[]
...
@@ -54,8 +55,8 @@ def wordize_and_map(text):
...
@@ -54,8 +55,8 @@ def wordize_and_map(text):
return
words
,
index_map_from_text_to_word
,
index_map_from_word_to_text
return
words
,
index_map_from_text_to_word
,
index_map_from_word_to_text
def
tokenize_and_map
(
tokenizer
,
text
):
def
tokenize_and_map
(
tokenizer
,
text
:
str
):
words
,
text2word
,
word2text
=
wordize_and_map
(
text
)
words
,
text2word
,
word2text
=
wordize_and_map
(
text
=
text
)
tokens
=
[]
tokens
=
[]
index_map_from_token_to_text
=
[]
index_map_from_token_to_text
=
[]
...
@@ -82,7 +83,7 @@ def tokenize_and_map(tokenizer, text):
...
@@ -82,7 +83,7 @@ def tokenize_and_map(tokenizer, text):
return
tokens
,
index_map_from_text_to_token
,
index_map_from_token_to_text
return
tokens
,
index_map_from_text_to_token
,
index_map_from_token_to_text
def
_load_config
(
config_path
):
def
_load_config
(
config_path
:
os
.
PathLike
):
import
importlib.util
import
importlib.util
spec
=
importlib
.
util
.
spec_from_file_location
(
'__init__'
,
config_path
)
spec
=
importlib
.
util
.
spec_from_file_location
(
'__init__'
,
config_path
)
config
=
importlib
.
util
.
module_from_spec
(
spec
)
config
=
importlib
.
util
.
module_from_spec
(
spec
)
...
@@ -130,7 +131,7 @@ default_config_dict = {
...
@@ -130,7 +131,7 @@ default_config_dict = {
}
}
def
load_config
(
config_path
,
use_default
=
False
):
def
load_config
(
config_path
:
os
.
PathLike
,
use_default
:
bool
=
False
):
config
=
_load_config
(
config_path
)
config
=
_load_config
(
config_path
)
if
use_default
:
if
use_default
:
for
attr
,
val
in
default_config_dict
.
items
():
for
attr
,
val
in
default_config_dict
.
items
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录