提交 626317cb 编写于 作者: R Raghuraman Krishnamoorthi 提交者: TensorFlower Gardener

Generalize quantization rewriter to handle seperable convolutions. Insert...

 Generalize quantization rewriter to handle seperable convolutions. Insert fake quant ops for weights in both depthwise and regular convolutions inside a seperable convolution op. Also insert fake quant ops for activations produced by first depthwise convolution

PiperOrigin-RevId: 207009650
上级 2ff02637
......@@ -261,6 +261,16 @@ def _FindLayersToQuantize(graph):
layer_output_pattern = graph_matcher.OneofPattern(
[batch_to_space_pattern, layer_pattern])
# For separable convolutions, we are looking for a conv, followed by a conv
# with no activations between the two.
sep_conv_pattern = graph_matcher.OpTypePattern(
'|'.join(_QUANTIZABLE_TYPES),
inputs=[
graph_matcher.OneofPattern([layer_output_pattern]),
graph_matcher.OpTypePattern('*')
],
ordered_inputs=False)
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
'Mul',
inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
......@@ -393,6 +403,17 @@ def _FindLayersToQuantize(graph):
layer_matches.append(
_LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
# Look for separable convolutions here
sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern)
for match_result in sep_conv_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern)
activation_op = match_result.get_op(layer_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(
_LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
return layer_matches
......
......@@ -122,12 +122,67 @@ class QuantizeTest(test_util.TensorFlowTestCase):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
# Check if output of bias add is quantized
quantization_node_name = 'FakeQuantWithMinMaxVars'
conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
for op in graph.get_operations():
if op.type == quantization_node_name:
quant_op = graph.get_operation_by_name(op.name)
# Scan through all FakeQuant operations, ensuring that the activation
# identity op isn't in the consumers of the operation.
consumers = []
for output in quant_op.outputs:
consumers.extend(output.consumers())
self.assertNotIn('test/relu6', [c.name for c in consumers])
def testInsertQuantOpInSeparableConv2d(self):
self._RunTestOverParameters(self._TestInsertQuantOpInSeparableConv2d)
def _TestInsertQuantOpInSeparableConv2d(self, is_training):
graph = ops.Graph()
with graph.as_default():
batch_size, height, width, depth = 5, 128, 128, 3
input1 = array_ops.zeros((batch_size, height, width, depth))
input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth))
conv = separable_conv2d(
input1,
3, [5, 5],
stride=2,
depth_multiplier=1.0,
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
scope='test/test')
node = math_ops.add(conv, input2, name='test/add')
node = nn_ops.relu6(node, name='test/relu6')
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
# Check if output of bias add is quantized
quantization_node_name = 'FakeQuantWithMinMaxVars'
conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
# Check if weights for both convs inside seperable conv are quantized
pointwise_weight_quant = graph.get_operation_by_name(
'test/test/weights_quant/' + quantization_node_name)
self.assertEqual(pointwise_weight_quant.type, quantization_node_name)
depthwise_weight_quant = graph.get_operation_by_name(
'test/test/separable_conv2d/weights_quant/' + quantization_node_name)
self.assertEqual(depthwise_weight_quant.type, quantization_node_name)
# Check if activations after first depthwise conv are quantized.
depthwise_act_quant = graph.get_operation_by_name(
'test/test/separable_conv2d/act_quant/' + quantization_node_name)
self.assertEqual(depthwise_act_quant.type, quantization_node_name)
for op in graph.get_operations():
if op.type == quantization_node_name:
quant_op = graph.get_operation_by_name(op.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册