未验证 提交 6390b175 编写于 作者: X XGZhang 提交者: GitHub

support quantization of bert (#36593)

上级 11c2874e
...@@ -272,7 +272,8 @@ class QuantizationTransformPass(object): ...@@ -272,7 +272,8 @@ class QuantizationTransformPass(object):
the quantized ops's inputs. the quantized ops's inputs.
""" """
_supported_quantizable_op_type = [ _supported_quantizable_op_type = [
'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul' 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul',
'matmul_v2'
] ]
def __init__(self, def __init__(self,
...@@ -520,6 +521,16 @@ class QuantizationTransformPass(object): ...@@ -520,6 +521,16 @@ class QuantizationTransformPass(object):
dequant_var_node = dequantized_vars[var_node.name()] dequant_var_node = dequantized_vars[var_node.name()]
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
def _has_weight(op):
has_weight = False
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
name = var_node.name()
if var_node.name() in persistable_vars:
has_weight = True
return has_weight
if not self._is_test: if not self._is_test:
self._create_global_step(graph) self._create_global_step(graph)
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
...@@ -535,11 +546,11 @@ class QuantizationTransformPass(object): ...@@ -535,11 +546,11 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op): if not self._is_skip_quant(graph, op) and _has_weight(op):
_transform_forward(graph, op) _transform_forward(graph, op)
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops: if op.name() in self._quantizable_grad_ops and _has_weight(op):
_transform_backward(graph, op) _transform_backward(graph, op)
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册