Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
6c770910
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6c770910
编写于
4月 27, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add unittest for evaluate and predict
上级
a7742a72
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
291 addition
and
23 deletion
+291
-23
hapi/model.py
hapi/model.py
+0
-1
hapi/test/dist_mnist.py
hapi/test/dist_mnist.py
+2
-2
hapi/test/test_distributed.py
hapi/test/test_distributed.py
+60
-20
hapi/test/test_evaluate_predict.py
hapi/test/test_evaluate_predict.py
+229
-0
未找到文件。
hapi/model.py
浏览文件 @
6c770910
...
...
@@ -41,7 +41,6 @@ from hapi.callbacks import config_callbacks
__all__
=
[
'Model'
,
'Loss'
,
'Input'
,
'set_device'
,
]
...
...
hapi/test/dist_mnist.py
浏览文件 @
6c770910
...
...
@@ -27,8 +27,8 @@ from paddle import fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.io
import
BatchSampler
,
DataLoader
from
hapi.model
import
Model
,
Input
,
Loss
,
set_device
from
hapi.loss
import
CrossEntropy
from
hapi.model
import
Model
,
Input
,
set_device
from
hapi.loss
import
Loss
,
CrossEntropy
from
hapi.metrics
import
Accuracy
from
hapi.callbacks
import
ProgBarLogger
from
hapi.datasets
import
MNIST
as
MnistDataset
...
...
hapi/test/test_distributed.py
浏览文件 @
6c770910
...
...
@@ -36,9 +36,6 @@ def get_cluster_from_args(selected_gpus):
node_rank
=
node_ips
.
index
(
node_ip
)
logger
.
debug
(
"parsed from args:node_ips:{} node_ip:{} node_rank:{}"
.
format
(
node_ips
,
node_ip
,
node_rank
))
free_ports
=
None
if
not
use_paddlecloud
and
len
(
node_ips
)
<=
1
and
started_port
is
None
:
free_ports
=
find_free_ports
(
len
(
selected_gpus
))
...
...
@@ -54,10 +51,6 @@ def get_cluster_from_args(selected_gpus):
def
get_gpus
(
selected_gpus
):
if
selected_gpus
is
None
:
gpus_num
=
fluid
.
core
.
get_cuda_device_count
()
selected_gpus
=
[
str
(
x
)
for
x
in
range
(
0
,
gpus_num
)]
else
:
cuda_visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
if
cuda_visible_devices
is
None
or
cuda_visible_devices
==
""
:
selected_gpus
=
[
x
.
strip
()
for
x
in
selected_gpus
.
split
(
','
)]
...
...
@@ -74,6 +67,53 @@ def get_gpus(selected_gpus):
return
selected_gpus
def
start_local_trainers
(
cluster
,
pod
,
training_script
,
training_script_args
,
log_dir
=
None
):
current_env
=
copy
.
copy
(
os
.
environ
.
copy
())
#paddle broadcast ncclUniqueId use socket, and
#proxy maybe make trainers unreachable, so delete them.
#if we set them to "", grpc will log error message "bad uri"
#so just delete them.
current_env
.
pop
(
"http_proxy"
,
None
)
current_env
.
pop
(
"https_proxy"
,
None
)
procs
=
[]
for
idx
,
t
in
enumerate
(
pod
.
trainers
):
proc_env
=
{
"FLAGS_selected_gpus"
:
"%s"
%
","
.
join
([
str
(
g
)
for
g
in
t
.
gpus
]),
"PADDLE_TRAINER_ID"
:
"%d"
%
t
.
rank
,
"PADDLE_CURRENT_ENDPOINT"
:
"%s"
%
t
.
endpoint
,
"PADDLE_TRAINERS_NUM"
:
"%d"
%
cluster
.
trainers_nranks
(),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
cluster
.
trainers_endpoints
())
}
current_env
.
update
(
proc_env
)
print
(
"trainer proc env:{}"
.
format
(
current_env
))
cmd
=
"python -m coverage run --branch -p "
+
training_script
# cmd = [sys.executable, "-u", training_script] + training_script_args
print
(
"start trainer proc:{} env:{}"
.
format
(
cmd
,
proc_env
))
fn
=
None
proc
=
subprocess
.
Popen
(
cmd
.
split
(
" "
),
env
=
current_env
)
tp
=
TrainerProc
()
tp
.
proc
=
proc
tp
.
rank
=
t
.
rank
tp
.
log_fn
=
fn
tp
.
cmd
=
cmd
procs
.
append
(
tp
)
return
procs
class
TestMultipleGpus
(
unittest
.
TestCase
):
def
test_mnist_2gpu
(
self
):
if
fluid
.
core
.
get_cuda_device_count
()
==
0
:
...
...
@@ -95,7 +135,7 @@ class TestMultipleGpus(unittest.TestCase):
alive
=
watch_local_trainers
(
procs
,
cluster
.
trainers_nranks
())
if
not
alive
:
logger
.
info
(
"Local procs complete, POD info:{}"
.
format
(
pod
))
print
(
"Local procs complete, POD info:{}"
.
format
(
pod
))
break
time
.
sleep
(
3
)
...
...
hapi/test/test_evaluate_predict.py
0 → 100644
浏览文件 @
6c770910
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
os
import
cv2
import
numpy
as
np
import
paddle
from
paddle
import
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.io
import
BatchSampler
,
DataLoader
from
hapi.model
import
Model
,
Input
,
set_device
from
hapi.loss
import
Loss
from
hapi.metrics
import
Accuracy
from
hapi.datasets
import
MNIST
from
hapi.vision.models
import
LeNet
from
hapi.download
import
get_weights_path_from_url
class
LeNetDygraph
(
fluid
.
dygraph
.
Layer
):
"""LeNet model from
`"LeCun Y, Bottou L, Bengio Y, et al. Gradient-based learning applied to document recognition[J]. Proceedings of the IEEE, 1998, 86(11): 2278-2324.`_
Args:
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 10.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
"""
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
super
(
LeNetDygraph
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
features
=
Sequential
(
Conv2D
(
1
,
6
,
3
,
stride
=
1
,
padding
=
1
),
Pool2D
(
2
,
'max'
,
2
),
Conv2D
(
6
,
16
,
5
,
stride
=
1
,
padding
=
0
),
Pool2D
(
2
,
'max'
,
2
))
if
num_classes
>
0
:
self
.
fc
=
Sequential
(
Linear
(
400
,
120
),
Linear
(
120
,
84
),
Linear
(
84
,
10
,
act
=
classifier_activation
))
def
forward
(
self
,
inputs
):
x
=
self
.
features
(
inputs
)
if
self
.
num_classes
>
0
:
x
=
fluid
.
layers
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
return
x
class
MnistDataset
(
MNIST
):
def
__init__
(
self
,
mode
,
return_label
=
True
):
super
(
MnistDataset
,
self
).
__init__
(
mode
=
mode
)
self
.
return_label
=
return_label
def
__getitem__
(
self
,
idx
):
img
=
np
.
reshape
(
self
.
images
[
idx
],
[
1
,
28
,
28
])
if
self
.
return_label
:
return
img
,
np
.
array
(
self
.
labels
[
idx
]).
astype
(
'int64'
)
return
img
,
def
__len__
(
self
):
return
len
(
self
.
images
)
def
get_predict_accuracy
(
pred
,
gt
):
pred
=
np
.
argmax
(
pred
,
-
1
)
gt
=
np
.
array
(
gt
)
correct
=
pred
[:,
np
.
newaxis
]
==
gt
return
np
.
sum
(
correct
)
/
correct
.
shape
[
0
]
def
low_level_lenet_dygraph_train
(
model
,
dataloader
):
optim
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
model
.
train
()
for
inputs
,
labels
in
dataloader
:
outputs
=
model
(
inputs
)
loss
=
fluid
.
layers
.
cross_entropy
(
outputs
,
labels
)
avg_loss
=
fluid
.
layers
.
reduce_sum
(
loss
)
avg_loss
.
backward
()
optim
.
minimize
(
avg_loss
)
model
.
clear_gradients
()
def
low_level_dynamic_evaluate
(
model
,
dataloader
):
with
fluid
.
dygraph
.
no_grad
():
model
.
eval
()
cnt
=
0
for
inputs
,
labels
in
dataloader
:
outputs
=
model
(
inputs
)
cnt
+=
(
np
.
argmax
(
outputs
.
numpy
(),
-
1
)[:,
np
.
newaxis
]
==
labels
.
numpy
()).
astype
(
'int'
).
sum
()
return
cnt
/
len
(
dataloader
.
dataset
)
class
TestEvaluatePredict
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
device
=
set_device
(
'gpu'
)
self
.
train_dataset
=
MnistDataset
(
mode
=
'train'
)
self
.
val_dataset
=
MnistDataset
(
mode
=
'test'
)
self
.
test_dataset
=
MnistDataset
(
mode
=
'test'
,
return_label
=
False
)
fluid
.
enable_dygraph
(
self
.
device
)
train_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
train_dataset
,
places
=
self
.
device
,
batch_size
=
64
)
val_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
)
self
.
lenet_dygraph
=
LeNetDygraph
()
low_level_lenet_dygraph_train
(
self
.
lenet_dygraph
,
train_dataloader
)
self
.
acc1
=
low_level_dynamic_evaluate
(
self
.
lenet_dygraph
,
val_dataloader
)
def
evaluate
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
inputs
=
[
Input
([
-
1
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
if
fluid
.
in_dygraph_mode
():
feed_list
=
None
else
:
feed_list
=
[
x
.
forward
()
for
x
in
inputs
+
labels
]
self
.
train_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
train_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
feed_list
=
feed_list
)
self
.
val_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
feed_list
=
feed_list
)
self
.
test_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
test_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
feed_list
=
feed_list
)
model
=
LeNet
()
model
.
load_dict
(
self
.
lenet_dygraph
.
state_dict
())
model
.
prepare
(
metrics
=
Accuracy
(),
inputs
=
inputs
,
labels
=
labels
)
result
=
model
.
evaluate
(
self
.
val_dataloader
)
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
def
predict
(
self
,
dynamic
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
inputs
=
[
Input
([
-
1
,
1
,
28
,
28
],
'float32'
,
name
=
'image'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
if
fluid
.
in_dygraph_mode
():
feed_list
=
None
else
:
feed_list
=
[
x
.
forward
()
for
x
in
inputs
+
labels
]
self
.
train_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
train_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
feed_list
=
feed_list
)
self
.
val_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
val_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
feed_list
=
feed_list
)
self
.
test_dataloader
=
fluid
.
io
.
DataLoader
(
self
.
test_dataset
,
places
=
self
.
device
,
batch_size
=
64
,
feed_list
=
feed_list
)
model
=
LeNet
()
model
.
load_dict
(
self
.
lenet_dygraph
.
state_dict
())
model
.
prepare
(
metrics
=
Accuracy
(),
inputs
=
inputs
,
labels
=
labels
)
output
=
model
.
predict
(
self
.
test_dataloader
,
stack_outputs
=
True
)
np
.
testing
.
assert_equal
(
output
[
0
].
shape
[
0
],
len
(
self
.
test_dataset
))
acc
=
get_predict_accuracy
(
output
[
0
],
self
.
val_dataset
.
labels
)
np
.
testing
.
assert_allclose
(
acc
,
self
.
acc1
)
def
test_evaluate_dygraph
(
self
):
self
.
evaluate
(
True
)
def
test_evaluate_static
(
self
):
self
.
evaluate
(
False
)
def
test_predict_dygraph
(
self
):
self
.
predict
(
True
)
def
test_predict_static
(
self
):
self
.
predict
(
False
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录