Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
7710978b
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7710978b
编写于
1月 07, 2020
作者:
Y
Yang Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Touch up demos
上级
470e490b
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
172 addition
and
105 deletion
+172
-105
mnist.py
mnist.py
+98
-56
resnet.py
resnet.py
+43
-23
yolov3.py
yolov3.py
+31
-26
未找到文件。
mnist.py
浏览文件 @
7710978b
...
...
@@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
contextlib
import
os
import
numpy
as
np
import
paddle
from
paddle
import
fluid
from
paddle.fluid.optimizer
import
Momentum
Optimizer
from
paddle.fluid.optimizer
import
Momentum
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
model
import
Model
,
shape_hints
,
CrossEntropy
from
model
import
Model
,
CrossEntropy
class
SimpleImgConvPool
(
fluid
.
dygraph
.
Layer
):
...
...
@@ -90,14 +95,7 @@ class MNIST(Model):
loc
=
0.0
,
scale
=
scale
)),
act
=
"softmax"
)
@
shape_hints
(
inputs
=
[
None
,
1
,
28
,
28
])
def
forward
(
self
,
inputs
):
if
self
.
mode
==
'test'
:
# XXX demo purpose
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
fluid
.
layers
.
flatten
(
x
,
axis
=
1
)
x
=
self
.
_fc
(
x
)
else
:
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
fluid
.
layers
.
flatten
(
x
,
axis
=
1
)
...
...
@@ -105,52 +103,96 @@ class MNIST(Model):
return
x
@
contextlib
.
contextmanager
def
null_guard
():
def
accuracy
(
pred
,
label
,
topk
=
(
1
,
)):
maxk
=
max
(
topk
)
pred
=
np
.
argsort
(
pred
)[:,
::
-
1
][:,
:
maxk
]
correct
=
(
pred
==
np
.
repeat
(
label
,
maxk
,
1
))
batch_size
=
label
.
shape
[
0
]
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:,
:
k
].
sum
()
res
.
append
(
100.0
*
correct_k
/
batch_size
)
return
res
def
main
():
@
contextlib
.
contextmanager
def
null_guard
():
yield
guard
=
fluid
.
dygraph
.
guard
()
if
FLAGS
.
dynamic
else
null_guard
()
if
__name__
==
'__main__'
:
import
sys
if
len
(
sys
.
argv
)
>
1
and
sys
.
argv
[
1
]
==
'--dynamic'
:
guard
=
fluid
.
dygraph
.
guard
()
else
:
guard
=
null_guard
()
if
not
os
.
path
.
exists
(
'mnist_checkpoints'
):
os
.
mkdir
(
'mnist_checkpoints'
)
with
guard
:
train_loader
=
fluid
.
io
.
xmap_readers
(
lambda
b
:
[
np
.
array
([
x
[
0
]
for
x
in
b
]).
reshape
(
-
1
,
1
,
28
,
28
),
np
.
array
([
x
[
1
]
for
x
in
b
]).
reshape
(
-
1
,
1
)],
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(
),
batch_size
=
4
,
drop_last
=
True
),
1
,
1
)
test
_loader
=
fluid
.
io
.
xmap_readers
(
paddle
.
batch
(
fluid
.
io
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
6e4
),
batch_size
=
FLAGS
.
batch_size
,
drop_last
=
True
),
1
,
1
)
val
_loader
=
fluid
.
io
.
xmap_readers
(
lambda
b
:
[
np
.
array
([
x
[
0
]
for
x
in
b
]).
reshape
(
-
1
,
1
,
28
,
28
),
np
.
array
([
x
[
1
]
for
x
in
b
]).
reshape
(
-
1
,
1
)],
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
4
,
drop_last
=
True
),
1
,
1
)
batch_size
=
FLAGS
.
batch_size
,
drop_last
=
True
),
1
,
1
)
device_ids
=
list
(
range
(
FLAGS
.
num_devices
))
with
guard
:
model
=
MNIST
()
sgd
=
MomentumOptimizer
(
learning_rate
=
1e-3
,
momentum
=
0
.9
,
optim
=
Momentum
(
learning_rate
=
FLAGS
.
lr
,
momentum
=
.
9
,
parameter_list
=
model
.
parameters
())
# sgd = SGDOptimizer(learning_rate=1e-3)
model
.
prepare
(
sgd
,
CrossEntropy
())
for
e
in
range
(
2
):
model
.
prepare
(
optim
,
CrossEntropy
())
if
FLAGS
.
resume
is
not
None
:
model
.
load
(
FLAGS
.
resume
)
for
e
in
range
(
FLAGS
.
epoch
):
train_loss
=
0.0
train_acc
=
0.0
val_loss
=
0.0
val_acc
=
0.0
print
(
"======== train epoch {} ========"
.
format
(
e
))
for
idx
,
batch
in
enumerate
(
train_loader
()):
out
,
loss
=
model
.
train
(
batch
[
0
],
batch
[
1
],
device
=
'gpu'
,
device_ids
=
[
0
,
1
,
2
,
3
])
print
(
"=============== output ========="
)
print
(
out
)
print
(
"=============== loss ==========="
)
print
(
loss
)
if
idx
>
10
:
model
.
save
(
"test.{}"
.
format
(
e
))
break
print
(
"==== switch to test mode ====="
)
for
idx
,
batch
in
enumerate
(
test_loader
()):
out
=
model
.
test
(
batch
[
0
],
device
=
'gpu'
,
device_ids
=
[
0
,
1
,
2
,
3
])
print
(
out
)
if
idx
>
10
:
break
model
.
load
(
"test.1"
)
outputs
,
losses
=
model
.
train
(
batch
[
0
],
batch
[
1
],
device
=
'gpu'
,
device_ids
=
device_ids
)
acc
=
accuracy
(
outputs
[
0
],
batch
[
1
])[
0
]
train_loss
+=
np
.
sum
(
losses
)
train_acc
+=
acc
if
idx
%
10
==
0
:
print
(
"{:04d}: loss {:0.3f} top1: {:0.3f}%"
.
format
(
idx
,
train_loss
/
(
idx
+
1
),
train_acc
/
(
idx
+
1
)))
print
(
"======== eval epoch {} ========"
.
format
(
e
))
for
idx
,
batch
in
enumerate
(
val_loader
()):
outputs
,
losses
=
model
.
eval
(
batch
[
0
],
batch
[
1
],
device
=
'gpu'
,
device_ids
=
device_ids
)
acc
=
accuracy
(
outputs
[
0
],
batch
[
1
])[
0
]
val_loss
+=
np
.
sum
(
losses
)
val_acc
+=
acc
if
idx
%
10
==
0
:
print
(
"{:04d}: loss {:0.3f} top1: {:0.3f}%"
.
format
(
idx
,
val_loss
/
(
idx
+
1
),
val_acc
/
(
idx
+
1
)))
model
.
save
(
'mnist_checkpoints/{:02d}'
.
format
(
e
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
"CNN training on MNIST"
)
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode"
)
parser
.
add_argument
(
"-e"
,
"--epoch"
,
default
=
100
,
type
=
int
,
help
=
"number of epoch"
)
parser
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
1e-3
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate'
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
default
=
128
,
type
=
int
,
help
=
"batch size"
)
parser
.
add_argument
(
"-n"
,
"--num_devices"
,
default
=
4
,
type
=
int
,
help
=
"number of devices"
)
parser
.
add_argument
(
"-r"
,
"--resume"
,
default
=
None
,
type
=
str
,
help
=
"checkpoint path to resume"
)
FLAGS
=
parser
.
parse_args
()
main
()
resnet.py
浏览文件 @
7710978b
...
...
@@ -13,12 +13,14 @@
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
contextlib
import
math
import
os
import
random
import
time
import
cv2
import
numpy
as
np
...
...
@@ -184,19 +186,23 @@ class ResNet(Model):
def
make_optimizer
(
parameter_list
=
None
):
total_images
=
1281167
base_lr
=
0.1
base_lr
=
FLAGS
.
lr
momentum
=
0.9
l2_decay
=
1e-4
weight_decay
=
1e-4
step_per_epoch
=
int
(
math
.
floor
(
float
(
total_images
)
/
FLAGS
.
batch_size
))
boundaries
=
[
step_per_epoch
*
e
for
e
in
[
30
,
60
,
80
]]
lr
=
[
base_lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
boundaries
)
+
1
)]
values
=
[
base_lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
boundaries
)
+
1
)]
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
boundaries
,
values
=
values
)
learning_rate
=
fluid
.
layers
.
linear_lr_warmup
(
learning_rate
=
learning_rate
,
warmup_steps
=
5
*
step_per_epoch
,
start_lr
=
0.
,
end_lr
=
base_lr
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
boundaries
,
values
=
lr
),
learning_rate
=
learning_rate
,
momentum
=
momentum
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
l2
_decay
),
regularization
=
fluid
.
regularizer
.
L2Decay
(
weight
_decay
),
parameter_list
=
parameter_list
)
return
optimizer
...
...
@@ -293,10 +299,14 @@ def image_folder(path, shuffle=False):
def
run
(
model
,
loader
,
mode
=
'train'
):
total_loss
=
0.0
total_acc1
=
0.0
total_acc5
=
0.0
total_loss
=
0.
total_acc1
=
0.
total_acc5
=
0.
total_time
=
0.
start
=
time
.
time
()
device_ids
=
list
(
range
(
FLAGS
.
num_devices
))
start
=
time
.
time
()
for
idx
,
batch
in
enumerate
(
loader
()):
outputs
,
losses
=
getattr
(
model
,
mode
)(
batch
[
0
],
batch
[
1
],
device
=
'gpu'
,
device_ids
=
device_ids
)
...
...
@@ -305,10 +315,14 @@ def run(model, loader, mode='train'):
total_loss
+=
np
.
sum
(
losses
)
total_acc1
+=
top1
total_acc5
+=
top5
if
idx
>
1
:
# skip first two steps
total_time
+=
time
.
time
()
-
start
if
idx
%
10
==
0
:
print
(
"{:04d}: loss {:0.3f} top1: {:0.3f}% top5: {:0.3f}%"
.
format
(
print
((
"{:04d} loss: {:0.3f} top1: {:0.3f}% top5: {:0.3f}% "
"time: {:0.3f}"
).
format
(
idx
,
total_loss
/
(
idx
+
1
),
total_acc1
/
(
idx
+
1
),
total_acc5
/
(
idx
+
1
)))
total_acc5
/
(
idx
+
1
),
total_time
/
max
(
1
,
(
idx
-
1
))))
start
=
time
.
time
()
def
main
():
...
...
@@ -318,10 +332,7 @@ def main():
epoch
=
FLAGS
.
epoch
batch_size
=
FLAGS
.
batch_size
if
FLAGS
.
dynamic
:
guard
=
fluid
.
dygraph
.
guard
()
else
:
guard
=
null_guard
()
guard
=
fluid
.
dygraph
.
guard
()
if
FLAGS
.
dynamic
else
null_guard
()
train_dir
=
os
.
path
.
join
(
FLAGS
.
data
,
'train'
)
val_dir
=
os
.
path
.
join
(
FLAGS
.
data
,
'val'
)
...
...
@@ -352,18 +363,20 @@ def main():
batch_size
=
batch_size
),
process_num
=
2
,
buffer_size
=
4
)
if
not
os
.
path
.
exists
(
'checkpoints'
):
os
.
mkdir
(
'checkpoints'
)
if
not
os
.
path
.
exists
(
'
resnet_
checkpoints'
):
os
.
mkdir
(
'
resnet_
checkpoints'
)
with
guard
:
model
=
ResNet
()
optim
=
make_optimizer
(
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optim
,
CrossEntropy
())
if
FLAGS
.
resume
is
not
None
:
model
.
load
(
FLAGS
.
resume
)
for
e
in
range
(
epoch
):
print
(
"======== train epoch {} ========"
.
format
(
e
))
run
(
model
,
train_loader
)
model
.
save
(
'checkpoints/{:02d}'
.
format
(
e
))
model
.
save
(
'
resnet_
checkpoints/{:02d}'
.
format
(
e
))
print
(
"======== eval epoch {} ========"
.
format
(
e
))
run
(
model
,
val_loader
,
mode
=
'eval'
)
...
...
@@ -372,13 +385,20 @@ if __name__ == '__main__':
parser
=
argparse
.
ArgumentParser
(
"Resnet Training on ImageNet"
)
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
help
=
'path to dataset '
'(should have subdirectories named "train" and "val"'
)
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode"
)
parser
.
add_argument
(
"-e"
,
"--epoch"
,
default
=
90
,
type
=
int
,
help
=
"number of epoch"
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
default
=
512
,
type
=
int
,
help
=
"batch size"
)
'--lr'
,
'--learning-rate'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate'
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
default
=
256
,
type
=
int
,
help
=
"batch size"
)
parser
.
add_argument
(
"-n"
,
"--num_devices"
,
default
=
4
,
type
=
int
,
help
=
"number of devices"
)
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode"
)
"-r"
,
"--resume"
,
default
=
None
,
type
=
str
,
help
=
"checkpoint path to resume"
)
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
data
,
"error: must provide data path"
main
()
yolov3.py
浏览文件 @
7710978b
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
contextlib
...
...
@@ -20,6 +21,8 @@ import os
import
random
import
time
from
functools
import
partial
import
cv2
import
numpy
as
np
from
pycocotools.coco
import
COCO
...
...
@@ -143,7 +146,7 @@ class YOLOv3(Model):
act
=
'leaky_relu'
))
self
.
route_blocks
.
append
(
route
)
@
shape_hints
(
inputs
=
[
None
,
3
,
None
,
None
])
@
shape_hints
(
inputs
=
[
None
,
3
,
None
,
None
]
,
im_shape
=
[
None
,
2
]
)
def
forward
(
self
,
inputs
,
im_shape
):
outputs
=
[]
boxes
=
[]
...
...
@@ -239,25 +242,22 @@ class YoloLoss(Loss):
def
make_optimizer
(
parameter_list
=
None
):
base_lr
=
0.001
boundaries
=
[
400000
,
450000
]
base_lr
=
FLAGS
.
lr
warm_up_iter
=
4000
momentum
=
0.9
weight_decay
=
5e-4
boundaries
=
[
400000
,
450000
]
values
=
[
base_lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
boundaries
)
+
1
)]
lr
=
fluid
.
layers
.
piecewise_decay
(
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
boundaries
,
values
=
values
)
lr
=
fluid
.
layers
.
linear_lr_warmup
(
learning_rate
=
lr
,
learning_rate
=
fluid
.
layers
.
linear_lr_warmup
(
learning_rate
=
learning_rate
,
warmup_steps
=
warm_up_iter
,
start_lr
=
0.0
,
end_lr
=
base_lr
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
l
r
,
learning_rate
=
l
earning_rate
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
weight_decay
),
momentum
=
momentum
,
parameter_list
=
parameter_list
)
...
...
@@ -392,7 +392,7 @@ def batch_transform(batch, mode='train'):
im_shapes
=
np
.
full
([
len
(
imgs
),
2
],
d
,
dtype
=
np
.
int32
)
gt_boxes
=
np
.
array
(
gt_boxes
)
gt_labels
=
np
.
array
(
gt_labels
)
# XXX since mix up is not used, scores are all
1
s
# XXX since mix up is not used, scores are all
one
s
gt_scores
=
np
.
ones_like
(
gt_labels
,
dtype
=
np
.
float32
)
return
[
imgs
,
im_shapes
],
[
gt_boxes
,
gt_labels
,
gt_scores
]
...
...
@@ -451,20 +451,21 @@ def coco2017(root_dir, mode='train'):
# XXX coco metrics not included for simplicity
def
run
(
model
,
loader
,
mode
=
'train'
):
total_loss
=
0.
0
total_loss
=
0.
total_time
=
0.
device_ids
=
list
(
range
(
FLAGS
.
num_devices
))
start
=
time
.
time
()
for
idx
,
batch
in
enumerate
(
loader
()):
outputs
,
losses
=
getattr
(
model
,
mode
)(
batch
[
0
],
batch
[
1
],
device
=
'gpu'
,
device_ids
=
device_ids
)
total_loss
+=
np
.
sum
(
losses
)
if
idx
>
1
:
# skip first two step
if
idx
>
1
:
# skip first two step
s
total_time
+=
time
.
time
()
-
start
if
idx
%
10
==
0
:
print
(
"{:04d}: loss {:0.3f} time: {:0.3f}"
.
format
(
idx
,
total_loss
/
(
idx
+
1
),
total_time
/
(
idx
-
1
)))
idx
,
total_loss
/
(
idx
+
1
),
total_time
/
max
(
1
,
(
idx
-
1
)
)))
start
=
time
.
time
()
...
...
@@ -475,16 +476,13 @@ def main():
epoch
=
FLAGS
.
epoch
batch_size
=
FLAGS
.
batch_size
if
FLAGS
.
dynamic
:
guard
=
fluid
.
dygraph
.
guard
()
else
:
guard
=
null_guard
()
guard
=
fluid
.
dygraph
.
guard
()
if
FLAGS
.
dynamic
else
null_guard
()
train_loader
=
fluid
.
io
.
xmap_readers
(
lambda
batch
:
batch_transform
(
batch
,
'train'
)
,
batch_transform
,
paddle
.
batch
(
fluid
.
io
.
xmap_readers
(
lambda
inputs
:
sample_transform
(
inputs
,
'train'
)
,
sample_transform
,
coco2017
(
FLAGS
.
data
,
'train'
),
process_num
=
8
,
buffer_size
=
4
*
batch_size
),
...
...
@@ -492,11 +490,14 @@ def main():
drop_last
=
True
),
process_num
=
2
,
buffer_size
=
4
)
val_sample_transform
=
partial
(
sample_transform
,
mode
=
'val'
)
val_batch_transform
=
partial
(
batch_transform
,
mode
=
'val'
)
val_loader
=
fluid
.
io
.
xmap_readers
(
lambda
batch
:
batch_transform
(
batch
,
'train'
)
,
val_batch_transform
,
paddle
.
batch
(
fluid
.
io
.
xmap_readers
(
lambda
inputs
:
sample_transform
(
inputs
,
'val'
)
,
val_sample_transform
,
coco2017
(
FLAGS
.
data
,
'val'
),
process_num
=
8
,
buffer_size
=
4
*
batch_size
),
...
...
@@ -517,7 +518,7 @@ def main():
for
e
in
range
(
epoch
):
print
(
"======== train epoch {} ========"
.
format
(
e
))
run
(
model
,
train_loader
)
model
.
save
(
'checkpoints/{:02d}'
.
format
(
e
))
model
.
save
(
'
yolo_
checkpoints/{:02d}'
.
format
(
e
))
print
(
"======== eval epoch {} ========"
.
format
(
e
))
run
(
model
,
val_loader
,
mode
=
'eval'
)
...
...
@@ -525,16 +526,20 @@ def main():
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
"Yolov3 Training on COCO"
)
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
help
=
'path to COCO dataset'
)
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode"
)
parser
.
add_argument
(
"-e"
,
"--epoch"
,
default
=
300
,
type
=
int
,
help
=
"number of epoch"
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
default
=
32
,
type
=
int
,
help
=
"batch size"
)
'--lr'
,
'--learning-rate'
,
default
=
0.001
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate'
)
parser
.
add_argument
(
"-
n"
,
"--num_devices"
,
default
=
8
,
type
=
int
,
help
=
"number of devices
"
)
"-
b"
,
"--batch_size"
,
default
=
64
,
type
=
int
,
help
=
"batch size
"
)
parser
.
add_argument
(
"-
d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode
"
)
"-
n"
,
"--num_devices"
,
default
=
8
,
type
=
int
,
help
=
"number of devices
"
)
parser
.
add_argument
(
"-w"
,
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"path to pretrained weights"
)
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
data
,
"error: must provide data path"
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录