Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c809530e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c809530e
编写于
5月 17, 2021
作者:
S
ShenLiang
提交者:
GitHub
5月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[HybridParallel]Fix precision problem of model parallel (#32897)
* fix precision of mp * fix bug of seed * fix dp * print group
上级
906db719
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
151 addition
and
58 deletion
+151
-58
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+7
-0
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+4
-1
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+12
-3
python/paddle/distributed/fleet/base/topology.py
python/paddle/distributed/fleet/base/topology.py
+3
-3
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
...ptimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
+1
-1
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
...optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
+2
-2
python/paddle/distributed/fleet/meta_parallel/__init__.py
python/paddle/distributed/fleet/meta_parallel/__init__.py
+1
-1
python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/mp_layers.py
+97
-38
python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py
...distributed/fleet/meta_parallel/parallel_layers/random.py
+9
-4
python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py
...paddle/distributed/fleet/meta_parallel/tensor_parallel.py
+3
-3
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
...on/paddle/distributed/fleet/utils/hybrid_parallel_util.py
+9
-1
python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py
...paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py
+1
-1
python/paddle/fluid/tests/unittests/new_group.py
python/paddle/fluid/tests/unittests/new_group.py
+1
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
c809530e
...
...
@@ -141,6 +141,7 @@ message PipelineConfig {
message
TensorParallelConfig
{
optional
int32
tensor_parallel_degree
=
1
[
default
=
1
];
optional
int32
tensor_init_seed
=
2
[
default
=
-
1
];
}
message
DistributedStrategy
{
...
...
python/paddle/distributed/collective.py
浏览文件 @
c809530e
...
...
@@ -99,6 +99,13 @@ class Group():
else
:
return
-
1
def
__repr__
(
self
):
debug_str
=
"rank: {}, nranks: {}, id: {}, ranks: "
.
format
(
self
.
rank
,
self
.
nranks
,
self
.
id
)
debug_str
+=
", "
.
join
(
map
(
str
,
self
.
ranks
))
debug_str
+=
". "
return
debug_str
_global_env
=
None
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
c809530e
...
...
@@ -949,6 +949,8 @@ class DistributedStrategy(object):
**Notes**:
**Detailed arguments for tensor_parallel_configs**
**tensor_parallel_degree**: degree of tensor parallel
**tensor_init_seed**: parameter initialization random seed
Examples:
...
...
@@ -957,7 +959,8 @@ class DistributedStrategy(object):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4}
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4,
"tensor_init_seed": 123}
"""
return
get_msg_dict
(
self
.
strategy
.
tensor_parallel_configs
)
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
c809530e
...
...
@@ -17,6 +17,7 @@ import copy
import
warnings
import
paddle
import
os
import
numpy
as
np
from
paddle.fluid.framework
import
dygraph_only
from
paddle.fluid
import
compiler
from
.role_maker
import
UserDefinedRoleMaker
,
PaddleCloudRoleMaker
,
RoleMakerBase
...
...
@@ -28,7 +29,7 @@ from paddle.fluid.wrapped_decorator import wrap_decorator
from
paddle.fluid.dygraph
import
parallel_helper
from
.
import
topology
as
tp
from
.topology
import
ParallelMode
from
..meta_parallel
import
ModelParallel
from
..meta_parallel
import
TensorParallel
,
model_parallel_random_seed
from
..meta_parallel
import
PipelineParallel
from
..meta_optimizers
import
HybridParallelOptimizer
from
..meta_optimizers
import
HybridParallelGradScaler
...
...
@@ -279,6 +280,14 @@ class Fleet(object):
self
.
_hcg
=
tp
.
HybridCommunicateGroup
(
self
.
_topology
)
if
self
.
mp_degree
>
1
:
tensor_parallel_configs
=
self
.
_user_defined_strategy
.
tensor_parallel_configs
tensor_init_seed
=
tensor_parallel_configs
[
"tensor_init_seed"
]
if
tensor_init_seed
==
-
1
:
model_parallel_random_seed
()
else
:
model_parallel_random_seed
(
tensor_init_seed
)
def
get_hybrid_communicate_group
(
self
):
assert
self
.
_hcg
is
not
None
return
self
.
_hcg
...
...
@@ -829,8 +838,8 @@ class Fleet(object):
last_comm_group_size_MB
,
find_unused_parameters
=
self
.
_user_defined_strategy
.
find_unused_parameters
)
elif
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
MODEL
_PARALLEL
:
distributed_model
=
Model
Parallel
(
elif
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
TENSOR
_PARALLEL
:
distributed_model
=
Tensor
Parallel
(
model
,
self
.
_hcg
,
strategy
=
self
.
_user_defined_strategy
)
elif
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
PIPELINE_PARALLEL
:
distributed_model
=
PipelineParallel
(
...
...
python/paddle/distributed/fleet/base/topology.py
浏览文件 @
c809530e
...
...
@@ -28,7 +28,7 @@ _HYBRID_PARALLEL_GROUP = None
class
ParallelMode
(
object
):
DATA_PARALLEL
=
0
MODEL
_PARALLEL
=
1
TENSOR
_PARALLEL
=
1
PIPELINE_PARALLEL
=
2
...
...
@@ -155,12 +155,12 @@ class HybridCommunicateGroup(object):
_HYBRID_PARALLEL_GROUP
=
self
def
get_parallel_mode
(
self
):
# there are three modes : DataParallel /
Model
Parallel / PipelineParallel
# there are three modes : DataParallel /
Tensor
Parallel / PipelineParallel
if
self
.
_mp_degree
==
1
and
self
.
_pp_degree
==
1
:
return
ParallelMode
.
DATA_PARALLEL
elif
self
.
_mp_degree
>
1
and
self
.
_pp_degree
==
1
:
# initialize the seed
return
ParallelMode
.
MODEL
_PARALLEL
return
ParallelMode
.
TENSOR
_PARALLEL
elif
self
.
_pp_degree
>
1
:
return
ParallelMode
.
PIPELINE_PARALLEL
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
浏览文件 @
c809530e
...
...
@@ -31,7 +31,7 @@ class HybridParallelGradScaler:
self
.
_scaler
=
scaler
self
.
_hcg
=
hcg
self
.
_is_mp
=
(
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
MODEL
_PARALLEL
)
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
TENSOR
_PARALLEL
)
def
scale
(
self
,
var
):
return
self
.
_scaler
.
scale
(
var
)
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
浏览文件 @
c809530e
...
...
@@ -90,12 +90,12 @@ class HybridParallelOptimizer:
self
.
_strategy
=
strategy
self
.
_hcg
=
hcg
self
.
_is_mp
=
(
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
MODEL
_PARALLEL
)
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
TENSOR
_PARALLEL
)
self
.
_need_dp
=
(
self
.
_hcg
.
get_data_parallel_world_size
()
>
1
)
if
isinstance
(
self
.
_inner_opt
.
_grad_clip
,
ClipGradByGlobalNorm
)
and
self
.
_is_mp
:
logger
.
warning
(
"using ClipGradByGlobalNorm in
Model
Parallel, the origin "
\
logger
.
warning
(
"using ClipGradByGlobalNorm in
Tensor
Parallel, the origin "
\
"optmizer'grad clip will be changed."
)
self
.
_inner_opt
.
_grad_clip
=
HybridParallelClipGrad
(
self
.
_inner_opt
.
_grad_clip
,
hcg
)
...
...
python/paddle/distributed/fleet/meta_parallel/__init__.py
浏览文件 @
c809530e
...
...
@@ -20,7 +20,7 @@ from .parallel_layers import PipelineLayer # noqa: F401
from
.parallel_layers
import
RNGStatesTracker
# noqa: F401
from
.parallel_layers
import
model_parallel_random_seed
# noqa: F401
from
.parallel_layers
import
get_rng_state_tracker
# noqa: F401
from
.
model_parallel
import
Model
Parallel
# noqa: F401
from
.
tensor_parallel
import
Tensor
Parallel
# noqa: F401
from
.pipeline_parallel
import
PipelineParallel
# noqa: F401
__all__
=
[]
python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
浏览文件 @
c809530e
...
...
@@ -41,6 +41,7 @@ class VocabParallelEmbedding(Layer):
self
.
rank
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_rank
()
self
.
origin_num_embeddings
=
num_embeddings
self
.
is_mp
=
(
self
.
world_size
>
1
)
per_part_size
=
(
num_embeddings
+
self
.
world_size
-
1
)
//
self
.
world_size
...
...
@@ -50,16 +51,36 @@ class VocabParallelEmbedding(Layer):
per_part_size
+=
1
# make the last row as the padding index
self
.
per_part_size
=
per_part_size
self
.
embedding
=
paddle
.
nn
.
Embedding
(
per_part_size
,
embedding_dim
,
padding_idx
=
per_part_size
-
1
,
sparse
=
False
,
weight_attr
=
weight_attr
,
name
=
name
)
self
.
embedding
.
weight
.
is_distributed
=
True
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
_size
=
[
per_part_size
,
embedding_dim
]
self
.
_weight_attr
=
weight_attr
self
.
_name
=
name
if
self
.
is_mp
:
with
get_rng_state_tracker
().
rng_state
():
self
.
weight
=
self
.
create_parameter
(
attr
=
self
.
_weight_attr
,
shape
=
self
.
_size
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
weight
[
per_part_size
-
1
]
=
0.0
self
.
weight
.
is_distributed
=
True
else
:
self
.
weight
=
self
.
create_parameter
(
attr
=
self
.
_weight_attr
,
shape
=
[
num_embeddings
,
embedding_dim
],
dtype
=
self
.
_dtype
,
is_bias
=
False
)
def
forward
(
self
,
x
):
if
not
self
.
is_mp
:
return
F
.
embedding
(
x
,
weight
=
self
.
weight
,
padding_idx
=
None
,
sparse
=
False
,
name
=
self
.
_name
)
origin_input_shape
=
x
.
shape
if
len
(
origin_input_shape
)
==
2
:
x
=
paddle
.
unsqueeze
(
x
,
axis
=-
1
)
...
...
@@ -72,8 +93,13 @@ class VocabParallelEmbedding(Layer):
if
len
(
origin_input_shape
)
==
2
:
x_shard
=
paddle
.
squeeze
(
x_shard
,
axis
=-
1
)
emb_out
=
self
.
embedding
(
x_shard
)
if
self
.
world_size
>
1
:
emb_out
=
F
.
embedding
(
x_shard
,
weight
=
self
.
weight
,
padding_idx
=
self
.
per_part_size
-
1
,
sparse
=
False
,
name
=
self
.
_name
)
emb_out
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
emb_out
,
group
=
self
.
model_parallel_group
,
...
...
@@ -96,8 +122,9 @@ class ColumnParallelLinear(Layer):
)
self
.
world_size
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_world_size
(
)
self
.
_name
=
name
self
.
is_mp
=
(
self
.
world_size
>
1
)
self
.
name
=
name
self
.
gather_output
=
gather_output
assert
out_features
%
self
.
world_size
==
0
,
(
"Number of column of the weight for linear ({}) must be"
...
...
@@ -108,10 +135,20 @@ class ColumnParallelLinear(Layer):
self
.
_weight_attr
=
weight_attr
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
if
self
.
is_mp
:
with
get_rng_state_tracker
().
rng_state
():
self
.
weight
=
self
.
create_parameter
(
shape
=
[
in_features
,
self
.
output_size_per_partition
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
,
is_bias
=
False
)
else
:
self
.
weight
=
self
.
create_parameter
(
shape
=
[
in_features
,
self
.
output_size_per_partition
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
weight
.
is_distributed
=
True
if
has_bias
:
...
...
@@ -119,18 +156,24 @@ class ColumnParallelLinear(Layer):
self
.
bias
=
self
.
create_parameter
(
shape
=
[
self
.
output_size_per_partition
],
attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.0
),
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
,
is_bias
=
True
)
self
.
bias
.
is_distributed
=
True
else
:
self
.
bias
=
None
def
forward
(
self
,
x
):
# use inner api to process identity
if
self
.
is_mp
:
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
x
,
group
=
self
.
model_parallel_group
)
else
:
input_parallel
=
x
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
self
.
bias
,
name
=
self
.
name
)
if
self
.
gather_output
:
input_parallel
,
self
.
weight
,
self
.
bias
,
name
=
self
.
_name
)
if
self
.
gather_output
and
self
.
is_mp
:
output
=
paddle
.
distributed
.
collective
.
_c_concat
(
output_parallel
,
nranks
=
self
.
world_size
,
...
...
@@ -155,7 +198,7 @@ class RowParallelLinear(Layer):
self
.
input_is_parallel
=
input_is_parallel
self
.
_weight_attr
=
weight_attr
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
name
=
name
self
.
_
name
=
name
self
.
model_parallel_group
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_group
(
)
...
...
@@ -163,6 +206,7 @@ class RowParallelLinear(Layer):
)
self
.
rank
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_rank
()
self
.
is_mp
=
(
self
.
world_size
>
1
)
assert
in_features
%
self
.
world_size
==
0
,
(
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})"
.
format
(
in_features
,
...
...
@@ -170,22 +214,33 @@ class RowParallelLinear(Layer):
self
.
input_size_per_partition
=
in_features
//
self
.
world_size
if
self
.
is_mp
:
with
get_rng_state_tracker
().
rng_state
():
self
.
weight
=
self
.
create_parameter
(
shape
=
[
self
.
input_size_per_partition
,
self
.
out_features
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
else
:
self
.
weight
=
self
.
create_parameter
(
shape
=
[
self
.
input_size_per_partition
,
self
.
out_features
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
weight
.
is_distributed
=
True
if
has_bias
:
self
.
bias
=
self
.
create_parameter
(
shape
=
[
self
.
out_features
],
attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.0
),
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
,
is_bias
=
True
)
else
:
self
.
bias
=
None
def
forward
(
self
,
x
):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
or
(
not
self
.
is_mp
)
:
input_parallel
=
x
else
:
# split last dim
...
...
@@ -195,12 +250,16 @@ class RowParallelLinear(Layer):
nranks
=
self
.
world_size
,
group
=
self
.
model_parallel_group
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
name
=
self
.
name
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
name
=
self
.
_name
)
if
self
.
is_mp
:
output_
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
else
:
output_
=
output_parallel
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
return
output
python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py
浏览文件 @
c809530e
...
...
@@ -14,6 +14,7 @@
import
paddle
import
contextlib
import
numpy
as
np
__all__
=
[]
...
...
@@ -65,14 +66,18 @@ def get_rng_state_tracker():
return
RNG_STATE_TRACKER
def
model_parallel_random_seed
(
seed
=
2048
):
def
model_parallel_random_seed
(
seed
=
None
):
import
paddle.distributed.fleet
as
fleet
hcg
=
fleet
.
get_hybrid_communicate_group
()
rank
=
hcg
.
get_model_parallel_rank
()
local_seed
=
seed
+
1024
+
rank
if
seed
:
global_seed
=
seed
local_seed
=
seed
*
1024
+
rank
*
100
else
:
global_seed
=
np
.
random
.
randint
(
0
,
655350
)
local_seed
=
np
.
random
.
randint
(
rank
*
10000
,
(
rank
+
1
)
*
10000
-
1
)
RNG_STATE_TRACKER
.
reset
()
paddle
.
seed
(
global_seed
)
RNG_STATE_TRACKER
.
add
(
MODEL_PARALLEL_RNG
,
local_seed
)
paddle
.
seed
(
global_seed
)
python/paddle/distributed/fleet/meta_parallel/
model
_parallel.py
→
python/paddle/distributed/fleet/meta_parallel/
tensor
_parallel.py
浏览文件 @
c809530e
...
...
@@ -22,15 +22,15 @@ from ..utils.log_util import logger
__all__
=
[]
class
Model
Parallel
(
MetaParallelBase
):
class
Tensor
Parallel
(
MetaParallelBase
):
def
__init__
(
self
,
layers
,
hcg
,
**
kwargs
):
super
(
Model
Parallel
,
self
).
__init__
(
layers
,
hcg
,
**
kwargs
)
super
(
Tensor
Parallel
,
self
).
__init__
(
layers
,
hcg
,
**
kwargs
)
def
_prepare_for_model
(
self
):
logger
.
info
(
"start broadcast mp parameters"
)
broadcast_mp_parameters
(
self
.
_layers
,
self
.
_hcg
)
logger
.
info
(
"start broadcast
m
p parameters"
)
logger
.
info
(
"start broadcast
d
p parameters"
)
broadcast_dp_parameters
(
self
.
_layers
,
self
.
_hcg
)
logger
.
info
(
"mp's parameters is ready"
)
...
...
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
浏览文件 @
c809530e
...
...
@@ -44,7 +44,15 @@ def _apply_collective_grads(parameters, comm_group):
for
coalesced_grad
,
_
,
_
in
coalesced_grads_and_vars
:
# need to div nranks
coalesced_grad
=
coalesced_grad
/
comm_group
.
nranks
div_factor
=
paddle
.
to_tensor
(
comm_group
.
nranks
,
dtype
=
coalesced_grad
.
dtype
)
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
type
=
"elementwise_div"
,
inputs
=
{
'X'
:
coalesced_grad
,
'Y'
:
div_factor
},
outputs
=
{
'Out'
:
coalesced_grad
},
attrs
=
{
'axis'
:
-
1
})
paddle
.
distributed
.
all_reduce
(
coalesced_grad
,
group
=
comm_group
)
_split_tensors
(
coalesced_grads_and_vars
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py
浏览文件 @
c809530e
...
...
@@ -231,7 +231,7 @@ class TestDistTraning(unittest.TestCase):
# model_b
check_group
=
dist
.
new_group
(
list
(
range
(
self
.
model_parallel_size
)))
integral_w
=
[]
partial_w
=
model_a
.
embedding
.
embedding
.
weight
.
clone
().
detach
()
partial_w
=
model_a
.
embedding
.
weight
.
clone
().
detach
()
paddle
.
distributed
.
all_gather
(
integral_w
,
partial_w
,
group
=
check_group
)
result_w
=
[]
for
idx
in
range
(
len
(
integral_w
)):
...
...
python/paddle/fluid/tests/unittests/new_group.py
浏览文件 @
c809530e
...
...
@@ -27,6 +27,7 @@ class TestNewGroupAPI(object):
def
test_all
(
self
):
gp
=
paddle
.
distributed
.
new_group
([
0
,
1
])
print
(
"gp info:"
,
gp
)
print
(
"test new group api ok"
)
tmp
=
np
.
array
([
0
,
0
,
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录