Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
03e9ea9e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“c299a141ca87fb4022120ae2153fd27b425afce0”上不存在“micro/git@gitcode.net:qq_37101384/mace.git”
提交
03e9ea9e
编写于
7月 12, 2023
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add roformer
上级
94987f26
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
367 addition
and
12 deletion
+367
-12
examples/aishell/asr1/conf/chunk_roformer.yaml
examples/aishell/asr1/conf/chunk_roformer.yaml
+98
-0
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
+98
-0
paddlespeech/s2t/modules/attention.py
paddlespeech/s2t/modules/attention.py
+133
-1
paddlespeech/s2t/modules/embedding.py
paddlespeech/s2t/modules/embedding.py
+5
-2
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+30
-6
paddlespeech/s2t/modules/encoder_layer.py
paddlespeech/s2t/modules/encoder_layer.py
+3
-3
未找到文件。
examples/aishell/asr1/conf/chunk_roformer.yaml
0 → 100644
浏览文件 @
03e9ea9e
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
encoder_conf
:
output_size
:
256
# dimension of attention
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rpoe_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
cnn_module_norm
:
'
layer_norm'
# using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk
:
false
# decoder related
decoder
:
transformer
# transformer, bitransformer
decoder_conf
:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
6
r_num_blocks
:
3
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.3
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
spm_model_prefix
:
'
'
unit_type
:
'
char'
preprocess_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
###########################################
# Training #
###########################################
n_epoch
:
240
accum_grad
:
1
global_grad_clip
:
5.0
dist_sampler
:
True
optim
:
adam
optim_conf
:
lr
:
0.001
weight_decay
:
1.0e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
0 → 100644
浏览文件 @
03e9ea9e
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
encoder_conf
:
output_size
:
256
# dimension of attention
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rpoe_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
cnn_module_norm
:
'
layer_norm'
# using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk
:
false
# decoder related
decoder
:
bitransformer
# transformer, bitransformer
decoder_conf
:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
3
r_num_blocks
:
3
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.3
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
spm_model_prefix
:
'
'
unit_type
:
'
char'
preprocess_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
###########################################
# Training #
###########################################
n_epoch
:
240
accum_grad
:
1
global_grad_clip
:
5.0
dist_sampler
:
True
optim
:
adam
optim_conf
:
lr
:
0.001
weight_decay
:
1.0e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
paddlespeech/s2t/modules/attention.py
浏览文件 @
03e9ea9e
...
...
@@ -26,7 +26,10 @@ from paddlespeech.s2t.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"MultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
]
__all__
=
[
"MultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
,
"RoPERelPositionMultiHeadedAttention"
]
# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
...
...
@@ -165,6 +168,7 @@ class MultiHeadedAttention(nn.Layer):
and `head * d_k == size`
"""
# (B,T,D) -> (B,T,H,D/H)
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# when export onnx model, for 1st chunk, we feed
...
...
@@ -373,3 +377,131 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self
.
d_k
)
# (batch, head, time1, time2)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
class
RoPERelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with RoPE relative position encoding."""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
adaptive_scale
=
False
,
init_weights
=
False
):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
def
align
(
self
,
tensor
:
paddle
.
Tensor
,
axes
:
List
[
int
],
ndim
=
None
):
"""重新对齐tensor(批量版expand_dims)
axes:原来的第i维对齐新tensor的第axes[i]维;
ndim:新tensor的维度。
"""
assert
len
(
axes
)
==
tensor
.
dim
()
assert
ndim
or
min
(
axes
)
>=
0
ndim
=
ndim
or
max
(
axes
)
+
1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices
=
[
None
]
*
ndim
for
i
in
axes
:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices
[
i
]
=
slice
(
None
)
return
tensor
[
indices
]
def
apply_rotary_position_embeddings
(
self
,
sinusoidal
,
*
tensors
):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,T,H,D/H)
"""
assert
len
(
tensors
)
>
0
,
'at least one input tensor'
assert
all
(
[
tensor
.
shape
==
tensors
[
0
].
shape
for
tensor
in
tensors
[
1
:]]),
'all tensors must have the same shape'
ndim
=
tensors
[
0
].
dim
()
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D]
sinusoidal
=
self
.
align
(
sinusoidal
,
[
0
,
1
,
-
1
],
ndim
)
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
1
::
2
],
2
,
axis
=-
1
)
sin_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
0
::
2
],
2
,
axis
=-
1
)
outputs
=
[]
for
tensor
in
tensors
:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2
=
paddle
.
stack
([
-
tensor
[...,
1
::
2
],
tensor
[...,
::
2
]],
ndim
)
tensor2
=
paddle
.
reshape
(
tensor2
,
paddle
.
shape
(
tensor
))
# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs
.
append
(
tensor
*
cos_pos
+
tensor2
*
sin_pos
)
return
outputs
[
0
]
if
len
(
outputs
)
==
1
else
outputs
def
forward
(
self
,
query
:
paddle
.
Tensor
,
key
:
paddle
.
Tensor
,
value
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
),
pos_emb
:
paddle
.
Tensor
=
paddle
.
empty
([
0
]),
cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
cache
.
shape
[
0
]
>
0
:
# last dim `d_k * 2` for (key, val)
key_cache
,
value_cache
=
paddle
.
split
(
cache
,
2
,
axis
=-
1
)
k
=
paddle
.
concat
([
key_cache
,
k
],
axis
=
2
)
v
=
paddle
.
concat
([
value_cache
,
v
],
axis
=
2
)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
paddle
.
concat
((
k
,
v
),
axis
=-
1
)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q
,
k
=
self
.
apply_rotary_position_embeddings
(
pos_emb
,
[
q
,
k
])
# dot(q, k)
scores
=
paddle
.
matmul
(
q
,
k
,
transpose_y
=
True
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
paddlespeech/s2t/modules/embedding.py
浏览文件 @
03e9ea9e
...
...
@@ -89,14 +89,17 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
self
.
max_len
=
max_len
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
base
=
10000.0
self
.
pe
=
paddle
.
zeros
([
1
,
self
.
max_len
,
self
.
d_model
])
#[B=1,T,D]
position
=
paddle
.
arange
(
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
-
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
(
math
.
log
(
self
.
base
)
/
self
.
d_model
))
# [B,T,D]
self
.
pe
[:,
:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
self
.
pe
[:,
:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
...
...
paddlespeech/s2t/modules/encoder.py
浏览文件 @
03e9ea9e
...
...
@@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RoPERelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.conformer_convolution
import
ConvolutionModule
from
paddlespeech.s2t.modules.embedding
import
NoPositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
...
...
@@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"rope_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
else
:
...
...
@@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer):
chunk_size
=
xs
.
shape
[
1
]
attention_key_size
=
cache_t1
+
chunk_size
# only used when using `RelPositionMultiHeadedAttention`
# only used when using `RelPositionMultiHeadedAttention`
and `RoPERelPositionMultiHeadedAttention`
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
...
...
@@ -474,9 +477,22 @@ class ConformerEncoder(BaseEncoder):
activation
=
get_activation
(
activation_type
)
# self-attention module definition
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
if
pos_enc_layer_type
==
"abs_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rope_pos"
:
encoder_selfattn_layer
=
RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
else
:
raise
ValueError
(
f
"pos_enc_layer_type
{
pos_enc_layer_type
}
not supported."
)
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
output_size
,
linear_units
,
dropout_rate
,
...
...
@@ -580,15 +596,23 @@ class SqueezeformerEncoder(nn.Layer):
activation
=
get_activation
(
activation_type
)
# self-attention module definition
if
pos_enc_layer_type
!=
"rel
_pos"
:
if
pos_enc_layer_type
==
"abs
_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
el
se
:
el
if
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
adaptive_scale
,
init_weights
)
elif
pos_enc_layer_type
==
"rope_pos"
:
encoder_selfattn_layer
=
RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
adaptive_scale
,
init_weights
)
else
:
raise
ValueError
(
f
"pos_enc_layer_type
{
pos_enc_layer_type
}
not supported."
)
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
...
...
paddlespeech/s2t/modules/encoder_layer.py
浏览文件 @
03e9ea9e
...
...
@@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
...
...
@@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
...
...
@@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录