From 7c8aa99d5b2eab45dfe78bb6a1eb394db8a7e78e Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Tue, 21 Nov 2017 17:48:25 +0800 Subject: [PATCH] Print const and conv2d ops in tf_ops_stats.py --- mace/python/tools/tf_ops_stats.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mace/python/tools/tf_ops_stats.py b/mace/python/tools/tf_ops_stats.py index bdf47e07..9301b3f1 100644 --- a/mace/python/tools/tf_ops_stats.py +++ b/mace/python/tools/tf_ops_stats.py @@ -1,3 +1,5 @@ +import operator +import functools import argparse import sys import six @@ -39,6 +41,7 @@ def main(unused_args): # extract kernel size for conv_2d tensor_shapes = {} tensor_values = {} + print("=========================consts============================") for op in ops: if op.type == 'Const': for output in op.outputs: @@ -46,9 +49,11 @@ def main(unused_args): tensor = output.eval() tensor_shape = list(tensor.shape) tensor_shapes[tensor_name] = tensor_shape + print("Const %s: %s, %d" % (tensor_name, tensor_shape, functools.reduce(operator.mul, tensor_shape, 1))) if len(tensor_shape) == 1 and tensor_shape[0] < 10: tensor_values[tensor_name] = list(tensor) + print("=========================ops============================") for op in ops: if op.type in ['Conv2D']: padding = op.get_attr('padding') @@ -63,6 +68,7 @@ def main(unused_args): if input_name.endswith('weights:0') and input_name in tensor_shapes: ksize = tensor_shapes[input_name] break + print('%s(padding=%s, strides=%s, ksize=%s, format=%s) %s => %s' % (op.type, padding, strides, ksize, data_format, op.inputs[0].shape.as_list(), op.outputs[0].shape.as_list())) key = '%s(padding=%s, strides=%s, ksize=%s, format=%s)' % (op.type, padding, strides, ksize, data_format) hist_inc(stats, key) elif op.type in ['FusedResizeAndPadConv2D']: @@ -117,6 +123,7 @@ def main(unused_args): key = '%s(block_shape=%s, paddings=%s)' % (op.type, block_shape, paddings) else: key = '%s(block_shape=%s, crops=%s)' % (op.type, block_shape, crops) + print(key) hist_inc(stats, key) elif op.type == 'Pad': paddings = 'Unknown' @@ -130,6 +137,7 @@ def main(unused_args): else: hist_inc(stats, op.type) + print("=========================stats============================") for key, value in sorted(six.iteritems(stats)): print('%s: %d' % (key, value)) -- GitLab