Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
d390d968
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d390d968
编写于
4月 29, 2020
作者:
Q
qingqing01
提交者:
GitHub
4月 29, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #64 from qingqing01/ut_and_doc
Add unit testing and doc.
上级
581b4944
b295ffbb
变更
5
展开全部
隐藏空白更改
内联
并排
Showing
5 changed file
with
660 addition
and
147 deletion
+660
-147
hapi/callbacks.py
hapi/callbacks.py
+10
-11
hapi/datasets/mnist.py
hapi/datasets/mnist.py
+8
-2
hapi/loss.py
hapi/loss.py
+2
-2
hapi/model.py
hapi/model.py
+426
-63
hapi/tests/test_model.py
hapi/tests/test_model.py
+214
-69
未找到文件。
hapi/callbacks.py
浏览文件 @
d390d968
...
@@ -185,6 +185,9 @@ class ProgBarLogger(Callback):
...
@@ -185,6 +185,9 @@ class ProgBarLogger(Callback):
self
.
verbose
=
verbose
self
.
verbose
=
verbose
self
.
log_freq
=
log_freq
self
.
log_freq
=
log_freq
def
_is_print
(
self
):
return
self
.
verbose
and
ParallelEnv
().
local_rank
==
0
def
on_train_begin
(
self
,
logs
=
None
):
def
on_train_begin
(
self
,
logs
=
None
):
self
.
epochs
=
self
.
params
[
'epochs'
]
self
.
epochs
=
self
.
params
[
'epochs'
]
assert
self
.
epochs
assert
self
.
epochs
...
@@ -195,7 +198,7 @@ class ProgBarLogger(Callback):
...
@@ -195,7 +198,7 @@ class ProgBarLogger(Callback):
self
.
steps
=
self
.
params
[
'steps'
]
self
.
steps
=
self
.
params
[
'steps'
]
self
.
epoch
=
epoch
self
.
epoch
=
epoch
self
.
train_step
=
0
self
.
train_step
=
0
if
self
.
verbose
and
self
.
epochs
and
ParallelEnv
().
local_rank
==
0
:
if
self
.
epochs
and
self
.
_is_print
()
:
print
(
'Epoch %d/%d'
%
(
epoch
+
1
,
self
.
epochs
))
print
(
'Epoch %d/%d'
%
(
epoch
+
1
,
self
.
epochs
))
self
.
train_progbar
=
ProgressBar
(
num
=
self
.
steps
,
verbose
=
self
.
verbose
)
self
.
train_progbar
=
ProgressBar
(
num
=
self
.
steps
,
verbose
=
self
.
verbose
)
...
@@ -213,15 +216,13 @@ class ProgBarLogger(Callback):
...
@@ -213,15 +216,13 @@ class ProgBarLogger(Callback):
logs
=
logs
or
{}
logs
=
logs
or
{}
self
.
train_step
+=
1
self
.
train_step
+=
1
if
self
.
train_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
ParallelEnv
(
if
self
.
_is_print
()
and
self
.
train_step
%
self
.
log_freq
==
0
:
).
local_rank
==
0
:
if
self
.
steps
is
None
or
self
.
train_step
<
self
.
steps
:
if
self
.
steps
is
None
or
self
.
train_step
<
self
.
steps
:
self
.
_updates
(
logs
,
'train'
)
self
.
_updates
(
logs
,
'train'
)
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
logs
=
logs
or
{}
logs
=
logs
or
{}
if
self
.
train_step
%
self
.
log_freq
!=
0
and
self
.
verbose
and
ParallelEnv
(
if
self
.
_is_print
()
and
(
self
.
steps
is
not
None
):
).
local_rank
==
0
:
self
.
_updates
(
logs
,
'train'
)
self
.
_updates
(
logs
,
'train'
)
def
on_eval_begin
(
self
,
logs
=
None
):
def
on_eval_begin
(
self
,
logs
=
None
):
...
@@ -231,7 +232,7 @@ class ProgBarLogger(Callback):
...
@@ -231,7 +232,7 @@ class ProgBarLogger(Callback):
self
.
evaled_samples
=
0
self
.
evaled_samples
=
0
self
.
eval_progbar
=
ProgressBar
(
self
.
eval_progbar
=
ProgressBar
(
num
=
self
.
eval_steps
,
verbose
=
self
.
verbose
)
num
=
self
.
eval_steps
,
verbose
=
self
.
verbose
)
if
ParallelEnv
().
local_rank
==
0
:
if
self
.
_is_print
()
:
print
(
'Eval begin...'
)
print
(
'Eval begin...'
)
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
...
@@ -240,16 +241,14 @@ class ProgBarLogger(Callback):
...
@@ -240,16 +241,14 @@ class ProgBarLogger(Callback):
samples
=
logs
.
get
(
'batch_size'
,
1
)
samples
=
logs
.
get
(
'batch_size'
,
1
)
self
.
evaled_samples
+=
samples
self
.
evaled_samples
+=
samples
if
self
.
eval_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
ParallelEnv
(
if
self
.
_is_print
()
and
self
.
eval_step
%
self
.
log_freq
==
0
:
).
local_rank
==
0
:
if
self
.
eval_steps
is
None
or
self
.
eval_step
<
self
.
eval_steps
:
if
self
.
eval_steps
is
None
or
self
.
eval_step
<
self
.
eval_steps
:
self
.
_updates
(
logs
,
'eval'
)
self
.
_updates
(
logs
,
'eval'
)
def
on_eval_end
(
self
,
logs
=
None
):
def
on_eval_end
(
self
,
logs
=
None
):
logs
=
logs
or
{}
logs
=
logs
or
{}
if
self
.
verbose
and
ParallelEnv
().
local_rank
==
0
:
if
self
.
_is_print
()
and
(
self
.
steps
is
not
None
):
if
self
.
eval_step
%
self
.
log_freq
!=
0
:
self
.
_updates
(
logs
,
'eval'
)
self
.
_updates
(
logs
,
'eval'
)
print
(
'Eval samples: %d'
%
(
self
.
evaled_samples
))
print
(
'Eval samples: %d'
%
(
self
.
evaled_samples
))
...
...
hapi/datasets/mnist.py
浏览文件 @
d390d968
...
@@ -45,6 +45,8 @@ class MNIST(Dataset):
...
@@ -45,6 +45,8 @@ class MNIST(Dataset):
:attr:`download` is True. Default None
:attr:`download` is True. Default None
label_path(str): path to label file, can be set None if
label_path(str): path to label file, can be set None if
:attr:`download` is True. Default None
:attr:`download` is True. Default None
chw_format(bool): If set True, the output shape is [1, 28, 28],
otherwise, output shape is [1, 784]. Default True.
mode(str): 'train' or 'test' mode. Default 'train'.
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
:attr:`image_path`/:attr:`label_path` unset. Default
...
@@ -70,13 +72,14 @@ class MNIST(Dataset):
...
@@ -70,13 +72,14 @@ class MNIST(Dataset):
def
__init__
(
self
,
def
__init__
(
self
,
image_path
=
None
,
image_path
=
None
,
label_path
=
None
,
label_path
=
None
,
chw_format
=
True
,
mode
=
'train'
,
mode
=
'train'
,
transform
=
None
,
transform
=
None
,
download
=
True
):
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train' or 'test', but got {}"
.
format
(
mode
)
"mode should be 'train' or 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
mode
=
mode
.
lower
()
self
.
chw_format
=
chw_format
self
.
image_path
=
image_path
self
.
image_path
=
image_path
if
self
.
image_path
is
None
:
if
self
.
image_path
is
None
:
assert
download
,
"image_path not set and auto download disabled"
assert
download
,
"image_path not set and auto download disabled"
...
@@ -144,10 +147,13 @@ class MNIST(Dataset):
...
@@ -144,10 +147,13 @@ class MNIST(Dataset):
for
i
in
range
(
buffer_size
):
for
i
in
range
(
buffer_size
):
self
.
images
.
append
(
images
[
i
,
:])
self
.
images
.
append
(
images
[
i
,
:])
self
.
labels
.
append
(
np
.
array
([
labels
[
i
]]).
astype
(
'int64'
))
self
.
labels
.
append
(
np
.
array
([
labels
[
i
]]).
astype
(
'int64'
))
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
image
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
image
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
if
self
.
chw_format
:
image
=
np
.
reshape
(
image
,
[
1
,
28
,
28
])
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
image
)
image
=
self
.
transform
(
image
)
return
image
,
label
return
image
,
label
...
...
hapi/loss.py
浏览文件 @
d390d968
...
@@ -66,7 +66,7 @@ class CrossEntropy(Loss):
...
@@ -66,7 +66,7 @@ class CrossEntropy(Loss):
"""
"""
def
__init__
(
self
,
average
=
True
):
def
__init__
(
self
,
average
=
True
):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
(
average
)
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
[
return
[
...
@@ -88,7 +88,7 @@ class SoftmaxWithCrossEntropy(Loss):
...
@@ -88,7 +88,7 @@ class SoftmaxWithCrossEntropy(Loss):
"""
"""
def
__init__
(
self
,
average
=
True
):
def
__init__
(
self
,
average
=
True
):
super
(
SoftmaxWithCrossEntropy
,
self
).
__init__
()
super
(
SoftmaxWithCrossEntropy
,
self
).
__init__
(
average
)
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
[
return
[
...
...
hapi/model.py
浏览文件 @
d390d968
此差异已折叠。
点击以展开。
hapi/tests/test_model.py
浏览文件 @
d390d968
...
@@ -18,27 +18,25 @@ from __future__ import print_function
...
@@ -18,27 +18,25 @@ from __future__ import print_function
import
unittest
import
unittest
import
os
import
os
import
cv2
import
numpy
as
np
import
numpy
as
np
import
tempfile
import
shutil
import
shutil
import
tempfile
import
paddle
import
paddle
from
paddle
import
fluid
from
paddle
import
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.io
import
BatchSampler
,
DataLoader
from
paddle.io
import
DataLoader
from
paddle.fluid.dygraph.base
import
to_variable
from
hapi.model
import
Model
,
Input
,
set_device
from
hapi.model
import
Model
,
Input
,
set_device
from
hapi.loss
import
Loss
from
hapi.loss
import
CrossEntropy
from
hapi.metrics
import
Accuracy
from
hapi.metrics
import
Accuracy
from
hapi.datasets
import
MNIST
from
hapi.datasets
import
MNIST
from
hapi.vision.models
import
LeNet
from
hapi.vision.models
import
LeNet
from
hapi.download
import
get_weights_path_from_url
class
LeNetDygraph
(
fluid
.
dygraph
.
Layer
):
class
LeNetDygraph
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
super
(
LeNetDygraph
,
self
).
__init__
()
super
(
LeNetDygraph
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
...
@@ -67,12 +65,16 @@ class LeNetDygraph(fluid.dygraph.Layer):
...
@@ -67,12 +65,16 @@ class LeNetDygraph(fluid.dygraph.Layer):
class
MnistDataset
(
MNIST
):
class
MnistDataset
(
MNIST
):
def
__init__
(
self
,
mode
,
return_label
=
True
):
def
__init__
(
self
,
mode
,
return_label
=
True
,
sample_num
=
None
):
super
(
MnistDataset
,
self
).
__init__
(
mode
=
mode
)
super
(
MnistDataset
,
self
).
__init__
(
mode
=
mode
)
self
.
return_label
=
return_label
self
.
return_label
=
return_label
if
sample_num
:
self
.
images
=
self
.
images
[:
sample_num
]
self
.
labels
=
self
.
labels
[:
sample_num
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
img
=
np
.
reshape
(
self
.
images
[
idx
],
[
1
,
28
,
28
])
img
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
img
=
np
.
reshape
(
img
,
[
1
,
28
,
28
])
if
self
.
return_label
:
if
self
.
return_label
:
return
img
,
np
.
array
(
self
.
labels
[
idx
]).
astype
(
'int64'
)
return
img
,
np
.
array
(
self
.
labels
[
idx
]).
astype
(
'int64'
)
return
img
,
return
img
,
...
@@ -81,15 +83,14 @@ class MnistDataset(MNIST):
...
@@ -81,15 +83,14 @@ class MnistDataset(MNIST):
return
len
(
self
.
images
)
return
len
(
self
.
images
)
def
get_predict_accuracy
(
pred
,
gt
):
def
compute_acc
(
pred
,
label
):
pred
=
np
.
argmax
(
pred
,
-
1
)
pred
=
np
.
argmax
(
pred
,
-
1
)
gt
=
np
.
array
(
gt
)
label
=
np
.
array
(
label
)
correct
=
pred
[:,
np
.
newaxis
]
==
label
correct
=
pred
[:,
np
.
newaxis
]
==
gt
return
np
.
sum
(
correct
)
/
correct
.
shape
[
0
]
return
np
.
sum
(
correct
)
/
correct
.
shape
[
0
]
def
low_level_lenet_dygraph
_train
(
model
,
dataloader
):
def
dynamic
_train
(
model
,
dataloader
):
optim
=
fluid
.
optimizer
.
Adam
(
optim
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
model
.
train
()
model
.
train
()
...
@@ -102,7 +103,7 @@ def low_level_lenet_dygraph_train(model, dataloader):
...
@@ -102,7 +103,7 @@ def low_level_lenet_dygraph_train(model, dataloader):
model
.
clear_gradients
()
model
.
clear_gradients
()
def
low_level_
dynamic_evaluate
(
model
,
dataloader
):
def
dynamic_evaluate
(
model
,
dataloader
):
with
fluid
.
dygraph
.
no_grad
():
with
fluid
.
dygraph
.
no_grad
():
model
.
eval
()
model
.
eval
()
cnt
=
0
cnt
=
0
...
@@ -115,56 +116,65 @@ def low_level_dynamic_evaluate(model, dataloader):
...
@@ -115,56 +116,65 @@ def low_level_dynamic_evaluate(model, dataloader):
return
cnt
/
len
(
dataloader
.
dataset
)
return
cnt
/
len
(
dataloader
.
dataset
)
class
TestEvaluatePredict
(
unittest
.
TestCase
):
class
TestModel
(
unittest
.
TestCase
):
def
setUp
(
self
):
@
classmethod
self
.
device
=
set_device
(
'gpu'
)
def
setUpClass
(
cls
):
self
.
train_dataset
=
MnistDataset
(
mode
=
'train'
)
cls
.
device
=
set_device
(
'gpu'
)
self
.
val_dataset
=
MnistDataset
(
mode
=
'test'
)
fluid
.
enable_dygraph
(
cls
.
device
)
self
.
test_dataset
=
MnistDataset
(
mode
=
'test'
,
return_label
=
False
)
fluid
.
enable_dygraph
(
self
.
device
)
train_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
train_dataset
,
places
=
self
.
device
,
batch_size
=
64
)
val_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
)
self
.
lenet_dygraph
=
LeNetDygraph
()
low_level_lenet_dygraph_train
(
self
.
lenet_dygraph
,
train_dataloader
)
self
.
acc1
=
low_level_dynamic_evaluate
(
self
.
lenet_dygraph
,
val_dataloader
)
self
.
save_dir
=
tempfile
.
mkdtemp
()
self
.
weight_path
=
os
.
path
.
join
(
self
.
save_dir
,
'lenet'
)
fluid
.
dygraph
.
save_dygraph
(
self
.
lenet_dygraph
.
state_dict
(),
self
.
weight_path
)
fluid
.
disable_dygraph
()
sp_num
=
1280
cls
.
train_dataset
=
MnistDataset
(
mode
=
'train'
,
sample_num
=
sp_num
)
cls
.
val_dataset
=
MnistDataset
(
mode
=
'test'
,
sample_num
=
sp_num
)
cls
.
test_dataset
=
MnistDataset
(
mode
=
'test'
,
return_label
=
False
,
sample_num
=
sp_num
)
def
tearDown
(
self
):
cls
.
train_loader
=
fluid
.
io
.
DataLoader
(
shutil
.
rmtree
(
self
.
save_dir
)
cls
.
train_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
cls
.
val_loader
=
fluid
.
io
.
DataLoader
(
cls
.
val_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
cls
.
test_loader
=
fluid
.
io
.
DataLoader
(
cls
.
test_dataset
,
places
=
cls
.
device
,
batch_size
=
64
)
def
evaluate
(
self
,
dynamic
):
seed
=
333
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
inputs
=
[
Input
([
-
1
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
dy_lenet
=
LeNetDygraph
()
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
cls
.
init_param
=
dy_lenet
.
state_dict
()
dynamic_train
(
dy_lenet
,
cls
.
train_loader
)
val_dataloader
=
fluid
.
io
.
DataLoader
(
cls
.
acc1
=
dynamic_evaluate
(
dy_lenet
,
cls
.
val_loader
)
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
return_list
=
True
)
model
=
LeNet
()
cls
.
inputs
=
[
Input
([
-
1
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
cls
.
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
.
load
(
self
.
weight_path
)
cls
.
save_dir
=
tempfile
.
mkdtemp
()
cls
.
weight_path
=
os
.
path
.
join
(
cls
.
save_dir
,
'lenet'
)
fluid
.
dygraph
.
save_dygraph
(
dy_lenet
.
state_dict
(),
cls
.
weight_path
)
model
.
prepare
(
metrics
=
Accuracy
(),
inputs
=
inputs
,
labels
=
labels
)
fluid
.
disable_dygraph
(
)
result
=
model
.
evaluate
(
val_dataloader
)
@
classmethod
def
tearDownClass
(
cls
):
shutil
.
rmtree
(
cls
.
save_dir
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
def
test_fit_dygraph
(
self
):
self
.
fit
(
True
)
if
fluid
.
in_dygraph_mode
():
def
test_fit_static
(
self
):
fluid
.
disable_dygraph
()
self
.
fit
(
False
)
def
test_evaluate_dygraph
(
self
):
self
.
evaluate
(
True
)
def
test_evaluate_static
(
self
):
self
.
evaluate
(
False
)
def
test_predict_dygraph
(
self
):
self
.
predict
(
True
)
def
test_predict_static
(
self
):
self
.
predict
(
False
)
def
predict
(
self
,
dynamic
):
def
predict
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
...
@@ -186,26 +196,161 @@ class TestEvaluatePredict(unittest.TestCase):
...
@@ -186,26 +196,161 @@ class TestEvaluatePredict(unittest.TestCase):
output
=
model
.
predict
(
test_dataloader
,
stack_outputs
=
True
)
output
=
model
.
predict
(
test_dataloader
,
stack_outputs
=
True
)
np
.
testing
.
assert_equal
(
output
[
0
].
shape
[
0
],
len
(
self
.
test_dataset
))
def
fit
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
acc
=
get_predict_accuracy
(
output
[
0
],
self
.
val_dataset
.
labels
)
seed
=
333
fluid
.
default_startup_program
().
random_seed
=
seed
np
.
testing
.
assert_allclose
(
acc
,
self
.
acc1
)
fluid
.
default_main_program
().
random_seed
=
seed
if
fluid
.
in_dygraph_mode
():
fluid
.
disable_dygraph
()
def
test_evaluate_dygraph
(
self
):
model
=
LeNet
()
self
.
evaluate
(
True
)
optim_new
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optim_new
,
loss_function
=
CrossEntropy
(
average
=
False
),
metrics
=
Accuracy
(),
inputs
=
self
.
inputs
,
labels
=
self
.
labels
)
model
.
fit
(
self
.
train_dataset
,
batch_size
=
64
,
shuffle
=
False
)
result
=
model
.
evaluate
(
self
.
val_dataset
,
batch_size
=
64
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_evaluate_static
(
self
):
def
evaluate
(
self
,
dynamic
):
self
.
evaluate
(
False
)
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
model
=
LeNet
()
model
.
prepare
(
metrics
=
Accuracy
(),
inputs
=
self
.
inputs
,
labels
=
self
.
labels
)
model
.
load
(
self
.
weight_path
)
result
=
model
.
evaluate
(
self
.
val_dataset
,
batch_size
=
64
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_predict_dygraph
(
self
):
def
predict
(
self
,
dynamic
):
self
.
predict
(
True
)
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
model
=
LeNet
()
model
.
prepare
(
inputs
=
self
.
inputs
)
model
.
load
(
self
.
weight_path
)
output
=
model
.
predict
(
self
.
test_dataset
,
batch_size
=
64
,
stack_outputs
=
True
)
np
.
testing
.
assert_equal
(
output
[
0
].
shape
[
0
],
len
(
self
.
test_dataset
))
def
test_predict_static
(
self
):
acc
=
compute_acc
(
output
[
0
],
self
.
val_dataset
.
labels
)
self
.
predict
(
False
)
np
.
testing
.
assert_allclose
(
acc
,
self
.
acc1
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
class
MyModel
(
Model
):
def
__init__
(
self
):
super
(
MyModel
,
self
).
__init__
()
self
.
_fc
=
Linear
(
20
,
10
,
act
=
'softmax'
)
def
forward
(
self
,
x
):
y
=
self
.
_fc
(
x
)
return
y
class
TestModelFunction
(
unittest
.
TestCase
):
def
set_seed
(
self
,
seed
=
1024
):
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
def
test_train_batch
(
self
,
dynamic
=
True
):
dim
=
20
data
=
np
.
random
.
random
(
size
=
(
4
,
dim
)).
astype
(
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
size
=
(
4
,
1
)).
astype
(
np
.
int64
)
def
get_expect
():
fluid
.
enable_dygraph
(
fluid
.
CPUPlace
())
self
.
set_seed
()
m
=
MyModel
()
optim
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameter_list
=
m
.
parameters
())
m
.
train
()
output
=
m
(
to_variable
(
data
))
l
=
to_variable
(
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
output
,
l
)
avg_loss
=
fluid
.
layers
.
reduce_sum
(
loss
)
avg_loss
.
backward
()
optim
.
minimize
(
avg_loss
)
m
.
clear_gradients
()
fluid
.
disable_dygraph
()
return
avg_loss
.
numpy
()
ref
=
get_expect
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
self
.
set_seed
()
model
=
MyModel
()
optim2
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
inputs
=
[
Input
([
None
,
dim
],
'float32'
,
name
=
'x'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
.
prepare
(
optim2
,
loss_function
=
CrossEntropy
(
average
=
False
),
inputs
=
inputs
,
labels
=
labels
,
device
=
device
)
loss
,
=
model
.
train_batch
([
data
],
[
label
])
np
.
testing
.
assert_allclose
(
loss
.
flatten
(),
ref
.
flatten
())
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_test_batch
(
self
,
dynamic
=
True
):
dim
=
20
data
=
np
.
random
.
random
(
size
=
(
4
,
dim
)).
astype
(
np
.
float32
)
def
get_expect
():
fluid
.
enable_dygraph
(
fluid
.
CPUPlace
())
self
.
set_seed
()
m
=
MyModel
()
m
.
eval
()
output
=
m
(
to_variable
(
data
))
fluid
.
disable_dygraph
()
return
output
.
numpy
()
ref
=
get_expect
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
self
.
set_seed
()
model
=
MyModel
()
inputs
=
[
Input
([
None
,
dim
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
,
device
=
device
)
out
,
=
model
.
test_batch
([
data
])
np
.
testing
.
assert_allclose
(
out
,
ref
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_save_load
(
self
):
path
=
tempfile
.
mkdtemp
()
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
model
=
MyModel
()
inputs
=
[
Input
([
None
,
20
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
)
model
.
save
(
path
+
'/test'
)
model
.
load
(
path
+
'/test'
)
shutil
.
rmtree
(
path
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_parameters
(
self
):
for
dynamic
in
[
True
,
False
]:
device
=
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
model
=
MyModel
()
inputs
=
[
Input
([
None
,
20
],
'float32'
,
name
=
'x'
)]
model
.
prepare
(
inputs
=
inputs
)
params
=
model
.
parameters
()
self
.
assertTrue
(
params
[
0
].
shape
[
0
]
==
20
)
self
.
assertTrue
(
params
[
0
].
shape
[
1
]
==
10
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录