Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
eac36205
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
eac36205
编写于
9月 16, 2022
作者:
小湉湉
提交者:
GitHub
9月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add typehint for g2pw (#2390)
上级
68c2ec75
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
71 addition
and
58 deletion
+71
-58
paddlespeech/t2s/frontend/g2pw/__init__.py
paddlespeech/t2s/frontend/g2pw/__init__.py
+1
-1
paddlespeech/t2s/frontend/g2pw/dataset.py
paddlespeech/t2s/frontend/g2pw/dataset.py
+34
-32
paddlespeech/t2s/frontend/g2pw/onnx_api.py
paddlespeech/t2s/frontend/g2pw/onnx_api.py
+30
-20
paddlespeech/t2s/frontend/g2pw/utils.py
paddlespeech/t2s/frontend/g2pw/utils.py
+6
-5
未找到文件。
paddlespeech/t2s/frontend/g2pw/__init__.py
浏览文件 @
eac36205
from
paddlespeech.t2s.frontend.g2pw
.onnx_api
import
G2PWOnnxConverter
from
.onnx_api
import
G2PWOnnxConverter
paddlespeech/t2s/frontend/g2pw/dataset.py
浏览文件 @
eac36205
...
...
@@ -15,6 +15,10 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
import
numpy
as
np
from
paddlespeech.t2s.frontend.g2pw.utils
import
tokenize_and_map
...
...
@@ -23,22 +27,17 @@ ANCHOR_CHAR = '▁'
def
prepare_onnx_input
(
tokenizer
,
labels
,
char2phonemes
,
chars
,
texts
,
query_ids
,
phonemes
=
None
,
pos_tags
=
None
,
use_mask
=
False
,
use_char_phoneme
=
False
,
use_pos
=
False
,
window_size
=
None
,
max_len
=
512
):
labels
:
List
[
str
],
char2phonemes
:
Dict
[
str
,
List
[
int
]],
chars
:
List
[
str
],
texts
:
List
[
str
],
query_ids
:
List
[
int
],
use_mask
:
bool
=
False
,
window_size
:
int
=
None
,
max_len
:
int
=
512
)
->
Dict
[
str
,
np
.
array
]:
if
window_size
is
not
None
:
truncated_texts
,
truncated_query_ids
=
_truncate_texts
(
window_size
,
texts
,
query_ids
)
truncated_texts
,
truncated_query_ids
=
_truncate_texts
(
window_size
=
window_size
,
texts
=
texts
,
query_ids
=
query_ids
)
input_ids
=
[]
token_type_ids
=
[]
attention_masks
=
[]
...
...
@@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer,
query_id
=
(
truncated_query_ids
if
window_size
else
query_ids
)[
idx
]
try
:
tokens
,
text2token
,
token2text
=
tokenize_and_map
(
tokenizer
,
text
)
tokens
,
text2token
,
token2text
=
tokenize_and_map
(
tokenizer
=
tokenizer
,
text
=
text
)
except
Exception
:
print
(
f
'warning: text "
{
text
}
" is invalid'
)
return
{}
text
,
query_id
,
tokens
,
text2token
,
token2text
=
_truncate
(
max_len
,
text
,
query_id
,
tokens
,
text2token
,
token2text
)
max_len
=
max_len
,
text
=
text
,
query_id
=
query_id
,
tokens
=
tokens
,
text2token
=
text2token
,
token2text
=
token2text
)
processed_tokens
=
[
'[CLS]'
]
+
tokens
+
[
'[SEP]'
]
...
...
@@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer,
return
outputs
def
_truncate_texts
(
window_size
,
texts
,
query_ids
):
def
_truncate_texts
(
window_size
:
int
,
texts
:
List
[
str
],
query_ids
:
List
[
int
])
->
Tuple
[
List
[
str
],
List
[
int
]]:
truncated_texts
=
[]
truncated_query_ids
=
[]
for
text
,
query_id
in
zip
(
texts
,
query_ids
):
...
...
@@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids):
return
truncated_texts
,
truncated_query_ids
def
_truncate
(
max_len
,
text
,
query_id
,
tokens
,
text2token
,
token2text
):
def
_truncate
(
max_len
:
int
,
text
:
str
,
query_id
:
int
,
tokens
:
List
[
str
],
text2token
:
List
[
int
],
token2text
:
List
[
Tuple
[
int
]]):
truncate_len
=
max_len
-
2
if
len
(
tokens
)
<=
truncate_len
:
return
(
text
,
query_id
,
tokens
,
text2token
,
token2text
)
...
...
@@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
],
[(
s
-
start
,
e
-
start
)
for
s
,
e
in
token2text
[
token_start
:
token_end
]])
def
prepare_data
(
sent_path
,
lb_path
=
None
):
raw_texts
=
open
(
sent_path
).
read
().
rstrip
().
split
(
'
\n
'
)
query_ids
=
[
raw
.
index
(
ANCHOR_CHAR
)
for
raw
in
raw_texts
]
texts
=
[
raw
.
replace
(
ANCHOR_CHAR
,
''
)
for
raw
in
raw_texts
]
if
lb_path
is
None
:
return
texts
,
query_ids
else
:
phonemes
=
open
(
lb_path
).
read
().
rstrip
().
split
(
'
\n
'
)
return
texts
,
query_ids
,
phonemes
def
get_phoneme_labels
(
polyphonic_chars
):
def
get_phoneme_labels
(
polyphonic_chars
:
List
[
List
[
str
]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
List
[
int
]]]:
labels
=
sorted
(
list
(
set
([
phoneme
for
char
,
phoneme
in
polyphonic_chars
])))
char2phonemes
=
{}
for
char
,
phoneme
in
polyphonic_chars
:
...
...
@@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
return
labels
,
char2phonemes
def
get_char_phoneme_labels
(
polyphonic_chars
):
def
get_char_phoneme_labels
(
polyphonic_chars
:
List
[
List
[
str
]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
List
[
int
]]]:
labels
=
sorted
(
list
(
set
([
f
'
{
char
}
{
phoneme
}
'
for
char
,
phoneme
in
polyphonic_chars
])))
char2phonemes
=
{}
...
...
paddlespeech/t2s/frontend/g2pw/onnx_api.py
浏览文件 @
eac36205
...
...
@@ -17,6 +17,10 @@ Credits
"""
import
json
import
os
from
typing
import
Any
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
import
numpy
as
np
import
onnxruntime
...
...
@@ -37,7 +41,8 @@ from paddlespeech.utils.env import MODEL_HOME
model_version
=
'1.1'
def
predict
(
session
,
onnx_input
,
labels
):
def
predict
(
session
,
onnx_input
:
Dict
[
str
,
Any
],
labels
:
List
[
str
])
->
Tuple
[
List
[
str
],
List
[
float
]]:
all_preds
=
[]
all_confidences
=
[]
probs
=
session
.
run
([],
{
...
...
@@ -61,10 +66,10 @@ def predict(session, onnx_input, labels):
class
G2PWOnnxConverter
:
def
__init__
(
self
,
model_dir
=
MODEL_HOME
,
style
=
'bopomofo'
,
model_source
=
None
,
enable_non_tradional_chinese
=
False
):
model_dir
:
os
.
PathLike
=
MODEL_HOME
,
style
:
str
=
'bopomofo'
,
model_source
:
str
=
None
,
enable_non_tradional_chinese
:
bool
=
False
):
uncompress_path
=
download_and_decompress
(
g2pw_onnx_models
[
'G2PWModel'
][
model_version
],
model_dir
)
...
...
@@ -76,7 +81,8 @@ class G2PWOnnxConverter:
os
.
path
.
join
(
uncompress_path
,
'g2pW.onnx'
),
sess_options
=
sess_options
)
self
.
config
=
load_config
(
os
.
path
.
join
(
uncompress_path
,
'config.py'
),
use_default
=
True
)
config_path
=
os
.
path
.
join
(
uncompress_path
,
'config.py'
),
use_default
=
True
)
self
.
model_source
=
model_source
if
model_source
else
self
.
config
.
model_source
self
.
enable_opencc
=
enable_non_tradional_chinese
...
...
@@ -103,9 +109,9 @@ class G2PWOnnxConverter:
.
strip
().
split
(
'
\n
'
)
]
self
.
labels
,
self
.
char2phonemes
=
get_char_phoneme_labels
(
self
.
polyphonic_chars
polyphonic_chars
=
self
.
polyphonic_chars
)
if
self
.
config
.
use_char_phoneme
else
get_phoneme_labels
(
self
.
polyphonic_chars
)
polyphonic_chars
=
self
.
polyphonic_chars
)
self
.
chars
=
sorted
(
list
(
self
.
char2phonemes
.
keys
()))
...
...
@@ -146,7 +152,7 @@ class G2PWOnnxConverter:
if
self
.
enable_opencc
:
self
.
cc
=
OpenCC
(
's2tw'
)
def
_convert_bopomofo_to_pinyin
(
self
,
bopomofo
)
:
def
_convert_bopomofo_to_pinyin
(
self
,
bopomofo
:
str
)
->
str
:
tone
=
bopomofo
[
-
1
]
assert
tone
in
'12345'
component
=
self
.
bopomofo_convert_dict
.
get
(
bopomofo
[:
-
1
])
...
...
@@ -156,7 +162,7 @@ class G2PWOnnxConverter:
print
(
f
'Warning: "
{
bopomofo
}
" cannot convert to pinyin'
)
return
None
def
__call__
(
self
,
sentences
)
:
def
__call__
(
self
,
sentences
:
List
[
str
])
->
List
[
List
[
str
]]
:
if
isinstance
(
sentences
,
str
):
sentences
=
[
sentences
]
...
...
@@ -169,23 +175,25 @@ class G2PWOnnxConverter:
sentences
=
translated_sentences
texts
,
query_ids
,
sent_ids
,
partial_results
=
self
.
_prepare_data
(
sentences
)
sentences
=
sentences
)
if
len
(
texts
)
==
0
:
# sentences no polyphonic words
return
partial_results
onnx_input
=
prepare_onnx_input
(
self
.
tokenizer
,
self
.
labels
,
self
.
char2phonemes
,
self
.
chars
,
texts
,
query_ids
,
tokenizer
=
self
.
tokenizer
,
labels
=
self
.
labels
,
char2phonemes
=
self
.
char2phonemes
,
chars
=
self
.
chars
,
texts
=
texts
,
query_ids
=
query_ids
,
use_mask
=
self
.
config
.
use_mask
,
use_char_phoneme
=
self
.
config
.
use_char_phoneme
,
window_size
=
None
)
preds
,
confidences
=
predict
(
self
.
session_g2pW
,
onnx_input
,
self
.
labels
)
preds
,
confidences
=
predict
(
session
=
self
.
session_g2pW
,
onnx_input
=
onnx_input
,
labels
=
self
.
labels
)
if
self
.
config
.
use_char_phoneme
:
preds
=
[
pred
.
split
(
' '
)[
1
]
for
pred
in
preds
]
...
...
@@ -195,7 +203,9 @@ class G2PWOnnxConverter:
return
results
def
_prepare_data
(
self
,
sentences
):
def
_prepare_data
(
self
,
sentences
:
List
[
str
]
)
->
Tuple
[
List
[
str
],
List
[
int
],
List
[
int
],
List
[
List
[
str
]]]:
texts
,
query_ids
,
sent_ids
,
partial_results
=
[],
[],
[],
[]
for
sent_id
,
sent
in
enumerate
(
sentences
):
# pypinyin works well for Simplified Chinese than Traditional Chinese
...
...
paddlespeech/t2s/frontend/g2pw/utils.py
浏览文件 @
eac36205
...
...
@@ -15,10 +15,11 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import
os
import
re
def
wordize_and_map
(
text
):
def
wordize_and_map
(
text
:
str
):
words
=
[]
index_map_from_text_to_word
=
[]
index_map_from_word_to_text
=
[]
...
...
@@ -54,8 +55,8 @@ def wordize_and_map(text):
return
words
,
index_map_from_text_to_word
,
index_map_from_word_to_text
def
tokenize_and_map
(
tokenizer
,
text
):
words
,
text2word
,
word2text
=
wordize_and_map
(
text
)
def
tokenize_and_map
(
tokenizer
,
text
:
str
):
words
,
text2word
,
word2text
=
wordize_and_map
(
text
=
text
)
tokens
=
[]
index_map_from_token_to_text
=
[]
...
...
@@ -82,7 +83,7 @@ def tokenize_and_map(tokenizer, text):
return
tokens
,
index_map_from_text_to_token
,
index_map_from_token_to_text
def
_load_config
(
config_path
):
def
_load_config
(
config_path
:
os
.
PathLike
):
import
importlib.util
spec
=
importlib
.
util
.
spec_from_file_location
(
'__init__'
,
config_path
)
config
=
importlib
.
util
.
module_from_spec
(
spec
)
...
...
@@ -130,7 +131,7 @@ default_config_dict = {
}
def
load_config
(
config_path
,
use_default
=
False
):
def
load_config
(
config_path
:
os
.
PathLike
,
use_default
:
bool
=
False
):
config
=
_load_config
(
config_path
)
if
use_default
:
for
attr
,
val
in
default_config_dict
.
items
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录