Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
060fd947
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
060fd947
编写于
8月 15, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format g2pw
上级
ac385053
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
226 addition
and
121 deletion
+226
-121
paddlespeech/t2s/frontend/g2pw/dataset.py
paddlespeech/t2s/frontend/g2pw/dataset.py
+91
-61
paddlespeech/t2s/frontend/g2pw/onnx_api.py
paddlespeech/t2s/frontend/g2pw/onnx_api.py
+98
-42
paddlespeech/t2s/frontend/g2pw/utils.py
paddlespeech/t2s/frontend/g2pw/utils.py
+16
-5
paddlespeech/t2s/frontend/zh_frontend.py
paddlespeech/t2s/frontend/zh_frontend.py
+21
-13
未找到文件。
paddlespeech/t2s/frontend/g2pw/dataset.py
浏览文件 @
060fd947
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
This code is modified from https://github.com/GitYCC/g2pW
"""
import
numpy
as
np
from
paddlespeech.t2s.frontend.g2pw.utils
import
tokenize_and_map
ANCHOR_CHAR
=
'▁'
def
prepare_onnx_input
(
tokenizer
,
labels
,
char2phonemes
,
chars
,
texts
,
query_ids
,
phonemes
=
None
,
pos_tags
=
None
,
use_mask
=
False
,
use_char_phoneme
=
False
,
use_pos
=
False
,
window_size
=
None
,
max_len
=
512
):
if
window_size
is
not
None
:
truncated_texts
,
truncated_query_ids
=
_truncate_texts
(
window_size
,
texts
,
query_ids
)
input_ids
=
[]
token_type_ids
=
[]
attention_masks
=
[]
phoneme_masks
=
[]
char_ids
=
[]
position_ids
=
[]
for
idx
in
range
(
len
(
texts
)):
text
=
(
truncated_texts
if
window_size
else
texts
)[
idx
].
lower
()
query_id
=
(
truncated_query_ids
if
window_size
else
query_ids
)[
idx
]
try
:
tokens
,
text2token
,
token2text
=
tokenize_and_map
(
tokenizer
,
text
)
except
Exception
:
print
(
f
'warning: text "
{
text
}
" is invalid'
)
return
{}
text
,
query_id
,
tokens
,
text2token
,
token2text
=
_truncate
(
max_len
,
text
,
query_id
,
tokens
,
text2token
,
token2text
)
processed_tokens
=
[
'[CLS]'
]
+
tokens
+
[
'[SEP]'
]
input_id
=
list
(
np
.
array
(
tokenizer
.
convert_tokens_to_ids
(
processed_tokens
)))
token_type_id
=
list
(
np
.
zeros
((
len
(
processed_tokens
),),
dtype
=
int
))
attention_mask
=
list
(
np
.
ones
((
len
(
processed_tokens
),),
dtype
=
int
))
query_char
=
text
[
query_id
]
phoneme_mask
=
[
1
if
i
in
char2phonemes
[
query_char
]
else
0
for
i
in
range
(
len
(
labels
))]
\
if
use_mask
else
[
1
]
*
len
(
labels
)
char_id
=
chars
.
index
(
query_char
)
position_id
=
text2token
[
query_id
]
+
1
# [CLS] token locate at first place
input_ids
.
append
(
input_id
)
token_type_ids
.
append
(
token_type_id
)
attention_masks
.
append
(
attention_mask
)
phoneme_masks
.
append
(
phoneme_mask
)
char_ids
.
append
(
char_id
)
position_ids
.
append
(
position_id
)
outputs
=
{
'input_ids'
:
np
.
array
(
input_ids
),
'token_type_ids'
:
np
.
array
(
token_type_ids
),
'attention_masks'
:
np
.
array
(
attention_masks
),
'phoneme_masks'
:
np
.
array
(
phoneme_masks
).
astype
(
np
.
float32
),
'char_ids'
:
np
.
array
(
char_ids
),
'position_ids'
:
np
.
array
(
position_ids
),
}
return
outputs
def
prepare_onnx_input
(
tokenizer
,
labels
,
char2phonemes
,
chars
,
texts
,
query_ids
,
phonemes
=
None
,
pos_tags
=
None
,
use_mask
=
False
,
use_char_phoneme
=
False
,
use_pos
=
False
,
window_size
=
None
,
max_len
=
512
):
if
window_size
is
not
None
:
truncated_texts
,
truncated_query_ids
=
_truncate_texts
(
window_size
,
texts
,
query_ids
)
input_ids
=
[]
token_type_ids
=
[]
attention_masks
=
[]
phoneme_masks
=
[]
char_ids
=
[]
position_ids
=
[]
for
idx
in
range
(
len
(
texts
)):
text
=
(
truncated_texts
if
window_size
else
texts
)[
idx
].
lower
()
query_id
=
(
truncated_query_ids
if
window_size
else
query_ids
)[
idx
]
try
:
tokens
,
text2token
,
token2text
=
tokenize_and_map
(
tokenizer
,
text
)
except
Exception
:
print
(
f
'warning: text "
{
text
}
" is invalid'
)
return
{}
text
,
query_id
,
tokens
,
text2token
,
token2text
=
_truncate
(
max_len
,
text
,
query_id
,
tokens
,
text2token
,
token2text
)
processed_tokens
=
[
'[CLS]'
]
+
tokens
+
[
'[SEP]'
]
input_id
=
list
(
np
.
array
(
tokenizer
.
convert_tokens_to_ids
(
processed_tokens
)))
token_type_id
=
list
(
np
.
zeros
((
len
(
processed_tokens
),
),
dtype
=
int
))
attention_mask
=
list
(
np
.
ones
((
len
(
processed_tokens
),
),
dtype
=
int
))
query_char
=
text
[
query_id
]
phoneme_mask
=
[
1
if
i
in
char2phonemes
[
query_char
]
else
0
for
i
in
range
(
len
(
labels
))]
\
if
use_mask
else
[
1
]
*
len
(
labels
)
char_id
=
chars
.
index
(
query_char
)
position_id
=
text2token
[
query_id
]
+
1
# [CLS] token locate at first place
input_ids
.
append
(
input_id
)
token_type_ids
.
append
(
token_type_id
)
attention_masks
.
append
(
attention_mask
)
phoneme_masks
.
append
(
phoneme_mask
)
char_ids
.
append
(
char_id
)
position_ids
.
append
(
position_id
)
outputs
=
{
'input_ids'
:
np
.
array
(
input_ids
),
'token_type_ids'
:
np
.
array
(
token_type_ids
),
'attention_masks'
:
np
.
array
(
attention_masks
),
'phoneme_masks'
:
np
.
array
(
phoneme_masks
).
astype
(
np
.
float32
),
'char_ids'
:
np
.
array
(
char_ids
),
'position_ids'
:
np
.
array
(
position_ids
),
}
return
outputs
def
_truncate_texts
(
window_size
,
texts
,
query_ids
):
truncated_texts
=
[]
...
...
@@ -74,6 +104,7 @@ def _truncate_texts(window_size, texts, query_ids):
truncated_query_ids
.
append
(
truncated_query_id
)
return
truncated_texts
,
truncated_query_ids
def
_truncate
(
max_len
,
text
,
query_id
,
tokens
,
text2token
,
token2text
):
truncate_len
=
max_len
-
2
if
len
(
tokens
)
<=
truncate_len
:
...
...
@@ -95,13 +126,11 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
start
=
token2text
[
token_start
][
0
]
end
=
token2text
[
token_end
-
1
][
1
]
return
(
text
[
start
:
end
],
query_id
-
start
,
tokens
[
token_start
:
token_end
],
[
i
-
token_start
if
i
is
not
None
else
None
for
i
in
text2token
[
start
:
end
]],
[(
s
-
start
,
e
-
start
)
for
s
,
e
in
token2text
[
token_start
:
token_end
]]
)
return
(
text
[
start
:
end
],
query_id
-
start
,
tokens
[
token_start
:
token_end
],
[
i
-
token_start
if
i
is
not
None
else
None
for
i
in
text2token
[
start
:
end
]
],
[(
s
-
start
,
e
-
start
)
for
s
,
e
in
token2text
[
token_start
:
token_end
]])
def
prepare_data
(
sent_path
,
lb_path
=
None
):
raw_texts
=
open
(
sent_path
).
read
().
rstrip
().
split
(
'
\n
'
)
...
...
@@ -125,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
def
get_char_phoneme_labels
(
polyphonic_chars
):
labels
=
sorted
(
list
(
set
([
f
'
{
char
}
{
phoneme
}
'
for
char
,
phoneme
in
polyphonic_chars
])))
labels
=
sorted
(
list
(
set
([
f
'
{
char
}
{
phoneme
}
'
for
char
,
phoneme
in
polyphonic_chars
])))
char2phonemes
=
{}
for
char
,
phoneme
in
polyphonic_chars
:
if
char
not
in
char2phonemes
:
...
...
paddlespeech/t2s/frontend/g2pw/onnx_api.py
浏览文件 @
060fd947
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
This code is modified from https://github.com/GitYCC/g2pW
"""
import
os
import
json
import
onnxruntime
import
numpy
as
np
import
os
import
numpy
as
np
import
onnxruntime
from
opencc
import
OpenCC
from
pypinyin
import
pinyin
,
lazy_pinyin
,
Style
from
paddlenlp.transformers
import
BertTokenizer
from
paddlespeech.utils.env
import
MODEL_HOME
from
paddlespeech.t2s.frontend.g2pw.dataset
import
prepare_data
,
\
prepare_onnx_input
,
\
get_phoneme_labels
,
\
get_char_phoneme_labels
from
paddlespeech.t2s.frontend.g2pw.utils
import
load_config
from
pypinyin
import
pinyin
from
pypinyin
import
Style
from
paddlespeech.cli.utils
import
download_and_decompress
from
paddlespeech.resource.pretrained_models
import
g2pw_onnx_models
from
paddlespeech.t2s.frontend.g2pw.dataset
import
get_char_phoneme_labels
from
paddlespeech.t2s.frontend.g2pw.dataset
import
get_phoneme_labels
from
paddlespeech.t2s.frontend.g2pw.dataset
import
prepare_onnx_input
from
paddlespeech.t2s.frontend.g2pw.utils
import
load_config
from
paddlespeech.utils.env
import
MODEL_HOME
def
predict
(
session
,
onnx_input
,
labels
):
all_preds
=
[]
all_confidences
=
[]
probs
=
session
.
run
([],{
"input_ids"
:
onnx_input
[
'input_ids'
],
"token_type_ids"
:
onnx_input
[
'token_type_ids'
],
"attention_mask"
:
onnx_input
[
'attention_masks'
],
"phoneme_mask"
:
onnx_input
[
'phoneme_masks'
],
"char_ids"
:
onnx_input
[
'char_ids'
],
"position_ids"
:
onnx_input
[
'position_ids'
]})[
0
]
preds
=
np
.
argmax
(
probs
,
axis
=
1
).
tolist
()
probs
=
session
.
run
([],
{
"input_ids"
:
onnx_input
[
'input_ids'
],
"token_type_ids"
:
onnx_input
[
'token_type_ids'
],
"attention_mask"
:
onnx_input
[
'attention_masks'
],
"phoneme_mask"
:
onnx_input
[
'phoneme_masks'
],
"char_ids"
:
onnx_input
[
'char_ids'
],
"position_ids"
:
onnx_input
[
'position_ids'
]
})[
0
]
preds
=
np
.
argmax
(
probs
,
axis
=
1
).
tolist
()
max_probs
=
[]
for
index
,
arr
in
zip
(
preds
,
probs
.
tolist
()):
for
index
,
arr
in
zip
(
preds
,
probs
.
tolist
()):
max_probs
.
append
(
arr
[
index
])
all_preds
+=
[
labels
[
pred
]
for
pred
in
preds
]
all_confidences
+=
max_probs
...
...
@@ -41,39 +57,69 @@ def predict(session, onnx_input, labels):
class
G2PWOnnxConverter
:
def
__init__
(
self
,
model_dir
=
MODEL_HOME
,
style
=
'bopomofo'
,
model_source
=
None
,
enable_non_tradional_chinese
=
False
):
def
__init__
(
self
,
model_dir
=
MODEL_HOME
,
style
=
'bopomofo'
,
model_source
=
None
,
enable_non_tradional_chinese
=
False
):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/g2pW.onnx'
)):
uncompress_path
=
download_and_decompress
(
g2pw_onnx_models
[
'G2PWModel'
][
'1.0'
],
model_dir
)
uncompress_path
=
download_and_decompress
(
g2pw_onnx_models
[
'G2PWModel'
][
'1.0'
],
model_dir
)
sess_options
=
onnxruntime
.
SessionOptions
()
sess_options
.
graph_optimization_level
=
onnxruntime
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
sess_options
.
execution_mode
=
onnxruntime
.
ExecutionMode
.
ORT_SEQUENTIAL
sess_options
.
intra_op_num_threads
=
2
self
.
session_g2pW
=
onnxruntime
.
InferenceSession
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/g2pW.onnx'
),
sess_options
=
sess_options
)
self
.
config
=
load_config
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/config.py'
),
use_default
=
True
)
self
.
session_g2pW
=
onnxruntime
.
InferenceSession
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/g2pW.onnx'
),
sess_options
=
sess_options
)
self
.
config
=
load_config
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/config.py'
),
use_default
=
True
)
self
.
model_source
=
model_source
if
model_source
else
self
.
config
.
model_source
self
.
enable_opencc
=
enable_non_tradional_chinese
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
self
.
config
.
model_source
)
polyphonic_chars_path
=
os
.
path
.
join
(
model_dir
,
'G2PWModel/POLYPHONIC_CHARS.txt'
)
monophonic_chars_path
=
os
.
path
.
join
(
model_dir
,
'G2PWModel/MONOPHONIC_CHARS.txt'
)
self
.
polyphonic_chars
=
[
line
.
split
(
'
\t
'
)
for
line
in
open
(
polyphonic_chars_path
,
encoding
=
'utf-8'
).
read
().
strip
().
split
(
'
\n
'
)]
self
.
monophonic_chars
=
[
line
.
split
(
'
\t
'
)
for
line
in
open
(
monophonic_chars_path
,
encoding
=
'utf-8'
).
read
().
strip
().
split
(
'
\n
'
)]
self
.
labels
,
self
.
char2phonemes
=
get_char_phoneme_labels
(
self
.
polyphonic_chars
)
if
self
.
config
.
use_char_phoneme
else
get_phoneme_labels
(
self
.
polyphonic_chars
)
polyphonic_chars_path
=
os
.
path
.
join
(
model_dir
,
'G2PWModel/POLYPHONIC_CHARS.txt'
)
monophonic_chars_path
=
os
.
path
.
join
(
model_dir
,
'G2PWModel/MONOPHONIC_CHARS.txt'
)
self
.
polyphonic_chars
=
[
line
.
split
(
'
\t
'
)
for
line
in
open
(
polyphonic_chars_path
,
encoding
=
'utf-8'
).
read
()
.
strip
().
split
(
'
\n
'
)
]
self
.
monophonic_chars
=
[
line
.
split
(
'
\t
'
)
for
line
in
open
(
monophonic_chars_path
,
encoding
=
'utf-8'
).
read
()
.
strip
().
split
(
'
\n
'
)
]
self
.
labels
,
self
.
char2phonemes
=
get_char_phoneme_labels
(
self
.
polyphonic_chars
)
if
self
.
config
.
use_char_phoneme
else
get_phoneme_labels
(
self
.
polyphonic_chars
)
self
.
chars
=
sorted
(
list
(
self
.
char2phonemes
.
keys
()))
self
.
pos_tags
=
[
'UNK'
,
'A'
,
'C'
,
'D'
,
'I'
,
'N'
,
'P'
,
'T'
,
'V'
,
'DE'
,
'SHI'
]
with
open
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'
),
'r'
,
encoding
=
'utf-8'
)
as
fr
:
self
.
pos_tags
=
[
'UNK'
,
'A'
,
'C'
,
'D'
,
'I'
,
'N'
,
'P'
,
'T'
,
'V'
,
'DE'
,
'SHI'
]
with
open
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'
),
'r'
,
encoding
=
'utf-8'
)
as
fr
:
self
.
bopomofo_convert_dict
=
json
.
load
(
fr
)
self
.
style_convert_func
=
{
'bopomofo'
:
lambda
x
:
x
,
'pinyin'
:
self
.
_convert_bopomofo_to_pinyin
,
}[
style
]
with
open
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/char_bopomofo_dict.json'
),
'r'
,
encoding
=
'utf-8'
)
as
fr
:
with
open
(
os
.
path
.
join
(
model_dir
,
'G2PWModel/char_bopomofo_dict.json'
),
'r'
,
encoding
=
'utf-8'
)
as
fr
:
self
.
char_bopomofo_dict
=
json
.
load
(
fr
)
if
self
.
enable_opencc
:
...
...
@@ -100,15 +146,23 @@ class G2PWOnnxConverter:
assert
len
(
translated_sent
)
==
len
(
sent
)
translated_sentences
.
append
(
translated_sent
)
sentences
=
translated_sentences
texts
,
query_ids
,
sent_ids
,
partial_results
=
self
.
_prepare_data
(
sentences
)
texts
,
query_ids
,
sent_ids
,
partial_results
=
self
.
_prepare_data
(
sentences
)
if
len
(
texts
)
==
0
:
# sentences no polyphonic words
return
partial_results
onnx_input
=
prepare_onnx_input
(
self
.
tokenizer
,
self
.
labels
,
self
.
char2phonemes
,
self
.
chars
,
texts
,
query_ids
,
use_mask
=
self
.
config
.
use_mask
,
use_char_phoneme
=
self
.
config
.
use_char_phoneme
,
window_size
=
None
)
onnx_input
=
prepare_onnx_input
(
self
.
tokenizer
,
self
.
labels
,
self
.
char2phonemes
,
self
.
chars
,
texts
,
query_ids
,
use_mask
=
self
.
config
.
use_mask
,
use_char_phoneme
=
self
.
config
.
use_char_phoneme
,
window_size
=
None
)
preds
,
confidences
=
predict
(
self
.
session_g2pW
,
onnx_input
,
self
.
labels
)
if
self
.
config
.
use_char_phoneme
:
...
...
@@ -123,11 +177,12 @@ class G2PWOnnxConverter:
def
_prepare_data
(
self
,
sentences
):
polyphonic_chars
=
set
(
self
.
chars
)
monophonic_chars_dict
=
{
char
:
phoneme
for
char
,
phoneme
in
self
.
monophonic_chars
char
:
phoneme
for
char
,
phoneme
in
self
.
monophonic_chars
}
texts
,
query_ids
,
sent_ids
,
partial_results
=
[],
[],
[],
[]
for
sent_id
,
sent
in
enumerate
(
sentences
):
pypinyin_result
=
pinyin
(
sent
,
style
=
Style
.
TONE3
)
pypinyin_result
=
pinyin
(
sent
,
style
=
Style
.
TONE3
)
partial_result
=
[
None
]
*
len
(
sent
)
for
i
,
char
in
enumerate
(
sent
):
if
char
in
polyphonic_chars
:
...
...
@@ -135,9 +190,10 @@ class G2PWOnnxConverter:
query_ids
.
append
(
i
)
sent_ids
.
append
(
sent_id
)
elif
char
in
monophonic_chars_dict
:
partial_result
[
i
]
=
self
.
style_convert_func
(
monophonic_chars_dict
[
char
])
partial_result
[
i
]
=
self
.
style_convert_func
(
monophonic_chars_dict
[
char
])
elif
char
in
self
.
char_bopomofo_dict
:
partial_result
[
i
]
=
pypinyin_result
[
i
][
0
]
partial_result
[
i
]
=
pypinyin_result
[
i
][
0
]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
partial_results
.
append
(
partial_result
)
return
texts
,
query_ids
,
sent_ids
,
partial_results
paddlespeech/t2s/frontend/g2pw/utils.py
浏览文件 @
060fd947
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
This code is modified from https://github.com/GitYCC/g2pW
"""
import
re
import
sys
def
wordize_and_map
(
text
):
words
=
[]
...
...
@@ -92,7 +104,6 @@ default_config_dict = {
'char-linear'
:
True
,
'pos-linear'
:
False
,
'char+pos-second'
:
True
,
'char+pos-second_lowrank'
:
False
,
'lowrank_size'
:
0
,
'char+pos-second_fm'
:
False
,
...
...
@@ -130,4 +141,4 @@ def load_config(config_path, use_default=False):
for
dict_k
,
dict_v
in
val
.
items
():
if
dict_k
not
in
d
:
d
[
dict_k
]
=
dict_v
return
config
\ No newline at end of file
return
config
paddlespeech/t2s/frontend/zh_frontend.py
浏览文件 @
060fd947
...
...
@@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
import
os
import
yaml
import
re
from
typing
import
Dict
from
typing
import
List
import
jieba.posseg
as
psg
import
numpy
as
np
import
paddle
import
yaml
from
g2pM
import
G2pM
from
pypinyin
import
lazy_pinyin
from
pypinyin
import
load_phrases_dict
...
...
@@ -58,19 +58,24 @@ def insert_after_character(lst, item):
class
Polyphonic
():
def
__init__
(
self
):
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'polyphonic.yaml'
),
'r'
,
encoding
=
'utf-8'
)
as
polyphonic_file
:
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'polyphonic.yaml'
),
'r'
,
encoding
=
'utf-8'
)
as
polyphonic_file
:
# 解析yaml
polyphonic_dict
=
yaml
.
load
(
polyphonic_file
,
Loader
=
yaml
.
FullLoader
)
self
.
polyphonic_words
=
polyphonic_dict
[
"polyphonic"
]
def
correct_pronunciation
(
self
,
word
,
pinyin
):
def
correct_pronunciation
(
self
,
word
,
pinyin
):
# 词汇被词典收录则返回纠正后的读音
if
word
in
self
.
polyphonic_words
.
keys
():
pinyin
=
self
.
polyphonic_words
[
word
]
# 否则返回原读音
return
pinyin
class
Frontend
():
def
__init__
(
self
,
g2p_model
=
"g2pW"
,
...
...
@@ -88,7 +93,8 @@ class Frontend():
elif
self
.
g2p_model
==
"g2pW"
:
self
.
corrector
=
Polyphonic
()
self
.
g2pM_model
=
G2pM
()
self
.
g2pW_model
=
G2PWOnnxConverter
(
style
=
'pinyin'
,
enable_non_tradional_chinese
=
True
)
self
.
g2pW_model
=
G2PWOnnxConverter
(
style
=
'pinyin'
,
enable_non_tradional_chinese
=
True
)
self
.
pinyin2phone
=
generate_lexicon
(
with_tone
=
True
,
with_erhua
=
False
)
...
...
@@ -187,7 +193,7 @@ class Frontend():
pinyins
=
self
.
g2pW_model
(
seg
)[
0
]
except
Exception
:
# g2pW采用模型采用繁体输入,如果有cover不了的简体词,采用g2pM预测
print
(
"[%s] not in g2pW dict,use g2pM"
%
seg
)
print
(
"[%s] not in g2pW dict,use g2pM"
%
seg
)
pinyins
=
self
.
g2pM_model
(
seg
,
tone
=
True
,
char_split
=
False
)
pre_word_length
=
0
for
word
,
pos
in
seg_cut
:
...
...
@@ -199,13 +205,15 @@ class Frontend():
continue
word_pinyins
=
pinyins
[
pre_word_length
:
now_word_length
]
# 矫正发音
word_pinyins
=
self
.
corrector
.
correct_pronunciation
(
word
,
word_pinyins
)
for
pinyin
,
char
in
zip
(
word_pinyins
,
word
):
if
pinyin
==
None
:
word_pinyins
=
self
.
corrector
.
correct_pronunciation
(
word
,
word_pinyins
)
for
pinyin
,
char
in
zip
(
word_pinyins
,
word
):
if
pinyin
is
None
:
pinyin
=
char
pinyin
=
pinyin
.
replace
(
"u:"
,
"v"
)
if
pinyin
in
self
.
pinyin2phone
:
initial_final_list
=
self
.
pinyin2phone
[
pinyin
].
split
(
" "
)
initial_final_list
=
self
.
pinyin2phone
[
pinyin
].
split
(
" "
)
if
len
(
initial_final_list
)
==
2
:
sub_initials
.
append
(
initial_final_list
[
0
])
sub_finals
.
append
(
initial_final_list
[
1
])
...
...
@@ -218,7 +226,7 @@ class Frontend():
sub_finals
.
append
(
pinyin
)
pre_word_length
=
now_word_length
sub_finals
=
self
.
tone_modifier
.
modified_tone
(
word
,
pos
,
sub_finals
)
sub_finals
)
if
with_erhua
:
sub_initials
,
sub_finals
=
self
.
_merge_erhua
(
sub_initials
,
sub_finals
,
word
,
pos
)
...
...
@@ -231,7 +239,7 @@ class Frontend():
continue
sub_initials
,
sub_finals
=
self
.
_get_initials_finals
(
word
)
sub_finals
=
self
.
tone_modifier
.
modified_tone
(
word
,
pos
,
sub_finals
)
sub_finals
)
if
with_erhua
:
sub_initials
,
sub_finals
=
self
.
_merge_erhua
(
sub_initials
,
sub_finals
,
word
,
pos
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录