Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a71ea009
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a71ea009
编写于
12月 01, 2020
作者:
Y
yukavio
提交者:
GitHub
12月 01, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add unit test (#29228)
上级
46b73e6c
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
34 addition
and
12 deletion
+34
-12
python/paddle/__init__.py
python/paddle/__init__.py
+1
-0
python/paddle/hapi/__init__.py
python/paddle/hapi/__init__.py
+2
-0
python/paddle/hapi/dynamic_flops.py
python/paddle/hapi/dynamic_flops.py
+8
-2
python/paddle/hapi/static_flops.py
python/paddle/hapi/static_flops.py
+3
-10
python/paddle/tests/test_model.py
python/paddle/tests/test_model.py
+20
-0
未找到文件。
python/paddle/__init__.py
浏览文件 @
a71ea009
...
...
@@ -273,6 +273,7 @@ from . import onnx
from
.hapi
import
Model
from
.hapi
import
callbacks
from
.hapi
import
summary
from
.hapi
import
flops
import
paddle.text
import
paddle.vision
...
...
python/paddle/hapi/__init__.py
浏览文件 @
a71ea009
...
...
@@ -19,7 +19,9 @@ from . import model_summary
from
.
import
model
from
.model
import
*
from
.model_summary
import
summary
from
.dynamic_flops
import
flops
logger
.
setup_logger
()
__all__
=
[
'callbacks'
]
+
model
.
__all__
+
[
'summary'
]
__all__
=
model
.
__all__
+
[
'flops'
]
python/paddle/hapi/dynamic_flops.py
浏览文件 @
a71ea009
...
...
@@ -16,7 +16,7 @@ import paddle
import
warnings
import
paddle.nn
as
nn
import
numpy
as
np
from
.static_flops
import
static_flops
,
_verify_dependent_package
from
.static_flops
import
static_flops
__all__
=
[
'flops'
]
...
...
@@ -264,7 +264,13 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
model
.
train
()
for
handler
in
handler_collection
:
handler
.
remove
()
_verify_dependent_package
()
try
:
from
prettytable
import
PrettyTable
except
ImportError
:
raise
ImportError
(
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
)
table
=
PrettyTable
(
[
"Layer Name"
,
"Input Shape"
,
"Output Shape"
,
"Params"
,
"Flops"
])
...
...
python/paddle/hapi/static_flops.py
浏览文件 @
a71ea009
...
...
@@ -166,22 +166,15 @@ def count_element_op(op):
return
total_ops
def
_verify_dependent_package
():
"""
Verify whether `prettytable` is installed.
"""
def
_graph_flops
(
graph
,
detail
=
False
):
assert
isinstance
(
graph
,
GraphWrapper
)
flops
=
0
try
:
from
prettytable
import
PrettyTable
except
ImportError
:
raise
ImportError
(
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
)
def
_graph_flops
(
graph
,
detail
=
False
):
assert
isinstance
(
graph
,
GraphWrapper
)
flops
=
0
_verify_dependent_package
()
table
=
PrettyTable
([
"OP Type"
,
'Param name'
,
"Flops"
])
for
op
in
graph
.
ops
():
param_name
=
''
...
...
python/paddle/tests/test_model.py
浏览文件 @
a71ea009
...
...
@@ -33,6 +33,8 @@ from paddle.nn.layer.loss import CrossEntropyLoss
from
paddle.metric
import
Accuracy
from
paddle.vision.datasets
import
MNIST
from
paddle.vision.models
import
LeNet
import
paddle.vision.models
as
models
import
paddle.fluid.dygraph.jit
as
jit
from
paddle.io
import
DistributedBatchSampler
,
Dataset
from
paddle.hapi.model
import
prepare_distributed_context
from
paddle.fluid.dygraph.jit
import
declarative
...
...
@@ -564,6 +566,24 @@ class TestModelFunction(unittest.TestCase):
nlp_net
=
paddle
.
nn
.
GRU
(
input_size
=
2
,
hidden_size
=
3
,
num_layers
=
3
)
paddle
.
summary
(
nlp_net
,
(
1
,
1
,
2
))
def
test_static_flops
(
self
):
paddle
.
disable_static
()
net
=
models
.
__dict__
[
'mobilenet_v2'
](
pretrained
=
False
)
inputs
=
paddle
.
randn
([
1
,
3
,
224
,
224
])
static_program
=
jit
.
_trace
(
net
,
inputs
=
[
inputs
])[
1
]
paddle
.
flops
(
static_program
,
[
1
,
3
,
224
,
224
],
print_detail
=
True
)
def
test_dynamic_flops
(
self
):
net
=
models
.
__dict__
[
'mobilenet_v2'
](
pretrained
=
False
)
def
customize_dropout
(
m
,
x
,
y
):
m
.
total_ops
+=
0
paddle
.
flops
(
net
,
[
1
,
3
,
224
,
224
],
custom_ops
=
{
paddle
.
nn
.
Dropout
:
customize_dropout
},
print_detail
=
True
)
def
test_export_deploy_model
(
self
):
self
.
set_seed
()
np
.
random
.
seed
(
201
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录