Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
0fc79f47
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0fc79f47
编写于
3月 30, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add CNNDecoder, test=tts
上级
b5315657
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
417 addition
and
10 deletion
+417
-10
examples/csmsc/tts3/conf/cnndecoder.yaml
examples/csmsc/tts3/conf/cnndecoder.yaml
+107
-0
examples/csmsc/tts3/local/synthesize_streaming.sh
examples/csmsc/tts3/local/synthesize_streaming.sh
+92
-0
examples/csmsc/tts3/run_cnndecoder.sh
examples/csmsc/tts3/run_cnndecoder.sh
+48
-0
paddlespeech/t2s/models/fastspeech2/fastspeech2.py
paddlespeech/t2s/models/fastspeech2/fastspeech2.py
+37
-10
paddlespeech/t2s/modules/transformer/encoder.py
paddlespeech/t2s/modules/transformer/encoder.py
+133
-0
未找到文件。
examples/csmsc/tts3/conf/cnndecoder.yaml
0 → 100644
浏览文件 @
0fc79f47
# use CNND
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs
:
24000
# sr
n_fft
:
2048
# FFT size (samples).
n_shift
:
300
# Hop size (samples). 12.5ms
win_length
:
1200
# Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window
:
"
hann"
# Window function.
# Only used for feats_type != raw
fmin
:
80
# Minimum frequency of Mel basis.
fmax
:
7600
# Maximum frequency of Mel basis.
n_mels
:
80
# The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min
:
80
# Minimum f0 for pitch extraction.
f0max
:
400
# Maximum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size
:
64
num_workers
:
4
###########################################################
# MODEL SETTING #
###########################################################
model
:
adim
:
384
# attention dimension
aheads
:
2
# number of attention heads
elayers
:
4
# number of encoder layers
eunits
:
1536
# number of encoder ff units
dlayers
:
4
# number of decoder layers
dunits
:
1536
# number of decoder ff units
positionwise_layer_type
:
conv1d
# type of position-wise layer
positionwise_conv_kernel_size
:
3
# kernel size of position wise conv layer
duration_predictor_layers
:
2
# number of layers of duration predictor
duration_predictor_chans
:
256
# number of channels of duration predictor
duration_predictor_kernel_size
:
3
# filter size of duration predictor
postnet_layers
:
5
# number of layers of postnset
postnet_filts
:
5
# filter size of conv layers in postnet
postnet_chans
:
256
# number of channels of conv layers in postnet
use_scaled_pos_enc
:
True
# whether to use scaled positional encoding
encoder_normalize_before
:
True
# whether to perform layer normalization before the input
decoder_normalize_before
:
True
# whether to perform layer normalization before the input
reduction_factor
:
1
# reduction factor
encoder_type
:
transformer
# encoder type
decoder_type
:
cnndecoder
# decoder type
init_type
:
xavier_uniform
# initialization type
init_enc_alpha
:
1.0
# initial value of alpha of encoder scaled position encoding
init_dec_alpha
:
1.0
# initial value of alpha of decoder scaled position encoding
transformer_enc_dropout_rate
:
0.2
# dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate
:
0.2
# dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate
:
0.2
# dropout rate for transformer encoder attention layer
cnn_dec_dropout_rate
:
0.2
# dropout rate for cnn decoder layer
cnn_postnet_dropout_rate
:
0.2
cnn_postnet_resblock_kernel_sizes
:
[
256
,
256
]
# kernel sizes for residual block of cnn_postnet
cnn_postnet_kernel_size
:
5
# kernel size of cnn_postnet
cnn_decoder_embedding_dim
:
256
pitch_predictor_layers
:
5
# number of conv layers in pitch predictor
pitch_predictor_chans
:
256
# number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size
:
5
# kernel size of conv leyers in pitch predictor
pitch_predictor_dropout
:
0.5
# dropout rate in pitch predictor
pitch_embed_kernel_size
:
1
# kernel size of conv embedding layer for pitch
pitch_embed_dropout
:
0.0
# dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor
:
True
# whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers
:
2
# number of conv layers in energy predictor
energy_predictor_chans
:
256
# number of channels of conv layers in energy predictor
energy_predictor_kernel_size
:
3
# kernel size of conv leyers in energy predictor
energy_predictor_dropout
:
0.5
# dropout rate in energy predictor
energy_embed_kernel_size
:
1
# kernel size of conv embedding layer for energy
energy_embed_dropout
:
0.0
# dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor
:
False
# whether to stop the gradient from energy predictor to encoder
###########################################################
# UPDATER SETTING #
###########################################################
updater
:
use_masking
:
True
# whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer
:
optim
:
adam
# optimizer type
learning_rate
:
0.001
# learning rate
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch
:
1000
num_snapshots
:
5
###########################################################
# OTHER SETTING #
###########################################################
seed
:
10086
examples/csmsc/tts3/local/synthesize_streaming.sh
0 → 100755
浏览文件 @
0fc79f47
#!/bin/bash
config_path
=
$1
train_output_path
=
$2
ckpt_name
=
$3
stage
=
0
stop_stage
=
0
# pwgan
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/../synthesize_streaming.py
\
--am
=
fastspeech2_csmsc
\
--am_config
=
${
config_path
}
\
--am_ckpt
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--am_stat
=
dump/train/speech_stats.npy
\
--voc
=
pwgan_csmsc
\
--voc_config
=
pwg_baker_ckpt_0.4/pwg_default.yaml
\
--voc_ckpt
=
pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz
\
--voc_stat
=
pwg_baker_ckpt_0.4/pwg_stats.npy
\
--lang
=
zh
\
--text
=
${
BIN_DIR
}
/../sentences.txt
\
--output_dir
=
${
train_output_path
}
/test_e2e
\
--phones_dict
=
dump/phone_id_map.txt
\
--inference_dir
=
${
train_output_path
}
/inference
fi
# for more GAN Vocoders
# multi band melgan
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/../synthesize_streaming.py
\
--am
=
fastspeech2_csmsc
\
--am_config
=
${
config_path
}
\
--am_ckpt
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--am_stat
=
dump/train/speech_stats.npy
\
--voc
=
mb_melgan_csmsc
\
--voc_config
=
mb_melgan_csmsc_ckpt_0.1.1/default.yaml
\
--voc_ckpt
=
mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz
\
--voc_stat
=
mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy
\
--lang
=
zh
\
--text
=
${
BIN_DIR
}
/../sentences.txt
\
--output_dir
=
${
train_output_path
}
/test_e2e
\
--phones_dict
=
dump/phone_id_map.txt
\
--inference_dir
=
${
train_output_path
}
/inference
fi
# the pretrained models haven't release now
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/../synthesize_streaming.py
\
--am
=
fastspeech2_csmsc
\
--am_config
=
${
config_path
}
\
--am_ckpt
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--am_stat
=
dump/train/speech_stats.npy
\
--voc
=
style_melgan_csmsc
\
--voc_config
=
style_melgan_csmsc_ckpt_0.1.1/default.yaml
\
--voc_ckpt
=
style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz
\
--voc_stat
=
style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy
\
--lang
=
zh
\
--text
=
${
BIN_DIR
}
/../sentences.txt
\
--output_dir
=
${
train_output_path
}
/test_e2e
\
--phones_dict
=
dump/phone_id_map.txt
# --inference_dir=${train_output_path}/inference
fi
# hifigan
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
echo
"in hifigan syn_e2e"
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/../synthesize_streaming.py
\
--am
=
fastspeech2_csmsc
\
--am_config
=
${
config_path
}
\
--am_ckpt
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--am_stat
=
dump/train/speech_stats.npy
\
--voc
=
hifigan_csmsc
\
--voc_config
=
hifigan_csmsc_ckpt_0.1.1/default.yaml
\
--voc_ckpt
=
hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz
\
--voc_stat
=
hifigan_csmsc_ckpt_0.1.1/feats_stats.npy
\
--lang
=
zh
\
--text
=
${
BIN_DIR
}
/../sentences.txt
\
--output_dir
=
${
train_output_path
}
/test_e2e
\
--phones_dict
=
dump/phone_id_map.txt
\
--inference_dir
=
${
train_output_path
}
/inference
fi
examples/csmsc/tts3/run_cnndecoder.sh
0 → 100755
浏览文件 @
0fc79f47
#!/bin/bash
set
-e
source
path.sh
gpus
=
0,1
stage
=
0
stop_stage
=
100
conf_path
=
conf/cnndecoder.yaml
train_output_path
=
exp/cnndecoder
ckpt_name
=
snapshot_iter_153.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# prepare data
./local/preprocess.sh
${
conf_path
}
||
exit
-1
fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
train_output_path
}
||
exit
-1
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize_e2e.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# inference with static model
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/inference.sh
${
train_output_path
}
||
exit
-1
fi
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize_streaming.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
paddlespeech/t2s/models/fastspeech2/fastspeech2.py
浏览文件 @
0fc79f47
...
...
@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
"""Fastspeech2 related modules for paddle"""
from
typing
import
Dict
from
typing
import
List
from
typing
import
Sequence
from
typing
import
Tuple
from
typing
import
Union
...
...
@@ -32,6 +33,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredic
from
paddlespeech.t2s.modules.predictor.length_regulator
import
LengthRegulator
from
paddlespeech.t2s.modules.predictor.variance_predictor
import
VariancePredictor
from
paddlespeech.t2s.modules.tacotron2.decoder
import
Postnet
from
paddlespeech.t2s.modules.transformer.encoder
import
CNNDecoder
from
paddlespeech.t2s.modules.transformer.encoder
import
CNNPostnet
from
paddlespeech.t2s.modules.transformer.encoder
import
ConformerEncoder
from
paddlespeech.t2s.modules.transformer.encoder
import
TransformerEncoder
...
...
@@ -97,6 +100,12 @@ class FastSpeech2(nn.Layer):
zero_triu
:
bool
=
False
,
conformer_enc_kernel_size
:
int
=
7
,
conformer_dec_kernel_size
:
int
=
31
,
# for CNN Decoder
cnn_dec_dropout_rate
:
float
=
0.2
,
cnn_postnet_dropout_rate
:
float
=
0.2
,
cnn_postnet_resblock_kernel_sizes
:
List
[
int
]
=
[
256
,
256
],
cnn_postnet_kernel_size
:
int
=
5
,
cnn_decoder_embedding_dim
:
int
=
256
,
# duration predictor
duration_predictor_layers
:
int
=
2
,
duration_predictor_chans
:
int
=
384
,
...
...
@@ -392,6 +401,13 @@ class FastSpeech2(nn.Layer):
activation_type
=
conformer_activation_type
,
use_cnn_module
=
use_cnn_in_conformer
,
cnn_module_kernel
=
conformer_dec_kernel_size
,
)
elif
decoder_type
==
'cnndecoder'
:
self
.
decoder
=
CNNDecoder
(
emb_dim
=
adim
,
odim
=
odim
,
kernel_size
=
cnn_postnet_kernel_size
,
dropout_rate
=
cnn_dec_dropout_rate
,
resblock_kernel_sizes
=
cnn_postnet_resblock_kernel_sizes
)
else
:
raise
ValueError
(
f
"
{
decoder_type
}
is not supported."
)
...
...
@@ -399,14 +415,21 @@ class FastSpeech2(nn.Layer):
self
.
feat_out
=
nn
.
Linear
(
adim
,
odim
*
reduction_factor
)
# define postnet
self
.
postnet
=
(
None
if
postnet_layers
==
0
else
Postnet
(
idim
=
idim
,
odim
=
odim
,
n_layers
=
postnet_layers
,
n_chans
=
postnet_chans
,
n_filts
=
postnet_filts
,
use_batch_norm
=
use_batch_norm
,
dropout_rate
=
postnet_dropout_rate
,
))
if
decoder_type
==
'cnndecoder'
:
self
.
postnet
=
CNNPostnet
(
odim
=
odim
,
kernel_size
=
cnn_postnet_kernel_size
,
dropout_rate
=
cnn_postnet_dropout_rate
,
resblock_kernel_sizes
=
cnn_postnet_resblock_kernel_sizes
)
else
:
self
.
postnet
=
(
None
if
postnet_layers
==
0
else
Postnet
(
idim
=
idim
,
odim
=
odim
,
n_layers
=
postnet_layers
,
n_chans
=
postnet_chans
,
n_filts
=
postnet_filts
,
use_batch_norm
=
use_batch_norm
,
dropout_rate
=
postnet_dropout_rate
,
))
nn
.
initializer
.
set_global_initializer
(
None
)
...
...
@@ -562,6 +585,7 @@ class FastSpeech2(nn.Layer):
[
olen
//
self
.
reduction_factor
for
olen
in
olens
.
numpy
()])
else
:
olens_in
=
olens
# (B, 1, T)
h_masks
=
self
.
_source_mask
(
olens_in
)
else
:
h_masks
=
None
...
...
@@ -569,8 +593,11 @@ class FastSpeech2(nn.Layer):
zs
,
_
=
self
.
decoder
(
hs
,
h_masks
)
# (B, Lmax, odim)
before_outs
=
self
.
feat_out
(
zs
).
reshape
(
(
paddle
.
shape
(
zs
)[
0
],
-
1
,
self
.
odim
))
if
self
.
decoder_type
==
'cnndecoder'
:
before_outs
=
zs
else
:
before_outs
=
self
.
feat_out
(
zs
).
reshape
(
(
paddle
.
shape
(
zs
)[
0
],
-
1
,
self
.
odim
))
# postnet -> (B, Lmax//r * r, odim)
if
self
.
postnet
is
None
:
...
...
paddlespeech/t2s/modules/transformer/encoder.py
浏览文件 @
0fc79f47
...
...
@@ -515,3 +515,136 @@ class ConformerEncoder(BaseEncoder):
if
self
.
intermediate_layers
is
not
None
:
return
xs
,
masks
,
intermediate_outputs
return
xs
,
masks
class
Conv1dResidualBlock
(
nn
.
Layer
):
"""
Special module for simplified version of Encoder class.
"""
def
__init__
(
self
,
idim
:
int
=
256
,
odim
:
int
=
256
,
kernel_size
:
int
=
5
,
dropout_rate
:
float
=
0.2
):
super
().
__init__
()
self
.
main_block
=
nn
.
Sequential
(
nn
.
Conv1D
(
idim
,
odim
,
kernel_size
=
kernel_size
,
padding
=
kernel_size
//
2
),
nn
.
ReLU
(),
nn
.
BatchNorm1D
(
odim
),
nn
.
Dropout
(
p
=
dropout_rate
))
self
.
conv1d_residual
=
nn
.
Conv1D
(
idim
,
odim
,
kernel_size
=
1
)
def
forward
(
self
,
xs
):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, idim, T).
Returns:
Tensor: Output tensor (#batch, odim, T).
"""
outputs
=
self
.
main_block
(
xs
)
outputs
=
self
.
conv1d_residual
(
xs
)
+
outputs
return
outputs
class
CNNDecoder
(
nn
.
Layer
):
"""
Much simplified decoder than the original one with Prenet.
"""
def
__init__
(
self
,
emb_dim
:
int
=
256
,
odim
:
int
=
80
,
kernel_size
:
int
=
5
,
dropout_rate
:
float
=
0.2
,
resblock_kernel_sizes
:
List
[
int
]
=
[
256
,
256
],
):
super
().
__init__
()
input_shape
=
emb_dim
out_sizes
=
resblock_kernel_sizes
out_sizes
.
append
(
out_sizes
[
-
1
])
in_sizes
=
[
input_shape
]
+
out_sizes
[:
-
1
]
self
.
residual_blocks
=
nn
.
LayerList
([
Conv1dResidualBlock
(
idim
=
in_channels
,
odim
=
out_channels
,
kernel_size
=
kernel_size
,
dropout_rate
=
dropout_rate
,
)
for
in_channels
,
out_channels
in
zip
(
in_sizes
,
out_sizes
)
])
self
.
conv1d
=
nn
.
Conv1D
(
in_channels
=
out_sizes
[
-
1
],
out_channels
=
odim
,
kernel_size
=
1
)
def
forward
(
self
,
xs
,
masks
=
None
):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, time, idim).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, time, odim).
"""
# print("input.shape in CNNDecoder:",xs.shape)
# exchange the temporal dimension and the feature dimension
xs
=
xs
.
transpose
([
0
,
2
,
1
])
if
masks
is
not
None
:
xs
=
xs
*
masks
for
layer
in
self
.
residual_blocks
:
outputs
=
layer
(
xs
)
if
masks
is
not
None
:
# input_mask B * 1 * T
outputs
=
outputs
*
masks
xs
=
outputs
outputs
=
self
.
conv1d
(
outputs
)
if
masks
is
not
None
:
outputs
=
outputs
*
masks
outputs
=
outputs
.
transpose
([
0
,
2
,
1
])
# print("outputs.shape in CNNDecoder:",outputs.shape)
return
outputs
,
masks
class
CNNPostnet
(
nn
.
Layer
):
def
__init__
(
self
,
odim
:
int
=
80
,
kernel_size
:
int
=
5
,
dropout_rate
:
float
=
0.2
,
resblock_kernel_sizes
:
List
[
int
]
=
[
256
,
256
],
):
super
().
__init__
()
out_sizes
=
resblock_kernel_sizes
in_sizes
=
[
odim
]
+
out_sizes
[:
-
1
]
self
.
residual_blocks
=
nn
.
LayerList
([
Conv1dResidualBlock
(
idim
=
in_channels
,
odim
=
out_channels
,
kernel_size
=
kernel_size
,
dropout_rate
=
dropout_rate
)
for
in_channels
,
out_channels
in
zip
(
in_sizes
,
out_sizes
)
])
self
.
conv1d
=
nn
.
Conv1D
(
in_channels
=
out_sizes
[
-
1
],
out_channels
=
odim
,
kernel_size
=
1
)
def
forward
(
self
,
xs
,
masks
=
None
):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, odim, time).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, odim, time).
"""
# print("xs.shape in CNNPostnet:",xs.shape)
for
layer
in
self
.
residual_blocks
:
outputs
=
layer
(
xs
)
if
masks
is
not
None
:
# input_mask B * 1 * T
outputs
=
outputs
*
masks
xs
=
outputs
outputs
=
self
.
conv1d
(
outputs
)
if
masks
is
not
None
:
outputs
=
outputs
*
masks
# print("outputs.shape in CNNPostnet:",outputs.shape)
return
outputs
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录