提交 d8c94dba 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Remove weight_parameters from OpStats and graph_metrics.

Change: 137885496
上级 07ae2d1d
......@@ -1881,23 +1881,21 @@ class RegisterStatistics(object):
Well-known types of statistics include these so far:
- weight_parameters: For operations like MatMul, Conv, and BiasAdd that take
learned weights as inputs, this statistic captures how many numerical values
are used. This is good to know because the weights take up most of the size
of a typical serialized graph on disk.
- flops: When running a graph, the bulk of the computation happens doing
numerical calculations like matrix multiplications. This type allows a node
to return how many floating-point operations it takes to complete. The
total number of FLOPs for a graph is a good guide to its expected latency.
You can add your own statistics just by picking a new type string, registering
functions for the ops you care about, and then calling something like
python/tools/graph_metrics.py with the new type as an argument.
functions for the ops you care about, and then calling get_stats_for_node_def.
If a statistic for an op is registered multiple times, a KeyError will be
raised.
Since the statistics is counted on a per-op basis. It is not suitable for
model parameters (capacity), which is expected to be counted only once, even
if it is shared by multiple ops. (e.g. RNN)
For example, you can define a new metric called doohickey for a Foo operation
by placing this in your code:
......
......@@ -1427,12 +1427,6 @@ class AsGraphDefTest(test_util.TensorFlowTestCase):
""", gd)
# NOTE(petewarden): Dummy stats registrations for ops used in the tests.
@ops.RegisterStatistics("a", "weight_parameters")
def _calc_a_weight_params(unused_graph, unused_node):
return ops.OpStats("weight_parameters", 10)
@ops.RegisterStatistics("a", "flops")
def _calc_a_forward_flops(unused_graph, unused_node):
return ops.OpStats("flops", 20)
......@@ -1443,8 +1437,6 @@ class StatisticsTest(test_util.TensorFlowTestCase):
def testRegisteredNode(self):
graph = ops.Graph()
node = ops._NodeDef("a", "an_a")
weight_params = ops.get_stats_for_node_def(graph, node, "weight_parameters")
self.assertEqual(10, weight_params.value)
flops = ops.get_stats_for_node_def(graph, node, "flops")
self.assertEqual(20, flops.value)
missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
......@@ -1457,19 +1449,11 @@ class StatisticsTest(test_util.TensorFlowTestCase):
self.assertEqual(None, weight_params.value)
def testAccumulateStatistics(self):
weight_params_total = ops.OpStats("weight_parameters")
self.assertEqual(None, weight_params_total.value)
flops_total = ops.OpStats("flops")
self.assertEqual(None, flops_total.value)
first_weight_params = ops.OpStats("weight_parameters", 100)
weight_params_total += first_weight_params
self.assertEqual(100, weight_params_total.value)
second_flops = ops.OpStats("flops", 3)
flops_total += second_flops
self.assertEqual(3, flops_total.value)
second_weight_params = ops.OpStats("weight_parameters", 200)
weight_params_total += second_weight_params
self.assertEqual(300, weight_params_total.value)
class ColocationGroupTest(test_util.TensorFlowTestCase):
......
......@@ -1534,19 +1534,6 @@ def _calc_mat_mul_flops(graph, node):
return ops.OpStats("flops", (k * output_count * 2))
@ops.RegisterStatistics("MatMul", "weight_parameters")
def _calc_mat_mul_weight_parameters(graph, node):
"""Calculates the on-disk size of the weights for MatMul."""
# We assume here that the weights are always in the second input to the op,
# which is generally true by convention for fully-connected layers, but not
# enforced or checked.
weights_shape = graph_util.tensor_shape_from_node_def_name(graph,
node.input[1])
weights_shape.assert_is_fully_defined()
return ops.OpStats("weight_parameters",
(int(weights_shape[1]) * int(weights_shape[0])))
def _as_indexed_slices(x, optimize=True):
"""Convert 'x' to IndexedSlices.
......
......@@ -1809,24 +1809,6 @@ def _calc_conv_flops(graph, node):
filter_width * 2))
@ops.RegisterStatistics("Conv2D", "weight_parameters")
def _calc_conv_weight_params(graph, node):
"""Calculates the on-disk size of the weights for Conv2D."""
input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
input_shape.assert_is_fully_defined()
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
node.input[1])
filter_shape.assert_is_fully_defined()
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
filter_in_depth = int(filter_shape[2])
filter_out_depth = int(filter_shape[3])
return ops.OpStats("weight_parameters", (filter_height * filter_width *
filter_in_depth * filter_out_depth))
@ops.RegisterStatistics("DepthwiseConv2dNative", "flops")
def _calc_depthwise_conv_flops(graph, node):
"""Calculates the compute resources needed for DepthwiseConv2dNative."""
......@@ -1843,25 +1825,6 @@ def _calc_depthwise_conv_flops(graph, node):
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
@ops.RegisterStatistics("DepthwiseConv2dNative", "weight_parameters")
def _calc_depthwise_conv_weight_params(graph, node):
"""Calculates the on-disk size of the weights for DepthwiseConv2dNative."""
input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
input_shape.assert_is_fully_defined()
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
node.input[1])
filter_shape.assert_is_fully_defined()
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
filter_in_depth = int(filter_shape[2])
filter_channel_multiplier = int(filter_shape[3])
return ops.OpStats("weight_parameters", (filter_height * filter_width *
filter_in_depth *
filter_channel_multiplier))
ops.RegisterShape("Conv3D")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("MaxPool3D")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("AvgPool3D")(common_shapes.call_cpp_shape_fn)
......@@ -1882,15 +1845,6 @@ def _calc_bias_add_flops(graph, node):
return ops.OpStats("flops", input_count)
@ops.RegisterStatistics("BiasAdd", "weight_parameters")
def _calc_bias_add_weight_params(graph, node):
"""Calculates the on-disk weight parameters for BiasAdd."""
bias_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[1])
bias_shape.assert_is_fully_defined()
bias_count = np.prod(bias_shape.as_list())
return ops.OpStats("weight_parameters", bias_count)
def xw_plus_b(x, weights, biases, name=None): # pylint: disable=invalid-name
"""Computes matmul(x, weights) + biases.
......@@ -2112,19 +2066,6 @@ def _calc_dilation2d_flops(graph, node):
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
@ops.RegisterStatistics("Dilation2D", "weight_parameters")
def _calc_dilation2d_weight_params(graph, node):
"""Calculates the on-disk size of the weights for Dilation2D."""
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
node.input[1])
filter_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
filter_depth = int(filter_shape[2])
return ops.OpStats("weight_parameters",
(filter_height * filter_width * filter_depth))
def erosion2d(value, kernel, strides, rates, padding, name=None):
"""Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
......
......@@ -41,41 +41,6 @@ py_test(
],
)
py_library(
name = "graph_metrics_lib",
srcs = ["graph_metrics.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
py_binary(
name = "graph_metrics",
srcs = [
"graph_metrics.py",
],
main = "graph_metrics.py",
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "graph_metrics_test",
size = "small",
srcs = [
"graph_metrics_test.py",
],
srcs_version = "PY2AND3",
deps = [
":graph_metrics_lib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
py_binary(
name = "inspect_checkpoint",
srcs = [
......
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gives estimates of computation and parameter sizes for a GraphDef.
This script takes a GraphDef representing a network, and produces rough
estimates of the number of floating-point operations needed to implement it and
how many parameters are stored. You need to pass in the input size, and the
results are only approximate, since it only calculates them for a subset of
common operations.
If you have downloaded the Inception graph for the label_image example, an
example of using this script would be:
bazel-bin/third_party/tensorflow/python/tools/graph_metrics \
--graph tensorflow_inception_graph.pb \
--statistics=weight_parameters,flops
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import locale
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import ops
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""")
tf.flags.DEFINE_bool("input_binary", True,
"""Whether the input files are in binary format.""")
tf.flags.DEFINE_string("input_layer", "Mul:0",
"""The name of the input node.""")
tf.flags.DEFINE_integer("batch_size", 1,
"""The batch size to use for the calculations.""")
tf.flags.DEFINE_string("statistics", "weight_parameters,flops",
"""Which statistic types to examine.""")
tf.flags.DEFINE_string("input_shape_override", "",
"""If this is set, the comma-separated values will be"""
""" used to set the shape of the input layer.""")
tf.flags.DEFINE_boolean("print_nodes", False,
"""Whether to show statistics for each op.""")
def print_stat(prefix, statistic_type, value):
if value is None:
friendly_value = "None"
else:
friendly_value = locale.format("%d", value, grouping=True)
print("%s%s=%s" % (prefix, statistic_type, friendly_value))
def main(unused_args):
if not tf.gfile.Exists(FLAGS.graph):
print("Input graph file '" + FLAGS.graph + "' does not exist!")
return -1
graph_def = graph_pb2.GraphDef()
with open(FLAGS.graph, "rb") as f:
if FLAGS.input_binary:
graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), graph_def)
statistic_types = FLAGS.statistics.split(",")
if FLAGS.input_shape_override:
input_shape_override = map(int, FLAGS.input_shape_override.split(","))
else:
input_shape_override = None
total_stats, node_stats = calculate_graph_metrics(
graph_def, statistic_types, FLAGS.input_layer, input_shape_override,
FLAGS.batch_size)
if FLAGS.print_nodes:
for node in graph_def.node:
for statistic_type in statistic_types:
current_stats = node_stats[statistic_type][node.name]
print_stat(node.name + "(" + node.op + "): ", statistic_type,
current_stats.value)
for statistic_type in statistic_types:
value = total_stats[statistic_type].value
print_stat("Total: ", statistic_type, value)
def calculate_graph_metrics(graph_def, statistic_types, input_layer,
input_shape_override, batch_size):
"""Looks at the performance statistics of all nodes in the graph."""
_ = tf.import_graph_def(graph_def, name="")
total_stats = {}
node_stats = {}
for statistic_type in statistic_types:
total_stats[statistic_type] = ops.OpStats(statistic_type)
node_stats[statistic_type] = {}
# Make sure we get pretty-printed numbers with separators.
locale.setlocale(locale.LC_ALL, "")
with tf.Session() as sess:
input_tensor = sess.graph.get_tensor_by_name(input_layer)
input_shape_tensor = input_tensor.get_shape()
if input_shape_tensor:
input_shape = input_shape_tensor.as_list()
else:
input_shape = None
if input_shape_override:
input_shape = input_shape_override
if input_shape is None:
raise ValueError("""No input shape was provided on the command line,"""
""" and the input op itself had no default shape, so"""
""" shape inference couldn't be performed. This is"""
""" required for metrics calculations.""")
input_shape[0] = batch_size
input_tensor.set_shape(input_shape)
for node in graph_def.node:
# Ensure that the updated input shape has been fully-propagated before we
# ask for the statistics, since they may depend on the output size.
op = sess.graph.get_operation_by_name(node.name)
ops.set_shapes_for_outputs(op)
for statistic_type in statistic_types:
current_stats = ops.get_stats_for_node_def(sess.graph, node,
statistic_type)
node_stats[statistic_type][node.name] = current_stats
total_stats[statistic_type] += current_stats
return total_stats, node_stats
if __name__ == "__main__":
tf.app.run()
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests the graph metrics tool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.tools import graph_metrics
class GraphMetricsTest(tf.test.TestCase):
def testGraphMetrics(self):
with tf.Graph().as_default():
input_node = tf.placeholder(tf.float32, shape=[10, 20], name="input_node")
weights_node = tf.constant(0.0,
dtype=tf.float32,
shape=[20, 5],
name="weights_node")
tf.matmul(input_node, weights_node, name="matmul_node")
sess = tf.Session()
graph_def = sess.graph.as_graph_def()
statistic_types = ["weight_parameters", "flops"]
total_stats, node_stats = graph_metrics.calculate_graph_metrics(
graph_def, statistic_types, "input_node:0", None, 10)
expected = {"weight_parameters": 100, "flops": 2000}
for statistic_type in statistic_types:
current_stats = node_stats[statistic_type]["matmul_node"]
self.assertEqual(expected[statistic_type], current_stats.value)
for statistic_type in statistic_types:
current_stats = total_stats[statistic_type]
self.assertEqual(expected[statistic_type], current_stats.value)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册