Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
d390d968
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d390d968
编写于
4月 29, 2020
作者:
Q
qingqing01
提交者:
GitHub
4月 29, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #64 from qingqing01/ut_and_doc
Add unit testing and doc.
上级
581b4944
b295ffbb
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
660 addition
and
147 deletion
+660
-147
hapi/callbacks.py
hapi/callbacks.py
+10
-11
hapi/datasets/mnist.py
hapi/datasets/mnist.py
+8
-2
hapi/loss.py
hapi/loss.py
+2
-2
hapi/model.py
hapi/model.py
+426
-63
hapi/tests/test_model.py
hapi/tests/test_model.py
+214
-69
未找到文件。
hapi/callbacks.py
浏览文件 @
d390d968
...
...
@@ -185,6 +185,9 @@ class ProgBarLogger(Callback):
self
.
verbose
=
verbose
self
.
log_freq
=
log_freq
def
_is_print
(
self
):
return
self
.
verbose
and
ParallelEnv
().
local_rank
==
0
def
on_train_begin
(
self
,
logs
=
None
):
self
.
epochs
=
self
.
params
[
'epochs'
]
assert
self
.
epochs
...
...
@@ -195,7 +198,7 @@ class ProgBarLogger(Callback):
self
.
steps
=
self
.
params
[
'steps'
]
self
.
epoch
=
epoch
self
.
train_step
=
0
if
self
.
verbose
and
self
.
epochs
and
ParallelEnv
().
local_rank
==
0
:
if
self
.
epochs
and
self
.
_is_print
()
:
print
(
'Epoch %d/%d'
%
(
epoch
+
1
,
self
.
epochs
))
self
.
train_progbar
=
ProgressBar
(
num
=
self
.
steps
,
verbose
=
self
.
verbose
)
...
...
@@ -213,15 +216,13 @@ class ProgBarLogger(Callback):
logs
=
logs
or
{}
self
.
train_step
+=
1
if
self
.
train_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
ParallelEnv
(
).
local_rank
==
0
:
if
self
.
_is_print
()
and
self
.
train_step
%
self
.
log_freq
==
0
:
if
self
.
steps
is
None
or
self
.
train_step
<
self
.
steps
:
self
.
_updates
(
logs
,
'train'
)
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
logs
=
logs
or
{}
if
self
.
train_step
%
self
.
log_freq
!=
0
and
self
.
verbose
and
ParallelEnv
(
).
local_rank
==
0
:
if
self
.
_is_print
()
and
(
self
.
steps
is
not
None
):
self
.
_updates
(
logs
,
'train'
)
def
on_eval_begin
(
self
,
logs
=
None
):
...
...
@@ -231,7 +232,7 @@ class ProgBarLogger(Callback):
self
.
evaled_samples
=
0
self
.
eval_progbar
=
ProgressBar
(
num
=
self
.
eval_steps
,
verbose
=
self
.
verbose
)
if
ParallelEnv
().
local_rank
==
0
:
if
self
.
_is_print
()
:
print
(
'Eval begin...'
)
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
...
...
@@ -240,15 +241,13 @@ class ProgBarLogger(Callback):
samples
=
logs
.
get
(
'batch_size'
,
1
)
self
.
evaled_samples
+=
samples
if
self
.
eval_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
ParallelEnv
(
).
local_rank
==
0
:
if
self
.
_is_print
()
and
self
.
eval_step
%
self
.
log_freq
==
0
:
if
self
.
eval_steps
is
None
or
self
.
eval_step
<
self
.
eval_steps
:
self
.
_updates
(
logs
,
'eval'
)
def
on_eval_end
(
self
,
logs
=
None
):
logs
=
logs
or
{}
if
self
.
verbose
and
ParallelEnv
().
local_rank
==
0
:
if
self
.
eval_step
%
self
.
log_freq
!=
0
:
if
self
.
_is_print
()
and
(
self
.
steps
is
not
None
):
self
.
_updates
(
logs
,
'eval'
)
print
(
'Eval samples: %d'
%
(
self
.
evaled_samples
))
...
...
hapi/datasets/mnist.py
浏览文件 @
d390d968
...
...
@@ -45,6 +45,8 @@ class MNIST(Dataset):
:attr:`download` is True. Default None
label_path(str): path to label file, can be set None if
:attr:`download` is True. Default None
chw_format(bool): If set True, the output shape is [1, 28, 28],
otherwise, output shape is [1, 784]. Default True.
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
...
...
@@ -70,13 +72,14 @@ class MNIST(Dataset):
def
__init__
(
self
,
image_path
=
None
,
label_path
=
None
,
chw_format
=
True
,
mode
=
'train'
,
transform
=
None
,
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train' or 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
chw_format
=
chw_format
self
.
image_path
=
image_path
if
self
.
image_path
is
None
:
assert
download
,
"image_path not set and auto download disabled"
...
...
@@ -144,10 +147,13 @@ class MNIST(Dataset):
for
i
in
range
(
buffer_size
):
self
.
images
.
append
(
images
[
i
,
:])
self
.
labels
.
append
(
np
.
array
([
labels
[
i
]]).
astype
(
'int64'
))
self
.
labels
.
append
(
np
.
array
([
labels
[
i
]]).
astype
(
'int64'
))
def
__getitem__
(
self
,
idx
):
image
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
if
self
.
chw_format
:
image
=
np
.
reshape
(
image
,
[
1
,
28
,
28
])
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
image
)
return
image
,
label
...
...
hapi/loss.py
浏览文件 @
d390d968
...
...
@@ -66,7 +66,7 @@ class CrossEntropy(Loss):
"""
def
__init__
(
self
,
average
=
True
):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
(
average
)
def
forward
(
self
,
outputs
,
labels
):
return
[
...
...
@@ -88,7 +88,7 @@ class SoftmaxWithCrossEntropy(Loss):
"""
def
__init__
(
self
,
average
=
True
):
super
(
SoftmaxWithCrossEntropy
,
self
).
__init__
()
super
(
SoftmaxWithCrossEntropy
,
self
).
__init__
(
average
)
def
forward
(
self
,
outputs
,
labels
):
return
[
...
...
hapi/model.py
浏览文件 @
d390d968
...
...
@@ -639,7 +639,49 @@ class DynamicGraphAdapter(object):
class
Model
(
fluid
.
dygraph
.
Layer
):
"""
FIXME: add more comments and usage
An Model object is network with training and inference features.
Dynamic graph and static graph are supported at the same time,
switched by `fluid.enable_dygraph()`. The usage is as follows.
But note, the switching between dynamic and static should be before
instantiating a Model. The input description, i.e, hapi.Input,
must be required for static graph.
Usage:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
#import paddle.incubate.hapi as hapi
from hapi import Model, Input, set_device
from hapi.loss import CrossEntropy
from hapi.dataset import MNIST
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 10, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('gpu')
# if use static graph, do not set
fluid.enable_dygraph(device)
model = MyModel()
optim = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
inputs = [Input([None, 784], 'float32', name='x')]
labels = [Input([None, 1], 'int64', name='label')]
mnist_data = MNIST(mode='train')
model.prepare(optim,
CrossEntropy(average=True),
hapi.metrics.Accuracy(),
inputs,
labels,
device=device)
model.fit(mnist_data, epochs=2, batch_size=32, verbose=1)
"""
def
__init__
(
self
):
...
...
@@ -660,18 +702,195 @@ class Model(fluid.dygraph.Layer):
else
:
self
.
_adapter
=
StaticGraphAdapter
(
self
)
def
train_batch
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
train_batch
(
*
args
,
**
kwargs
)
def
train_batch
(
self
,
inputs
,
labels
=
None
):
"""
Run one training step on a batch of data.
Args:
inputs (list): A list of numpy.ndarray, each is a batch of
input data.
labels (list): A list of numpy.ndarray, each is a batch of
input label. If has no labels, set None. Default is None.
Returns:
A list of scalar training loss if the model has no metrics,
or a tuple (list of scalar loss, list of metrics) if the model
set metrics.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from hapi import Model, Input, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = hapi.set_device('gpu')
fluid.enable_dygraph(device)
model = MyModel()
optim = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
inputs = [Input([None, 784], 'float32', name='x')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim,
CrossEntropy(average=True),
inputs=inputs,
labels=labels,
device=device)
data = np.random.random(size=(4,784)).astype(np.float32)
label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64)
loss = model.train_batch([data], [label])
print(loss)
"""
return
self
.
_adapter
.
train_batch
(
inputs
,
labels
)
def
eval_batch
(
self
,
inputs
,
labels
=
None
):
"""
Run one evaluating step on a batch of data.
Args:
inputs (list): A list of numpy.ndarray, each is a batch of
input data.
labels (list): A list of numpy.ndarray, each is a batch of
input label. If has no labels, set None. Default is None.
Returns:
A list of scalar testing loss if the model has no metrics,
or a tuple (list of scalar loss, list of metrics) if the model
set metrics.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from hapi import Model, Input, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('gpu')
fluid.enable_dygraph(device)
model = MyModel()
optim = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
inputs = [Input([None, 784], 'float32', name='x')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim,
CrossEntropy(average=True),
inputs=inputs,
labels=labels,
device=device)
data = np.random.random(size=(4,784)).astype(np.float32)
label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64)
loss = model.eval_batch([data], [label])
print(loss)
"""
return
self
.
_adapter
.
eval_batch
(
inputs
,
labels
)
def
test_batch
(
self
,
inputs
):
"""
Run one testing step on a batch of data.
Args:
inputs (list): A list of numpy.ndarray, each is a batch of
input data.
Returns:
A list of numpy.ndarray of predictions, that is the outputs
of Model forward.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from hapi import Model, Input, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('gpu')
fluid.enable_dygraph(device)
model = MyModel()
inputs = [Input([None, 784], 'float32', name='x')]
model.prepare(inputs=inputs,
device=device)
data = np.random.random(size=(4,784)).astype(np.float32)
out = model.eval_batch([data])
print(out)
"""
return
self
.
_adapter
.
test_batch
(
inputs
)
def
save
(
self
,
path
):
"""
This function saves parameters, optimizer infomation to path.
def
eval_batch
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
eval_batch
(
*
args
,
**
kwargs
)
The parameters contains all the trainable Variable, will save to
a file with suffix ".pdparams".
The optimizer information contains all the variable used by optimizer.
For Adam optimizer, contains beta1, beta2, momentum etc. All the
information will save to a file with suffix ".pdopt". (If the optimizer
have no variable need to save (like SGD), the fill will not generated).
def
test_batch
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
test_batch
(
*
args
,
**
kwargs
)
This function will silently overwrite existing file
at the target location.
def
save
(
self
,
*
args
,
**
kwargs
):
Args:
path (str): The file prefix to save model. The format is
'dirname/file_prefix' or 'file_prefix'. if empty str. A exception
will be raised.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
from hapi import Model, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('cpu')
fluid.enable_dygraph(device)
model = MyModel()
model.save('checkpoint/test')
"""
if
ParallelEnv
().
local_rank
==
0
:
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
self
.
_adapter
.
save
(
path
)
def
load
(
self
,
path
,
skip_mismatch
=
False
,
reset_optimizer
=
False
):
"""
...
...
@@ -698,6 +917,29 @@ class Model(fluid.dygraph.Layer):
optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
from hapi import Model, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = set_device('cpu')
fluid.enable_dygraph(device)
model = MyModel()
model.load('checkpoint/test')
"""
def
_load_state_from_path
(
path
):
...
...
@@ -747,7 +989,31 @@ class Model(fluid.dygraph.Layer):
return
self
.
_adapter
.
load
(
matched_param_state
,
optim_state
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
parameters
(
*
args
,
**
kwargs
)
"""
Returns a list of parameters of the model.
Returns:
A list of Parameter in static graph.
A list of ParamBase in dynamic graph.
Examples:
.. code-block:: python
from hapi.model import Model, Input, set_device
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self._fc = fluid.dygraph.Linear(20, 10, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
fluid.enable_dygraph()
model = MyModel()
params = model.parameters()
"""
return
self
.
_adapter
.
parameters
()
def
prepare
(
self
,
optimizer
=
None
,
...
...
@@ -757,27 +1023,32 @@ class Model(fluid.dygraph.Layer):
labels
=
None
,
device
=
None
):
"""
FIXME: add comments
Configures the model before runing.
Args:
optimizer (Optimizer|None):
o
ptimizer must be set in training
optimizer (Optimizer|None):
O
ptimizer must be set in training
and should be a Optimizer instance. It can be None in eval
and test mode.
loss_function (Loss|None):
l
oss function must be set in training
loss_function (Loss|None):
L
oss function must be set in training
and should be a Loss instance. It can be None when there is
no loss.
metrics (Metric|list of Metric|None):
i
f metrics is set, all
metric
will be calculate
and output in train/eval mode.
inputs (Input|list|dict|None):
inputs
, entry points of network,
metrics (Metric|list of Metric|None):
I
f metrics is set, all
metric
s will be calculated
and output in train/eval mode.
inputs (Input|list|dict|None):
`inputs`
, entry points of network,
could be a Input layer, or lits of Input layers,
or dict (name: Input), or None. For static graph,
inputs must be set. For dynamic graph, it could be None.
labels (Input|list|None):
labels
, entry points of network,
labels (Input|list|None):
`labels`
, entry points of network,
could be a Input layer or lits of Input layers, or None.
For static graph, if set loss_function in Model.prepare(), it
must be set. Otherwise, it could be None.
device (str|None): specify device type, 'CPU' or 'GPU'.
For static graph, if labels is required in loss_function,
labels must be set. Otherwise, it could be None.
device (str|fluid.CUDAPlace|fluid.CPUPlace|None): Specify device
type, 'CPU', 'GPU', fluid.CUDAPlace or fluid.CPUPlace.
If None, automatically select device according to
installation package version.
Returns:
None
"""
if
isinstance
(
device
,
fluid
.
CUDAPlace
)
or
\
...
...
@@ -859,7 +1130,9 @@ class Model(fluid.dygraph.Layer):
num_workers
=
0
,
callbacks
=
None
,
):
"""
FIXME: add more comments and usage
Trains the model for a fixed number of epochs. If `eval_data` is set,
evaluation will be done at the end of each epoch.
Args:
train_data (Dataset|DataLoader): An iterable data loader is used for
train. An instance of paddle paddle.io.Dataset or
...
...
@@ -868,30 +1141,117 @@ class Model(fluid.dygraph.Layer):
evaluation at the end of epoch. If None, will not do evaluation.
An instance of paddle.io.Dataset or paddle.io.Dataloader
is recomended. Default: None.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored. Default: 1.
epochs (int): Integer number. The number of epochs to train the model. Default: 1.
batch_size (int): Integer number. The batch size of train_data
and eval_data. When train_data and eval_data are both the
instance of Dataloader, this parameter will be ignored.
Default: 1.
epochs (int): Integer number. The number of epochs to train
the model. Default: 1.
eval_freq (int): The frequency, in number of epochs, an evalutation
is performed. Default: 1.
log_freq (int): The frequency, in number of steps, the training logs
are printed. Default: 10.
save_dir(str|None): The directory to save checkpoint during training.
If None, will not save checkpoint. Default: None.
save_freq (int): The frequency, in number of epochs, to save checkpoint. Default: 1.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2.
drop_last (bool): whether drop the last incomplete batch of train_data
when dataset size is not divisible by the batch size. When train_data
is an instance of Dataloader, this parameter will be ignored. Default: False.
shuffle (bool): whther to shuffle train_data. When train_data is an instance
of Dataloader, this parameter will be ignored. Default: True.
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored. Default: 0.
save_freq (int): The frequency, in number of epochs, to save
checkpoint. Default: 1.
verbose (int): The verbosity mode, should be 0, 1, or 2. 0 = silent,
1 = progress bar, 2 = one line per epoch. Default: 2.
drop_last (bool): Whether drop the last incomplete batch of
train_data when dataset size is not divisible by the batch size.
When train_data is an instance of Dataloader, this parameter
will be ignored. Default: False.
shuffle (bool): Whther to shuffle train_data. When train_data is
an instance of Dataloader, this parameter will be ignored.
Default: True.
num_workers (int): The number of subprocess to load data, 0 for no
subprocess used and loading data in main process.
When train_data and eval_data are both the instance of
Dataloader, this parameter will be ignored. Default: 0.
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None.
Returns:
None
Examples:
1. An example use Dataset and set btch size, shuffle in fit.
How to make a batch is done internally.
.. code-block:: python
from hapi.model import Model, Input, set_device
from hapi.loss import CrossEntropy
from hapi.metrics import Accuracy
from hapi.datasets import MNIST
from hapi.vision.models import LeNet
dynamic = True
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if dynamic else None
train_dataset = MNIST(mode='train')
val_dataset = MNIST(mode='test')
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model = LeNet()
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(
optim,
CrossEntropy(),
Accuracy(topk=(1, 2)),
inputs=inputs,
labels=labels,
device=device)
model.fit(train_dataset,
val_dataset,
epochs=2,
batch_size=64,
save_dir='mnist_checkpoint')
2. An example use DataLoader, batch size and shuffle is set in
DataLoader.
.. code-block:: python
from hapi.model import Model, Input, set_device
from hapi.loss import CrossEntropy
from hapi.metrics import Accuracy
from hapi.datasets import MNIST
from hapi.vision.models import LeNet
dynamic = True
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if dynamic else None
train_dataset = MNIST(mode='train')
train_loader = fluid.io.DataLoader(train_dataset,
places=device, batch_size=64)
val_dataset = MNIST(mode='test')
val_loader = fluid.io.DataLoader(val_dataset,
places=device, batch_size=64)
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model = LeNet()
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(
optim,
CrossEntropy(),
Accuracy(topk=(1, 2)),
inputs=inputs,
labels=labels,
device=device)
model.fit(train_loader,
val_loader,
epochs=2,
save_dir='mnist_checkpoint')
"""
assert
train_data
is
not
None
,
\
...
...
@@ -986,26 +1346,29 @@ class Model(fluid.dygraph.Layer):
num_workers
=
0
,
callbacks
=
None
,
):
"""
FIXME: add more comments and usage
Evaluate the loss and metrics of the model on input dataset.
Args:
eval_data (Dataset|DataLoader): An iterable data loader is used for
evaluation. An instance of paddle.io.Dataset or
paddle.io.Dataloader is recomended.
batch_size (int): Integer number. The batch size of train_data
and eval_data.
When eval_data is the instance of Dataloader, this argument will be ignored.
Default: 1.
batch_size (int): Integer number. The batch size of train_data
and eval_data. When eval_data is the instance of Dataloader,
this argument will be ignored.
Default: 1.
log_freq (int): The frequency, in number of steps, the eval logs
are printed. Default: 10.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2.
num_workers (int): The number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored. Default: 0.
verbose (int): The verbosity mode, should be 0, 1, or 2. 0 = silent,
1 = progress bar, 2 = one line per epoch. Default: 2.
num_workers (int): The number of subprocess to load data,
0 for no subprocess used and loading data in main process. When
train_data and eval_data are both the instance of Dataloader,
this parameter will be ignored. Default: 0.
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None.
Returns:
dict: Result of metric.
dict: Result of metric. The key is the names of Metric,
value is a scalar or numpy.array.
"""
if
fluid
.
in_dygraph_mode
():
...
...
@@ -1063,7 +1426,8 @@ class Model(fluid.dygraph.Layer):
num_workers
=
0
,
stack_outputs
=
False
):
"""
FIXME: add more comments and usage
Compute the output predictions on testing data.
Args:
test_data (Dataset|DataLoader): An iterable data loader is used for
predict. An instance of paddle.io.Dataset or paddle.io.Dataloader
...
...
@@ -1071,10 +1435,10 @@ class Model(fluid.dygraph.Layer):
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
argument will be ignored. Default: 1.
num_workers (int):
t
he number of subprocess to load data, 0 for no subprocess
num_workers (int):
T
he number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this argument will be ignored. Default: 0.
stack_output (bool):
w
hether stack output field like a batch, as for an output
stack_output (bool):
W
hether stack output field like a batch, as for an output
filed of a sample is in shape [X, Y], test_data contains N samples, predict
output field will be in shape [N, X, Y] if stack_output is True, and will
be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
...
...
@@ -1138,21 +1502,20 @@ class Model(fluid.dygraph.Layer):
save_dir
,
model_filename
=
None
,
params_filename
=
None
,
program
_only
=
False
):
model
_only
=
False
):
"""
Save inference model must in static mode.
Args:
dirname(str): The directory path to save the inference model.
model_filename(str|None): The name of file to save the inference
program
itself. If is set None, a default filename
model_filename(str|None): The name of file to save the inference
model
itself. If is set None, a default filename
:code:`__model__` will be used.
params_filename(str|None): The name of file to save all related
parameters.
If it is set None, parameters will be saved
params_filename(str|None): The name of file to save all related
parameters.
If it is set None, parameters will be saved
in separate files .
program_only(bool): If True, It will save inference program only, and do not
save params of Program.
Default: False.
model_only(bool): If True, It will save inference model only,
and do not save parameters. Default: False.
Returns:
list: The fetch variables' name list
...
...
@@ -1177,7 +1540,7 @@ class Model(fluid.dygraph.Layer):
main_program
=
infer_prog
,
model_filename
=
model_filename
,
params_filename
=
params_filename
,
program_only
=
program
_only
)
program_only
=
model
_only
)
def
_run_one_epoch
(
self
,
data_loader
,
...
...
hapi/tests/test_model.py
浏览文件 @
d390d968
...
...
@@ -18,27 +18,25 @@ from __future__ import print_function
import
unittest
import
os
import
cv2
import
numpy
as
np
import
tempfile
import
shutil
import
tempfile
import
paddle
from
paddle
import
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.io
import
BatchSampler
,
DataLoader
from
paddle.io
import
DataLoader
from
paddle.fluid.dygraph.base
import
to_variable
from
hapi.model
import
Model
,
Input
,
set_device
from
hapi.loss
import
Loss
from
hapi.loss
import
CrossEntropy
from
hapi.metrics
import
Accuracy
from
hapi.datasets
import
MNIST
from
hapi.vision.models
import
LeNet
from
hapi.download
import
get_weights_path_from_url
class
LeNetDygraph
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
super
(
LeNetDygraph
,
self
).
__init__
()
self
.
num_classes
=
num_classes
...
...
@@ -67,12 +65,16 @@ class LeNetDygraph(fluid.dygraph.Layer):
class
MnistDataset
(
MNIST
):
def
__init__
(
self
,
mode
,
return_label
=
True
):
def
__init__
(
self
,
mode
,
return_label
=
True
,
sample_num
=
None
):
super
(
MnistDataset
,
self
).
__init__
(
mode
=
mode
)
self
.
return_label
=
return_label
if
sample_num
:
self
.
images
=
self
.
images
[:
sample_num
]
self
.
labels
=
self
.
labels
[:
sample_num
]
def
__getitem__
(
self
,
idx
):
img
=
np
.
reshape
(
self
.
images
[
idx
],
[
1
,
28
,
28
])
img
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
img
=
np
.
reshape
(
img
,
[
1
,
28
,
28
])
if
self
.
return_label
:
return
img
,
np
.
array
(
self
.
labels
[
idx
]).
astype
(
'int64'
)
return
img
,
...
...
@@ -81,15 +83,14 @@ class MnistDataset(MNIST):
return
len
(
self
.
images
)
def
get_predict_accuracy
(
pred
,
gt
):
def
compute_acc
(
pred
,
label
):
pred
=
np
.
argmax
(
pred
,
-
1
)
gt
=
np
.
array
(
gt
)
correct
=
pred
[:,
np
.
newaxis
]
==
gt
label
=
np
.
array
(
label
)
correct
=
pred
[:,
np
.
newaxis
]
==
label
return
np
.
sum
(
correct
)
/
correct
.
shape
[
0
]
def
low_level_lenet_dygraph
_train
(
model
,
dataloader
):
def
dynamic
_train
(
model
,
dataloader
):
optim
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
model
.
train
()
...
...
@@ -102,7 +103,7 @@ def low_level_lenet_dygraph_train(model, dataloader):
model
.
clear_gradients
()
def
low_level_
dynamic_evaluate
(
model
,
dataloader
):
def
dynamic_evaluate
(
model
,
dataloader
):
with
fluid
.
dygraph
.
no_grad
():
model
.
eval
()
cnt
=
0
...
...
@@ -115,56 +116,65 @@ def low_level_dynamic_evaluate(model, dataloader):
return
cnt
/
len
(
dataloader
.
dataset
)
class
TestEvaluatePredict
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
device
=
set_device
(
'gpu'
)
self
.
train_dataset
=
MnistDataset
(
mode
=
'train'
)
self
.
val_dataset
=
MnistDataset
(
mode
=
'test'
)
self
.
test_dataset
=
MnistDataset
(
mode
=
'test'
,
return_label
=
False
)
class
TestModel
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
device
=
set_device
(
'gpu'
)
fluid
.
enable_dygraph
(
cls
.
device
)
fluid
.
enable_dygraph
(
self
.
device
)
train_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
train_dataset
,
places
=
self
.
device
,
batch_size
=
64
)
val_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
)
self
.
lenet_dygraph
=
LeNetDygraph
()
low_level_lenet_dygraph_train
(
self
.
lenet_dygraph
,
train_dataloader
)
self
.
acc1
=
low_level_dynamic_evaluate
(
self
.
lenet_dygraph
,
val_dataloader
)
self
.
save_dir
=
tempfile
.
mkdtemp
()
sp_num
=
1280
cls
.
train_dataset
=
MnistDataset
(
mode
=
'train'
,
sample_num
=
sp_num
)
cls
.
val_dataset
=
MnistDataset
(
mode
=
'test'
,
sample_num
=
sp_num
)
cls
.
test_dataset
=
MnistDataset
(
mode
=
'test'
,
return_label
=
False
,
sample_num
=
sp_num
)
self
.
weight_path
=
os
.
path
.
join
(
self
.
save_dir
,
'lenet'
)
fluid
.
dygraph
.
save_dygraph
(
self
.
lenet_dygraph
.
state_dict
(),
self
.
weight_path
)
cls
.
train_loader
=
fluid
.
io
.
DataLoader
(
cls
.
train_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
cls
.
val_loader
=
fluid
.
io
.
DataLoader
(
cls
.
val_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
cls
.
test_loader
=
fluid
.
io
.
DataLoader
(
cls
.
test_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
fluid
.
disable_dygraph
()
seed
=
333
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
def
tearDown
(
self
):
shutil
.
rmtree
(
self
.
save_dir
)
dy_lenet
=
LeNetDygraph
()
cls
.
init_param
=
dy_lenet
.
state_dict
()
dynamic_train
(
dy_lenet
,
cls
.
train_loader
)
def
evaluate
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
cls
.
acc1
=
dynamic_evaluate
(
dy_lenet
,
cls
.
val_loader
)
inputs
=
[
Input
([
-
1
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
cls
.
inputs
=
[
Input
([
-
1
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
cls
.
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
val_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
return_list
=
True
)
cls
.
save_dir
=
tempfile
.
mkdtemp
()
cls
.
weight_path
=
os
.
path
.
join
(
cls
.
save_dir
,
'lenet'
)
fluid
.
dygraph
.
save_dygraph
(
dy_lenet
.
state_dict
(),
cls
.
weight_path
)
model
=
LeNet
()
fluid
.
disable_dygraph
()
model
.
load
(
self
.
weight_path
)
@
classmethod
def
tearDownClass
(
cls
):
shutil
.
rmtree
(
cls
.
save_dir
)
model
.
prepare
(
metrics
=
Accuracy
(),
inputs
=
inputs
,
labels
=
labels
)
def
test_fit_dygraph
(
self
):
self
.
fit
(
True
)
result
=
model
.
evaluate
(
val_dataloader
)
def
test_fit_static
(
self
):
self
.
fit
(
False
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
def
test_evaluate_dygraph
(
self
):
self
.
evaluate
(
True
)
if
fluid
.
in_dygraph_mode
():
fluid
.
disable_dygraph
()
def
test_evaluate_static
(
self
):
self
.
evaluate
(
False
)
def
test_predict_dygraph
(
self
):
self
.
predict
(
True
)
def
test_predict_static
(
self
):
self
.
predict
(
False
)
def
predict
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
...
...
@@ -186,26 +196,161 @@ class TestEvaluatePredict(unittest.TestCase):
output
=
model
.
predict
(
test_dataloader
,
stack_outputs
=
True
)
np
.
testing
.
assert_equal
(
output
[
0
].
shape
[
0
],
len
(
self
.
test_dataset
))
def
fit
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
seed
=
333
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
acc
=
get_predict_accuracy
(
output
[
0
],
self
.
val_dataset
.
labels
)
model
=
LeNet
()
optim_new
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optim_new
,
loss_function
=
CrossEntropy
(
average
=
False
),
metrics
=
Accuracy
(),
inputs
=
self
.
inputs
,
labels
=
self
.
labels
)
model
.
fit
(
self
.
train_dataset
,
batch_size
=
64
,
shuffle
=
False
)
result
=
model
.
evaluate
(
self
.
val_dataset
,
batch_size
=
64
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
np
.
testing
.
assert_allclose
(
acc
,
self
.
acc1
)
def
evaluate
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
model
=
LeNet
()
model
.
prepare
(
metrics
=
Accuracy
(),
inputs
=
self
.
inputs
,
labels
=
self
.
labels
)
model
.
load
(
self
.
weight_path
)
result
=
model
.
evaluate
(
self
.
val_dataset
,
batch_size
=
64
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
if
fluid
.
in_dygraph_mode
():
fluid
.
disable_dygraph
()
def
predict
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
model
=
LeNet
()
model
.
prepare
(
inputs
=
self
.
inputs
)
model
.
load
(
self
.
weight_path
)
output
=
model
.
predict
(
self
.
test_dataset
,
batch_size
=
64
,
stack_outputs
=
True
)
np
.
testing
.
assert_equal
(
output
[
0
].
shape
[
0
],
len
(
self
.
test_dataset
))
def
test_evaluate_dygraph
(
self
):
self
.
evaluate
(
True
)
acc
=
compute_acc
(
output
[
0
],
self
.
val_dataset
.
labels
)
np
.
testing
.
assert_allclose
(
acc
,
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
class
MyModel
(
Model
):
def
__init__
(
self
):
super
(
MyModel
,
self
).
__init__
()
self
.
_fc
=
Linear
(
20
,
10
,
act
=
'softmax'
)
def
forward
(
self
,
x
):
y
=
self
.
_fc
(
x
)
return
y
class
TestModelFunction
(
unittest
.
TestCase
):
def
set_seed
(
self
,
seed
=
1024
):
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
def
test_train_batch
(
self
,
dynamic
=
True
):
dim
=
20
data
=
np
.
random
.
random
(
size
=
(
4
,
dim
)).
astype
(
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
size
=
(
4
,
1
)).
astype
(
np
.
int64
)
def
get_expect
():
fluid
.
enable_dygraph
(
fluid
.
CPUPlace
())
self
.
set_seed
()
m
=
MyModel
()
optim
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameter_list
=
m
.
parameters
())
m
.
train
()
output
=
m
(
to_variable
(
data
))
l
=
to_variable
(
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
output
,
l
)
avg_loss
=
fluid
.
layers
.
reduce_sum
(
loss
)
avg_loss
.
backward
()
optim
.
minimize
(
avg_loss
)
m
.
clear_gradients
()
fluid
.
disable_dygraph
()
return
avg_loss
.
numpy
()
def
test_evaluate_static
(
self
):
self
.
evaluate
(
False
)
ref
=
get_expect
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
self
.
set_seed
()
model
=
MyModel
()
def
test_predict_dygraph
(
self
):
self
.
predict
(
True
)
optim2
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
()
)
def
test_predict_static
(
self
):
self
.
predict
(
False
)
inputs
=
[
Input
([
None
,
dim
],
'float32'
,
name
=
'x'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
.
prepare
(
optim2
,
loss_function
=
CrossEntropy
(
average
=
False
),
inputs
=
inputs
,
labels
=
labels
,
device
=
device
)
loss
,
=
model
.
train_batch
([
data
],
[
label
])
np
.
testing
.
assert_allclose
(
loss
.
flatten
(),
ref
.
flatten
())
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_test_batch
(
self
,
dynamic
=
True
):
dim
=
20
data
=
np
.
random
.
random
(
size
=
(
4
,
dim
)).
astype
(
np
.
float32
)
def
get_expect
():
fluid
.
enable_dygraph
(
fluid
.
CPUPlace
())
self
.
set_seed
()
m
=
MyModel
()
m
.
eval
()
output
=
m
(
to_variable
(
data
))
fluid
.
disable_dygraph
()
return
output
.
numpy
()
ref
=
get_expect
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
self
.
set_seed
()
model
=
MyModel
()
inputs
=
[
Input
([
None
,
dim
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
,
device
=
device
)
out
,
=
model
.
test_batch
([
data
])
np
.
testing
.
assert_allclose
(
out
,
ref
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_save_load
(
self
):
path
=
tempfile
.
mkdtemp
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
model
=
MyModel
()
inputs
=
[
Input
([
None
,
20
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
)
model
.
save
(
path
+
'/test'
)
model
.
load
(
path
+
'/test'
)
shutil
.
rmtree
(
path
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_parameters
(
self
):
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
model
=
MyModel
()
inputs
=
[
Input
([
None
,
20
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
)
params
=
model
.
parameters
()
self
.
assertTrue
(
params
[
0
].
shape
[
0
]
==
20
)
self
.
assertTrue
(
params
[
0
].
shape
[
1
]
==
10
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录