Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
3f6afc48
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3f6afc48
编写于
12月 09, 2022
作者:
小湉湉
提交者:
GitHub
12月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[TTS]Add slim for TTS (#2729)
上级
6f927d55
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
513 addition
and
8 deletion
+513
-8
examples/csmsc/tts2/local/PTQ_static.sh
examples/csmsc/tts2/local/PTQ_static.sh
+1
-0
examples/csmsc/tts2/run.sh
examples/csmsc/tts2/run.sh
+5
-0
examples/csmsc/tts3/local/PTQ_dynamic.sh
examples/csmsc/tts3/local/PTQ_dynamic.sh
+8
-0
examples/csmsc/tts3/local/PTQ_static.sh
examples/csmsc/tts3/local/PTQ_static.sh
+8
-0
examples/csmsc/tts3/run.sh
examples/csmsc/tts3/run.sh
+13
-0
examples/csmsc/tts3/run_cnndecoder.sh
examples/csmsc/tts3/run_cnndecoder.sh
+5
-0
examples/csmsc/voc1/local/PTQ_static.sh
examples/csmsc/voc1/local/PTQ_static.sh
+8
-0
examples/csmsc/voc1/run.sh
examples/csmsc/voc1/run.sh
+5
-0
examples/csmsc/voc3/local/PTQ_static.sh
examples/csmsc/voc3/local/PTQ_static.sh
+1
-0
examples/csmsc/voc3/run.sh
examples/csmsc/voc3/run.sh
+5
-0
examples/csmsc/voc5/local/PTQ_static.sh
examples/csmsc/voc5/local/PTQ_static.sh
+1
-0
examples/csmsc/voc5/run.sh
examples/csmsc/voc5/run.sh
+5
-0
paddlespeech/t2s/datasets/am_batch_fn.py
paddlespeech/t2s/datasets/am_batch_fn.py
+67
-0
paddlespeech/t2s/datasets/vocoder_batch_fn.py
paddlespeech/t2s/datasets/vocoder_batch_fn.py
+47
-8
paddlespeech/t2s/exps/PTQ_dynamic.py
paddlespeech/t2s/exps/PTQ_dynamic.py
+80
-0
paddlespeech/t2s/exps/PTQ_static.py
paddlespeech/t2s/exps/PTQ_static.py
+156
-0
paddlespeech/t2s/exps/syn_utils.py
paddlespeech/t2s/exps/syn_utils.py
+98
-0
未找到文件。
examples/csmsc/tts2/local/PTQ_static.sh
0 → 120000
浏览文件 @
3f6afc48
../../tts3/local/PTQ_static.sh
\ No newline at end of file
examples/csmsc/tts2/run.sh
浏览文件 @
3f6afc48
...
...
@@ -72,3 +72,8 @@ fi
if
[
${
stage
}
-le
8
]
&&
[
${
stop_stage
}
-ge
8
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/lite_predict.sh
${
train_output_path
}
||
exit
-1
fi
# PTQ_static
if
[
${
stage
}
-le
9
]
&&
[
${
stop_stage
}
-ge
9
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/PTQ_static.sh
${
train_output_path
}
speedyspeech_csmsc
||
exit
-1
fi
examples/csmsc/tts3/local/PTQ_dynamic.sh
0 → 100755
浏览文件 @
3f6afc48
train_output_path
=
$1
model_name
=
$2
weight_bits
=
$3
python3
${
BIN_DIR
}
/../PTQ_dynamic.py
\
--inference_dir
${
train_output_path
}
/inference
\
--model_name
${
model_name
}
\
--weight_bits
${
weight_bits
}
\ No newline at end of file
examples/csmsc/tts3/local/PTQ_static.sh
0 → 100755
浏览文件 @
3f6afc48
train_output_path
=
$1
model_name
=
$2
python3
${
BIN_DIR
}
/../PTQ_static.py
\
--dev-metadata
=
dump/dev/norm/metadata.jsonl
\
--inference_dir
${
train_output_path
}
/inference
\
--model_name
${
model_name
}
\
--onnx_forma
=
True
\ No newline at end of file
examples/csmsc/tts3/run.sh
浏览文件 @
3f6afc48
...
...
@@ -76,3 +76,16 @@ fi
if
[
${
stage
}
-le
8
]
&&
[
${
stop_stage
}
-ge
8
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/lite_predict.sh
${
train_output_path
}
||
exit
-1
fi
# PTQ_dynamic
if
[
${
stage
}
-le
9
]
&&
[
${
stop_stage
}
-ge
9
]
;
then
./local/PTQ_dynamic.sh
${
train_output_path
}
fastspeech2_csmsc 8
# ./local/PTQ_dynamic.sh ${train_output_path} pwgan_csmsc 8
# ./local/PTQ_dynamic.sh ${train_output_path} mb_melgan_csmsc 8
# ./local/PTQ_dynamic.sh ${train_output_path} hifigan_csmsc 8
fi
# PTQ_static
if
[
${
stage
}
-le
10
]
&&
[
${
stop_stage
}
-ge
10
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/PTQ_static.sh
${
train_output_path
}
fastspeech2_csmsc
||
exit
-1
fi
examples/csmsc/tts3/run_cnndecoder.sh
浏览文件 @
3f6afc48
...
...
@@ -122,3 +122,8 @@ fi
if
[
${
stage
}
-le
14
]
&&
[
${
stop_stage
}
-ge
14
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/lite_predict_streaming.sh
${
train_output_path
}
||
exit
-1
fi
# PTQ_static
if
[
${
stage
}
-le
15
]
&&
[
${
stop_stage
}
-ge
15
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/PTQ_static.sh
${
train_output_path
}
fastspeech2_csmsc
||
exit
-1
fi
\ No newline at end of file
examples/csmsc/voc1/local/PTQ_static.sh
0 → 100755
浏览文件 @
3f6afc48
train_output_path
=
$1
model_name
=
$2
python3
${
BIN_DIR
}
/../../PTQ_static.py
\
--dev-metadata
=
dump/dev/norm/metadata.jsonl
\
--inference_dir
${
train_output_path
}
/inference
\
--model_name
${
model_name
}
\
--onnx_format
=
True
\ No newline at end of file
examples/csmsc/voc1/run.sh
浏览文件 @
3f6afc48
...
...
@@ -30,3 +30,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
# PTQ_static
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/PTQ_static.sh
${
train_output_path
}
pwgan_csmsc
||
exit
-1
fi
examples/csmsc/voc3/local/PTQ_static.sh
0 → 120000
浏览文件 @
3f6afc48
../../voc1/local/PTQ_static.sh
\ No newline at end of file
examples/csmsc/voc3/run.sh
浏览文件 @
3f6afc48
...
...
@@ -30,3 +30,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
# PTQ_static
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/PTQ_static.sh
${
train_output_path
}
mb_melgan_csmsc
||
exit
-1
fi
examples/csmsc/voc5/local/PTQ_static.sh
0 → 120000
浏览文件 @
3f6afc48
../../voc1/local/PTQ_static.sh
\ No newline at end of file
examples/csmsc/voc5/run.sh
浏览文件 @
3f6afc48
...
...
@@ -30,3 +30,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
# PTQ_static
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/PTQ_static.sh
${
train_output_path
}
hifigan_csmsc
||
exit
-1
fi
\ No newline at end of file
paddlespeech/t2s/datasets/am_batch_fn.py
浏览文件 @
3f6afc48
...
...
@@ -538,3 +538,70 @@ def vits_multi_spk_batch_fn(examples):
spk_id
=
paddle
.
to_tensor
(
spk_id
)
batch
[
"spk_id"
]
=
spk_id
return
batch
# for PaddleSlim
def
fastspeech2_single_spk_batch_fn_static
(
examples
):
text
=
[
np
.
array
(
item
[
"text"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
text
=
np
.
array
(
text
)
# do not need batch axis in infer
text
=
text
[
0
]
batch
=
{
"text"
:
text
,
}
return
batch
def
fastspeech2_multi_spk_batch_fn_static
(
examples
):
text
=
[
np
.
array
(
item
[
"text"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
text
=
np
.
array
(
text
)
text
=
text
[
0
]
batch
=
{
"text"
:
text
,
}
if
"spk_id"
in
examples
[
0
]:
spk_id
=
[
np
.
array
(
item
[
"spk_id"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
spk_id
=
np
.
array
(
spk_id
)
spk_id
=
spk_id
[
0
]
batch
[
"spk_id"
]
=
spk_id
if
"spk_emb"
in
examples
[
0
]:
spk_emb
=
[
np
.
array
(
item
[
"spk_emb"
],
dtype
=
np
.
float32
)
for
item
in
examples
]
spk_emb
=
np
.
array
(
spk_emb
)
spk_emb
=
spk_id
[
spk_emb
]
batch
[
"spk_emb"
]
=
spk_emb
return
batch
def
speedyspeech_single_spk_batch_fn_static
(
examples
):
phones
=
[
np
.
array
(
item
[
"phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
tones
=
[
np
.
array
(
item
[
"tones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
phones
=
np
.
array
(
phones
)
tones
=
np
.
array
(
tones
)
phones
=
phones
[
0
]
tones
=
tones
[
0
]
batch
=
{
"phones"
:
phones
,
"tones"
:
tones
,
}
return
batch
def
speedyspeech_multi_spk_batch_fn_static
(
examples
):
phones
=
[
np
.
array
(
item
[
"phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
tones
=
[
np
.
array
(
item
[
"tones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
phones
=
np
.
array
(
phones
)
tones
=
np
.
array
(
tones
)
phones
=
phones
[
0
]
tones
=
tones
[
0
]
batch
=
{
"phones"
:
phones
,
"tones"
:
tones
,
}
if
"spk_id"
in
examples
[
0
]:
spk_id
=
[
np
.
array
(
item
[
"spk_id"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
spk_id
=
np
.
array
(
spk_id
)
spk_id
=
spk_id
[
0
]
batch
[
"spk_id"
]
=
spk_id
return
batch
paddlespeech/t2s/datasets/vocoder_batch_fn.py
浏览文件 @
3f6afc48
...
...
@@ -55,13 +55,12 @@ class Clip(object):
Args:
batch (list): list of tuple of the pair of audio and features. Audio shape (T, ), features shape(T', C).
Returns:
Returns:
Tensor:
Target signal batch (B, 1, T).
Tensor:
Auxiliary feature batch (B, C, T'), where
T = (T' - 2 * aux_context_window) * hop_size.
Tensor:
Target signal batch (B, 1, T).
"""
# check length
batch
=
[
...
...
@@ -106,11 +105,7 @@ class Clip(object):
if
len
(
x
)
<
c
.
shape
[
0
]
*
self
.
hop_size
:
x
=
np
.
pad
(
x
,
(
0
,
c
.
shape
[
0
]
*
self
.
hop_size
-
len
(
x
)),
mode
=
"edge"
)
elif
len
(
x
)
>
c
.
shape
[
0
]
*
self
.
hop_size
:
# print(
# f"wave length: ({len(x)}), mel length: ({c.shape[0]}), hop size: ({self.hop_size })"
# )
x
=
x
[:
c
.
shape
[
0
]
*
self
.
hop_size
]
# check the legnth is valid
assert
len
(
x
)
==
c
.
shape
[
0
]
*
self
.
hop_size
,
f
"wave length: (
{
len
(
x
)
}
), mel length: (
{
c
.
shape
[
0
]
}
)"
...
...
@@ -218,3 +213,47 @@ class WaveRNNClip(Clip):
y
=
label_2_float
(
paddle
.
cast
(
y
,
dtype
=
'float32'
),
self
.
bits
)
return
x
,
y
,
mels
# for paddleslim
class
Clip_static
(
Clip
):
"""Collate functor for training vocoders.
"""
def
__call__
(
self
,
batch
):
"""Convert into batch tensors.
Args:
batch (list): list of tuple of the pair of audio and features. Audio shape (T, ), features shape(T', C).
Returns:
Dict[str, np.array]:
Auxiliary feature batch (B, C, T'), where
T = (T' - 2 * aux_context_window) * hop_size.
"""
# check length
batch
=
[
self
.
_adjust_length
(
b
[
'wave'
],
b
[
'feats'
])
for
b
in
batch
if
b
[
'feats'
].
shape
[
0
]
>
self
.
mel_threshold
]
xs
,
cs
=
[
b
[
0
]
for
b
in
batch
],
[
b
[
1
]
for
b
in
batch
]
# make batch with random cut
c_lengths
=
[
c
.
shape
[
0
]
for
c
in
cs
]
start_frames
=
np
.
array
([
np
.
random
.
randint
(
self
.
start_offset
,
cl
+
self
.
end_offset
)
for
cl
in
c_lengths
])
c_starts
=
start_frames
-
self
.
aux_context_window
c_ends
=
start_frames
+
self
.
batch_max_frames
+
self
.
aux_context_window
c_batch
=
np
.
stack
(
[
c
[
start
:
end
]
for
c
,
start
,
end
in
zip
(
cs
,
c_starts
,
c_ends
)])
# infer axis (T',C) is different with train axis (B, C, T')
# c_batch = c_batch.transpose([0, 2, 1]) # (B, C, T')
# do not need batch axis in infer
c_batch
=
c_batch
[
0
]
batch
=
{
"logmel"
:
c_batch
}
return
batch
paddlespeech/t2s/exps/PTQ_dynamic.py
0 → 100644
浏览文件 @
3f6afc48
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
paddle
from
paddleslim.quant
import
quant_post_dynamic
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Paddle Slim Dynamic with acoustic model & vocoder."
)
# acoustic model
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'fastspeech2_csmsc'
,
choices
=
[
'speedyspeech_csmsc'
,
'fastspeech2_csmsc'
,
'fastspeech2_aishell3'
,
'fastspeech2_ljspeech'
,
'fastspeech2_vctk'
,
'tacotron2_csmsc'
,
'fastspeech2_mix'
,
'pwgan_csmsc'
,
'pwgan_aishell3'
,
'pwgan_ljspeech'
,
'pwgan_vctk'
,
'mb_melgan_csmsc'
,
'hifigan_csmsc'
,
'hifigan_aishell3'
,
'hifigan_ljspeech'
,
'hifigan_vctk'
,
'wavernn_csmsc'
,
],
help
=
'Choose model type of tts task.'
)
parser
.
add_argument
(
"--inference_dir"
,
type
=
str
,
help
=
"dir to save inference models"
)
parser
.
add_argument
(
"--weight_bits"
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
],
help
=
"The bits for the quantized weight, and it should be 8 or 16. Default is 8."
,
)
args
,
_
=
parser
.
parse_known_args
()
return
args
# only inference for models trained with csmsc now
def
main
():
args
=
parse_args
()
paddle
.
enable_static
()
quant_post_dynamic
(
model_dir
=
args
.
inference_dir
,
save_model_dir
=
args
.
inference_dir
,
model_filename
=
args
.
model_name
+
".pdmodel"
,
params_filename
=
args
.
model_name
+
".pdiparams"
,
save_model_filename
=
args
.
model_name
+
"_"
+
str
(
args
.
weight_bits
)
+
"bits.pdmodel"
,
save_params_filename
=
args
.
model_name
+
"_"
+
str
(
args
.
weight_bits
)
+
"bits.pdiparams"
,
weight_bits
=
args
.
weight_bits
,
)
if
__name__
==
"__main__"
:
main
()
paddlespeech/t2s/exps/PTQ_static.py
0 → 100644
浏览文件 @
3f6afc48
import
argparse
import
random
import
jsonlines
import
numpy
as
np
import
paddle
from
paddleslim.quant
import
quant_post_static
from
paddlespeech.t2s.exps.syn_utils
import
get_dev_dataloader
from
paddlespeech.t2s.utils
import
str2bool
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Paddle Slim Static with acoustic model & vocoder."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Minibatch size."
)
parser
.
add_argument
(
"--batch_num"
,
type
=
int
,
default
=
1
,
help
=
"Batch number"
)
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu=0, use cpu."
)
# model_path save_path
parser
.
add_argument
(
"--inference_dir"
,
type
=
str
,
help
=
"dir to save inference models"
)
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'fastspeech2_csmsc'
,
choices
=
[
'speedyspeech_csmsc'
,
'fastspeech2_csmsc'
,
'fastspeech2_aishell3'
,
'fastspeech2_ljspeech'
,
'fastspeech2_vctk'
,
'fastspeech2_mix'
,
'pwgan_csmsc'
,
'pwgan_aishell3'
,
'pwgan_ljspeech'
,
'pwgan_vctk'
,
'mb_melgan_csmsc'
,
'hifigan_csmsc'
,
'hifigan_aishell3'
,
'hifigan_ljspeech'
,
'hifigan_vctk'
,
],
help
=
'Choose model type of tts task.'
)
parser
.
add_argument
(
"--algo"
,
type
=
str
,
default
=
'avg'
,
help
=
"calibration algorithm."
)
parser
.
add_argument
(
"--round_type"
,
type
=
str
,
default
=
'round'
,
help
=
"The method of converting the quantized weights."
)
parser
.
add_argument
(
"--hist_percent"
,
type
=
float
,
default
=
0.9999
,
help
=
"The percentile of algo:hist."
)
parser
.
add_argument
(
"--is_full_quantize"
,
type
=
str2bool
,
default
=
False
,
help
=
"Whether is full quantization or not."
)
parser
.
add_argument
(
"--bias_correction"
,
type
=
str2bool
,
default
=
False
,
help
=
"Whether to use bias correction."
)
parser
.
add_argument
(
"--ce_test"
,
type
=
str2bool
,
default
=
False
,
help
=
"Whether to CE test."
)
parser
.
add_argument
(
"--onnx_format"
,
type
=
str2bool
,
default
=
False
,
help
=
"Whether to export the quantized model with format of ONNX."
)
parser
.
add_argument
(
"--phones-dict"
,
type
=
str
,
default
=
None
,
help
=
"phone vocabulary file."
)
parser
.
add_argument
(
"--speaker-dict"
,
type
=
str
,
default
=
None
,
help
=
"speaker id map file for multiple speaker model."
)
parser
.
add_argument
(
"--dev-metadata"
,
type
=
str
,
help
=
"dev data."
)
parser
.
add_argument
(
"--quantizable_op_type"
,
type
=
list
,
nargs
=
'+'
,
default
=
[
"conv2d_transpose"
,
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
"matmul"
,
"matmul_v2"
],
help
=
"The list of op types that will be quantized."
)
args
=
parser
.
parse_args
()
return
args
def
quantize
(
args
):
shuffle
=
True
if
args
.
ce_test
:
# set seed
seed
=
111
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
random
.
seed
(
seed
)
shuffle
=
False
place
=
paddle
.
CUDAPlace
(
0
)
if
args
.
ngpu
>
0
else
paddle
.
CPUPlace
()
with
jsonlines
.
open
(
args
.
dev_metadata
,
'r'
)
as
reader
:
dev_metadata
=
list
(
reader
)
dataloader
=
get_dev_dataloader
(
dev_metadata
=
dev_metadata
,
am
=
args
.
model_name
,
batch_size
=
args
.
batch_size
,
speaker_dict
=
args
.
speaker_dict
,
shuffle
=
shuffle
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
()
print
(
"onnx_format:"
,
args
.
onnx_format
)
quant_post_static
(
executor
=
exe
,
model_dir
=
args
.
inference_dir
,
quantize_model_path
=
args
.
inference_dir
+
"/"
+
args
.
model_name
+
"_quant"
,
data_loader
=
dataloader
,
model_filename
=
args
.
model_name
+
".pdmodel"
,
params_filename
=
args
.
model_name
+
".pdiparams"
,
save_model_filename
=
args
.
model_name
+
".pdmodel"
,
save_params_filename
=
args
.
model_name
+
".pdiparams"
,
batch_size
=
args
.
batch_size
,
algo
=
args
.
algo
,
round_type
=
args
.
round_type
,
hist_percent
=
args
.
hist_percent
,
is_full_quantize
=
args
.
is_full_quantize
,
bias_correction
=
args
.
bias_correction
,
onnx_format
=
args
.
onnx_format
,
quantizable_op_type
=
args
.
quantizable_op_type
)
def
main
():
args
=
parse_args
()
new_quantizable_op_type
=
[]
for
item
in
args
.
quantizable_op_type
:
new_quantizable_op_type
.
append
(
''
.
join
(
item
))
args
.
quantizable_op_type
=
new_quantizable_op_type
paddle
.
enable_static
()
quantize
(
args
)
if
__name__
==
'__main__'
:
main
()
paddlespeech/t2s/exps/syn_utils.py
浏览文件 @
3f6afc48
...
...
@@ -25,10 +25,13 @@ import onnxruntime as ort
import
paddle
from
paddle
import
inference
from
paddle
import
jit
from
paddle.io
import
DataLoader
from
paddle.static
import
InputSpec
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.datasets.am_batch_fn
import
*
from
paddlespeech.t2s.datasets.data_table
import
DataTable
from
paddlespeech.t2s.datasets.vocoder_batch_fn
import
Clip_static
from
paddlespeech.t2s.frontend
import
English
from
paddlespeech.t2s.frontend.mix_frontend
import
MixFrontend
from
paddlespeech.t2s.frontend.zh_frontend
import
Frontend
...
...
@@ -118,6 +121,7 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
return
sentences
# am only
def
get_test_dataset
(
test_metadata
:
List
[
Dict
[
str
,
Any
]],
am
:
str
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
...
...
@@ -158,6 +162,100 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
return
test_dataset
# am and voc, for PTQ_static
def
get_dev_dataloader
(
dev_metadata
:
List
[
Dict
[
str
,
Any
]],
am
:
str
,
batch_size
:
int
=
1
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
voice_cloning
:
bool
=
False
,
n_shift
:
int
=
300
,
batch_max_steps
:
int
=
16200
,
shuffle
:
bool
=
True
):
# model: {model_name}_{dataset}
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
converters
=
{}
if
am_name
==
'fastspeech2'
:
fields
=
[
"utt_id"
,
"text"
]
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
"mix"
}
and
speaker_dict
is
not
None
:
print
(
"multiple speaker fastspeech2!"
)
collate_fn
=
fastspeech2_multi_spk_batch_fn_static
fields
+=
[
"spk_id"
]
elif
voice_cloning
:
print
(
"voice cloning!"
)
collate_fn
=
fastspeech2_multi_spk_batch_fn_static
fields
+=
[
"spk_emb"
]
else
:
print
(
"single speaker fastspeech2!"
)
collate_fn
=
fastspeech2_single_spk_batch_fn_static
elif
am_name
==
'speedyspeech'
:
fields
=
[
"utt_id"
,
"phones"
,
"tones"
]
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
"mix"
}
and
speaker_dict
is
not
None
:
print
(
"multiple speaker speedyspeech!"
)
collate_fn
=
speedyspeech_multi_spk_batch_fn_static
fields
+=
[
"spk_id"
]
else
:
print
(
"single speaker speedyspeech!"
)
collate_fn
=
speedyspeech_single_spk_batch_fn_static
fields
=
[
"utt_id"
,
"phones"
,
"tones"
]
elif
am_name
==
'tacotron2'
:
fields
=
[
"utt_id"
,
"text"
]
if
voice_cloning
:
print
(
"voice cloning!"
)
collate_fn
=
tacotron2_multi_spk_batch_fn_static
fields
+=
[
"spk_emb"
]
else
:
print
(
"single speaker tacotron2!"
)
collate_fn
=
tacotron2_single_spk_batch_fn_static
else
:
print
(
"voc dataloader"
)
# am
if
am_name
not
in
{
'pwgan'
,
'mb_melgan'
,
'hifigan'
}:
dev_dataset
=
DataTable
(
data
=
dev_metadata
,
fields
=
fields
,
converters
=
converters
,
)
dev_dataloader
=
DataLoader
(
dev_dataset
,
shuffle
=
shuffle
,
drop_last
=
False
,
batch_size
=
batch_size
,
collate_fn
=
collate_fn
)
# vocoder
else
:
# pwgan: batch_max_steps: 25500 aux_context_window: 2
# mb_melgan: batch_max_steps: 16200 aux_context_window 0
# hifigan: batch_max_steps: 8400 aux_context_window 0
aux_context_window
=
0
if
am_name
==
'pwgan'
:
aux_context_window
=
2
train_batch_fn
=
Clip_static
(
batch_max_steps
=
batch_max_steps
,
hop_size
=
n_shift
,
aux_context_window
=
aux_context_window
)
dev_dataset
=
DataTable
(
data
=
dev_metadata
,
fields
=
[
"wave"
,
"feats"
],
converters
=
{
"wave"
:
np
.
load
,
"feats"
:
np
.
load
,
},
)
dev_dataloader
=
DataLoader
(
dev_dataset
,
shuffle
=
shuffle
,
drop_last
=
False
,
batch_size
=
batch_size
,
collate_fn
=
train_batch_fn
)
return
dev_dataloader
# frontend
def
get_frontend
(
lang
:
str
=
'zh'
,
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录