Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
8c4b45c7
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
226
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c4b45c7
编写于
10月 25, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update dist resnet model config
上级
9fd693ba
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
276 addition
and
196 deletion
+276
-196
fluid/image_classification/dist_train/args.py
fluid/image_classification/dist_train/args.py
+0
-118
fluid/image_classification/dist_train/dist_train.py
fluid/image_classification/dist_train/dist_train.py
+112
-78
fluid/image_classification/models/learning_rate.py
fluid/image_classification/models/learning_rate.py
+28
-0
fluid/image_classification/models/resnet_dist.py
fluid/image_classification/models/resnet_dist.py
+136
-0
未找到文件。
fluid/image_classification/dist_train/args.py
已删除
100644 → 0
浏览文件 @
9fd693ba
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
__all__
=
[
'parse_args'
,
]
BENCHMARK_MODELS
=
[
"ResNet50"
,
"ResNet101"
,
"ResNet152"
]
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
'Distributed Image Classification Training.'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
choices
=
BENCHMARK_MODELS
,
default
=
'resnet'
,
help
=
'The model to run benchmark with.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
help
=
'The minibatch size.'
)
# args related to learning rate
parser
.
add_argument
(
'--learning_rate'
,
type
=
float
,
default
=
0.001
,
help
=
'The learning rate.'
)
# TODO(wuyi): add "--use_fake_data" option back.
parser
.
add_argument
(
'--skip_batch_num'
,
type
=
int
,
default
=
5
,
help
=
'The first num of minibatch num to skip, for better performance test'
)
parser
.
add_argument
(
'--iterations'
,
type
=
int
,
default
=
80
,
help
=
'The number of minibatches.'
)
parser
.
add_argument
(
'--pass_num'
,
type
=
int
,
default
=
100
,
help
=
'The number of passes.'
)
parser
.
add_argument
(
'--data_format'
,
type
=
str
,
default
=
'NCHW'
,
choices
=
[
'NCHW'
,
'NHWC'
],
help
=
'The data data_format, now only support NCHW.'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'GPU'
,
choices
=
[
'CPU'
,
'GPU'
],
help
=
'The device type.'
)
parser
.
add_argument
(
'--gpus'
,
type
=
int
,
default
=
1
,
help
=
'If gpus > 1, will use ParallelExecutor to run, else use Executor.'
)
# this option is available only for vgg and resnet.
parser
.
add_argument
(
'--cpus'
,
type
=
int
,
default
=
1
,
help
=
'If cpus > 1, will set ParallelExecutor to use multiple threads.'
)
parser
.
add_argument
(
'--data_set'
,
type
=
str
,
default
=
'flowers'
,
choices
=
[
'cifar10'
,
'flowers'
,
'imagenet'
],
help
=
'Optional dataset for benchmark.'
)
parser
.
add_argument
(
'--no_test'
,
action
=
'store_true'
,
help
=
'If set, do not test the testset during training.'
)
parser
.
add_argument
(
'--memory_optimize'
,
action
=
'store_true'
,
help
=
'If set, optimize runtime memory before start.'
)
parser
.
add_argument
(
'--update_method'
,
type
=
str
,
default
=
'local'
,
choices
=
[
'local'
,
'pserver'
,
'nccl2'
],
help
=
'Choose parameter update method, can be local, pserver, nccl2.'
)
parser
.
add_argument
(
'--no_split_var'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether split variables into blocks when update_method is pserver'
)
parser
.
add_argument
(
'--async_mode'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether start pserver in async mode to support ASGD'
)
parser
.
add_argument
(
'--no_random'
,
action
=
'store_true'
,
help
=
'If set, keep the random seed and do not shuffle the data.'
)
parser
.
add_argument
(
'--reduce_strategy'
,
type
=
str
,
choices
=
[
'reduce'
,
'all_reduce'
],
default
=
'all_reduce'
,
help
=
'Specify the reduce strategy, can be reduce, all_reduce'
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
"../data/ILSVRC2012"
,
help
=
"The ImageNet dataset root dir."
)
args
=
parser
.
parse_args
()
return
args
fluid/image_classification/dist_train/dist_train.py
浏览文件 @
8c4b45c7
...
...
@@ -26,9 +26,82 @@ import six
import
sys
sys
.
path
.
append
(
".."
)
import
models
from
args
import
*
from
reader
import
train
,
val
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
'Distributed Image Classification Training.'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'resnet_dist'
,
help
=
'The model to run.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
help
=
'The minibatch size per device.'
)
parser
.
add_argument
(
'--learning_rate'
,
type
=
float
,
default
=
0.1
,
help
=
'The learning rate.'
)
parser
.
add_argument
(
'--pass_num'
,
type
=
int
,
default
=
90
,
help
=
'The number of passes.'
)
parser
.
add_argument
(
'--data_format'
,
type
=
str
,
default
=
'NCHW'
,
choices
=
[
'NCHW'
,
'NHWC'
],
help
=
'The data data_format, now only support NCHW.'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'GPU'
,
choices
=
[
'CPU'
,
'GPU'
],
help
=
'The device type.'
)
parser
.
add_argument
(
'--gpus'
,
type
=
int
,
default
=
1
,
help
=
'If gpus > 1, will use ParallelExecutor to run, else use Executor.'
)
parser
.
add_argument
(
'--cpus'
,
type
=
int
,
default
=
1
,
help
=
'If cpus > 1, will set ParallelExecutor to use multiple threads.'
)
parser
.
add_argument
(
'--no_test'
,
action
=
'store_true'
,
help
=
'If set, do not test the testset during training.'
)
parser
.
add_argument
(
'--memory_optimize'
,
action
=
'store_true'
,
help
=
'If set, optimize runtime memory before start.'
)
parser
.
add_argument
(
'--update_method'
,
type
=
str
,
default
=
'local'
,
choices
=
[
'local'
,
'pserver'
,
'nccl2'
],
help
=
'Choose parameter update method, can be local, pserver, nccl2.'
)
parser
.
add_argument
(
'--no_split_var'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether split variables into blocks when update_method is pserver'
)
parser
.
add_argument
(
'--async_mode'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether start pserver in async mode to support ASGD'
)
parser
.
add_argument
(
'--reduce_strategy'
,
type
=
str
,
choices
=
[
'reduce'
,
'all_reduce'
],
default
=
'all_reduce'
,
help
=
'Specify the reduce strategy, can be reduce, all_reduce'
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
"../data/ILSVRC2012"
,
help
=
"The ImageNet dataset root dir."
)
args
=
parser
.
parse_args
()
return
args
def
get_model
(
args
,
is_train
,
main_prog
,
startup_prog
):
pyreader
=
None
class_dim
=
1000
...
...
@@ -51,7 +124,7 @@ def get_model(args, is_train, main_prog, startup_prog):
name
=
"train_reader"
if
is_train
else
"test_reader"
,
use_double_buffer
=
True
)
input
,
label
=
fluid
.
layers
.
read_file
(
pyreader
)
model_def
=
models
.
__dict__
[
args
.
model
]()
model_def
=
models
.
__dict__
[
args
.
model
](
is_train
)
predict
=
model_def
.
net
(
input
,
class_dim
=
class_dim
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
...
...
@@ -60,89 +133,64 @@ def get_model(args, is_train, main_prog, startup_prog):
batch_acc1
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
,
k
=
1
)
batch_acc5
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
,
k
=
5
)
# configure optimize
optimizer
=
None
if
is_train
:
start_lr
=
args
.
learning_rate
# n * worker * repeat
end_lr
=
args
.
learning_rate
*
trainer_count
*
args
.
multi_batch_repeat
total_images
=
1281167
/
trainer_count
step
=
int
(
total_images
/
(
args
.
batch_size
*
args
.
gpus
)
+
1
)
epochs
=
[
30
,
60
,
9
0
]
step
=
int
(
total_images
/
(
args
.
batch_size
*
args
.
gpus
*
args
.
multi_batch_repeat
)
+
1
)
warmup_steps
=
step
*
5
# warmup 5 passes
epochs
=
[
30
,
60
,
8
0
]
bd
=
[
step
*
e
for
e
in
epochs
]
base_lr
=
args
.
learning_rate
base_lr
=
end_lr
lr
=
[]
lr
=
[
base_lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
lr
),
learning_rate
=
models
.
learning_rate
.
lr_warmup
(
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
lr
),
warmup_steps
,
start_lr
,
end_lr
),
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
optimizer
.
minimize
(
avg_cost
)
if
args
.
memory_optimize
:
fluid
.
memory_optimize
(
main_prog
)
batched_reader
=
None
pyreader
.
decorate_paddle_reader
(
paddle
.
batch
(
reader
if
args
.
no_random
else
paddle
.
reader
.
shuffle
(
reader
,
buf_size
=
5120
),
reader
,
batch_size
=
args
.
batch_size
))
return
avg_cost
,
optimizer
,
[
batch_acc1
,
batch_acc5
],
batched_reader
,
pyreader
def
append_nccl2_prepare
(
trainer_id
,
startup_prog
):
if
trainer_id
>=
0
:
# append gen_nccl_id at the end of startup program
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
))
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
)
worker_ips
=
os
.
getenv
(
"PADDLE_TRAINER_IPS"
)
worker_endpoints
=
[]
for
ip
in
worker_ips
.
split
(
","
):
worker_endpoints
.
append
(
':'
.
join
([
ip
,
port
]))
num_trainers
=
len
(
worker_endpoints
)
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_IP"
)
+
":"
+
port
worker_endpoints
.
remove
(
current_endpoint
)
nccl_id_var
=
startup_prog
.
global_block
().
create_var
(
name
=
"NCCLID"
,
persistable
=
True
,
type
=
fluid
.
core
.
VarDesc
.
VarType
.
RAW
)
startup_prog
.
global_block
().
append_op
(
type
=
"gen_nccl_id"
,
inputs
=
{},
outputs
=
{
"NCCLID"
:
nccl_id_var
},
attrs
=
{
"endpoint"
:
current_endpoint
,
"endpoint_list"
:
worker_endpoints
,
"trainer_id"
:
trainer_id
})
return
nccl_id_var
,
num_trainers
,
trainer_id
else
:
raise
Exception
(
"must set positive PADDLE_TRAINER_ID env variables for "
"nccl-based dist train."
)
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
))
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
)
worker_ips
=
os
.
getenv
(
"PADDLE_TRAINER_IPS"
)
worker_endpoints
=
[]
for
ip
in
worker_ips
.
split
(
","
):
worker_endpoints
.
append
(
':'
.
join
([
ip
,
port
]))
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_IP"
)
+
":"
+
port
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
mode
=
"nccl2"
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
trainer_id
,
trainers
=
','
.
join
(
worker_endpoints
),
current_endpoint
=
current_endpoint
,
startup_program
=
startup_prog
)
def
dist_transpile
(
trainer_id
,
args
,
train_prog
,
startup_prog
):
if
trainer_id
<
0
:
return
None
,
None
# the port of all pservers, needed by both trainer and pserver
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
,
"6174"
)
# comma separated ips of all pservers, needed by trainer and
# pserver
pserver_ips
=
os
.
getenv
(
"PADDLE_PSERVER_IPS"
,
""
)
eplist
=
[]
for
ip
in
pserver_ips
.
split
(
","
):
eplist
.
append
(
':'
.
join
([
ip
,
port
]))
pserver_endpoints
=
","
.
join
(
eplist
)
# total number of workers/trainers in the job, needed by
# trainer and pserver
trainers
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS"
))
# the IP of the local machine, needed by pserver only
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_IP"
,
""
)
+
":"
+
port
# the role, should be either PSERVER or TRAINER
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
)
config
=
fluid
.
DistributeTranspilerConfig
()
...
...
@@ -150,8 +198,6 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
trainer_id
,
# NOTE: *MUST* use train_prog, for we are using with guard to
# generate different program for train and test.
program
=
train_prog
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
,
...
...
@@ -171,7 +217,7 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
)
def
test_parallel
(
exe
,
test_args
,
args
,
test_prog
,
feeder
):
def
test_parallel
(
exe
,
test_args
,
args
,
test_prog
):
acc_evaluators
=
[]
for
i
in
six
.
moves
.
xrange
(
len
(
test_args
[
2
])):
acc_evaluators
.
append
(
fluid
.
metrics
.
Accuracy
())
...
...
@@ -190,13 +236,10 @@ def test_parallel(exe, test_args, args, test_prog, feeder):
return
[
e
.
eval
()
for
e
in
acc_evaluators
]
# NOTE: only need to benchmark using parallelexe
def
train_parallel
(
train_args
,
test_args
,
args
,
train_prog
,
test_prog
,
startup_prog
,
nccl_id_var
,
num_trainers
,
trainer_id
):
over_all_start
=
time
.
time
()
place
=
core
.
CPUPlace
()
if
args
.
device
==
'CPU'
else
core
.
CUDAPlace
(
0
)
feeder
=
None
if
nccl_id_var
and
trainer_id
==
0
:
#FIXME(wuyi): wait other trainer to start listening
...
...
@@ -237,31 +280,27 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
if
args
.
update_method
==
"pserver"
:
test_scope
=
None
else
:
# NOTE: use an empty scope to avoid test exe using NCCLID
test_scope
=
fluid
.
Scope
()
test_exe
=
fluid
.
ParallelExecutor
(
True
,
main_program
=
test_prog
,
share_vars_from
=
exe
)
True
,
main_program
=
test_prog
,
share_vars_from
=
exe
,
scope
=
test_scope
)
pyreader
=
train_args
[
4
]
for
pass_id
in
range
(
args
.
pass_num
):
num_samples
=
0
iters
=
0
start_time
=
time
.
time
()
batch_id
=
0
pyreader
.
start
()
while
True
:
if
iters
==
args
.
iterations
:
break
if
iters
==
args
.
skip_batch_num
:
start_time
=
time
.
time
()
num_samples
=
0
fetch_list
=
[
avg_loss
.
name
]
acc_name_list
=
[
v
.
name
for
v
in
train_args
[
2
]]
fetch_list
.
extend
(
acc_name_list
)
try
:
fetch_ret
=
exe
.
run
(
fetch_list
)
if
batch_id
%
30
==
0
:
fetch_ret
=
exe
.
run
(
fetch_list
)
else
:
fetch_ret
=
exe
.
run
([])
except
fluid
.
core
.
EOFException
as
eof
:
break
except
fluid
.
core
.
EnforceNotMet
as
ex
:
...
...
@@ -269,20 +308,17 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
break
num_samples
+=
args
.
batch_size
*
args
.
gpus
iters
+=
1
if
batch_id
%
1
==
0
:
if
batch_id
%
30
==
0
:
fetched_data
=
[
np
.
mean
(
np
.
array
(
d
))
for
d
in
fetch_ret
]
print
(
"Pass %d, batch %d, loss %s, accucacys: %s"
%
(
pass_id
,
batch_id
,
fetched_data
[
0
],
fetched_data
[
1
:]))
batch_id
+=
1
print_train_time
(
start_time
,
time
.
time
(),
num_samples
)
pyreader
.
reset
()
# reset reader handle
pyreader
.
reset
()
if
not
args
.
no_test
and
test_args
[
2
]:
test_feeder
=
None
test_ret
=
test_parallel
(
test_exe
,
test_args
,
args
,
test_prog
,
test_feeder
)
test_ret
=
test_parallel
(
test_exe
,
test_args
,
args
,
test_prog
)
print
(
"Pass: %d, Test Accuracy: %s
\n
"
%
(
pass_id
,
[
np
.
mean
(
np
.
array
(
v
))
for
v
in
test_ret
]))
...
...
@@ -316,8 +352,6 @@ def main():
args
=
parse_args
()
print_arguments
(
args
)
print_paddle_envs
()
if
args
.
no_random
:
fluid
.
default_startup_program
().
random_seed
=
1
# the unique trainer id, starting from 0, needed by trainer
# only
...
...
fluid/image_classification/models/learning_rate.py
浏览文件 @
8c4b45c7
...
...
@@ -20,3 +20,31 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120):
decayed_lr
=
learning_rate
*
\
(
ops
.
cos
(
epoch
*
(
math
.
pi
/
epochs
))
+
1
)
/
2
return
decayed_lr
def
lr_warmup
(
learning_rate
,
warmup_steps
,
start_lr
,
end_lr
):
""" Applies linear learning rate warmup for distributed training
Argument learning_rate can be float or a Variable
lr = lr + (warmup_rate * step / warmup_steps)
"""
assert
(
isinstance
(
end_lr
,
float
))
assert
(
isinstance
(
start_lr
,
float
))
linear_step
=
end_lr
-
start_lr
with
fluid
.
default_main_program
().
_lr_schedule_guard
():
lr
=
fluid
.
layers
.
tensor
.
create_global_var
(
shape
=
[
1
],
value
=
0.0
,
dtype
=
'float32'
,
persistable
=
True
,
name
=
"learning_rate_warmup"
)
global_step
=
fluid
.
layers
.
learning_rate_scheduler
.
_decay_step_counter
()
with
fluid
.
layers
.
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
global_step
<
warmup_steps
):
decayed_lr
=
start_lr
+
linear_step
*
(
global_step
/
warmup_steps
)
fluid
.
layers
.
tensor
.
assign
(
decayed_lr
,
lr
)
with
switch
.
default
():
fluid
.
layers
.
tensor
.
assign
(
learning_rate
,
lr
)
return
lr
\ No newline at end of file
fluid/image_classification/models/resnet_dist.py
0 → 100644
浏览文件 @
8c4b45c7
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
math
__all__
=
[
"ResNet"
,
"ResNet50"
,
"ResNet101"
,
"ResNet152"
]
train_parameters
=
{
"input_size"
:
[
3
,
224
,
224
],
"input_mean"
:
[
0.485
,
0.456
,
0.406
],
"input_std"
:
[
0.229
,
0.224
,
0.225
],
"learning_strategy"
:
{
"name"
:
"piecewise_decay"
,
"batch_size"
:
256
,
"epochs"
:
[
30
,
60
,
90
],
"steps"
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
}
}
class
ResNet
():
def
__init__
(
self
,
layers
=
50
,
is_train
=
True
):
self
.
params
=
train_parameters
self
.
layers
=
layers
self
.
is_train
=
is_train
self
.
weight_decay
=
1e-4
def
net
(
self
,
input
,
class_dim
=
1000
):
layers
=
self
.
layers
supported_layers
=
[
50
,
101
,
152
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
num_filters
=
[
64
,
128
,
256
,
512
]
conv
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
64
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
)
conv
=
fluid
.
layers
.
pool2d
(
input
=
conv
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
for
block
in
range
(
len
(
depth
)):
for
i
in
range
(
depth
[
block
]):
conv
=
self
.
bottleneck_block
(
input
=
conv
,
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
)
pool
=
fluid
.
layers
.
pool2d
(
input
=
conv
,
pool_size
=
7
,
pool_type
=
'avg'
,
global_pooling
=
True
)
stdv
=
1.0
/
math
.
sqrt
(
pool
.
shape
[
1
]
*
1.0
)
out
=
fluid
.
layers
.
fc
(
input
=
pool
,
size
=
class_dim
,
act
=
'softmax'
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
),
regularizer
=
fluid
.
regularizer
.
L2Decay
(
self
.
weight_decay
)),
bias_attr
=
fluid
.
ParamAttr
(
regularizer
=
fluid
.
regularizer
.
L2Decay
(
self
.
weight_decay
))
)
return
out
def
conv_bn_layer
(
self
,
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
bn_init_value
=
1.0
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
regularizer
=
fluid
.
regularizer
.
L2Decay
(
self
.
weight_decay
)))
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
is_test
=
not
self
.
is_train
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
bn_init_value
),
regularizer
=
None
))
def
shortcut
(
self
,
input
,
ch_out
,
stride
):
ch_in
=
input
.
shape
[
1
]
if
ch_in
!=
ch_out
or
stride
!=
1
:
return
self
.
conv_bn_layer
(
input
,
ch_out
,
1
,
stride
)
else
:
return
input
def
bottleneck_block
(
self
,
input
,
num_filters
,
stride
):
conv0
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'relu'
)
conv1
=
self
.
conv_bn_layer
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
stride
,
act
=
'relu'
)
# NOTE: default bias is 0.0 already
conv2
=
self
.
conv_bn_layer
(
input
=
conv1
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
act
=
None
,
bn_init_value
=
0.0
)
short
=
self
.
shortcut
(
input
,
num_filters
*
4
,
stride
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
)
def
ResNet50
():
model
=
ResNet
(
layers
=
50
)
return
model
def
ResNet101
():
model
=
ResNet
(
layers
=
101
)
return
model
def
ResNet152
():
model
=
ResNet
(
layers
=
152
)
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录