Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5f239a19
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5f239a19
编写于
9月 04, 2020
作者:
L
LielinJiang
提交者:
GitHub
9月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherrypick-2.0-beta] Add summary (#26990)
* add summary api
上级
3231ce9f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
294 addition
and
3 deletion
+294
-3
python/paddle/__init__.py
python/paddle/__init__.py
+1
-0
python/paddle/hapi/__init__.py
python/paddle/hapi/__init__.py
+3
-2
python/paddle/hapi/model.py
python/paddle/hapi/model.py
+44
-1
python/paddle/hapi/model_summary.py
python/paddle/hapi/model_summary.py
+225
-0
python/paddle/tests/test_model.py
python/paddle/tests/test_model.py
+20
-0
python/paddle/utils/__init__.py
python/paddle/utils/__init__.py
+1
-0
未找到文件。
python/paddle/__init__.py
浏览文件 @
5f239a19
...
@@ -269,5 +269,6 @@ from . import static
...
@@ -269,5 +269,6 @@ from . import static
# high-level api
# high-level api
from
.hapi
import
Model
from
.hapi
import
Model
from
.hapi
import
callbacks
from
.hapi
import
callbacks
from
.hapi
import
summary
import
paddle.text
import
paddle.text
import
paddle.vision
import
paddle.vision
python/paddle/hapi/__init__.py
浏览文件 @
5f239a19
...
@@ -14,14 +14,15 @@
...
@@ -14,14 +14,15 @@
from
.
import
logger
from
.
import
logger
from
.
import
callbacks
from
.
import
callbacks
from
.
import
model_summary
from
.
import
model
from
.
import
model
from
.model
import
*
from
.model
import
*
from
.model_summary
import
summary
from
.dygraph_layer_patch
import
monkey_patch_layer
from
.dygraph_layer_patch
import
monkey_patch_layer
logger
.
setup_logger
()
logger
.
setup_logger
()
__all__
=
[
'callbacks'
]
+
model
.
__all__
__all__
=
[
'callbacks'
]
+
model
.
__all__
+
[
'summary'
]
monkey_patch_layer
()
monkey_patch_layer
()
python/paddle/hapi/model.py
浏览文件 @
5f239a19
...
@@ -47,10 +47,10 @@ from paddle.io import DataLoader, Dataset, DistributedBatchSampler
...
@@ -47,10 +47,10 @@ from paddle.io import DataLoader, Dataset, DistributedBatchSampler
from
paddle.fluid.executor
import
scope_guard
,
Executor
from
paddle.fluid.executor
import
scope_guard
,
Executor
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.metric
import
Metric
from
paddle.metric
import
Metric
from
paddle.static
import
InputSpec
as
Input
from
paddle.static
import
InputSpec
as
Input
from
.callbacks
import
config_callbacks
from
.callbacks
import
config_callbacks
from
.model_summary
import
summary
__all__
=
[
'Model'
,
]
__all__
=
[
'Model'
,
]
...
@@ -1828,6 +1828,49 @@ class Model(object):
...
@@ -1828,6 +1828,49 @@ class Model(object):
return
logs
,
outputs
return
logs
,
outputs
return
logs
return
logs
def
summary
(
self
,
input_size
=
None
,
batch_size
=
None
,
dtype
=
None
):
"""Prints a string summary of the network.
Args:
input_size (tuple|InputSpec|list[tuple|InputSpec], optional): size of input tensor.
if not set, input_size will get from ``self._inputs`` if network only have
one input, input_size can be tuple or InputSpec. if model have multiple
input, input_size must be a list which contain every input's shape.
Default: None.
batch_size (int, optional): batch size of input tensor, Default: None.
dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
Returns:
Dict: a summary of the network including total params and total trainable params.
Examples:
.. code-block:: python
import paddle
from paddle.static import InputSpec
dynamic = True
device = paddle.set_device('cpu')
paddle.disable_static(device) if dynamic else None
input = InputSpec([None, 1, 28, 28], 'float32', 'image')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(paddle.vision.LeNet(classifier_activation=None),
input, label)
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
model.prepare(
optim,
paddle.nn.CrossEntropyLoss())
params_info = model.summary()
print(params_info)
"""
return
summary
(
self
.
network
,
self
.
_inputs
,
batch_size
,
dtype
)
def
_verify_spec
(
self
,
specs
,
is_input
=
False
):
def
_verify_spec
(
self
,
specs
,
is_input
=
False
):
out_specs
=
[]
out_specs
=
[]
...
...
python/paddle/hapi/model_summary.py
0 → 100644
浏览文件 @
5f239a19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
paddle.static
import
InputSpec
from
collections
import
OrderedDict
__all__
=
[
'summary'
]
def
summary
(
net
,
input_size
,
batch_size
=
None
,
dtypes
=
None
):
"""Prints a string summary of the network.
Args:
net (Layer): the network which must be a subinstance of Layer.
input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only
have one input, input_size can be tuple or InputSpec. if model
have multiple input, input_size must be a list which contain
every input's shape.
batch_size (int, optional): batch size of input tensor, Default: None.
dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
Returns:
Dict: a summary of the network including total params and total trainable params.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
class LeNet(nn.Layer):
def __init__(self, num_classes=10):
super(LeNet, self).__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(
1, 6, 3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(
6, 16, 5, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d(2, 2))
if num_classes > 0:
self.fc = nn.Sequential(
nn.Linear(400, 120),
nn.Linear(120, 84),
nn.Linear(
84, 10))
def forward(self, inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
lenet = LeNet()
params_info = paddle.summary(lenet, (1, 28, 28))
print(params_info)
"""
if
isinstance
(
input_size
,
InputSpec
):
_input_size
=
tuple
(
input_size
.
shape
[
1
:])
if
batch_size
is
None
:
batch_size
=
input_size
.
shape
[
0
]
elif
isinstance
(
input_size
,
list
):
_input_size
=
[]
for
item
in
input_size
:
assert
isinstance
(
item
,
(
list
,
InputSpec
)),
'When input_size is list,
\
expect item in input_size is a tuple or InputSpec, but got {}'
.
format
(
type
(
item
))
if
isinstance
(
item
,
InputSpec
):
_input_size
.
append
(
tuple
(
item
.
shape
[
1
:]))
if
batch_size
is
None
:
batch_size
=
item
.
shape
[
0
]
else
:
_input_size
.
append
(
item
)
else
:
_input_size
=
input_size
if
batch_size
is
None
:
batch_size
=
-
1
result
,
params_info
=
summary_string
(
net
,
_input_size
,
batch_size
,
dtypes
)
print
(
result
)
return
params_info
def
summary_string
(
model
,
input_size
,
batch_size
=-
1
,
dtypes
=
None
):
if
dtypes
==
None
:
dtypes
=
[
'float32'
]
*
len
(
input_size
)
summary_str
=
''
depth
=
len
(
list
(
model
.
sublayers
()))
def
register_hook
(
module
):
def
hook
(
module
,
input
,
output
):
class_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
].
split
(
"'"
)[
0
]
try
:
module_idx
=
int
(
module
.
_full_name
.
split
(
'_'
)[
-
1
])
except
:
module_idx
=
len
(
summary
)
m_key
=
"%s-%i"
%
(
class_name
,
module_idx
+
1
)
summary
[
m_key
]
=
OrderedDict
()
summary
[
m_key
][
"input_shape"
]
=
list
(
input
[
0
].
shape
)
summary
[
m_key
][
"input_shape"
][
0
]
=
batch_size
if
isinstance
(
output
,
(
list
,
tuple
)):
summary
[
m_key
][
"output_shape"
]
=
[[
-
1
]
+
list
(
o
.
shape
)[
1
:]
for
o
in
output
]
else
:
summary
[
m_key
][
"output_shape"
]
=
list
(
output
.
shape
)
summary
[
m_key
][
"output_shape"
][
0
]
=
batch_size
params
=
0
if
hasattr
(
module
,
"weight"
):
params
+=
np
.
prod
(
module
.
weight
.
shape
)
summary
[
m_key
][
"trainable"
]
=
module
.
weight
.
trainable
or
(
not
module
.
weight
.
stop_gradient
)
if
hasattr
(
module
,
"bias"
):
params
+=
np
.
prod
(
module
.
bias
.
shape
)
summary
[
m_key
][
"nb_params"
]
=
params
if
(
not
isinstance
(
module
,
nn
.
Sequential
)
and
not
isinstance
(
module
,
nn
.
LayerList
)
and
(
not
(
module
==
model
)
or
depth
<
1
)):
hooks
.
append
(
module
.
register_forward_post_hook
(
hook
))
if
isinstance
(
input_size
,
tuple
):
input_size
=
[
input_size
]
x
=
[
paddle
.
rand
(
[
2
]
+
list
(
in_size
),
dtype
=
dtype
)
for
in_size
,
dtype
in
zip
(
input_size
,
dtypes
)
]
# create properties
summary
=
OrderedDict
()
hooks
=
[]
# register hook
model
.
apply
(
register_hook
)
# make a forward pass
model
(
*
x
)
# remove these hooks
for
h
in
hooks
:
h
.
remove
()
table_width
=
80
summary_str
+=
"-"
*
table_width
+
"
\n
"
line_new
=
"{:>15} {:>20} {:>20} {:>15}"
.
format
(
"Layer (type)"
,
"Input Shape"
,
"Output Shape"
,
"Param #"
)
summary_str
+=
line_new
+
"
\n
"
summary_str
+=
"="
*
table_width
+
"
\n
"
total_params
=
0
total_output
=
0
trainable_params
=
0
for
layer
in
summary
:
# input_shape, output_shape, trainable, nb_params
line_new
=
"{:>15} {:>20} {:>20} {:>15}"
.
format
(
layer
,
str
(
summary
[
layer
][
"input_shape"
]),
str
(
summary
[
layer
][
"output_shape"
]),
"{0:,}"
.
format
(
summary
[
layer
][
"nb_params"
]),
)
total_params
+=
summary
[
layer
][
"nb_params"
]
total_output
+=
np
.
prod
(
summary
[
layer
][
"output_shape"
])
if
"trainable"
in
summary
[
layer
]:
if
summary
[
layer
][
"trainable"
]
==
True
:
trainable_params
+=
summary
[
layer
][
"nb_params"
]
summary_str
+=
line_new
+
"
\n
"
# assume 4 bytes/number (float on cuda).
total_input_size
=
abs
(
np
.
prod
(
sum
(
input_size
,
()))
*
batch_size
*
4.
/
(
1024
**
2.
))
total_output_size
=
abs
(
2.
*
total_output
*
4.
/
(
1024
**
2.
))
# x2 for gradients
total_params_size
=
abs
(
total_params
*
4.
/
(
1024
**
2.
))
total_size
=
total_params_size
+
total_output_size
+
total_input_size
summary_str
+=
"="
*
table_width
+
"
\n
"
summary_str
+=
"Total params: {0:,}"
.
format
(
total_params
)
+
"
\n
"
summary_str
+=
"Trainable params: {0:,}"
.
format
(
trainable_params
)
+
"
\n
"
summary_str
+=
"Non-trainable params: {0:,}"
.
format
(
total_params
-
trainable_params
)
+
"
\n
"
summary_str
+=
"-"
*
table_width
+
"
\n
"
summary_str
+=
"Input size (MB): %0.2f"
%
total_input_size
+
"
\n
"
summary_str
+=
"Forward/backward pass size (MB): %0.2f"
%
total_output_size
+
"
\n
"
summary_str
+=
"Params size (MB): %0.2f"
%
total_params_size
+
"
\n
"
summary_str
+=
"Estimated Total Size (MB): %0.2f"
%
total_size
+
"
\n
"
summary_str
+=
"-"
*
table_width
+
"
\n
"
# return summary
return
summary_str
,
{
'total_params'
:
total_params
,
'trainable_params'
:
trainable_params
}
python/paddle/tests/test_model.py
浏览文件 @
5f239a19
...
@@ -499,6 +499,26 @@ class TestModelFunction(unittest.TestCase):
...
@@ -499,6 +499,26 @@ class TestModelFunction(unittest.TestCase):
self
.
assertTrue
(
params
[
0
].
shape
[
1
]
==
10
)
self
.
assertTrue
(
params
[
0
].
shape
[
1
]
==
10
)
fluid
.
disable_dygraph
()
if
dynamic
else
None
fluid
.
disable_dygraph
()
if
dynamic
else
None
def
test_summary
(
self
):
def
_get_param_from_state_dict
(
state_dict
):
params
=
0
for
k
,
v
in
state_dict
.
items
():
params
+=
np
.
prod
(
v
.
numpy
().
shape
)
return
params
for
dynamic
in
[
True
,
False
]:
device
=
paddle
.
set_device
(
'cpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
net
=
MyModel
()
inputs
=
[
InputSpec
([
None
,
20
],
'float32'
,
'x'
)]
model
=
Model
(
net
,
inputs
)
model
.
prepare
()
params_info
=
model
.
summary
()
gt_params
=
_get_param_from_state_dict
(
net
.
state_dict
())
np
.
testing
.
assert_allclose
(
params_info
[
'total_params'
],
gt_params
)
print
(
params_info
)
def
test_export_deploy_model
(
self
):
def
test_export_deploy_model
(
self
):
for
dynamic
in
[
True
,
False
]:
for
dynamic
in
[
True
,
False
]:
fluid
.
enable_dygraph
()
if
dynamic
else
None
fluid
.
enable_dygraph
()
if
dynamic
else
None
...
...
python/paddle/utils/__init__.py
浏览文件 @
5f239a19
...
@@ -17,6 +17,7 @@ from .profiler import ProfilerOptions
...
@@ -17,6 +17,7 @@ from .profiler import ProfilerOptions
from
.profiler
import
Profiler
from
.profiler
import
Profiler
from
.profiler
import
get_profiler
from
.profiler
import
get_profiler
from
.deprecated
import
deprecated
from
.deprecated
import
deprecated
from
.
import
download
from
.
import
download
__all__
=
[
'dump_config'
,
'Ploter'
,
'deprecated'
,
'download'
]
__all__
=
[
'dump_config'
,
'Ploter'
,
'deprecated'
,
'download'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录