Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
276bb001
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
276bb001
编写于
6月 19, 2020
作者:
L
littletomatodonkey
提交者:
GitHub
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fleet training (#961)
add fleet training
上级
d5702896
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
443 addition
and
5 deletion
+443
-5
docs/tutorials/MULTI_MACHINE_TRAINING_cn.md
docs/tutorials/MULTI_MACHINE_TRAINING_cn.md
+69
-0
ppdet/data/reader.py
ppdet/data/reader.py
+8
-5
tools/train_multi_machine.py
tools/train_multi_machine.py
+366
-0
未找到文件。
docs/tutorials/MULTI_MACHINE_TRAINING_cn.md
0 → 100644
浏览文件 @
276bb001
# 多机训练
## 简介
*
分布式训练的高性能,是飞桨的核心优势技术之一,在分类任务上,分布式训练可以达到几乎线性的加速比。
[
Fleet
](
https://github.com/PaddlePaddle/Fleet
)
是用于 PaddlePaddle 分布式训练的高层 API,基于这套接口用户可以很容易切换到分布式训练程序。
为了可以同时支持单机训练和多机训练,
[
PaddleDetection
](
https://github.com/PaddlePaddle/PaddleDetection/
)
采用 Fleet API 接口,可以同时支持单机训练和多机训练。
更多的分布式训练可以参考
[
Fleet API设计文档
](
https://github.com/PaddlePaddle/Fleet/blob/develop/README.md
)
。
## 使用方法
*
使用
`tools/train_multi_machine.py`
可以启动基于Fleet的训练,目前同时支持单机单卡、单机多卡与多机多卡的训练过程。
*
可选参数列表与
`tools/train.py`
完全相同,可以参考
[
入门使用文档
](
./GETTING_STARTED_cn.md
)
。
### 单机训练
*
训练脚本如下所示。
```
bash
# 设置PYTHONPATH路径
export
PYTHONPATH
=
$PYTHONPATH
:.
# 设置GPU卡号信息
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
# 启动训练
python
-m
paddle.distributed.launch
\
--selected_gpus
0,1,2,3,4,5,6,7
\
tools/train_multi_machine.py
\
-c
configs/faster_rcnn_r50_fpn_1x.yml
```
### 多机训练
*
训练脚本如下所示,其中ip1和ip2分别表示不同机器的ip地址,
`PADDLE_TRAINER_ID`
环境变量也是根据
`cluster_node_ips`
提供的ip顺序依次增大。
*
注意:在这里如果需要启动多机实验,需要保证不同的机器的运行代码是完全相同的。
```
export PYTHONPATH=$PYTHONPATH:.
# 设置GPU卡号信息
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# 启动训练
node_ip=`hostname -i`
python -m paddle.distributed.launch \
--use_paddlecloud \
--cluster_node_ips ${ip1},${ip2} \
--node_ip ${node_ip} \
tools/train_multi_machine.py \
-c configs/faster_rcnn_r50_fpn_1x.yml \
```
## 训练时间统计
*
以Faster RCNN R50_vd FPN 1x实验为例,下面给出了基于Fleet分布式训练,不同机器的训练时间对比。
*
这里均是在V100 GPU上展开的实验。
*
1x实验指的是8卡,单卡batch size为2时,训练的minibatch数量为90000(当训练卡数或者batch size变化时,对应的学习率和总的迭代轮数也需要变化)。
| 模型 | 训练策略 | 机器数量 | 每台机器的GPU数量 | 训练时间 | COCO bbox mAP | 加速比 |
| :----------------------: | :------------: | :------------: | :---------------: | :----------: | :-----------: | :-----------: |
| Faster RCNN R50_vd FPN | 1x | 1 | 4 | 15.1h | 38.3% | - |
| Faster RCNN R50_vd FPN | 1x | 2 | 4 | 9.8h | 38.2% | 76% |
| Faster RCNN R50_vd FPN | 1x | 1 | 8 | 8.6h | 38.2% | - |
| Faster RCNN R50_vd FPN | 1x | 2 | 8 | 5.1h | 38.0% | 84% |
*
由上图可知,2机8卡相比于单机8卡,加速比可以达到84%,2即4卡相比于单机4卡,加速比可以达到76%,而且精度几乎没有损失。
*
1x实验相当于COCO数据集训练了约13个epoch,因此在trainer数量很多的时候,每个trainer可能无法训练完1个epoch,这会导致精度出现一些差异,这可以通过适当增加迭代轮数实现精度的对齐,我们实验发现,在训练多尺度3x实验时(配置文件:
[
configs/rcnn_enhance/faster_rcnn_dcn_r50_vd_fpn_3x_server_side.yml
](
../../configs/rcnn_enhance/faster_rcnn_dcn_r50_vd_fpn_3x_server_side.yml
)
),分布式训练与单机训练的模型精度是可以对齐的。
ppdet/data/reader.py
浏览文件 @
276bb001
...
@@ -16,6 +16,7 @@ from __future__ import absolute_import
...
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
copy
import
copy
import
functools
import
functools
import
collections
import
collections
...
@@ -278,6 +279,11 @@ class Reader(object):
...
@@ -278,6 +279,11 @@ class Reader(object):
def
reset
(
self
):
def
reset
(
self
):
"""implementation of Dataset.reset
"""implementation of Dataset.reset
"""
"""
if
self
.
_epoch
<
0
:
self
.
_epoch
=
0
else
:
self
.
_epoch
+=
1
self
.
indexes
=
[
i
for
i
in
range
(
self
.
size
())]
self
.
indexes
=
[
i
for
i
in
range
(
self
.
size
())]
if
self
.
_class_aware_sampling
:
if
self
.
_class_aware_sampling
:
self
.
indexes
=
np
.
random
.
choice
(
self
.
indexes
=
np
.
random
.
choice
(
...
@@ -287,6 +293,8 @@ class Reader(object):
...
@@ -287,6 +293,8 @@ class Reader(object):
p
=
self
.
img_weights
)
p
=
self
.
img_weights
)
if
self
.
_shuffle
:
if
self
.
_shuffle
:
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
np
.
random
.
seed
(
self
.
_epoch
+
trainer_id
)
np
.
random
.
shuffle
(
self
.
indexes
)
np
.
random
.
shuffle
(
self
.
indexes
)
if
self
.
_mixup_epoch
>
0
and
len
(
self
.
indexes
)
<
2
:
if
self
.
_mixup_epoch
>
0
and
len
(
self
.
indexes
)
<
2
:
...
@@ -298,11 +306,6 @@ class Reader(object):
...
@@ -298,11 +306,6 @@ class Reader(object):
"less than 2 samples"
)
"less than 2 samples"
)
self
.
_cutmix_epoch
=
-
1
self
.
_cutmix_epoch
=
-
1
if
self
.
_epoch
<
0
:
self
.
_epoch
=
0
else
:
self
.
_epoch
+=
1
self
.
_pos
=
0
self
.
_pos
=
0
def
__next__
(
self
):
def
__next__
(
self
):
...
...
tools/train_multi_machine.py
0 → 100644
浏览文件 @
276bb001
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
,
sys
# add python path of PadleDetection to sys.path
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
]
*
2
)))
if
parent_path
not
in
sys
.
path
:
sys
.
path
.
append
(
parent_path
)
import
time
import
numpy
as
np
import
random
import
datetime
import
six
from
collections
import
deque
from
paddle.fluid
import
profiler
from
paddle
import
fluid
from
paddle.fluid.layers.learning_rate_scheduler
import
_decay_step_counter
from
paddle.fluid.optimizer
import
ExponentialMovingAverage
from
ppdet.experimental
import
mixed_precision_context
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.reader
import
create_reader
from
ppdet.utils
import
dist_utils
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_run
,
eval_results
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.check
import
check_gpu
,
check_version
,
check_config
import
ppdet.utils.checkpoint
as
checkpoint
from
paddle.fluid.incubate.fleet.collective
import
fleet
,
DistributedStrategy
# new line 1
from
paddle.fluid.incubate.fleet.base
import
role_maker
# new line 2
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
main
():
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
# new line 3
fleet
.
init
(
role
)
# new line 4
env
=
os
.
environ
num_trainers
=
int
(
env
.
get
(
'PADDLE_TRAINERS_NUM'
,
0
))
assert
num_trainers
!=
0
,
"multi-machine training process must be started using distributed.launch..."
trainer_id
=
int
(
env
.
get
(
"PADDLE_TRAINER_ID"
,
0
))
# set different seeds for different trainers
random
.
seed
(
trainer_id
)
np
.
random
.
seed
(
trainer_id
)
if
FLAGS
.
enable_ce
:
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
cfg
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
check_config
(
cfg
)
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu
(
cfg
.
use_gpu
)
# check if paddlepaddle version is satisfied
check_version
()
save_only
=
getattr
(
cfg
,
'save_prediction_only'
,
False
)
if
save_only
:
raise
NotImplementedError
(
'The config file only support prediction,'
' training stage is not implemented now'
)
main_arch
=
cfg
.
architecture
assert
cfg
.
use_gpu
==
True
,
"GPU must be supported for multi-machine training..."
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
if
'FLAGS_selected_gpus'
in
env
:
device_id
=
int
(
env
[
'FLAGS_selected_gpus'
])
else
:
device_id
=
0
place
=
fluid
.
CUDAPlace
(
device_id
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
# build program
startup_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
if
FLAGS
.
enable_ce
:
startup_prog
.
random_seed
=
1000
train_prog
.
random_seed
=
1000
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
if
FLAGS
.
fp16
:
assert
(
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
!=
'affine_channel'
),
\
'--fp16 currently does not support affine channel, '
\
' please modify backbone settings to use batch norm'
with
mixed_precision_context
(
FLAGS
.
loss_scale
,
FLAGS
.
fp16
)
as
ctx
:
inputs_def
=
cfg
[
'TrainReader'
][
'inputs_def'
]
feed_vars
,
train_loader
=
model
.
build_inputs
(
**
inputs_def
)
train_fetches
=
model
.
train
(
feed_vars
)
loss
=
train_fetches
[
'loss'
]
if
FLAGS
.
fp16
:
loss
*=
ctx
.
get_loss_scale_var
()
lr
=
lr_builder
()
optimizer
=
optim_builder
(
lr
)
dist_strategy
=
DistributedStrategy
()
sync_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'sync_bn'
dist_strategy
.
sync_batch_norm
=
sync_bn
dist_strategy
.
nccl_comm_num
=
1
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
3
exec_strategy
.
num_iteration_per_drop_scope
=
30
dist_strategy
.
exec_strategy
=
exec_strategy
dist_strategy
.
fuse_all_reduce_ops
=
True
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
dist_strategy
)
# new line 5
optimizer
.
minimize
(
loss
)
if
FLAGS
.
fp16
:
loss
/=
ctx
.
get_loss_scale_var
()
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
global_steps
=
_decay_step_counter
()
ema
=
ExponentialMovingAverage
(
cfg
[
'ema_decay'
],
thres_steps
=
global_steps
)
ema
.
update
()
# parse train fetches
train_keys
,
train_values
,
_
=
parse_fetches
(
train_fetches
)
train_values
.
append
(
lr
)
if
FLAGS
.
eval
:
eval_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
inputs_def
=
cfg
[
'EvalReader'
][
'inputs_def'
]
feed_vars
,
eval_loader
=
model
.
build_inputs
(
**
inputs_def
)
fetches
=
model
.
eval
(
feed_vars
)
eval_prog
=
eval_prog
.
clone
(
True
)
eval_reader
=
create_reader
(
cfg
.
EvalReader
,
devices_num
=
1
)
eval_loader
.
set_sample_list_generator
(
eval_reader
,
place
)
# parse eval fetches
extra_keys
=
[]
if
cfg
.
metric
==
'COCO'
:
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
]
if
cfg
.
metric
==
'VOC'
:
extra_keys
=
[
'gt_bbox'
,
'gt_class'
,
'is_difficult'
]
if
cfg
.
metric
==
'WIDERFACE'
:
extra_keys
=
[
'im_id'
,
'im_shape'
,
'gt_bbox'
]
eval_keys
,
eval_values
,
eval_cls
=
parse_fetches
(
fetches
,
eval_prog
,
extra_keys
)
exe
.
run
(
startup_prog
)
compiled_train_prog
=
fleet
.
main_program
if
FLAGS
.
eval
:
compiled_eval_prog
=
fluid
.
CompiledProgram
(
eval_prog
)
fuse_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'affine_channel'
ignore_params
=
cfg
.
finetune_exclude_pretrained_params
\
if
'finetune_exclude_pretrained_params'
in
cfg
else
[]
start_iter
=
0
if
FLAGS
.
resume_checkpoint
:
checkpoint
.
load_checkpoint
(
exe
,
train_prog
,
FLAGS
.
resume_checkpoint
)
start_iter
=
checkpoint
.
global_step
()
elif
cfg
.
pretrain_weights
and
fuse_bn
and
not
ignore_params
:
checkpoint
.
load_and_fusebn
(
exe
,
train_prog
,
cfg
.
pretrain_weights
)
elif
cfg
.
pretrain_weights
:
checkpoint
.
load_params
(
exe
,
train_prog
,
cfg
.
pretrain_weights
,
ignore_params
=
ignore_params
)
train_reader
=
create_reader
(
cfg
.
TrainReader
,
(
cfg
.
max_iters
-
start_iter
)
*
devices_num
,
cfg
,
devices_num
=
devices_num
)
train_loader
.
set_sample_list_generator
(
train_reader
,
place
)
# whether output bbox is normalized in model output layer
is_bbox_normalized
=
False
if
hasattr
(
model
,
'is_bbox_normalized'
)
and
\
callable
(
model
.
is_bbox_normalized
):
is_bbox_normalized
=
model
.
is_bbox_normalized
()
# if map_type not set, use default 11point, only use in VOC eval
map_type
=
cfg
.
map_type
if
'map_type'
in
cfg
else
'11point'
train_stats
=
TrainingStats
(
cfg
.
log_smooth_window
,
train_keys
)
train_loader
.
start
()
start_time
=
time
.
time
()
end_time
=
time
.
time
()
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_dir
=
os
.
path
.
join
(
cfg
.
save_dir
,
cfg_name
)
time_stat
=
deque
(
maxlen
=
cfg
.
log_smooth_window
)
best_box_ap_list
=
[
0.0
,
0
]
#[map, iter]
# use VisualDL to log data
if
FLAGS
.
use_vdl
:
assert
six
.
PY3
,
"VisualDL requires Python >= 3.5"
from
visualdl
import
LogWriter
vdl_writer
=
LogWriter
(
FLAGS
.
vdl_log_dir
)
vdl_loss_step
=
0
vdl_mAP_step
=
0
for
it
in
range
(
start_iter
,
cfg
.
max_iters
):
start_time
=
end_time
end_time
=
time
.
time
()
time_stat
.
append
(
end_time
-
start_time
)
time_cost
=
np
.
mean
(
time_stat
)
eta_sec
=
(
cfg
.
max_iters
-
it
)
*
time_cost
eta
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
)))
outs
=
exe
.
run
(
compiled_train_prog
,
fetch_list
=
train_values
)
stats
=
{
k
:
np
.
array
(
v
).
mean
()
for
k
,
v
in
zip
(
train_keys
,
outs
[:
-
1
])}
# use vdl-paddle to log loss
if
FLAGS
.
use_vdl
:
if
it
%
cfg
.
log_iter
==
0
:
for
loss_name
,
loss_value
in
stats
.
items
():
vdl_writer
.
add_scalar
(
loss_name
,
loss_value
,
vdl_loss_step
)
vdl_loss_step
+=
1
train_stats
.
update
(
stats
)
logs
=
train_stats
.
log
()
if
it
%
cfg
.
log_iter
==
0
and
trainer_id
==
0
:
strs
=
'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'
.
format
(
it
,
np
.
mean
(
outs
[
-
1
]),
logs
,
time_cost
,
eta
)
logger
.
info
(
strs
)
# NOTE : profiler tools, used for benchmark
if
FLAGS
.
is_profiler
and
it
==
5
:
profiler
.
start_profiler
(
"All"
)
elif
FLAGS
.
is_profiler
and
it
==
10
:
profiler
.
stop_profiler
(
"total"
,
FLAGS
.
profiler_path
)
return
if
(
it
>
0
and
it
%
cfg
.
snapshot_iter
==
0
or
it
==
cfg
.
max_iters
-
1
)
\
and
trainer_id
==
0
:
save_name
=
str
(
it
)
if
it
!=
cfg
.
max_iters
-
1
else
"model_final"
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
exe
.
run
(
ema
.
apply_program
)
checkpoint
.
save
(
exe
,
train_prog
,
os
.
path
.
join
(
save_dir
,
save_name
))
if
FLAGS
.
eval
:
# evaluation
resolution
=
None
if
'Mask'
in
cfg
.
architecture
:
resolution
=
model
.
mask_head
.
resolution
results
=
eval_run
(
exe
,
compiled_eval_prog
,
eval_loader
,
eval_keys
,
eval_values
,
eval_cls
,
cfg
,
resolution
=
resolution
)
box_ap_stats
=
eval_results
(
results
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
is_bbox_normalized
,
FLAGS
.
output_eval
,
map_type
,
cfg
[
'EvalReader'
][
'dataset'
])
# use vdl_paddle to log mAP
if
FLAGS
.
use_vdl
:
vdl_writer
.
add_scalar
(
"mAP"
,
box_ap_stats
[
0
],
vdl_mAP_step
)
vdl_mAP_step
+=
1
if
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
best_box_ap_list
[
1
]
=
it
checkpoint
.
save
(
exe
,
train_prog
,
os
.
path
.
join
(
save_dir
,
"best_model"
))
logger
.
info
(
"Best test box ap: {}, in iter: {}"
.
format
(
best_box_ap_list
[
0
],
best_box_ap_list
[
1
]))
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
exe
.
run
(
ema
.
restore_program
)
train_loader
.
reset
()
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"-r"
,
"--resume_checkpoint"
,
default
=
None
,
type
=
str
,
help
=
"Checkpoint path for resuming training."
)
parser
.
add_argument
(
"--fp16"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Enable mixed precision training."
)
parser
.
add_argument
(
"--loss_scale"
,
default
=
8.
,
type
=
float
,
help
=
"Mixed precision training loss scale."
)
parser
.
add_argument
(
"--eval"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Whether to perform evaluation in train"
)
parser
.
add_argument
(
"--output_eval"
,
default
=
None
,
type
=
str
,
help
=
"Evaluation directory, default is current directory."
)
parser
.
add_argument
(
"--use_vdl"
,
type
=
bool
,
default
=
False
,
help
=
"whether to record the data to VisualDL."
)
parser
.
add_argument
(
'--vdl_log_dir'
,
type
=
str
,
default
=
"vdl_log_dir/scalar"
,
help
=
'VisualDL logging directory for scalar.'
)
parser
.
add_argument
(
"--enable_ce"
,
type
=
bool
,
default
=
False
,
help
=
"If set True, enable continuous evaluation job."
"This flag is only used for internal test."
)
#NOTE:args for profiler tools, used for benchmark
parser
.
add_argument
(
'--is_profiler'
,
type
=
int
,
default
=
0
,
help
=
'The switch of profiler tools. (used for benchmark)'
)
parser
.
add_argument
(
'--profiler_path'
,
type
=
str
,
default
=
"./detection.profiler"
,
help
=
'The profiler output file path. (used for benchmark)'
)
FLAGS
=
parser
.
parse_args
()
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录