Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
839db7b1
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
839db7b1
编写于
3月 23, 2020
作者:
L
LielinJiang
提交者:
GitHub
3月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12 from LielinJiang/multiple-gpus-append-op
make hapi support multiple gpus train and eval
上级
3da16763
280dd0e3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
528 addition
and
146 deletion
+528
-146
callbacks.py
callbacks.py
+9
-8
distributed.py
distributed.py
+269
-0
mnist.py
mnist.py
+23
-77
model.py
model.py
+197
-28
progressbar.py
progressbar.py
+3
-3
tests/test_model.py
tests/test_model.py
+27
-30
未找到文件。
callbacks.py
浏览文件 @
839db7b1
...
...
@@ -16,7 +16,7 @@ import six
import
copy
from
progressbar
import
ProgressBar
from
distributed
import
get_local_rank
def
config_callbacks
(
callbacks
=
None
,
model
=
None
,
...
...
@@ -193,7 +193,7 @@ class ProgBarLogger(Callback):
self
.
steps
=
self
.
params
[
'steps'
]
self
.
epoch
=
epoch
self
.
train_step
=
0
if
self
.
verbose
and
self
.
epochs
:
if
self
.
verbose
and
self
.
epochs
and
get_local_rank
()
==
0
:
print
(
'Epoch %d/%d'
%
(
epoch
+
1
,
self
.
epochs
))
self
.
train_progbar
=
ProgressBar
(
num
=
self
.
steps
,
verbose
=
self
.
verbose
)
...
...
@@ -211,7 +211,7 @@ class ProgBarLogger(Callback):
logs
=
logs
or
{}
self
.
train_step
=
step
if
self
.
train_step
%
self
.
log_freq
==
0
and
self
.
verbose
:
if
self
.
train_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
get_local_rank
()
==
0
:
# if steps is not None, last step will update in on_epoch_end
if
self
.
steps
and
self
.
train_step
<
self
.
steps
:
self
.
_updates
(
logs
,
'train'
)
...
...
@@ -220,7 +220,7 @@ class ProgBarLogger(Callback):
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
logs
=
logs
or
{}
if
self
.
verbose
:
if
self
.
verbose
and
get_local_rank
()
==
0
:
self
.
_updates
(
logs
,
'train'
)
def
on_eval_begin
(
self
,
logs
=
None
):
...
...
@@ -230,7 +230,8 @@ class ProgBarLogger(Callback):
self
.
evaled_samples
=
0
self
.
eval_progbar
=
ProgressBar
(
num
=
self
.
eval_steps
,
verbose
=
self
.
verbose
)
print
(
'Eval begin...'
)
if
get_local_rank
()
==
0
:
print
(
'Eval begin...'
)
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
logs
=
logs
or
{}
...
...
@@ -240,7 +241,7 @@ class ProgBarLogger(Callback):
def
on_eval_end
(
self
,
logs
=
None
):
logs
=
logs
or
{}
if
self
.
verbose
:
if
self
.
verbose
and
get_local_rank
()
==
0
:
self
.
_updates
(
logs
,
'eval'
)
print
(
'Eval samples: %d'
%
(
self
.
evaled_samples
))
...
...
@@ -254,13 +255,13 @@ class ModelCheckpoint(Callback):
self
.
epoch
=
epoch
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
if
self
.
model
and
self
.
epoch
%
self
.
save_freq
==
0
:
if
self
.
model
and
self
.
epoch
%
self
.
save_freq
==
0
and
get_local_rank
()
==
0
:
path
=
'{}/{}'
.
format
(
self
.
save_file
,
epoch
)
print
(
'save checkpoint at {}'
.
format
(
path
))
self
.
model
.
save
(
path
)
def
on_train_end
(
self
,
logs
=
None
):
if
self
.
model
:
if
self
.
model
and
get_local_rank
()
==
0
:
path
=
'{}/final'
.
format
(
self
.
save_file
)
print
(
'save checkpoint at {}'
.
format
(
path
))
self
.
model
.
save
(
path
)
distributed.py
0 → 100644
浏览文件 @
839db7b1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
time
import
math
import
socket
import
contextlib
from
contextlib
import
closing
from
six
import
string_types
import
numpy
as
np
from
collections
import
OrderedDict
from
paddle
import
fluid
import
paddle.fluid.unique_name
as
nameGen
from
paddle.fluid
import
core
from
paddle.fluid
import
framework
from
paddle.fluid.layers
import
collective
from
paddle.fluid.dygraph
import
to_variable
,
no_grad
,
layers
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.executor
import
global_scope
from
paddle.fluid.dygraph.parallel
import
Env
,
DataParallel
,
ParallelStrategy
from
paddle.fluid.layers.collective
import
_c_allreduce
,
_c_allgather
,
_c_broadcast
,
\
_c_sync_comm_stream
,
_c_sync_calc_stream
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
__parallel_context_init
=
False
class
DistributedBatchSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
data_source: this could be a `fluid.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
"""
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
drop_last
=
False
):
self
.
dataset
=
dataset
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
"batch_size should be a positive integer"
self
.
batch_size
=
batch_size
assert
isinstance
(
shuffle
,
bool
),
\
"shuffle should be a boolean value"
self
.
shuffle
=
shuffle
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean number"
self
.
drop_last
=
drop_last
self
.
nranks
=
get_nranks
()
self
.
local_rank
=
get_local_rank
()
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
nranks
))
self
.
total_size
=
self
.
num_samples
*
self
.
nranks
def
__iter__
(
self
):
num_samples
=
len
(
self
.
dataset
)
indices
=
np
.
arange
(
num_samples
).
tolist
()
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
if
self
.
shuffle
:
np
.
random
.
RandomState
(
self
.
epoch
).
shuffle
(
indices
)
self
.
epoch
+=
1
# subsample
indices
=
indices
[
self
.
local_rank
*
self
.
num_samples
:
(
self
.
local_rank
+
1
)
*
self
.
num_samples
]
assert
len
(
indices
)
==
self
.
num_samples
_sample_iter
=
iter
(
indices
)
batch_indices
=
[]
for
idx
in
_sample_iter
:
batch_indices
.
append
(
idx
)
if
len
(
batch_indices
)
==
self
.
batch_size
:
yield
batch_indices
batch_indices
=
[]
if
not
self
.
drop_last
and
len
(
batch_indices
)
>
0
:
yield
batch_indices
def
__len__
(
self
):
num_samples
=
self
.
num_samples
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
def
_all_gather
(
x
,
nranks
,
ring_id
=
0
,
use_calc_stream
=
True
):
return
_c_allgather
(
x
,
nranks
,
ring_id
=
ring_id
,
use_calc_stream
=
use_calc_stream
)
def
get_local_rank
():
return
Env
().
local_rank
def
get_nranks
():
return
Env
().
nranks
def
wait_server_ready
(
endpoints
):
assert
not
isinstance
(
endpoints
,
string_types
)
while
True
:
all_ok
=
True
not_ready_endpoints
=
[]
for
ep
in
endpoints
:
ip_port
=
ep
.
split
(
":"
)
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
sock
:
sock
.
settimeout
(
2
)
result
=
sock
.
connect_ex
((
ip_port
[
0
],
int
(
ip_port
[
1
])))
if
result
!=
0
:
all_ok
=
False
not_ready_endpoints
.
append
(
ep
)
if
not
all_ok
:
sys
.
stderr
.
write
(
"server not ready, wait 3 sec to retry...
\n
"
)
sys
.
stderr
.
write
(
"not ready endpoints:"
+
str
(
not_ready_endpoints
)
+
"
\n
"
)
sys
.
stderr
.
flush
()
time
.
sleep
(
3
)
else
:
break
def
init_communicator
(
program
,
rank
,
nranks
,
wait_port
,
current_endpoint
,
endpoints
):
if
nranks
<
2
:
return
other_endpoints
=
endpoints
[:]
other_endpoints
.
remove
(
current_endpoint
)
if
rank
==
0
and
wait_port
:
wait_server_ready
(
other_endpoints
)
block
=
program
.
global_block
()
nccl_id_var
=
block
.
create_var
(
name
=
nameGen
.
generate
(
'nccl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
block
.
append_op
(
type
=
'c_gen_nccl_id'
,
inputs
=
{},
outputs
=
{
'Out'
:
nccl_id_var
},
attrs
=
{
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
})
block
.
append_op
(
type
=
'c_comm_init'
,
inputs
=
{
'X'
:
nccl_id_var
},
outputs
=
{},
attrs
=
{
'nranks'
:
nranks
,
'rank'
:
rank
,
'ring_id'
:
0
,
})
def
prepare_distributed_context
(
place
=
None
):
if
place
is
None
:
place
=
fluid
.
CUDAPlace
(
Env
().
dev_id
)
if
Env
().
nranks
>
1
\
else
fluid
.
CUDAPlace
(
0
)
strategy
=
ParallelStrategy
()
strategy
.
nranks
=
Env
().
nranks
strategy
.
local_rank
=
Env
().
local_rank
strategy
.
trainer_endpoints
=
Env
().
trainer_endpoints
strategy
.
current_endpoint
=
Env
().
current_endpoint
if
strategy
.
nranks
<
2
:
return
global
__parallel_context_init
if
not
__parallel_context_init
and
isinstance
(
place
,
core
.
CUDAPlace
):
def
_init_context
():
communicator_prog
=
framework
.
Program
()
init_communicator
(
communicator_prog
,
strategy
.
local_rank
,
strategy
.
nranks
,
True
,
strategy
.
current_endpoint
,
strategy
.
trainer_endpoints
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
communicator_prog
)
if
fluid
.
in_dygraph_mode
():
fluid
.
disable_dygraph
()
_init_context
()
fluid
.
enable_dygraph
(
place
)
else
:
_init_context
()
else
:
assert
(
"Only support CUDAPlace for now."
)
__parallel_context_init
=
True
return
strategy
class
DistributedDataParallel
(
DataParallel
):
def
__init__
(
self
,
layers
,
strategy
=
None
):
if
strategy
is
None
:
strategy
=
ParallelStrategy
()
strategy
.
nranks
=
Env
().
nranks
strategy
.
local_rank
=
Env
().
local_rank
strategy
.
trainer_endpoints
=
Env
().
trainer_endpoints
strategy
.
current_endpoint
=
Env
().
current_endpoint
super
(
DistributedDataParallel
,
self
).
__init__
(
layers
,
strategy
)
@
no_grad
def
apply_collective_grads
(
self
):
"""
AllReduce the Parameters' gradient.
"""
if
not
self
.
_is_data_parallel_mode
():
return
grad_var_set
=
set
()
grad_vars
=
[]
for
param
in
self
.
_layers
.
parameters
():
# NOTE(zcd): The grad_ivar maybe no generated.
if
param
.
trainable
and
param
.
_grad_ivar
():
g_var
=
param
.
_grad_ivar
()
grad_vars
.
append
(
g_var
)
assert
g_var
not
in
grad_var_set
grad_var_set
.
add
(
g_var
)
mega_bytes
=
128
*
1024
*
1024
group_idx
=
0
memory_counter
=
0
grad_var_groups
=
OrderedDict
()
dtype
=
grad_vars
[
0
].
dtype
for
g_var
in
grad_vars
:
# Note: the dtype of the same group should be the same.
bytes
=
np
.
prod
(
g_var
.
shape
)
*
core
.
size_of_dtype
(
g_var
.
dtype
)
if
memory_counter
<
mega_bytes
and
dtype
==
g_var
.
dtype
:
memory_counter
+=
bytes
else
:
memory_counter
=
bytes
group_idx
+=
1
grad_var_groups
.
setdefault
(
group_idx
,
[]).
append
(
g_var
)
coalesced_grads_and_vars
=
self
.
_coalesce_tensors
(
grad_var_groups
)
for
coalesced_grad
,
_
,
_
in
coalesced_grads_and_vars
:
collective
.
_c_allreduce
(
coalesced_grad
,
coalesced_grad
,
use_calc_stream
=
True
)
self
.
_split_tensors
(
coalesced_grads_and_vars
)
mnist.py
浏览文件 @
839db7b1
...
...
@@ -21,12 +21,12 @@ import os
import
numpy
as
np
import
paddle
from
paddle
import
fluid
from
paddle.fluid.optimizer
import
Momentum
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.io
import
MNIST
as
MnistDataset
from
model
import
Model
,
CrossEntropy
,
Input
from
model
import
Model
,
CrossEntropy
,
Input
,
init_context
from
metrics
import
Accuracy
...
...
@@ -97,6 +97,7 @@ class MNIST(Model):
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
inputs
=
fluid
.
layers
.
reshape
(
inputs
,
[
-
1
,
1
,
28
,
28
])
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
fluid
.
layers
.
flatten
(
x
,
axis
=
1
)
...
...
@@ -104,81 +105,26 @@ class MNIST(Model):
return
x
def
accuracy
(
pred
,
label
,
topk
=
(
1
,
)):
maxk
=
max
(
topk
)
pred
=
np
.
argsort
(
pred
)[:,
::
-
1
][:,
:
maxk
]
correct
=
(
pred
==
np
.
repeat
(
label
,
maxk
,
1
))
batch_size
=
label
.
shape
[
0
]
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:,
:
k
].
sum
()
res
.
append
(
100.0
*
correct_k
/
batch_size
)
return
res
def
main
():
@
contextlib
.
contextmanager
def
null_guard
():
yield
guard
=
fluid
.
dygraph
.
guard
()
if
FLAGS
.
dynamic
else
null_guard
()
if
not
os
.
path
.
exists
(
'mnist_checkpoints'
):
os
.
mkdir
(
'mnist_checkpoints'
)
train_loader
=
fluid
.
io
.
xmap_readers
(
lambda
b
:
[
np
.
array
([
x
[
0
]
for
x
in
b
]).
reshape
(
-
1
,
1
,
28
,
28
),
np
.
array
([
x
[
1
]
for
x
in
b
]).
reshape
(
-
1
,
1
)],
paddle
.
batch
(
fluid
.
io
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
6e4
),
batch_size
=
FLAGS
.
batch_size
,
drop_last
=
True
),
1
,
1
)
val_loader
=
fluid
.
io
.
xmap_readers
(
lambda
b
:
[
np
.
array
([
x
[
0
]
for
x
in
b
]).
reshape
(
-
1
,
1
,
28
,
28
),
np
.
array
([
x
[
1
]
for
x
in
b
]).
reshape
(
-
1
,
1
)],
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
FLAGS
.
batch_size
,
drop_last
=
True
),
1
,
1
)
with
guard
:
model
=
MNIST
()
optim
=
Momentum
(
learning_rate
=
FLAGS
.
lr
,
momentum
=
.
9
,
parameter_list
=
model
.
parameters
())
inputs
=
[
Input
([
None
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
.
prepare
(
optim
,
CrossEntropy
(),
Accuracy
(
topk
=
(
1
,
2
)),
inputs
,
labels
)
if
FLAGS
.
resume
is
not
None
:
model
.
load
(
FLAGS
.
resume
)
for
e
in
range
(
FLAGS
.
epoch
):
train_loss
=
0.0
val_loss
=
0.0
print
(
"======== train epoch {} ========"
.
format
(
e
))
for
idx
,
batch
in
enumerate
(
train_loader
()):
losses
,
metrics
=
model
.
train
(
batch
[
0
],
batch
[
1
])
train_loss
+=
np
.
sum
(
losses
)
if
idx
%
10
==
0
:
print
(
"{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%"
.
format
(
idx
,
train_loss
/
(
idx
+
1
),
metrics
[
0
][
0
],
metrics
[
0
][
1
]))
for
metric
in
model
.
_metrics
:
res
=
metric
.
accumulate
()
print
(
"train epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}"
.
format
(
e
,
res
[
0
],
res
[
1
]))
metric
.
reset
()
print
(
"======== eval epoch {} ========"
.
format
(
e
))
for
idx
,
batch
in
enumerate
(
val_loader
()):
losses
,
metrics
=
model
.
eval
(
batch
[
0
],
batch
[
1
])
val_loss
+=
np
.
sum
(
losses
)
if
idx
%
10
==
0
:
print
(
"{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%"
.
format
(
idx
,
val_loss
/
(
idx
+
1
),
metrics
[
0
][
0
],
metrics
[
0
][
1
]))
for
metric
in
model
.
_metrics
:
res
=
metric
.
accumulate
()
print
(
"eval epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}"
.
format
(
e
,
res
[
0
],
res
[
1
]))
metric
.
reset
()
model
.
save
(
'mnist_checkpoints/{:02d}'
.
format
(
e
))
init_context
(
'dynamic'
if
FLAGS
.
dynamic
else
'static'
)
train_dataset
=
MnistDataset
(
mode
=
'train'
)
val_dataset
=
MnistDataset
(
mode
=
'test'
)
inputs
=
[
Input
([
None
,
784
],
'float32'
,
name
=
'image'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
=
MNIST
()
optim
=
Momentum
(
learning_rate
=
FLAGS
.
lr
,
momentum
=
.
9
,
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optim
,
CrossEntropy
(),
Accuracy
(
topk
=
(
1
,
2
)),
inputs
,
labels
)
if
FLAGS
.
resume
is
not
None
:
model
.
load
(
FLAGS
.
resume
)
model
.
fit
(
train_dataset
,
val_dataset
,
epochs
=
FLAGS
.
epoch
,
batch_size
=
FLAGS
.
batch_size
)
if
__name__
==
'__main__'
:
...
...
@@ -186,7 +132,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode"
)
parser
.
add_argument
(
"-e"
,
"--epoch"
,
default
=
100
,
type
=
int
,
help
=
"number of epoch"
)
"-e"
,
"--epoch"
,
default
=
2
,
type
=
int
,
help
=
"number of epoch"
)
parser
.
add_argument
(
'--lr'
,
'--learning-rate'
,
...
...
model.py
浏览文件 @
839db7b1
...
...
@@ -23,15 +23,22 @@ import warnings
from
collections
import
Iterable
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
paddle
import
fluid
from
paddle.fluid.framework
import
in_dygraph_mode
,
Variable
from
paddle.fluid.executor
import
global_scope
from
paddle.fluid.io
import
is_belong_to_optimizer
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.incubate.fleet.collective
import
fleet
,
DistributedStrategy
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
import
distributed
from
distributed
import
DistributedBatchSampler
from
paddle.fluid.io
import
DataLoader
from
metrics
import
Metric
from
callbacks
import
config_callbacks
__all__
=
[
'Model'
,
'Loss'
,
'CrossEntropy'
,
'Input'
]
...
...
@@ -78,6 +85,18 @@ def extract_args(func):
return
inspect
.
getargspec
(
func
)[
0
]
def
init_context
(
backend
):
assert
isinstance
(
backend
,
str
)
and
backend
.
lower
()
in
[
'dynamic'
,
'static'
],
\
"Expected backend in ['dynamic', 'static'], but got {}"
.
format
(
backend
)
place
=
fluid
.
CUDAPlace
(
distributed
.
Env
().
dev_id
)
if
\
distributed
.
Env
().
nranks
>
1
else
fluid
.
CUDAPlace
(
0
)
distributed
.
prepare_distributed_context
(
place
)
backend
=
backend
.
lower
()
if
backend
==
'dynamic'
:
fluid
.
enable_dygraph
(
place
)
class
Input
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
shape
=
None
,
dtype
=
None
,
name
=
None
):
super
(
Input
,
self
).
__init__
()
...
...
@@ -136,6 +155,12 @@ class StaticGraphAdapter(object):
self
.
_progs
=
{}
self
.
_compiled_progs
=
{}
self
.
_merge_count
=
{
'eval_total'
:
0
,
'test_total'
:
0
,
'eval_batch'
:
0
,
'test_batch'
:
0
}
self
.
_nranks
=
distributed
.
Env
().
nranks
self
.
_local_rank
=
distributed
.
Env
().
local_rank
@
property
def
mode
(
self
):
return
self
.
model
.
mode
...
...
@@ -336,6 +361,22 @@ class StaticGraphAdapter(object):
metric_states
=
restore_flatten_list
(
rets
[
num_loss
:],
metric_splits
)
metrics
=
[]
for
metric
,
state
in
zip
(
self
.
model
.
_metrics
,
metric_states
):
# cut off padding size
if
self
.
mode
!=
'train'
and
self
.
model
.
_test_dataloader
is
not
None
\
and
isinstance
(
self
.
model
.
_test_dataloader
,
DataLoader
)
\
and
self
.
_nranks
>
1
:
total_size
=
len
(
self
.
model
.
_test_dataloader
.
dataset
)
# TODO: fixme if have better way to get batch size
samples
=
state
[
0
].
shape
[
0
]
current_count
=
self
.
_merge_count
.
get
(
self
.
mode
+
'_total'
,
0
)
if
current_count
+
samples
>=
total_size
:
state
=
[
s
[:
total_size
-
current_count
,
...]
for
s
in
state
]
self
.
_merge_count
[
self
.
mode
+
'_total'
]
=
0
self
.
_merge_count
[
self
.
mode
+
'_batch'
]
=
total_size
-
current_count
else
:
self
.
_merge_count
[
self
.
mode
+
'_total'
]
+=
samples
self
.
_merge_count
[
self
.
mode
+
'_batch'
]
=
samples
metrics
.
append
(
metric
.
update
(
*
state
))
return
(
losses
,
metrics
)
if
len
(
metrics
)
>
0
else
losses
...
...
@@ -364,6 +405,7 @@ class StaticGraphAdapter(object):
# HACK workaround learning rate map issue
lr_var
=
self
.
model
.
_optimizer
.
_learning_rate_map
[
self
.
_orig_prog
]
self
.
model
.
_optimizer
.
_learning_rate_map
[
prog
]
=
lr_var
losses
=
[]
metrics
=
[]
with
fluid
.
program_guard
(
prog
,
self
.
_startup_prog
):
...
...
@@ -375,21 +417,39 @@ class StaticGraphAdapter(object):
lbls
=
self
.
model
.
_labels
if
self
.
model
.
_labels
else
[]
inputs
=
[
k
.
forward
()
for
k
in
to_list
(
ins
)]
labels
=
[
k
.
forward
()
for
k
in
to_list
(
lbls
)]
self
.
_label_vars
[
mode
]
=
labels
outputs
=
to_list
(
self
.
model
.
forward
(
*
inputs
))
if
mode
!=
'test'
:
if
self
.
model
.
_loss_function
:
if
mode
!=
'test'
and
self
.
model
.
_loss_function
:
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
for
metric
in
self
.
model
.
_metrics
:
metrics
.
append
(
to_list
(
metric
.
add_metric_op
(
outputs
,
labels
)))
if
mode
==
'train'
and
self
.
model
.
_optimizer
:
self
.
_loss_endpoint
=
fluid
.
layers
.
sum
(
losses
)
self
.
model
.
_optimizer
.
minimize
(
self
.
_loss_endpoint
)
if
self
.
_nranks
>
1
and
mode
!=
'train'
:
outputs
=
[
distributed
.
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
outputs
]
if
mode
!=
'test'
:
labels
=
[
distributed
.
_all_gather
(
l
,
self
.
_nranks
)
for
l
in
labels
]
if
mode
!=
'test'
:
for
metric
in
self
.
model
.
_metrics
:
metrics
.
append
(
to_list
(
metric
.
add_metric_op
(
outputs
,
labels
)))
if
mode
==
'train'
and
self
.
model
.
_optimizer
:
self
.
_loss_endpoint
=
fluid
.
layers
.
sum
(
losses
)
if
self
.
_nranks
>
1
:
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
fleet
.
init
(
role
)
dist_strategy
=
DistributedStrategy
()
dist_strategy
.
mode
=
"collective"
dist_strategy
.
collective_mode
=
"grad_allreduce"
self
.
model
.
_optimizer
=
fleet
.
distributed_optimizer
(
self
.
model
.
_optimizer
,
strategy
=
dist_strategy
)
self
.
model
.
_optimizer
.
minimize
(
self
.
_loss_endpoint
)
if
mode
!=
'train'
:
# clone again to put it in test mode
prog
=
prog
.
clone
(
for_test
=
True
)
self
.
_input_vars
[
mode
]
=
inputs
self
.
_label_vars
[
mode
]
=
labels
self
.
_progs
[
mode
]
=
prog
self
.
_endpoints
[
mode
]
=
{
"output"
:
outputs
,
...
...
@@ -397,6 +457,7 @@ class StaticGraphAdapter(object):
"metric"
:
metrics
}
def
_compile_and_initialize
(
self
,
prog
,
mode
):
compiled_prog
=
self
.
_compiled_progs
.
get
(
mode
,
None
)
if
compiled_prog
is
not
None
:
...
...
@@ -414,19 +475,30 @@ class StaticGraphAdapter(object):
# even if `forward()` may run different code path for different mode
# therefore startup program only needs to run once
if
self
.
_executor
is
None
:
self
.
_executor
=
fluid
.
Executor
(
places
[
0
])
if
self
.
_nranks
>
1
and
device
.
lower
()
==
'gpu'
:
gpu_id
=
int
(
distributed
.
Env
().
dev_id
)
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
device
.
lower
()
==
'gpu'
else
fluid
.
CPUPlace
()
else
:
place
=
places
[
0
]
self
.
_executor
=
fluid
.
Executor
(
place
)
# XXX incremental initialization
uninitialized
=
[]
for
var_py
in
self
.
_startup_prog
.
list_vars
():
var
=
fluid
.
global_scope
().
find_var
(
var_py
.
name
)
if
var
and
var
.
get_tensor
().
_is_initialized
():
if
not
var_py
.
name
.
startswith
(
'nccl_id'
)
and
var
and
\
var
.
get_tensor
().
_is_initialized
():
continue
uninitialized
.
append
(
var_py
)
if
uninitialized
:
startup_prog
=
self
.
_startup_prog
.
_prune
(
uninitialized
)
self
.
_executor
.
run
(
startup_prog
)
compiled_prog
=
fluid
.
CompiledProgram
(
prog
)
if
self
.
_nranks
<
2
:
compiled_prog
=
fluid
.
CompiledProgram
(
prog
)
else
:
compiled_prog
=
prog
#fleet.main_program
if
len
(
places
)
>
1
:
loss_name
=
None
if
mode
==
'train'
and
self
.
_loss_endpoint
is
not
None
:
...
...
@@ -440,6 +512,13 @@ class DynamicGraphAdapter(object):
def
__init__
(
self
,
model
):
super
(
DynamicGraphAdapter
,
self
).
__init__
()
self
.
model
=
model
self
.
_nranks
=
distributed
.
Env
().
nranks
self
.
_local_rank
=
distributed
.
Env
().
local_rank
self
.
_merge_count
=
{
'eval_total'
:
0
,
'test_total'
:
0
,
'eval_batch'
:
0
,
'test_batch'
:
0
}
if
self
.
_nranks
>
1
:
self
.
ddp_model
=
distributed
.
DistributedDataParallel
(
self
.
model
)
@
property
def
mode
(
self
):
...
...
@@ -458,18 +537,27 @@ class DynamicGraphAdapter(object):
inputs
=
to_list
(
inputs
)
if
labels
is
not
None
:
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
outputs
=
to_list
(
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
.
backward
()
if
self
.
_nranks
>
1
:
outputs
=
self
.
ddp_model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
=
self
.
ddp_model
.
scale_loss
(
final_loss
)
final_loss
.
backward
()
self
.
ddp_model
.
apply_collective_grads
()
else
:
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
.
backward
()
self
.
model
.
_optimizer
.
minimize
(
final_loss
)
self
.
model
.
clear_gradients
()
metrics
=
[]
for
metric
in
self
.
model
.
_metrics
:
metric_outs
=
metric
.
add_metric_op
(
outputs
,
to_list
(
labels
))
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metric_outs
=
metric
.
add_metric_op
(
to_list
(
outputs
)
,
to_list
(
labels
))
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metrics
.
append
(
m
)
return
([
to_numpy
(
l
)
for
l
in
losses
],
metrics
)
\
if
len
(
metrics
)
>
0
else
[
to_numpy
(
l
)
for
l
in
losses
]
...
...
@@ -479,18 +567,34 @@ class DynamicGraphAdapter(object):
inputs
=
to_list
(
inputs
)
if
labels
is
not
None
:
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
outputs
=
to_list
(
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
if
self
.
model
.
_loss_function
:
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
else
:
losses
=
[]
if
self
.
_nranks
>
1
:
outputs
=
[
distributed
.
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
to_list
(
outputs
)]
labels
=
[
distributed
.
_all_gather
(
l
,
self
.
_nranks
)
for
l
in
labels
]
metrics
=
[]
for
metric
in
self
.
model
.
_metrics
:
metric_outs
=
metric
.
add_metric_op
(
outputs
,
labels
)
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
# cut off padding value.
if
self
.
model
.
_test_dataloader
is
not
None
and
self
.
_nranks
>
1
\
and
isinstance
(
self
.
model
.
_test_dataloader
,
DataLoader
):
total_size
=
len
(
self
.
model
.
_test_dataloader
.
dataset
)
samples
=
outputs
[
0
].
shape
[
0
]
current_count
=
self
.
_merge_count
.
get
(
self
.
mode
+
'_total'
,
0
)
if
current_count
+
samples
>=
total_size
:
outputs
=
[
o
[:
total_size
-
metric
.
count
[
0
]]
for
o
in
outputs
]
labels
=
[
l
[:
total_size
-
metric
.
count
[
0
]]
for
l
in
labels
]
self
.
_merge_count
[
self
.
mode
+
'_total'
]
=
0
self
.
_merge_count
[
self
.
mode
+
'_batch'
]
=
total_size
-
current_count
else
:
self
.
_merge_count
[
self
.
mode
+
'_total'
]
+=
samples
self
.
_merge_count
[
self
.
mode
+
'_batch'
]
=
samples
metric_outs
=
metric
.
add_metric_op
(
to_list
(
outputs
),
labels
)
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metrics
.
append
(
m
)
# To be consistent with static graph
...
...
@@ -503,6 +607,8 @@ class DynamicGraphAdapter(object):
self
.
mode
=
'test'
inputs
=
[
to_variable
(
x
)
for
x
in
to_list
(
inputs
)]
outputs
=
self
.
model
.
forward
(
*
inputs
)
if
self
.
_nranks
>
2
:
outputs
=
[
distributed
.
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
to_list
(
outputs
)]
return
[
to_numpy
(
o
)
for
o
in
to_list
(
outputs
)]
def
parameters
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -586,7 +692,15 @@ class Model(fluid.dygraph.Layer):
self
.
_optimizer
=
None
self
.
_device
=
None
self
.
_device_ids
=
None
if
in_dygraph_mode
():
self
.
_optimizer
=
None
self
.
_test_dataloader
=
None
# init multiple gpus context
self
.
_place
=
fluid
.
CUDAPlace
(
distributed
.
Env
().
dev_id
)
\
if
distributed
.
Env
().
nranks
>
1
else
fluid
.
CUDAPlace
(
0
)
# init backend
if
fluid
.
in_dygraph_mode
():
self
.
_adapter
=
DynamicGraphAdapter
(
self
)
else
:
self
.
_adapter
=
StaticGraphAdapter
(
self
)
...
...
@@ -601,7 +715,8 @@ class Model(fluid.dygraph.Layer):
return
self
.
_adapter
.
test
(
*
args
,
**
kwargs
)
def
save
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
if
distributed
.
get_local_rank
()
==
0
:
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
def
load
(
self
,
path
,
skip_mismatch
=
False
,
reset_optimizer
=
False
):
"""
...
...
@@ -714,6 +829,7 @@ class Model(fluid.dygraph.Layer):
the variable to the environment variable and set its value to 1.
The default is None.
"""
self
.
_optimizer
=
optimizer
if
loss_function
:
if
not
isinstance
(
loss_function
,
Loss
):
...
...
@@ -736,6 +852,7 @@ class Model(fluid.dygraph.Layer):
self
.
_inputs
=
inputs
self
.
_labels
=
labels
self
.
_device
=
device
if
device
is
None
:
self
.
_device
=
'GPU'
if
fluid
.
is_compiled_with_cuda
()
else
'CPU'
self
.
_device_ids
=
device_ids
...
...
@@ -744,13 +861,19 @@ class Model(fluid.dygraph.Layer):
def
fit
(
self
,
train_dataset
=
None
,
eval_dataset
=
None
,
train_loader
=
None
,
eval_loader
=
None
,
batch_size
=
1
,
epochs
=
1
,
eval_freq
=
1
,
log_freq
=
10
,
save_freq
=
1
,
verbose
=
2
,
drop_last
=
False
,
shuffle
=
True
,
num_workers
=
0
,
callbacks
=
None
,
):
"""
FIXME: add more comments and usage
...
...
@@ -767,7 +890,43 @@ class Model(fluid.dygraph.Layer):
callbacks (Callback|None): list of `Callback` instances to apply
during training.
"""
assert
train_dataset
is
not
None
or
train_loader
is
not
None
,
\
"train_dataset or train_loader must be given"
assert
(
train_loader
is
not
None
and
train_dataset
is
None
)
or
\
(
train_loader
is
None
and
train_dataset
is
not
None
),
\
"train_dataset should not be set when train_loader is given"
if
fluid
.
in_dygraph_mode
():
feed_list
=
None
else
:
feed_list
=
[
x
.
forward
()
for
x
in
self
.
_inputs
+
self
.
_labels
]
if
train_loader
is
None
:
train_sampler
=
DistributedBatchSampler
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
places
=
self
.
_place
,
feed_list
=
feed_list
,
num_workers
=
num_workers
,
return_list
=
True
)
if
eval_loader
is
None
and
eval_dataset
is
not
None
:
eval_sampler
=
DistributedBatchSampler
(
eval_dataset
,
batch_size
=
batch_size
)
eval_loader
=
DataLoader
(
eval_dataset
,
batch_sampler
=
eval_sampler
,
places
=
self
.
_place
,
feed_list
=
feed_list
,
num_workers
=
num_workers
,
return_list
=
True
)
do_eval
=
eval_loader
is
not
None
self
.
_test_dataloader
=
eval_loader
metrics_name
=
self
.
_metrics_name
()
cbks
=
config_callbacks
(
callbacks
,
...
...
@@ -786,6 +945,12 @@ class Model(fluid.dygraph.Layer):
'metrics_name'
:
metrics_name
,
}
for
step
,
data
in
enumerate
(
data_loader
):
if
not
fluid
.
in_dygraph_mode
():
data
=
data
[
0
]
batch_size
=
data
[
0
].
shape
()[
0
]
else
:
batch_size
=
data
[
0
].
shape
[
0
]
cbks
.
on_batch_begin
(
mode
,
step
,
logs
)
if
mode
==
'train'
:
outs
=
self
.
train
(
*
data
)
...
...
@@ -800,12 +965,16 @@ class Model(fluid.dygraph.Layer):
for
metric
in
self
.
_metrics
:
res
=
metric
.
accumulate
()
metrics
.
extend
(
to_list
(
res
))
assert
len
(
metrics_name
)
==
len
(
metrics
)
for
k
,
v
in
zip
(
metrics_name
,
metrics
):
logs
[
k
]
=
v
logs
[
'step'
]
=
step
logs
[
'batch_size'
]
=
data
[
0
].
shape
[
0
]
if
mode
==
'train'
or
self
.
_adapter
.
_merge_count
.
get
(
mode
+
'_batch'
,
0
)
<=
0
:
logs
[
'batch_size'
]
=
batch_size
*
distributed
.
Env
().
nranks
else
:
logs
[
'batch_size'
]
=
self
.
_adapter
.
_merge_count
[
mode
+
'_batch'
]
cbks
.
on_batch_end
(
mode
,
step
,
logs
)
self
.
_reset_metrics
()
...
...
progressbar.py
浏览文件 @
839db7b1
...
...
@@ -107,7 +107,7 @@ class ProgressBar(object):
eta
=
time_per_unit
*
(
self
.
_num
-
current_num
)
if
eta
>
3600
:
eta_format
=
'%d:%02d:%02d'
%
(
eta
//
3600
,
(
eta
%
3600
)
//
60
,
eta
%
60
)
60
,
eta
%
60
)
elif
eta
>
60
:
eta_format
=
'%d:%02d'
%
(
eta
//
60
,
eta
%
60
)
else
:
...
...
@@ -148,8 +148,8 @@ class ProgressBar(object):
else
:
info
+=
' %.4e'
%
v
elif
isinstance
(
v
,
np
.
ndarray
)
and
\
isinstance
(
v
.
size
,
1
)
and
\
isinstance
(
v
.
dtype
,
(
np
.
float32
,
np
.
float64
)):
isinstance
(
v
.
size
,
1
)
and
\
isinstance
(
v
.
dtype
,
(
np
.
float32
,
np
.
float64
)):
if
abs
(
v
[
0
])
>
1e-3
:
info
+=
' %.4f'
%
v
[
0
]
else
:
...
...
tests/test_model.py
浏览文件 @
839db7b1
...
...
@@ -18,15 +18,21 @@ from __future__ import print_function
import
unittest
import
os
import
sys
sys
.
path
.
append
(
'../'
)
import
numpy
as
np
import
contextlib
import
paddle
from
paddle
import
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
model
import
Model
,
CrossEntropy
,
Input
,
Loss
from
model
import
Model
,
CrossEntropy
,
Input
,
Loss
,
init_context
from
metrics
import
Accuracy
from
callbacks
import
ProgBarLogger
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
from
paddle.fluid.io
import
MNIST
as
MnistDataset
class
SimpleImgConvPool
(
fluid
.
dygraph
.
Layer
):
...
...
@@ -96,6 +102,7 @@ class MNIST(Model):
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
inputs
=
fluid
.
layers
.
reshape
(
inputs
,
[
-
1
,
1
,
28
,
28
])
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
fluid
.
layers
.
flatten
(
x
,
axis
=
1
)
...
...
@@ -103,11 +110,6 @@ class MNIST(Model):
return
x
@
contextlib
.
contextmanager
def
null_guard
():
yield
class
MLP
(
Model
):
def
__init__
(
self
):
super
(
MLP
,
self
).
__init__
()
...
...
@@ -139,31 +141,26 @@ class MyCrossEntropy(Loss):
class
TestModel
(
unittest
.
TestCase
):
def
fit
(
self
,
dynamic
,
is_mlp
=
False
):
im_shape
=
(
-
1
,
784
)
if
is_mlp
else
(
-
1
,
1
,
28
,
28
)
guard
=
fluid
.
dygraph
.
guard
()
if
dynamic
else
null_guard
()
init_context
(
'dynamic'
if
dynamic
else
'static'
)
im_shape
=
(
-
1
,
784
)
batch_size
=
128
train_loader
=
fluid
.
io
.
xmap_readers
(
lambda
b
:
[
np
.
array
([
x
[
0
]
for
x
in
b
]).
reshape
(
im_shape
),
np
.
array
([
x
[
1
]
for
x
in
b
]).
reshape
(
-
1
,
1
)],
paddle
.
batch
(
fluid
.
io
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
6e4
),
batch_size
=
batch_size
,
drop_last
=
True
),
1
,
1
)
val_loader
=
fluid
.
io
.
xmap_readers
(
lambda
b
:
[
np
.
array
([
x
[
0
]
for
x
in
b
]).
reshape
(
im_shape
),
np
.
array
([
x
[
1
]
for
x
in
b
]).
reshape
(
-
1
,
1
)],
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
,
drop_last
=
False
),
1
,
1
)
with
guard
:
inputs
=
[
Input
(
im_shape
,
'float32'
,
name
=
'image'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
=
MNIST
()
if
not
is_mlp
else
MLP
()
optim
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
.
9
,
parameter_list
=
model
.
parameters
())
loss
=
CrossEntropy
()
if
not
is_mlp
else
MyCrossEntropy
()
model
.
prepare
(
optim
,
loss
,
Accuracy
(),
inputs
,
labels
)
cbk
=
ProgBarLogger
(
50
)
model
.
fit
(
train_loader
,
val_loader
,
epochs
=
2
,
callbacks
=
cbk
)
inputs
=
[
Input
(
im_shape
,
'float32'
,
name
=
'image'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
train_dataset
=
MnistDataset
(
mode
=
'train'
)
val_dataset
=
MnistDataset
(
mode
=
'test'
)
model
=
MNIST
()
if
not
is_mlp
else
MLP
()
optim
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
.
9
,
parameter_list
=
model
.
parameters
())
loss
=
CrossEntropy
()
if
not
is_mlp
else
MyCrossEntropy
()
model
.
prepare
(
optim
,
loss
,
Accuracy
(),
inputs
,
labels
)
cbk
=
ProgBarLogger
(
50
)
model
.
fit
(
train_dataset
,
val_dataset
,
epochs
=
2
,
batch_size
=
batch_size
,
callbacks
=
cbk
)
def
test_fit_static
(
self
):
self
.
fit
(
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录