Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
d8c94dba
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d8c94dba
编写于
11月 01, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
11月 01, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove weight_parameters from OpStats and graph_metrics.
Change: 137885496
上级
07ae2d1d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
5 addition
and
321 deletion
+5
-321
tensorflow/python/framework/ops.py
tensorflow/python/framework/ops.py
+5
-7
tensorflow/python/framework/ops_test.py
tensorflow/python/framework/ops_test.py
+0
-16
tensorflow/python/ops/math_ops.py
tensorflow/python/ops/math_ops.py
+0
-13
tensorflow/python/ops/nn_ops.py
tensorflow/python/ops/nn_ops.py
+0
-59
tensorflow/python/tools/BUILD
tensorflow/python/tools/BUILD
+0
-35
tensorflow/python/tools/graph_metrics.py
tensorflow/python/tools/graph_metrics.py
+0
-141
tensorflow/python/tools/graph_metrics_test.py
tensorflow/python/tools/graph_metrics_test.py
+0
-50
未找到文件。
tensorflow/python/framework/ops.py
浏览文件 @
d8c94dba
...
@@ -1881,23 +1881,21 @@ class RegisterStatistics(object):
...
@@ -1881,23 +1881,21 @@ class RegisterStatistics(object):
Well-known types of statistics include these so far:
Well-known types of statistics include these so far:
- weight_parameters: For operations like MatMul, Conv, and BiasAdd that take
learned weights as inputs, this statistic captures how many numerical values
are used. This is good to know because the weights take up most of the size
of a typical serialized graph on disk.
- flops: When running a graph, the bulk of the computation happens doing
- flops: When running a graph, the bulk of the computation happens doing
numerical calculations like matrix multiplications. This type allows a node
numerical calculations like matrix multiplications. This type allows a node
to return how many floating-point operations it takes to complete. The
to return how many floating-point operations it takes to complete. The
total number of FLOPs for a graph is a good guide to its expected latency.
total number of FLOPs for a graph is a good guide to its expected latency.
You can add your own statistics just by picking a new type string, registering
You can add your own statistics just by picking a new type string, registering
functions for the ops you care about, and then calling something like
functions for the ops you care about, and then calling get_stats_for_node_def.
python/tools/graph_metrics.py with the new type as an argument.
If a statistic for an op is registered multiple times, a KeyError will be
If a statistic for an op is registered multiple times, a KeyError will be
raised.
raised.
Since the statistics is counted on a per-op basis. It is not suitable for
model parameters (capacity), which is expected to be counted only once, even
if it is shared by multiple ops. (e.g. RNN)
For example, you can define a new metric called doohickey for a Foo operation
For example, you can define a new metric called doohickey for a Foo operation
by placing this in your code:
by placing this in your code:
...
...
tensorflow/python/framework/ops_test.py
浏览文件 @
d8c94dba
...
@@ -1427,12 +1427,6 @@ class AsGraphDefTest(test_util.TensorFlowTestCase):
...
@@ -1427,12 +1427,6 @@ class AsGraphDefTest(test_util.TensorFlowTestCase):
"""
,
gd
)
"""
,
gd
)
# NOTE(petewarden): Dummy stats registrations for ops used in the tests.
@
ops
.
RegisterStatistics
(
"a"
,
"weight_parameters"
)
def
_calc_a_weight_params
(
unused_graph
,
unused_node
):
return
ops
.
OpStats
(
"weight_parameters"
,
10
)
@
ops
.
RegisterStatistics
(
"a"
,
"flops"
)
@
ops
.
RegisterStatistics
(
"a"
,
"flops"
)
def
_calc_a_forward_flops
(
unused_graph
,
unused_node
):
def
_calc_a_forward_flops
(
unused_graph
,
unused_node
):
return
ops
.
OpStats
(
"flops"
,
20
)
return
ops
.
OpStats
(
"flops"
,
20
)
...
@@ -1443,8 +1437,6 @@ class StatisticsTest(test_util.TensorFlowTestCase):
...
@@ -1443,8 +1437,6 @@ class StatisticsTest(test_util.TensorFlowTestCase):
def
testRegisteredNode
(
self
):
def
testRegisteredNode
(
self
):
graph
=
ops
.
Graph
()
graph
=
ops
.
Graph
()
node
=
ops
.
_NodeDef
(
"a"
,
"an_a"
)
node
=
ops
.
_NodeDef
(
"a"
,
"an_a"
)
weight_params
=
ops
.
get_stats_for_node_def
(
graph
,
node
,
"weight_parameters"
)
self
.
assertEqual
(
10
,
weight_params
.
value
)
flops
=
ops
.
get_stats_for_node_def
(
graph
,
node
,
"flops"
)
flops
=
ops
.
get_stats_for_node_def
(
graph
,
node
,
"flops"
)
self
.
assertEqual
(
20
,
flops
.
value
)
self
.
assertEqual
(
20
,
flops
.
value
)
missing_stat
=
ops
.
get_stats_for_node_def
(
graph
,
node
,
"missing_stat"
)
missing_stat
=
ops
.
get_stats_for_node_def
(
graph
,
node
,
"missing_stat"
)
...
@@ -1457,19 +1449,11 @@ class StatisticsTest(test_util.TensorFlowTestCase):
...
@@ -1457,19 +1449,11 @@ class StatisticsTest(test_util.TensorFlowTestCase):
self
.
assertEqual
(
None
,
weight_params
.
value
)
self
.
assertEqual
(
None
,
weight_params
.
value
)
def
testAccumulateStatistics
(
self
):
def
testAccumulateStatistics
(
self
):
weight_params_total
=
ops
.
OpStats
(
"weight_parameters"
)
self
.
assertEqual
(
None
,
weight_params_total
.
value
)
flops_total
=
ops
.
OpStats
(
"flops"
)
flops_total
=
ops
.
OpStats
(
"flops"
)
self
.
assertEqual
(
None
,
flops_total
.
value
)
self
.
assertEqual
(
None
,
flops_total
.
value
)
first_weight_params
=
ops
.
OpStats
(
"weight_parameters"
,
100
)
weight_params_total
+=
first_weight_params
self
.
assertEqual
(
100
,
weight_params_total
.
value
)
second_flops
=
ops
.
OpStats
(
"flops"
,
3
)
second_flops
=
ops
.
OpStats
(
"flops"
,
3
)
flops_total
+=
second_flops
flops_total
+=
second_flops
self
.
assertEqual
(
3
,
flops_total
.
value
)
self
.
assertEqual
(
3
,
flops_total
.
value
)
second_weight_params
=
ops
.
OpStats
(
"weight_parameters"
,
200
)
weight_params_total
+=
second_weight_params
self
.
assertEqual
(
300
,
weight_params_total
.
value
)
class
ColocationGroupTest
(
test_util
.
TensorFlowTestCase
):
class
ColocationGroupTest
(
test_util
.
TensorFlowTestCase
):
...
...
tensorflow/python/ops/math_ops.py
浏览文件 @
d8c94dba
...
@@ -1534,19 +1534,6 @@ def _calc_mat_mul_flops(graph, node):
...
@@ -1534,19 +1534,6 @@ def _calc_mat_mul_flops(graph, node):
return
ops
.
OpStats
(
"flops"
,
(
k
*
output_count
*
2
))
return
ops
.
OpStats
(
"flops"
,
(
k
*
output_count
*
2
))
@
ops
.
RegisterStatistics
(
"MatMul"
,
"weight_parameters"
)
def
_calc_mat_mul_weight_parameters
(
graph
,
node
):
"""Calculates the on-disk size of the weights for MatMul."""
# We assume here that the weights are always in the second input to the op,
# which is generally true by convention for fully-connected layers, but not
# enforced or checked.
weights_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
1
])
weights_shape
.
assert_is_fully_defined
()
return
ops
.
OpStats
(
"weight_parameters"
,
(
int
(
weights_shape
[
1
])
*
int
(
weights_shape
[
0
])))
def
_as_indexed_slices
(
x
,
optimize
=
True
):
def
_as_indexed_slices
(
x
,
optimize
=
True
):
"""Convert 'x' to IndexedSlices.
"""Convert 'x' to IndexedSlices.
...
...
tensorflow/python/ops/nn_ops.py
浏览文件 @
d8c94dba
...
@@ -1809,24 +1809,6 @@ def _calc_conv_flops(graph, node):
...
@@ -1809,24 +1809,6 @@ def _calc_conv_flops(graph, node):
filter_width
*
2
))
filter_width
*
2
))
@
ops
.
RegisterStatistics
(
"Conv2D"
,
"weight_parameters"
)
def
_calc_conv_weight_params
(
graph
,
node
):
"""Calculates the on-disk size of the weights for Conv2D."""
input_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
0
])
input_shape
.
assert_is_fully_defined
()
filter_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
1
])
filter_shape
.
assert_is_fully_defined
()
output_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
name
)
output_shape
.
assert_is_fully_defined
()
filter_height
=
int
(
filter_shape
[
0
])
filter_width
=
int
(
filter_shape
[
1
])
filter_in_depth
=
int
(
filter_shape
[
2
])
filter_out_depth
=
int
(
filter_shape
[
3
])
return
ops
.
OpStats
(
"weight_parameters"
,
(
filter_height
*
filter_width
*
filter_in_depth
*
filter_out_depth
))
@
ops
.
RegisterStatistics
(
"DepthwiseConv2dNative"
,
"flops"
)
@
ops
.
RegisterStatistics
(
"DepthwiseConv2dNative"
,
"flops"
)
def
_calc_depthwise_conv_flops
(
graph
,
node
):
def
_calc_depthwise_conv_flops
(
graph
,
node
):
"""Calculates the compute resources needed for DepthwiseConv2dNative."""
"""Calculates the compute resources needed for DepthwiseConv2dNative."""
...
@@ -1843,25 +1825,6 @@ def _calc_depthwise_conv_flops(graph, node):
...
@@ -1843,25 +1825,6 @@ def _calc_depthwise_conv_flops(graph, node):
return
ops
.
OpStats
(
"flops"
,
(
output_count
*
filter_height
*
filter_width
*
2
))
return
ops
.
OpStats
(
"flops"
,
(
output_count
*
filter_height
*
filter_width
*
2
))
@
ops
.
RegisterStatistics
(
"DepthwiseConv2dNative"
,
"weight_parameters"
)
def
_calc_depthwise_conv_weight_params
(
graph
,
node
):
"""Calculates the on-disk size of the weights for DepthwiseConv2dNative."""
input_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
0
])
input_shape
.
assert_is_fully_defined
()
filter_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
1
])
filter_shape
.
assert_is_fully_defined
()
output_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
name
)
output_shape
.
assert_is_fully_defined
()
filter_height
=
int
(
filter_shape
[
0
])
filter_width
=
int
(
filter_shape
[
1
])
filter_in_depth
=
int
(
filter_shape
[
2
])
filter_channel_multiplier
=
int
(
filter_shape
[
3
])
return
ops
.
OpStats
(
"weight_parameters"
,
(
filter_height
*
filter_width
*
filter_in_depth
*
filter_channel_multiplier
))
ops
.
RegisterShape
(
"Conv3D"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"Conv3D"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"MaxPool3D"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"MaxPool3D"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"AvgPool3D"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"AvgPool3D"
)(
common_shapes
.
call_cpp_shape_fn
)
...
@@ -1882,15 +1845,6 @@ def _calc_bias_add_flops(graph, node):
...
@@ -1882,15 +1845,6 @@ def _calc_bias_add_flops(graph, node):
return
ops
.
OpStats
(
"flops"
,
input_count
)
return
ops
.
OpStats
(
"flops"
,
input_count
)
@
ops
.
RegisterStatistics
(
"BiasAdd"
,
"weight_parameters"
)
def
_calc_bias_add_weight_params
(
graph
,
node
):
"""Calculates the on-disk weight parameters for BiasAdd."""
bias_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
1
])
bias_shape
.
assert_is_fully_defined
()
bias_count
=
np
.
prod
(
bias_shape
.
as_list
())
return
ops
.
OpStats
(
"weight_parameters"
,
bias_count
)
def
xw_plus_b
(
x
,
weights
,
biases
,
name
=
None
):
# pylint: disable=invalid-name
def
xw_plus_b
(
x
,
weights
,
biases
,
name
=
None
):
# pylint: disable=invalid-name
"""Computes matmul(x, weights) + biases.
"""Computes matmul(x, weights) + biases.
...
@@ -2112,19 +2066,6 @@ def _calc_dilation2d_flops(graph, node):
...
@@ -2112,19 +2066,6 @@ def _calc_dilation2d_flops(graph, node):
return
ops
.
OpStats
(
"flops"
,
(
output_count
*
filter_height
*
filter_width
*
2
))
return
ops
.
OpStats
(
"flops"
,
(
output_count
*
filter_height
*
filter_width
*
2
))
@
ops
.
RegisterStatistics
(
"Dilation2D"
,
"weight_parameters"
)
def
_calc_dilation2d_weight_params
(
graph
,
node
):
"""Calculates the on-disk size of the weights for Dilation2D."""
filter_shape
=
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
1
])
filter_shape
.
assert_is_fully_defined
()
filter_height
=
int
(
filter_shape
[
0
])
filter_width
=
int
(
filter_shape
[
1
])
filter_depth
=
int
(
filter_shape
[
2
])
return
ops
.
OpStats
(
"weight_parameters"
,
(
filter_height
*
filter_width
*
filter_depth
))
def
erosion2d
(
value
,
kernel
,
strides
,
rates
,
padding
,
name
=
None
):
def
erosion2d
(
value
,
kernel
,
strides
,
rates
,
padding
,
name
=
None
):
"""Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
"""Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
...
...
tensorflow/python/tools/BUILD
浏览文件 @
d8c94dba
...
@@ -41,41 +41,6 @@ py_test(
...
@@ -41,41 +41,6 @@ py_test(
],
],
)
)
py_library
(
name
=
"graph_metrics_lib"
,
srcs
=
[
"graph_metrics.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//tensorflow:tensorflow_py"
,
],
)
py_binary
(
name
=
"graph_metrics"
,
srcs
=
[
"graph_metrics.py"
,
],
main
=
"graph_metrics.py"
,
srcs_version
=
"PY2AND3"
,
deps
=
[
"//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"graph_metrics_test"
,
size
=
"small"
,
srcs
=
[
"graph_metrics_test.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":graph_metrics_lib"
,
"//tensorflow/python:framework_test_lib"
,
"//tensorflow/python:platform_test"
,
],
)
py_binary
(
py_binary
(
name
=
"inspect_checkpoint"
,
name
=
"inspect_checkpoint"
,
srcs
=
[
srcs
=
[
...
...
tensorflow/python/tools/graph_metrics.py
已删除
100644 → 0
浏览文件 @
07ae2d1d
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gives estimates of computation and parameter sizes for a GraphDef.
This script takes a GraphDef representing a network, and produces rough
estimates of the number of floating-point operations needed to implement it and
how many parameters are stored. You need to pass in the input size, and the
results are only approximate, since it only calculates them for a subset of
common operations.
If you have downloaded the Inception graph for the label_image example, an
example of using this script would be:
bazel-bin/third_party/tensorflow/python/tools/graph_metrics
\
--graph tensorflow_inception_graph.pb
\
--statistics=weight_parameters,flops
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
locale
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
tensorflow.core.framework
import
graph_pb2
from
tensorflow.python.framework
import
ops
FLAGS
=
tf
.
flags
.
FLAGS
tf
.
flags
.
DEFINE_string
(
"graph"
,
""
,
"""TensorFlow 'GraphDef' file to load."""
)
tf
.
flags
.
DEFINE_bool
(
"input_binary"
,
True
,
"""Whether the input files are in binary format."""
)
tf
.
flags
.
DEFINE_string
(
"input_layer"
,
"Mul:0"
,
"""The name of the input node."""
)
tf
.
flags
.
DEFINE_integer
(
"batch_size"
,
1
,
"""The batch size to use for the calculations."""
)
tf
.
flags
.
DEFINE_string
(
"statistics"
,
"weight_parameters,flops"
,
"""Which statistic types to examine."""
)
tf
.
flags
.
DEFINE_string
(
"input_shape_override"
,
""
,
"""If this is set, the comma-separated values will be"""
""" used to set the shape of the input layer."""
)
tf
.
flags
.
DEFINE_boolean
(
"print_nodes"
,
False
,
"""Whether to show statistics for each op."""
)
def
print_stat
(
prefix
,
statistic_type
,
value
):
if
value
is
None
:
friendly_value
=
"None"
else
:
friendly_value
=
locale
.
format
(
"%d"
,
value
,
grouping
=
True
)
print
(
"%s%s=%s"
%
(
prefix
,
statistic_type
,
friendly_value
))
def
main
(
unused_args
):
if
not
tf
.
gfile
.
Exists
(
FLAGS
.
graph
):
print
(
"Input graph file '"
+
FLAGS
.
graph
+
"' does not exist!"
)
return
-
1
graph_def
=
graph_pb2
.
GraphDef
()
with
open
(
FLAGS
.
graph
,
"rb"
)
as
f
:
if
FLAGS
.
input_binary
:
graph_def
.
ParseFromString
(
f
.
read
())
else
:
text_format
.
Merge
(
f
.
read
(),
graph_def
)
statistic_types
=
FLAGS
.
statistics
.
split
(
","
)
if
FLAGS
.
input_shape_override
:
input_shape_override
=
map
(
int
,
FLAGS
.
input_shape_override
.
split
(
","
))
else
:
input_shape_override
=
None
total_stats
,
node_stats
=
calculate_graph_metrics
(
graph_def
,
statistic_types
,
FLAGS
.
input_layer
,
input_shape_override
,
FLAGS
.
batch_size
)
if
FLAGS
.
print_nodes
:
for
node
in
graph_def
.
node
:
for
statistic_type
in
statistic_types
:
current_stats
=
node_stats
[
statistic_type
][
node
.
name
]
print_stat
(
node
.
name
+
"("
+
node
.
op
+
"): "
,
statistic_type
,
current_stats
.
value
)
for
statistic_type
in
statistic_types
:
value
=
total_stats
[
statistic_type
].
value
print_stat
(
"Total: "
,
statistic_type
,
value
)
def
calculate_graph_metrics
(
graph_def
,
statistic_types
,
input_layer
,
input_shape_override
,
batch_size
):
"""Looks at the performance statistics of all nodes in the graph."""
_
=
tf
.
import_graph_def
(
graph_def
,
name
=
""
)
total_stats
=
{}
node_stats
=
{}
for
statistic_type
in
statistic_types
:
total_stats
[
statistic_type
]
=
ops
.
OpStats
(
statistic_type
)
node_stats
[
statistic_type
]
=
{}
# Make sure we get pretty-printed numbers with separators.
locale
.
setlocale
(
locale
.
LC_ALL
,
""
)
with
tf
.
Session
()
as
sess
:
input_tensor
=
sess
.
graph
.
get_tensor_by_name
(
input_layer
)
input_shape_tensor
=
input_tensor
.
get_shape
()
if
input_shape_tensor
:
input_shape
=
input_shape_tensor
.
as_list
()
else
:
input_shape
=
None
if
input_shape_override
:
input_shape
=
input_shape_override
if
input_shape
is
None
:
raise
ValueError
(
"""No input shape was provided on the command line,"""
""" and the input op itself had no default shape, so"""
""" shape inference couldn't be performed. This is"""
""" required for metrics calculations."""
)
input_shape
[
0
]
=
batch_size
input_tensor
.
set_shape
(
input_shape
)
for
node
in
graph_def
.
node
:
# Ensure that the updated input shape has been fully-propagated before we
# ask for the statistics, since they may depend on the output size.
op
=
sess
.
graph
.
get_operation_by_name
(
node
.
name
)
ops
.
set_shapes_for_outputs
(
op
)
for
statistic_type
in
statistic_types
:
current_stats
=
ops
.
get_stats_for_node_def
(
sess
.
graph
,
node
,
statistic_type
)
node_stats
[
statistic_type
][
node
.
name
]
=
current_stats
total_stats
[
statistic_type
]
+=
current_stats
return
total_stats
,
node_stats
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
tensorflow/python/tools/graph_metrics_test.py
已删除
100644 → 0
浏览文件 @
07ae2d1d
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests the graph metrics tool."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
tensorflow.python.tools
import
graph_metrics
class
GraphMetricsTest
(
tf
.
test
.
TestCase
):
def
testGraphMetrics
(
self
):
with
tf
.
Graph
().
as_default
():
input_node
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
10
,
20
],
name
=
"input_node"
)
weights_node
=
tf
.
constant
(
0.0
,
dtype
=
tf
.
float32
,
shape
=
[
20
,
5
],
name
=
"weights_node"
)
tf
.
matmul
(
input_node
,
weights_node
,
name
=
"matmul_node"
)
sess
=
tf
.
Session
()
graph_def
=
sess
.
graph
.
as_graph_def
()
statistic_types
=
[
"weight_parameters"
,
"flops"
]
total_stats
,
node_stats
=
graph_metrics
.
calculate_graph_metrics
(
graph_def
,
statistic_types
,
"input_node:0"
,
None
,
10
)
expected
=
{
"weight_parameters"
:
100
,
"flops"
:
2000
}
for
statistic_type
in
statistic_types
:
current_stats
=
node_stats
[
statistic_type
][
"matmul_node"
]
self
.
assertEqual
(
expected
[
statistic_type
],
current_stats
.
value
)
for
statistic_type
in
statistic_types
:
current_stats
=
total_stats
[
statistic_type
]
self
.
assertEqual
(
expected
[
statistic_type
],
current_stats
.
value
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录