Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
5c490902
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5c490902
编写于
12月 17, 2019
作者:
Y
Yang Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add model API and demo
上级
725e3e65
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
469 addition
and
0 deletion
+469
-0
mnist.py
mnist.py
+135
-0
model.py
model.py
+334
-0
未找到文件。
mnist.py
0 → 100644
浏览文件 @
5c490902
import
contextlib
import
numpy
as
np
import
paddle
from
paddle
import
fluid
from
paddle.fluid.optimizer
import
MomentumOptimizer
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
FC
from
model
import
Model
,
shape_hints
class
SimpleImgConvPool
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
filter_size
,
pool_size
,
pool_stride
,
pool_padding
=
0
,
pool_type
=
'max'
,
global_pooling
=
False
,
conv_stride
=
1
,
conv_padding
=
0
,
conv_dilation
=
1
,
conv_groups
=
None
,
act
=
None
,
use_cudnn
=
False
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
SimpleImgConvPool
,
self
).
__init__
(
'SimpleConv'
)
self
.
_conv2d
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
conv_stride
,
padding
=
conv_padding
,
dilation
=
conv_dilation
,
groups
=
conv_groups
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
use_cudnn
)
self
.
_pool2d
=
Pool2D
(
pool_size
=
pool_size
,
pool_type
=
pool_type
,
pool_stride
=
pool_stride
,
pool_padding
=
pool_padding
,
global_pooling
=
global_pooling
,
use_cudnn
=
use_cudnn
)
def
forward
(
self
,
inputs
):
x
=
self
.
_conv2d
(
inputs
)
x
=
self
.
_pool2d
(
x
)
return
x
class
MNIST
(
Model
):
def
__init__
(
self
):
super
(
MNIST
,
self
).
__init__
()
self
.
_simple_img_conv_pool_1
=
SimpleImgConvPool
(
1
,
20
,
5
,
2
,
2
,
act
=
"relu"
)
self
.
_simple_img_conv_pool_2
=
SimpleImgConvPool
(
20
,
50
,
5
,
2
,
2
,
act
=
"relu"
)
pool_2_shape
=
50
*
4
*
4
SIZE
=
10
scale
=
(
2.0
/
(
pool_2_shape
**
2
*
SIZE
))
**
0.5
self
.
_fc
=
FC
(
self
.
full_name
(),
10
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
scale
)),
act
=
"softmax"
)
@
shape_hints
(
inputs
=
[
None
,
1
,
28
,
28
])
def
forward
(
self
,
inputs
):
if
self
.
mode
==
'test'
:
# XXX demo purpose
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
else
:
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
self
.
_fc
(
x
)
return
x
@
contextlib
.
contextmanager
def
null_guard
():
yield
if
__name__
==
'__main__'
:
import
sys
if
len
(
sys
.
argv
)
>
1
and
sys
.
argv
[
1
]
==
'--dynamic'
:
guard
=
fluid
.
dygraph
.
guard
()
else
:
guard
=
null_guard
()
with
guard
:
# sgd = SGDOptimizer(learning_rate=1e-3)
sgd
=
MomentumOptimizer
(
learning_rate
=
1e-3
,
momentum
=
0.9
)
train_loader
=
fluid
.
io
.
xmap_readers
(
lambda
b
:
[
np
.
array
([
x
[
0
]
for
x
in
b
]).
reshape
(
-
1
,
1
,
28
,
28
),
np
.
array
([
x
[
1
]
for
x
in
b
]).
reshape
(
-
1
,
1
)],
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
4
,
drop_last
=
True
),
1
,
1
)
test_loader
=
fluid
.
io
.
xmap_readers
(
lambda
b
:
[
np
.
array
([
x
[
0
]
for
x
in
b
]).
reshape
(
-
1
,
1
,
28
,
28
),
np
.
array
([
x
[
1
]
for
x
in
b
]).
reshape
(
-
1
,
1
)],
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
4
,
drop_last
=
True
),
1
,
1
)
model
=
MNIST
()
model
.
prepare
(
sgd
,
'cross_entropy'
)
for
e
in
range
(
2
):
for
idx
,
batch
in
enumerate
(
train_loader
()):
out
=
model
.
train
(
batch
[
0
],
batch
[
1
],
device
=
'gpu'
,
device_ids
=
[
0
,
1
,
2
,
3
])
print
(
out
)
if
idx
>
10
:
model
.
save
(
"test.{}"
.
format
(
e
))
break
print
(
"==== switch to test mode ====="
)
for
idx
,
batch
in
enumerate
(
test_loader
()):
out
=
model
.
test
(
batch
[
0
],
device
=
'gpu'
,
device_ids
=
[
0
,
1
,
2
,
3
])
print
(
out
)
if
idx
>
10
:
break
model
.
load
(
"test.1"
)
model.py
0 → 100644
浏览文件 @
5c490902
from
__future__
import
absolute_import
import
inspect
from
collections
import
OrderedDict
import
numpy
as
np
from
paddle
import
fluid
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.fluid.dygraph.base
import
to_variable
__all__
=
[
'Model'
,
'shape_hints'
]
LOSS_DTYPE_MAP
=
{
'cross_entropy'
:
'int64'
}
def
to_list
(
value
):
if
isinstance
(
value
,
(
list
,
tuple
)):
return
value
return
[
value
]
def
extract_args
(
func
):
if
hasattr
(
inspect
,
'getfullargspec'
):
return
inspect
.
getfullargspec
(
func
)[
0
]
else
:
return
inspect
.
getargspec
(
func
)[
0
]
def
shape_hints
(
**
hints
):
assert
hints
,
"hints can not be empty"
assert
all
(
isinstance
(
h
,
(
list
,
tuple
))
for
h
in
hints
.
values
()),
\
"shape hint must be a list or tuple"
def
wrapper
(
func
):
args
=
extract_args
(
func
)
invalid
=
set
(
hints
.
keys
())
-
set
(
args
)
assert
not
invalid
,
\
"shape hint for arguments that are not present in forward method"
\
+
": ({})"
.
format
(
", "
.
join
(
invalid
))
func
.
shape_hints
=
hints
return
func
return
wrapper
class
StaticGraphAdapter
(
object
):
def
__init__
(
self
,
model
):
super
(
StaticGraphAdapter
,
self
).
__init__
()
self
.
model
=
model
# with `_build_once` gone, parameters are now created in `__init__`
# so we need to keep track of the parameters already created
self
.
_startup_prog
=
fluid
.
default_startup_program
()
self
.
_main_prog
=
fluid
.
default_main_program
()
# HACK separate models by cleanup global scope
self
.
_scope
=
fluid
.
executor
.
global_scope
()
fluid
.
executor
.
g_scope
=
fluid
.
core
.
Scope
()
self
.
_label_vars
=
None
# label variables
self
.
_endpoints
=
{}
self
.
_loss_endpoint
=
None
self
.
_executor
=
None
self
.
_progs
=
{}
self
.
_compiled_progs
=
{}
# parse shape hints
self
.
_input_desc
=
OrderedDict
([
(
n
,
None
)
for
n
in
extract_args
(
self
.
model
.
forward
)
if
n
!=
'self'
])
if
hasattr
(
self
.
model
.
forward
,
'shape_hints'
):
self
.
_input_desc
.
update
(
self
.
model
.
forward
.
shape_hints
)
@
property
def
mode
(
self
):
return
self
.
model
.
mode
@
mode
.
setter
def
mode
(
self
,
value
):
self
.
model
.
mode
=
value
def
train
(
self
,
inputs
,
labels
,
device
=
'CPU'
,
device_ids
=
None
):
assert
self
.
model
.
_optimizer
and
self
.
model
.
_loss_functions
,
\
"model not ready, please call `model.prepare()` first"
self
.
mode
=
'train'
return
self
.
_run
(
inputs
,
labels
,
device
,
device_ids
)
def
eval
(
self
,
inputs
,
labels
,
device
=
'CPU'
,
device_ids
=
None
):
assert
self
.
model
.
_loss_functions
,
\
"model not ready, please call `model.prepare()` first"
self
.
mode
=
'eval'
return
self
.
_run
(
inputs
,
labels
,
device
,
device_ids
)
def
test
(
self
,
inputs
,
device
=
'CPU'
,
device_ids
=
None
):
self
.
mode
=
'test'
return
self
.
_run
(
inputs
,
None
,
device
,
device_ids
)
def
save
(
self
,
path
):
prog
=
self
.
_progs
.
get
(
'train'
,
None
)
if
prog
is
None
or
self
.
model
.
_optimizer
is
None
:
print
(
"optimizer not initialized, save parameters only"
)
prog
=
self
.
_main_prog
with
fluid
.
executor
.
scope_guard
(
self
.
_scope
):
fluid
.
save
(
prog
,
path
)
def
load
(
self
,
path
):
prog
=
self
.
_main_prog
with
fluid
.
executor
.
scope_guard
(
self
.
_scope
):
fluid
.
load
(
prog
,
path
,
self
.
_executor
)
def
_run
(
self
,
inputs
,
labels
=
None
,
device
=
'CPU'
,
device_ids
=
None
):
inputs
=
to_list
(
inputs
)
if
labels
is
not
None
:
labels
=
to_list
(
labels
)
assert
len
(
inputs
)
==
len
(
self
.
_input_desc
),
"number of inputs"
\
+
" does not match number of arguments of `forward` method"
with
fluid
.
executor
.
scope_guard
(
self
.
_scope
):
if
self
.
_progs
.
get
(
self
.
mode
,
None
)
is
None
:
self
.
_make_program
(
self
.
_infer_input_vars
(
inputs
))
ids
=
[
str
(
i
)
for
i
in
device_ids
]
ids
.
sort
()
prog_hash
=
'_'
.
join
([
self
.
mode
]
+
ids
)
compiled_prog
=
self
.
_compiled_progs
.
get
(
prog_hash
,
None
)
if
compiled_prog
is
None
:
compiled_prog
=
self
.
_compile_and_initialize
(
self
.
_progs
[
self
.
mode
],
device
,
device_ids
)
self
.
_compiled_progs
[
prog_hash
]
=
compiled_prog
feed
=
{}
input_names
=
[
name
for
name
in
self
.
_input_desc
.
keys
()]
for
idx
,
n
in
enumerate
(
input_names
):
# train and test may take different arguments
if
inputs
[
idx
]
is
not
None
:
feed
[
n
]
=
inputs
[
idx
]
if
labels
is
not
None
:
for
idx
,
v
in
enumerate
(
self
.
_label_vars
):
feed
[
v
.
name
]
=
labels
[
idx
]
outputs
=
self
.
_executor
.
run
(
compiled_prog
,
scope
=
self
.
_scope
,
feed
=
feed
,
fetch_list
=
self
.
_endpoints
[
self
.
mode
])
return
outputs
def
_make_program
(
self
,
inputs
):
prog
=
self
.
_main_prog
.
clone
(
self
.
mode
!=
'train'
)
with
fluid
.
program_guard
(
prog
,
self
.
_startup_prog
):
outputs
=
to_list
(
self
.
model
.
forward
(
*
inputs
))
label_vars
=
[]
if
self
.
mode
!=
'test'
:
losses
=
[]
for
o
,
l
in
zip
(
outputs
,
self
.
model
.
_loss_functions
):
if
l
is
None
:
continue
label_var
=
self
.
_infer_label_var
(
o
,
l
)
label_vars
.
append
(
label_var
)
loss_fn
=
getattr
(
fluid
.
layers
,
l
)
loss
=
loss_fn
(
o
,
label_var
)
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
outputs
=
losses
if
self
.
mode
==
'train'
:
self
.
_label_vars
=
label_vars
self
.
_loss_endpoint
=
fluid
.
layers
.
sum
(
losses
)
self
.
model
.
_optimizer
.
minimize
(
self
.
_loss_endpoint
)
self
.
_progs
[
self
.
mode
]
=
prog
self
.
_endpoints
[
self
.
mode
]
=
outputs
def
_infer_input_vars
(
self
,
inputs
):
input_vars
=
[]
for
idx
,
i
in
enumerate
(
inputs
):
if
i
is
None
:
# train and test may take different arguments
input_vars
.
append
(
None
)
continue
ndarray
=
np
.
array
(
i
)
name
=
list
(
self
.
_input_desc
.
keys
())[
idx
]
shape
=
list
(
self
.
_input_desc
.
values
())[
idx
]
if
shape
is
None
:
shape
=
(
None
,
)
+
ndarray
.
shape
[
1
:]
input_vars
.
append
(
fluid
.
data
(
name
,
shape
,
ndarray
.
dtype
))
return
input_vars
# TODO wrap loss in callable classes
# - same call signaure
# - infer_shape method? or same shape as y_pred (e.g., one hot)
# - split multiple dtype loss functions (e.g., soft label)
def
_infer_label_var
(
self
,
output
,
loss
):
name
=
output
.
name
+
'.label'
shape
=
output
.
shape
# XXX could get ugly very quickly
if
loss
==
'cross_entropy'
:
shape
=
shape
[:
-
1
]
+
(
1
,
)
dtype
=
LOSS_DTYPE_MAP
.
get
(
loss
,
output
.
dtype
)
return
fluid
.
data
(
name
,
shape
,
dtype
)
def
_compile_and_initialize
(
self
,
prog
,
device
=
'CPU'
,
device_ids
=
None
):
if
device
.
lower
()
==
'cpu'
:
place
=
fluid
.
CPUPlace
()
elif
device
.
lower
()
==
'gpu'
and
isinstance
(
device_ids
,
(
list
,
tuple
)):
place
=
fluid
.
CUDAPlace
(
device_ids
[
0
])
else
:
raise
"device not supported"
compiled_prog
=
fluid
.
CompiledProgram
(
prog
)
if
device
.
lower
()
==
'gpu'
and
len
(
device_ids
)
>
0
:
places
=
[
fluid
.
CUDAPlace
(
i
)
for
i
in
device_ids
]
loss_name
=
None
if
self
.
_loss_endpoint
is
not
None
:
loss_name
=
self
.
_loss_endpoint
.
name
compiled_prog
=
compiled_prog
.
with_data_parallel
(
loss_name
=
loss_name
,
places
=
places
)
if
self
.
_executor
is
None
:
self
.
_executor
=
fluid
.
Executor
(
place
)
# XXX only run startup once as *ALL* weights should be initialized
# upon construction of the model
# XXX incremental initialization, lifted from GuoSheng code
uninitialized
=
[]
for
var_py
in
self
.
_startup_prog
.
list_vars
():
var
=
fluid
.
global_scope
().
find_var
(
var_py
.
name
)
if
var
and
var
.
get_tensor
().
_is_initialized
():
continue
uninitialized
.
append
(
var_py
)
if
uninitialized
:
startup_prog
=
self
.
_startup_prog
.
_prune
(
uninitialized
)
self
.
_executor
.
run
(
startup_prog
)
return
compiled_prog
class
DynamicGraphAdapter
(
object
):
def
__init__
(
self
,
model
):
super
(
DynamicGraphAdapter
,
self
).
__init__
()
self
.
model
=
model
@
property
def
mode
(
self
):
return
self
.
model
.
mode
@
mode
.
setter
def
mode
(
self
,
value
):
self
.
model
.
mode
=
value
def
train
(
self
,
inputs
,
labels
,
device
=
'CPU'
,
device_ids
=
None
):
assert
self
.
model
.
_optimizer
and
self
.
model
.
_loss_functions
,
\
"model not ready, please call `model.prepare()` first"
super
(
Model
,
self
.
model
).
train
()
self
.
mode
=
'train'
inputs
=
to_list
(
inputs
)
labels
=
to_list
(
labels
)
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
losses
=
self
.
_loss
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
.
backward
()
self
.
model
.
_optimizer
.
minimize
(
final_loss
)
self
.
model
.
clear_gradients
()
return
losses
def
eval
(
self
,
inputs
,
labels
,
device
=
'CPU'
,
device_ids
=
None
):
assert
self
.
model
.
_loss_functions
,
\
"model not ready, please call `model.prepare()` first"
super
(
Model
,
self
.
model
).
train
()
self
.
mode
=
'eval'
inputs
=
to_list
(
inputs
)
labels
=
to_list
(
labels
)
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
return
self
.
_loss
(
outputs
,
labels
)
def
test
(
self
,
inputs
,
device
=
'CPU'
,
device_ids
=
None
):
super
(
Model
,
self
.
model
).
train
()
self
.
mode
=
'test'
inputs
=
to_list
(
inputs
)
return
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
def
save
(
self
,
path
):
params
=
self
.
model
.
state_dict
()
fluid
.
save_dygraph
(
params
,
path
)
if
self
.
model
.
_optimizer
is
None
:
print
(
"model does not have an optimizer, save parameters only"
)
return
if
self
.
model
.
_optimizer
.
state_dict
():
optim
=
self
.
model
.
_optimizer
.
state_dict
()
fluid
.
save_dygraph
(
optim
,
path
)
def
load
(
self
,
path
):
params
,
optim
=
fluid
.
load_dygraph
(
path
)
self
.
model
.
set_dict
(
params
)
if
optim
is
None
:
print
(
"optimizer state file not found, load parameters only"
)
return
self
.
model
.
_optimizer
.
set_dict
(
optim
)
def
_loss
(
self
,
pred
,
labels
):
losses
=
[]
for
o
,
l
,
t
in
zip
(
to_list
(
pred
),
self
.
model
.
_loss_functions
,
labels
):
if
l
is
None
:
continue
loss_fn
=
getattr
(
fluid
.
layers
,
l
)
loss
=
loss_fn
(
o
,
to_variable
(
t
))
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
return
losses
class
Model
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
(
self
.
__class__
.
__name__
)
self
.
mode
=
'train'
self
.
_loss_functions
=
[]
self
.
_optimizer
=
None
if
in_dygraph_mode
():
self
.
_adapter
=
DynamicGraphAdapter
(
self
)
else
:
self
.
_adapter
=
StaticGraphAdapter
(
self
)
def
train
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
train
(
*
args
,
**
kwargs
)
def
eval
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
eval
(
*
args
,
**
kwargs
)
def
test
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
test
(
*
args
,
**
kwargs
)
def
save
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
def
load
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
load
(
*
args
,
**
kwargs
)
def
prepare
(
self
,
optimizer
,
loss_functions
):
self
.
_optimizer
=
optimizer
self
.
_loss_functions
=
to_list
(
loss_functions
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录