Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
d390d968
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d390d968
编写于
4月 29, 2020
作者:
Q
qingqing01
提交者:
GitHub
4月 29, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #64 from qingqing01/ut_and_doc
Add unit testing and doc.
上级
581b4944
b295ffbb
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
660 addition
and
147 deletion
+660
-147
hapi/callbacks.py
hapi/callbacks.py
+10
-11
hapi/datasets/mnist.py
hapi/datasets/mnist.py
+8
-2
hapi/loss.py
hapi/loss.py
+2
-2
hapi/model.py
hapi/model.py
+426
-63
hapi/tests/test_model.py
hapi/tests/test_model.py
+214
-69
未找到文件。
hapi/callbacks.py
浏览文件 @
d390d968
...
@@ -185,6 +185,9 @@ class ProgBarLogger(Callback):
...
@@ -185,6 +185,9 @@ class ProgBarLogger(Callback):
self
.
verbose
=
verbose
self
.
verbose
=
verbose
self
.
log_freq
=
log_freq
self
.
log_freq
=
log_freq
def
_is_print
(
self
):
return
self
.
verbose
and
ParallelEnv
().
local_rank
==
0
def
on_train_begin
(
self
,
logs
=
None
):
def
on_train_begin
(
self
,
logs
=
None
):
self
.
epochs
=
self
.
params
[
'epochs'
]
self
.
epochs
=
self
.
params
[
'epochs'
]
assert
self
.
epochs
assert
self
.
epochs
...
@@ -195,7 +198,7 @@ class ProgBarLogger(Callback):
...
@@ -195,7 +198,7 @@ class ProgBarLogger(Callback):
self
.
steps
=
self
.
params
[
'steps'
]
self
.
steps
=
self
.
params
[
'steps'
]
self
.
epoch
=
epoch
self
.
epoch
=
epoch
self
.
train_step
=
0
self
.
train_step
=
0
if
self
.
verbose
and
self
.
epochs
and
ParallelEnv
().
local_rank
==
0
:
if
self
.
epochs
and
self
.
_is_print
()
:
print
(
'Epoch %d/%d'
%
(
epoch
+
1
,
self
.
epochs
))
print
(
'Epoch %d/%d'
%
(
epoch
+
1
,
self
.
epochs
))
self
.
train_progbar
=
ProgressBar
(
num
=
self
.
steps
,
verbose
=
self
.
verbose
)
self
.
train_progbar
=
ProgressBar
(
num
=
self
.
steps
,
verbose
=
self
.
verbose
)
...
@@ -213,15 +216,13 @@ class ProgBarLogger(Callback):
...
@@ -213,15 +216,13 @@ class ProgBarLogger(Callback):
logs
=
logs
or
{}
logs
=
logs
or
{}
self
.
train_step
+=
1
self
.
train_step
+=
1
if
self
.
train_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
ParallelEnv
(
if
self
.
_is_print
()
and
self
.
train_step
%
self
.
log_freq
==
0
:
).
local_rank
==
0
:
if
self
.
steps
is
None
or
self
.
train_step
<
self
.
steps
:
if
self
.
steps
is
None
or
self
.
train_step
<
self
.
steps
:
self
.
_updates
(
logs
,
'train'
)
self
.
_updates
(
logs
,
'train'
)
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
logs
=
logs
or
{}
logs
=
logs
or
{}
if
self
.
train_step
%
self
.
log_freq
!=
0
and
self
.
verbose
and
ParallelEnv
(
if
self
.
_is_print
()
and
(
self
.
steps
is
not
None
):
).
local_rank
==
0
:
self
.
_updates
(
logs
,
'train'
)
self
.
_updates
(
logs
,
'train'
)
def
on_eval_begin
(
self
,
logs
=
None
):
def
on_eval_begin
(
self
,
logs
=
None
):
...
@@ -231,7 +232,7 @@ class ProgBarLogger(Callback):
...
@@ -231,7 +232,7 @@ class ProgBarLogger(Callback):
self
.
evaled_samples
=
0
self
.
evaled_samples
=
0
self
.
eval_progbar
=
ProgressBar
(
self
.
eval_progbar
=
ProgressBar
(
num
=
self
.
eval_steps
,
verbose
=
self
.
verbose
)
num
=
self
.
eval_steps
,
verbose
=
self
.
verbose
)
if
ParallelEnv
().
local_rank
==
0
:
if
self
.
_is_print
()
:
print
(
'Eval begin...'
)
print
(
'Eval begin...'
)
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
...
@@ -240,16 +241,14 @@ class ProgBarLogger(Callback):
...
@@ -240,16 +241,14 @@ class ProgBarLogger(Callback):
samples
=
logs
.
get
(
'batch_size'
,
1
)
samples
=
logs
.
get
(
'batch_size'
,
1
)
self
.
evaled_samples
+=
samples
self
.
evaled_samples
+=
samples
if
self
.
eval_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
ParallelEnv
(
if
self
.
_is_print
()
and
self
.
eval_step
%
self
.
log_freq
==
0
:
).
local_rank
==
0
:
if
self
.
eval_steps
is
None
or
self
.
eval_step
<
self
.
eval_steps
:
if
self
.
eval_steps
is
None
or
self
.
eval_step
<
self
.
eval_steps
:
self
.
_updates
(
logs
,
'eval'
)
self
.
_updates
(
logs
,
'eval'
)
def
on_eval_end
(
self
,
logs
=
None
):
def
on_eval_end
(
self
,
logs
=
None
):
logs
=
logs
or
{}
logs
=
logs
or
{}
if
self
.
verbose
and
ParallelEnv
().
local_rank
==
0
:
if
self
.
_is_print
()
and
(
self
.
steps
is
not
None
):
if
self
.
eval_step
%
self
.
log_freq
!=
0
:
self
.
_updates
(
logs
,
'eval'
)
self
.
_updates
(
logs
,
'eval'
)
print
(
'Eval samples: %d'
%
(
self
.
evaled_samples
))
print
(
'Eval samples: %d'
%
(
self
.
evaled_samples
))
...
...
hapi/datasets/mnist.py
浏览文件 @
d390d968
...
@@ -45,6 +45,8 @@ class MNIST(Dataset):
...
@@ -45,6 +45,8 @@ class MNIST(Dataset):
:attr:`download` is True. Default None
:attr:`download` is True. Default None
label_path(str): path to label file, can be set None if
label_path(str): path to label file, can be set None if
:attr:`download` is True. Default None
:attr:`download` is True. Default None
chw_format(bool): If set True, the output shape is [1, 28, 28],
otherwise, output shape is [1, 784]. Default True.
mode(str): 'train' or 'test' mode. Default 'train'.
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
:attr:`image_path`/:attr:`label_path` unset. Default
...
@@ -70,13 +72,14 @@ class MNIST(Dataset):
...
@@ -70,13 +72,14 @@ class MNIST(Dataset):
def
__init__
(
self
,
def
__init__
(
self
,
image_path
=
None
,
image_path
=
None
,
label_path
=
None
,
label_path
=
None
,
chw_format
=
True
,
mode
=
'train'
,
mode
=
'train'
,
transform
=
None
,
transform
=
None
,
download
=
True
):
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train' or 'test', but got {}"
.
format
(
mode
)
"mode should be 'train' or 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
mode
=
mode
.
lower
()
self
.
chw_format
=
chw_format
self
.
image_path
=
image_path
self
.
image_path
=
image_path
if
self
.
image_path
is
None
:
if
self
.
image_path
is
None
:
assert
download
,
"image_path not set and auto download disabled"
assert
download
,
"image_path not set and auto download disabled"
...
@@ -144,10 +147,13 @@ class MNIST(Dataset):
...
@@ -144,10 +147,13 @@ class MNIST(Dataset):
for
i
in
range
(
buffer_size
):
for
i
in
range
(
buffer_size
):
self
.
images
.
append
(
images
[
i
,
:])
self
.
images
.
append
(
images
[
i
,
:])
self
.
labels
.
append
(
np
.
array
([
labels
[
i
]]).
astype
(
'int64'
))
self
.
labels
.
append
(
np
.
array
([
labels
[
i
]]).
astype
(
'int64'
))
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
image
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
image
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
if
self
.
chw_format
:
image
=
np
.
reshape
(
image
,
[
1
,
28
,
28
])
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
image
)
image
=
self
.
transform
(
image
)
return
image
,
label
return
image
,
label
...
...
hapi/loss.py
浏览文件 @
d390d968
...
@@ -66,7 +66,7 @@ class CrossEntropy(Loss):
...
@@ -66,7 +66,7 @@ class CrossEntropy(Loss):
"""
"""
def
__init__
(
self
,
average
=
True
):
def
__init__
(
self
,
average
=
True
):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
(
average
)
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
[
return
[
...
@@ -88,7 +88,7 @@ class SoftmaxWithCrossEntropy(Loss):
...
@@ -88,7 +88,7 @@ class SoftmaxWithCrossEntropy(Loss):
"""
"""
def
__init__
(
self
,
average
=
True
):
def
__init__
(
self
,
average
=
True
):
super
(
SoftmaxWithCrossEntropy
,
self
).
__init__
()
super
(
SoftmaxWithCrossEntropy
,
self
).
__init__
(
average
)
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
[
return
[
...
...
hapi/model.py
浏览文件 @
d390d968
...
@@ -639,7 +639,49 @@ class DynamicGraphAdapter(object):
...
@@ -639,7 +639,49 @@ class DynamicGraphAdapter(object):
class
Model
(
fluid
.
dygraph
.
Layer
):
class
Model
(
fluid
.
dygraph
.
Layer
):
"""
"""
FIXME: add more comments and usage
An Model object is network with training and inference features.
Dynamic graph and static graph are supported at the same time,
switched by `fluid.enable_dygraph()`. The usage is as follows.
But note, the switching between dynamic and static should be before
instantiating a Model. The input description, i.e, hapi.Input,
must be required for static graph.
Usage:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
#import paddle.incubate.hapi as hapi
from hapi import Model, Input, set_device
from hapi.loss import CrossEntropy
from hapi.dataset import MNIST
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 10, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('gpu')
# if use static graph, do not set
fluid.enable_dygraph(device)
model = MyModel()
optim = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
inputs = [Input([None, 784], 'float32', name='x')]
labels = [Input([None, 1], 'int64', name='label')]
mnist_data = MNIST(mode='train')
model.prepare(optim,
CrossEntropy(average=True),
hapi.metrics.Accuracy(),
inputs,
labels,
device=device)
model.fit(mnist_data, epochs=2, batch_size=32, verbose=1)
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -660,18 +702,195 @@ class Model(fluid.dygraph.Layer):
...
@@ -660,18 +702,195 @@ class Model(fluid.dygraph.Layer):
else
:
else
:
self
.
_adapter
=
StaticGraphAdapter
(
self
)
self
.
_adapter
=
StaticGraphAdapter
(
self
)
def
train_batch
(
self
,
*
args
,
**
kwargs
):
def
train_batch
(
self
,
inputs
,
labels
=
None
):
return
self
.
_adapter
.
train_batch
(
*
args
,
**
kwargs
)
"""
Run one training step on a batch of data.
Args:
inputs (list): A list of numpy.ndarray, each is a batch of
input data.
labels (list): A list of numpy.ndarray, each is a batch of
input label. If has no labels, set None. Default is None.
Returns:
A list of scalar training loss if the model has no metrics,
or a tuple (list of scalar loss, list of metrics) if the model
set metrics.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from hapi import Model, Input, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = hapi.set_device('gpu')
fluid.enable_dygraph(device)
model = MyModel()
optim = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
inputs = [Input([None, 784], 'float32', name='x')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim,
CrossEntropy(average=True),
inputs=inputs,
labels=labels,
device=device)
data = np.random.random(size=(4,784)).astype(np.float32)
label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64)
loss = model.train_batch([data], [label])
print(loss)
"""
return
self
.
_adapter
.
train_batch
(
inputs
,
labels
)
def
eval_batch
(
self
,
inputs
,
labels
=
None
):
"""
Run one evaluating step on a batch of data.
Args:
inputs (list): A list of numpy.ndarray, each is a batch of
input data.
labels (list): A list of numpy.ndarray, each is a batch of
input label. If has no labels, set None. Default is None.
Returns:
A list of scalar testing loss if the model has no metrics,
or a tuple (list of scalar loss, list of metrics) if the model
set metrics.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from hapi import Model, Input, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('gpu')
fluid.enable_dygraph(device)
model = MyModel()
optim = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
inputs = [Input([None, 784], 'float32', name='x')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim,
CrossEntropy(average=True),
inputs=inputs,
labels=labels,
device=device)
data = np.random.random(size=(4,784)).astype(np.float32)
label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64)
loss = model.eval_batch([data], [label])
print(loss)
"""
return
self
.
_adapter
.
eval_batch
(
inputs
,
labels
)
def
test_batch
(
self
,
inputs
):
"""
Run one testing step on a batch of data.
Args:
inputs (list): A list of numpy.ndarray, each is a batch of
input data.
Returns:
A list of numpy.ndarray of predictions, that is the outputs
of Model forward.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from hapi import Model, Input, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('gpu')
fluid.enable_dygraph(device)
model = MyModel()
inputs = [Input([None, 784], 'float32', name='x')]
model.prepare(inputs=inputs,
device=device)
data = np.random.random(size=(4,784)).astype(np.float32)
out = model.eval_batch([data])
print(out)
"""
return
self
.
_adapter
.
test_batch
(
inputs
)
def
save
(
self
,
path
):
"""
This function saves parameters, optimizer infomation to path.
The parameters contains all the trainable Variable, will save to
a file with suffix ".pdparams".
The optimizer information contains all the variable used by optimizer.
For Adam optimizer, contains beta1, beta2, momentum etc. All the
information will save to a file with suffix ".pdopt". (If the optimizer
have no variable need to save (like SGD), the fill will not generated).
def
eval_batch
(
self
,
*
args
,
**
kwargs
):
This function will silently overwrite existing file
return
self
.
_adapter
.
eval_batch
(
*
args
,
**
kwargs
)
at the target location.
def
test_batch
(
self
,
*
args
,
**
kwargs
):
Args:
return
self
.
_adapter
.
test_batch
(
*
args
,
**
kwargs
)
path (str): The file prefix to save model. The format is
'dirname/file_prefix' or 'file_prefix'. if empty str. A exception
will be raised.
def
save
(
self
,
*
args
,
**
kwargs
):
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
from hapi import Model, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('cpu')
fluid.enable_dygraph(device)
model = MyModel()
model.save('checkpoint/test')
"""
if
ParallelEnv
().
local_rank
==
0
:
if
ParallelEnv
().
local_rank
==
0
:
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
self
.
_adapter
.
save
(
path
)
def
load
(
self
,
path
,
skip_mismatch
=
False
,
reset_optimizer
=
False
):
def
load
(
self
,
path
,
skip_mismatch
=
False
,
reset_optimizer
=
False
):
"""
"""
...
@@ -698,6 +917,29 @@ class Model(fluid.dygraph.Layer):
...
@@ -698,6 +917,29 @@ class Model(fluid.dygraph.Layer):
optimizer states and initialize optimizer states from scratch.
optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if
Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False.
a optimizer has been set to the model. Default False.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
from hapi import Model, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('cpu')
fluid.enable_dygraph(device)
model = MyModel()
model.load('checkpoint/test')
"""
"""
def
_load_state_from_path
(
path
):
def
_load_state_from_path
(
path
):
...
@@ -747,7 +989,31 @@ class Model(fluid.dygraph.Layer):
...
@@ -747,7 +989,31 @@ class Model(fluid.dygraph.Layer):
return
self
.
_adapter
.
load
(
matched_param_state
,
optim_state
)
return
self
.
_adapter
.
load
(
matched_param_state
,
optim_state
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
def
parameters
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
parameters
(
*
args
,
**
kwargs
)
"""
Returns a list of parameters of the model.
Returns:
A list of Parameter in static graph.
A list of ParamBase in dynamic graph.
Examples:
.. code-block:: python
from hapi.model import Model, Input, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(20, 10, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
fluid.enable_dygraph()
model = MyModel()
params = model.parameters()
"""
return
self
.
_adapter
.
parameters
()
def
prepare
(
self
,
def
prepare
(
self
,
optimizer
=
None
,
optimizer
=
None
,
...
@@ -757,27 +1023,32 @@ class Model(fluid.dygraph.Layer):
...
@@ -757,27 +1023,32 @@ class Model(fluid.dygraph.Layer):
labels
=
None
,
labels
=
None
,
device
=
None
):
device
=
None
):
"""
"""
FIXME: add comments
Configures the model before runing.
Args:
Args:
optimizer (Optimizer|None):
o
ptimizer must be set in training
optimizer (Optimizer|None):
O
ptimizer must be set in training
and should be a Optimizer instance. It can be None in eval
and should be a Optimizer instance. It can be None in eval
and test mode.
and test mode.
loss_function (Loss|None):
l
oss function must be set in training
loss_function (Loss|None):
L
oss function must be set in training
and should be a Loss instance. It can be None when there is
and should be a Loss instance. It can be None when there is
no loss.
no loss.
metrics (Metric|list of Metric|None):
i
f metrics is set, all
metrics (Metric|list of Metric|None):
I
f metrics is set, all
metric
will be calculate
and output in train/eval mode.
metric
s will be calculated
and output in train/eval mode.
inputs (Input|list|dict|None):
inputs
, entry points of network,
inputs (Input|list|dict|None):
`inputs`
, entry points of network,
could be a Input layer, or lits of Input layers,
could be a Input layer, or lits of Input layers,
or dict (name: Input), or None. For static graph,
or dict (name: Input), or None. For static graph,
inputs must be set. For dynamic graph, it could be None.
inputs must be set. For dynamic graph, it could be None.
labels (Input|list|None):
labels
, entry points of network,
labels (Input|list|None):
`labels`
, entry points of network,
could be a Input layer or lits of Input layers, or None.
could be a Input layer or lits of Input layers, or None.
For static graph, if set loss_function in Model.prepare(), it
For static graph, if labels is required in loss_function,
must be set. Otherwise, it could be None.
labels must be set. Otherwise, it could be None.
device (str|None): specify device type, 'CPU' or 'GPU'.
device (str|fluid.CUDAPlace|fluid.CPUPlace|None): Specify device
type, 'CPU', 'GPU', fluid.CUDAPlace or fluid.CPUPlace.
If None, automatically select device according to
If None, automatically select device according to
installation package version.
installation package version.
Returns:
None
"""
"""
if
isinstance
(
device
,
fluid
.
CUDAPlace
)
or
\
if
isinstance
(
device
,
fluid
.
CUDAPlace
)
or
\
...
@@ -859,7 +1130,9 @@ class Model(fluid.dygraph.Layer):
...
@@ -859,7 +1130,9 @@ class Model(fluid.dygraph.Layer):
num_workers
=
0
,
num_workers
=
0
,
callbacks
=
None
,
):
callbacks
=
None
,
):
"""
"""
FIXME: add more comments and usage
Trains the model for a fixed number of epochs. If `eval_data` is set,
evaluation will be done at the end of each epoch.
Args:
Args:
train_data (Dataset|DataLoader): An iterable data loader is used for
train_data (Dataset|DataLoader): An iterable data loader is used for
train. An instance of paddle paddle.io.Dataset or
train. An instance of paddle paddle.io.Dataset or
...
@@ -868,30 +1141,117 @@ class Model(fluid.dygraph.Layer):
...
@@ -868,30 +1141,117 @@ class Model(fluid.dygraph.Layer):
evaluation at the end of epoch. If None, will not do evaluation.
evaluation at the end of epoch. If None, will not do evaluation.
An instance of paddle.io.Dataset or paddle.io.Dataloader
An instance of paddle.io.Dataset or paddle.io.Dataloader
is recomended. Default: None.
is recomended. Default: None.
batch_size (int): Integer number. The batch size of train_data and eval_data.
batch_size (int): Integer number. The batch size of train_data
When train_data and eval_data are both the instance of Dataloader, this
and eval_data. When train_data and eval_data are both the
parameter will be ignored. Default: 1.
instance of Dataloader, this parameter will be ignored.
epochs (int): Integer number. The number of epochs to train the model. Default: 1.
Default: 1.
epochs (int): Integer number. The number of epochs to train
the model. Default: 1.
eval_freq (int): The frequency, in number of epochs, an evalutation
eval_freq (int): The frequency, in number of epochs, an evalutation
is performed. Default: 1.
is performed. Default: 1.
log_freq (int): The frequency, in number of steps, the training logs
log_freq (int): The frequency, in number of steps, the training logs
are printed. Default: 10.
are printed. Default: 10.
save_dir(str|None): The directory to save checkpoint during training.
save_dir(str|None): The directory to save checkpoint during training.
If None, will not save checkpoint. Default: None.
If None, will not save checkpoint. Default: None.
save_freq (int): The frequency, in number of epochs, to save checkpoint. Default: 1.
save_freq (int): The frequency, in number of epochs, to save
verbose (int): The verbosity mode, should be 0, 1, or 2.
checkpoint. Default: 1.
0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2.
verbose (int): The verbosity mode, should be 0, 1, or 2. 0 = silent,
drop_last (bool): whether drop the last incomplete batch of train_data
1 = progress bar, 2 = one line per epoch. Default: 2.
when dataset size is not divisible by the batch size. When train_data
drop_last (bool): Whether drop the last incomplete batch of
is an instance of Dataloader, this parameter will be ignored. Default: False.
train_data when dataset size is not divisible by the batch size.
shuffle (bool): whther to shuffle train_data. When train_data is an instance
When train_data is an instance of Dataloader, this parameter
of Dataloader, this parameter will be ignored. Default: True.
will be ignored. Default: False.
num_workers (int): the number of subprocess to load data, 0 for no subprocess
shuffle (bool): Whther to shuffle train_data. When train_data is
used and loading data in main process. When train_data and eval_data are
an instance of Dataloader, this parameter will be ignored.
both the instance of Dataloader, this parameter will be ignored. Default: 0.
Default: True.
num_workers (int): The number of subprocess to load data, 0 for no
subprocess used and loading data in main process.
When train_data and eval_data are both the instance of
Dataloader, this parameter will be ignored. Default: 0.
callbacks (Callback|None): A list of `Callback` instances to apply
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None.
are automatically inserted. Default: None.
Returns:
None
Examples:
1. An example use Dataset and set btch size, shuffle in fit.
How to make a batch is done internally.
.. code-block:: python
from hapi.model import Model, Input, set_device
from hapi.loss import CrossEntropy
from hapi.metrics import Accuracy
from hapi.datasets import MNIST
from hapi.vision.models import LeNet
dynamic = True
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if dynamic else None
train_dataset = MNIST(mode='train')
val_dataset = MNIST(mode='test')
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model = LeNet()
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(
optim,
CrossEntropy(),
Accuracy(topk=(1, 2)),
inputs=inputs,
labels=labels,
device=device)
model.fit(train_dataset,
val_dataset,
epochs=2,
batch_size=64,
save_dir='mnist_checkpoint')
2. An example use DataLoader, batch size and shuffle is set in
DataLoader.
.. code-block:: python
from hapi.model import Model, Input, set_device
from hapi.loss import CrossEntropy
from hapi.metrics import Accuracy
from hapi.datasets import MNIST
from hapi.vision.models import LeNet
dynamic = True
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if dynamic else None
train_dataset = MNIST(mode='train')
train_loader = fluid.io.DataLoader(train_dataset,
places=device, batch_size=64)
val_dataset = MNIST(mode='test')
val_loader = fluid.io.DataLoader(val_dataset,
places=device, batch_size=64)
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model = LeNet()
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(
optim,
CrossEntropy(),
Accuracy(topk=(1, 2)),
inputs=inputs,
labels=labels,
device=device)
model.fit(train_loader,
val_loader,
epochs=2,
save_dir='mnist_checkpoint')
"""
"""
assert
train_data
is
not
None
,
\
assert
train_data
is
not
None
,
\
...
@@ -986,26 +1346,29 @@ class Model(fluid.dygraph.Layer):
...
@@ -986,26 +1346,29 @@ class Model(fluid.dygraph.Layer):
num_workers
=
0
,
num_workers
=
0
,
callbacks
=
None
,
):
callbacks
=
None
,
):
"""
"""
FIXME: add more comments and usage
Evaluate the loss and metrics of the model on input dataset.
Args:
Args:
eval_data (Dataset|DataLoader): An iterable data loader is used for
eval_data (Dataset|DataLoader): An iterable data loader is used for
evaluation. An instance of paddle.io.Dataset or
evaluation. An instance of paddle.io.Dataset or
paddle.io.Dataloader is recomended.
paddle.io.Dataloader is recomended.
batch_size (int): Integer number. The batch size of train_data
and eval_data.
batch_size (int): Integer number. The batch size of train_data
When eval_data is the instance of Dataloader, this argument will be ignored.
and eval_data. When eval_data is the instance of Dataloader,
Default: 1.
this argument will be ignored.
Default: 1.
log_freq (int): The frequency, in number of steps, the eval logs
log_freq (int): The frequency, in number of steps, the eval logs
are printed. Default: 10.
are printed. Default: 10.
verbose (int): The verbosity mode, should be 0, 1, or 2.
verbose (int): The verbosity mode, should be 0, 1, or 2. 0 = silent,
0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2.
1 = progress bar, 2 = one line per epoch. Default: 2.
num_workers (int): The number of subprocess to load data, 0 for no subprocess
num_workers (int): The number of subprocess to load data,
used and loading data in main process. When train_data and eval_data are
0 for no subprocess used and loading data in main process. When
both the instance of Dataloader, this parameter will be ignored. Default: 0.
train_data and eval_data are both the instance of Dataloader,
this parameter will be ignored. Default: 0.
callbacks (Callback|None): A list of `Callback` instances to apply
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None.
are automatically inserted. Default: None.
Returns:
Returns:
dict: Result of metric.
dict: Result of metric. The key is the names of Metric,
value is a scalar or numpy.array.
"""
"""
if
fluid
.
in_dygraph_mode
():
if
fluid
.
in_dygraph_mode
():
...
@@ -1063,18 +1426,19 @@ class Model(fluid.dygraph.Layer):
...
@@ -1063,18 +1426,19 @@ class Model(fluid.dygraph.Layer):
num_workers
=
0
,
num_workers
=
0
,
stack_outputs
=
False
):
stack_outputs
=
False
):
"""
"""
FIXME: add more comments and usage
Compute the output predictions on testing data.
Args:
Args:
test_data (Dataset|DataLoader): An iterable data loader is used for
test_data (Dataset|DataLoader): An iterable data loader is used for
predict. An instance of paddle.io.Dataset or paddle.io.Dataloader
predict. An instance of paddle.io.Dataset or paddle.io.Dataloader
is recomended.
is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
When train_data and eval_data are both the instance of Dataloader, this
argument will be ignored. Default: 1.
argument will be ignored. Default: 1.
num_workers (int):
t
he number of subprocess to load data, 0 for no subprocess
num_workers (int):
T
he number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this argument will be ignored. Default: 0.
both the instance of Dataloader, this argument will be ignored. Default: 0.
stack_output (bool):
w
hether stack output field like a batch, as for an output
stack_output (bool):
W
hether stack output field like a batch, as for an output
filed of a sample is in shape [X, Y], test_data contains N samples, predict
filed of a sample is in shape [X, Y], test_data contains N samples, predict
output field will be in shape [N, X, Y] if stack_output is True, and will
output field will be in shape [N, X, Y] if stack_output is True, and will
be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
...
@@ -1138,21 +1502,20 @@ class Model(fluid.dygraph.Layer):
...
@@ -1138,21 +1502,20 @@ class Model(fluid.dygraph.Layer):
save_dir
,
save_dir
,
model_filename
=
None
,
model_filename
=
None
,
params_filename
=
None
,
params_filename
=
None
,
program
_only
=
False
):
model
_only
=
False
):
"""
"""
Save inference model must in static mode.
Save inference model must in static mode.
Args:
Args:
dirname(str): The directory path to save the inference model.
dirname(str): The directory path to save the inference model.
model_filename(str|None): The name of file to save the inference program
model_filename(str|None): The name of file to save the inference
itself. If is set None, a default filename
model itself. If is set None, a default filename
:code:`__model__` will be used.
:code:`__model__` will be used.
params_filename(str|None): The name of file to save all related parameters.
params_filename(str|None): The name of file to save all related
If it is set None, parameters will be saved
parameters. If it is set None, parameters will be saved
in separate files .
in separate files .
program_only(bool): If True, It will save inference program only, and do not
model_only(bool): If True, It will save inference model only,
save params of Program.
and do not save parameters. Default: False.
Default: False.
Returns:
Returns:
list: The fetch variables' name list
list: The fetch variables' name list
...
@@ -1177,7 +1540,7 @@ class Model(fluid.dygraph.Layer):
...
@@ -1177,7 +1540,7 @@ class Model(fluid.dygraph.Layer):
main_program
=
infer_prog
,
main_program
=
infer_prog
,
model_filename
=
model_filename
,
model_filename
=
model_filename
,
params_filename
=
params_filename
,
params_filename
=
params_filename
,
program_only
=
program
_only
)
program_only
=
model
_only
)
def
_run_one_epoch
(
self
,
def
_run_one_epoch
(
self
,
data_loader
,
data_loader
,
...
...
hapi/tests/test_model.py
浏览文件 @
d390d968
...
@@ -18,27 +18,25 @@ from __future__ import print_function
...
@@ -18,27 +18,25 @@ from __future__ import print_function
import
unittest
import
unittest
import
os
import
os
import
cv2
import
numpy
as
np
import
numpy
as
np
import
tempfile
import
shutil
import
shutil
import
tempfile
import
paddle
import
paddle
from
paddle
import
fluid
from
paddle
import
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.io
import
BatchSampler
,
DataLoader
from
paddle.io
import
DataLoader
from
paddle.fluid.dygraph.base
import
to_variable
from
hapi.model
import
Model
,
Input
,
set_device
from
hapi.model
import
Model
,
Input
,
set_device
from
hapi.loss
import
Loss
from
hapi.loss
import
CrossEntropy
from
hapi.metrics
import
Accuracy
from
hapi.metrics
import
Accuracy
from
hapi.datasets
import
MNIST
from
hapi.datasets
import
MNIST
from
hapi.vision.models
import
LeNet
from
hapi.vision.models
import
LeNet
from
hapi.download
import
get_weights_path_from_url
class
LeNetDygraph
(
fluid
.
dygraph
.
Layer
):
class
LeNetDygraph
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
super
(
LeNetDygraph
,
self
).
__init__
()
super
(
LeNetDygraph
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
...
@@ -67,12 +65,16 @@ class LeNetDygraph(fluid.dygraph.Layer):
...
@@ -67,12 +65,16 @@ class LeNetDygraph(fluid.dygraph.Layer):
class
MnistDataset
(
MNIST
):
class
MnistDataset
(
MNIST
):
def
__init__
(
self
,
mode
,
return_label
=
True
):
def
__init__
(
self
,
mode
,
return_label
=
True
,
sample_num
=
None
):
super
(
MnistDataset
,
self
).
__init__
(
mode
=
mode
)
super
(
MnistDataset
,
self
).
__init__
(
mode
=
mode
)
self
.
return_label
=
return_label
self
.
return_label
=
return_label
if
sample_num
:
self
.
images
=
self
.
images
[:
sample_num
]
self
.
labels
=
self
.
labels
[:
sample_num
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
img
=
np
.
reshape
(
self
.
images
[
idx
],
[
1
,
28
,
28
])
img
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
img
=
np
.
reshape
(
img
,
[
1
,
28
,
28
])
if
self
.
return_label
:
if
self
.
return_label
:
return
img
,
np
.
array
(
self
.
labels
[
idx
]).
astype
(
'int64'
)
return
img
,
np
.
array
(
self
.
labels
[
idx
]).
astype
(
'int64'
)
return
img
,
return
img
,
...
@@ -81,15 +83,14 @@ class MnistDataset(MNIST):
...
@@ -81,15 +83,14 @@ class MnistDataset(MNIST):
return
len
(
self
.
images
)
return
len
(
self
.
images
)
def
get_predict_accuracy
(
pred
,
gt
):
def
compute_acc
(
pred
,
label
):
pred
=
np
.
argmax
(
pred
,
-
1
)
pred
=
np
.
argmax
(
pred
,
-
1
)
gt
=
np
.
array
(
gt
)
label
=
np
.
array
(
label
)
correct
=
pred
[:,
np
.
newaxis
]
==
label
correct
=
pred
[:,
np
.
newaxis
]
==
gt
return
np
.
sum
(
correct
)
/
correct
.
shape
[
0
]
return
np
.
sum
(
correct
)
/
correct
.
shape
[
0
]
def
low_level_lenet_dygraph
_train
(
model
,
dataloader
):
def
dynamic
_train
(
model
,
dataloader
):
optim
=
fluid
.
optimizer
.
Adam
(
optim
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
model
.
train
()
model
.
train
()
...
@@ -102,7 +103,7 @@ def low_level_lenet_dygraph_train(model, dataloader):
...
@@ -102,7 +103,7 @@ def low_level_lenet_dygraph_train(model, dataloader):
model
.
clear_gradients
()
model
.
clear_gradients
()
def
low_level_
dynamic_evaluate
(
model
,
dataloader
):
def
dynamic_evaluate
(
model
,
dataloader
):
with
fluid
.
dygraph
.
no_grad
():
with
fluid
.
dygraph
.
no_grad
():
model
.
eval
()
model
.
eval
()
cnt
=
0
cnt
=
0
...
@@ -115,56 +116,65 @@ def low_level_dynamic_evaluate(model, dataloader):
...
@@ -115,56 +116,65 @@ def low_level_dynamic_evaluate(model, dataloader):
return
cnt
/
len
(
dataloader
.
dataset
)
return
cnt
/
len
(
dataloader
.
dataset
)
class
TestEvaluatePredict
(
unittest
.
TestCase
):
class
TestModel
(
unittest
.
TestCase
):
def
setUp
(
self
):
@
classmethod
self
.
device
=
set_device
(
'gpu'
)
def
setUpClass
(
cls
):
self
.
train_dataset
=
MnistDataset
(
mode
=
'train'
)
cls
.
device
=
set_device
(
'gpu'
)
self
.
val_dataset
=
MnistDataset
(
mode
=
'test'
)
fluid
.
enable_dygraph
(
cls
.
device
)
self
.
test_dataset
=
MnistDataset
(
mode
=
'test'
,
return_label
=
False
)
fluid
.
enable_dygraph
(
self
.
device
)
train_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
train_dataset
,
places
=
self
.
device
,
batch_size
=
64
)
val_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
)
self
.
lenet_dygraph
=
LeNetDygraph
()
low_level_lenet_dygraph_train
(
self
.
lenet_dygraph
,
train_dataloader
)
self
.
acc1
=
low_level_dynamic_evaluate
(
self
.
lenet_dygraph
,
val_dataloader
)
self
.
save_dir
=
tempfile
.
mkdtemp
()
self
.
weight_path
=
os
.
path
.
join
(
self
.
save_dir
,
'lenet'
)
fluid
.
dygraph
.
save_dygraph
(
self
.
lenet_dygraph
.
state_dict
(),
self
.
weight_path
)
fluid
.
disable_dygraph
()
sp_num
=
1280
cls
.
train_dataset
=
MnistDataset
(
mode
=
'train'
,
sample_num
=
sp_num
)
cls
.
val_dataset
=
MnistDataset
(
mode
=
'test'
,
sample_num
=
sp_num
)
cls
.
test_dataset
=
MnistDataset
(
mode
=
'test'
,
return_label
=
False
,
sample_num
=
sp_num
)
def
tearDown
(
self
):
cls
.
train_loader
=
fluid
.
io
.
DataLoader
(
shutil
.
rmtree
(
self
.
save_dir
)
cls
.
train_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
cls
.
val_loader
=
fluid
.
io
.
DataLoader
(
cls
.
val_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
cls
.
test_loader
=
fluid
.
io
.
DataLoader
(
cls
.
test_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
def
evaluate
(
self
,
dynamic
):
seed
=
333
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
inputs
=
[
Input
([
-
1
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
dy_lenet
=
LeNetDygraph
()
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
cls
.
init_param
=
dy_lenet
.
state_dict
()
dynamic_train
(
dy_lenet
,
cls
.
train_loader
)
val_dataloader
=
fluid
.
io
.
DataLoader
(
cls
.
acc1
=
dynamic_evaluate
(
dy_lenet
,
cls
.
val_loader
)
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
return_list
=
True
)
model
=
LeNet
()
cls
.
inputs
=
[
Input
([
-
1
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
cls
.
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
.
load
(
self
.
weight_path
)
cls
.
save_dir
=
tempfile
.
mkdtemp
()
cls
.
weight_path
=
os
.
path
.
join
(
cls
.
save_dir
,
'lenet'
)
fluid
.
dygraph
.
save_dygraph
(
dy_lenet
.
state_dict
(),
cls
.
weight_path
)
model
.
prepare
(
metrics
=
Accuracy
(),
inputs
=
inputs
,
labels
=
labels
)
fluid
.
disable_dygraph
(
)
result
=
model
.
evaluate
(
val_dataloader
)
@
classmethod
def
tearDownClass
(
cls
):
shutil
.
rmtree
(
cls
.
save_dir
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
def
test_fit_dygraph
(
self
):
self
.
fit
(
True
)
if
fluid
.
in_dygraph_mode
():
def
test_fit_static
(
self
):
fluid
.
disable_dygraph
()
self
.
fit
(
False
)
def
test_evaluate_dygraph
(
self
):
self
.
evaluate
(
True
)
def
test_evaluate_static
(
self
):
self
.
evaluate
(
False
)
def
test_predict_dygraph
(
self
):
self
.
predict
(
True
)
def
test_predict_static
(
self
):
self
.
predict
(
False
)
def
predict
(
self
,
dynamic
):
def
predict
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
...
@@ -186,26 +196,161 @@ class TestEvaluatePredict(unittest.TestCase):
...
@@ -186,26 +196,161 @@ class TestEvaluatePredict(unittest.TestCase):
output
=
model
.
predict
(
test_dataloader
,
stack_outputs
=
True
)
output
=
model
.
predict
(
test_dataloader
,
stack_outputs
=
True
)
np
.
testing
.
assert_equal
(
output
[
0
].
shape
[
0
],
len
(
self
.
test_dataset
))
def
fit
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
acc
=
get_predict_accuracy
(
output
[
0
],
self
.
val_dataset
.
labels
)
seed
=
333
fluid
.
default_startup_program
().
random_seed
=
seed
np
.
testing
.
assert_allclose
(
acc
,
self
.
acc1
)
fluid
.
default_main_program
().
random_seed
=
seed
if
fluid
.
in_dygraph_mode
():
fluid
.
disable_dygraph
()
def
test_evaluate_dygraph
(
self
):
model
=
LeNet
()
self
.
evaluate
(
True
)
optim_new
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optim_new
,
loss_function
=
CrossEntropy
(
average
=
False
),
metrics
=
Accuracy
(),
inputs
=
self
.
inputs
,
labels
=
self
.
labels
)
model
.
fit
(
self
.
train_dataset
,
batch_size
=
64
,
shuffle
=
False
)
result
=
model
.
evaluate
(
self
.
val_dataset
,
batch_size
=
64
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_evaluate_static
(
self
):
def
evaluate
(
self
,
dynamic
):
self
.
evaluate
(
False
)
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
model
=
LeNet
()
model
.
prepare
(
metrics
=
Accuracy
(),
inputs
=
self
.
inputs
,
labels
=
self
.
labels
)
model
.
load
(
self
.
weight_path
)
result
=
model
.
evaluate
(
self
.
val_dataset
,
batch_size
=
64
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_predict_dygraph
(
self
):
def
predict
(
self
,
dynamic
):
self
.
predict
(
True
)
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
model
=
LeNet
()
model
.
prepare
(
inputs
=
self
.
inputs
)
model
.
load
(
self
.
weight_path
)
output
=
model
.
predict
(
self
.
test_dataset
,
batch_size
=
64
,
stack_outputs
=
True
)
np
.
testing
.
assert_equal
(
output
[
0
].
shape
[
0
],
len
(
self
.
test_dataset
))
def
test_predict_static
(
self
):
acc
=
compute_acc
(
output
[
0
],
self
.
val_dataset
.
labels
)
self
.
predict
(
False
)
np
.
testing
.
assert_allclose
(
acc
,
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
class
MyModel
(
Model
):
def
__init__
(
self
):
super
(
MyModel
,
self
).
__init__
()
self
.
_fc
=
Linear
(
20
,
10
,
act
=
'softmax'
)
def
forward
(
self
,
x
):
y
=
self
.
_fc
(
x
)
return
y
class
TestModelFunction
(
unittest
.
TestCase
):
def
set_seed
(
self
,
seed
=
1024
):
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
def
test_train_batch
(
self
,
dynamic
=
True
):
dim
=
20
data
=
np
.
random
.
random
(
size
=
(
4
,
dim
)).
astype
(
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
size
=
(
4
,
1
)).
astype
(
np
.
int64
)
def
get_expect
():
fluid
.
enable_dygraph
(
fluid
.
CPUPlace
())
self
.
set_seed
()
m
=
MyModel
()
optim
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameter_list
=
m
.
parameters
())
m
.
train
()
output
=
m
(
to_variable
(
data
))
l
=
to_variable
(
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
output
,
l
)
avg_loss
=
fluid
.
layers
.
reduce_sum
(
loss
)
avg_loss
.
backward
()
optim
.
minimize
(
avg_loss
)
m
.
clear_gradients
()
fluid
.
disable_dygraph
()
return
avg_loss
.
numpy
()
ref
=
get_expect
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
self
.
set_seed
()
model
=
MyModel
()
optim2
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
inputs
=
[
Input
([
None
,
dim
],
'float32'
,
name
=
'x'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
.
prepare
(
optim2
,
loss_function
=
CrossEntropy
(
average
=
False
),
inputs
=
inputs
,
labels
=
labels
,
device
=
device
)
loss
,
=
model
.
train_batch
([
data
],
[
label
])
np
.
testing
.
assert_allclose
(
loss
.
flatten
(),
ref
.
flatten
())
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_test_batch
(
self
,
dynamic
=
True
):
dim
=
20
data
=
np
.
random
.
random
(
size
=
(
4
,
dim
)).
astype
(
np
.
float32
)
def
get_expect
():
fluid
.
enable_dygraph
(
fluid
.
CPUPlace
())
self
.
set_seed
()
m
=
MyModel
()
m
.
eval
()
output
=
m
(
to_variable
(
data
))
fluid
.
disable_dygraph
()
return
output
.
numpy
()
ref
=
get_expect
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
self
.
set_seed
()
model
=
MyModel
()
inputs
=
[
Input
([
None
,
dim
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
,
device
=
device
)
out
,
=
model
.
test_batch
([
data
])
np
.
testing
.
assert_allclose
(
out
,
ref
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_save_load
(
self
):
path
=
tempfile
.
mkdtemp
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
model
=
MyModel
()
inputs
=
[
Input
([
None
,
20
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
)
model
.
save
(
path
+
'/test'
)
model
.
load
(
path
+
'/test'
)
shutil
.
rmtree
(
path
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_parameters
(
self
):
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
model
=
MyModel
()
inputs
=
[
Input
([
None
,
20
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
)
params
=
model
.
parameters
()
self
.
assertTrue
(
params
[
0
].
shape
[
0
]
==
20
)
self
.
assertTrue
(
params
[
0
].
shape
[
1
]
==
10
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录