提交 a25be53c 编写于 作者: B bingyanghuang 提交者: Tao Luo

QAT int8 MKL-DNN transformation pass with MUL (#18322)

上级 667f88f9
...@@ -27,9 +27,9 @@ class TransformForMkldnnPass(object): ...@@ -27,9 +27,9 @@ class TransformForMkldnnPass(object):
1. Convert int8 range weights with float32 data type, which are generated by 1. Convert int8 range weights with float32 data type, which are generated by
the QuantizationFreezePass, to float32 range weights with float32 data type the QuantizationFreezePass, to float32 range weights with float32 data type
by using the corresponding scales. This conversion is because MKL-DNN INT8 by using the corresponding scales. This conversion is because MKL-DNN INT8
conv2d kernel now only supports float32 weights input, will do weights conv2d kernel and mul kernel now only support float32 weights input, hence
quantization inside the conv2d kernel. weights quantization will happen inside the conv2d and mul INT8 kernel.
2. Create the new conv2d op with the converted weights and link its output 2. Create the new conv2d or mul op with the converted weights and link its output
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32 to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
_output" as true _output" as true
3. Transform fake_quantize_xx op to quantize op 3. Transform fake_quantize_xx op to quantize op
...@@ -73,13 +73,8 @@ class TransformForMkldnnPass(object): ...@@ -73,13 +73,8 @@ class TransformForMkldnnPass(object):
self.InScale = {} self.InScale = {}
self.max_range = {} self.max_range = {}
self.conv_new_output = {} self.new_output = {}
self.s8_max = 127 self.s8_max = 127
# Temporary code for keeping the mul op as fake quantization
#TODO Intel: Remove the following code when mul int8 mkldnn
# kernel enabled
self.mul_input_id = []
self.mul_output_id = []
def apply(self, graph): def apply(self, graph):
""" """
...@@ -97,7 +92,7 @@ class TransformForMkldnnPass(object): ...@@ -97,7 +92,7 @@ class TransformForMkldnnPass(object):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
# Collect the InScales and max_range to calculate the new scales for MKL-DNN # Collect the InScales and max_range to calculate the new scales for MKL-DNN
# INT8 conv2d # INT8 conv2d and mul
for op_node in ops: for op_node in ops:
if op_node.name() in self.dequantize_type: if op_node.name() in self.dequantize_type:
input_name = op_node.input("X")[0] input_name = op_node.input("X")[0]
...@@ -105,20 +100,14 @@ class TransformForMkldnnPass(object): ...@@ -105,20 +100,14 @@ class TransformForMkldnnPass(object):
self.InScale[input_name] = self._load_param(self._scope, self.InScale[input_name] = self._load_param(self._scope,
scale_name)[0] scale_name)[0]
self.max_range[input_name] = op_node.op().attr("max_range") self.max_range[input_name] = op_node.op().attr("max_range")
self.conv_new_output[input_name] = op_node.output("Out")[0] self.new_output[input_name] = op_node.output("Out")[0]
# Temporary graph transform on keeping the mul op
# TODO Intel: Remove following code
elif op_node.name() in ['mul']:
input_node = graph._find_node_by_name(op_node.inputs,
op_node.input('X')[0])
output_node = graph._find_node_by_name(op_node.outputs,
op_node.output('Out')[0])
self.mul_input_id.append(input_node.id())
self.mul_output_id.append(output_node.id())
for op_node in ops: for op_node in ops:
if op_node.name() in self._quantizable_ops:
if op_node.name() in self._conv_ops: if op_node.name() in self._conv_ops:
self._transform_to_conv_mkldnn(graph, op_node) self._transform_to_conv_mkldnn(graph, op_node)
else:
self._transform_to_mul_mkldnn(graph, op_node)
elif op_node.name() in self.quantize_type: elif op_node.name() in self.quantize_type:
self._transform_to_quantize_mkldnn(graph, op_node) self._transform_to_quantize_mkldnn(graph, op_node)
elif op_node.name() in self.dequantize_type: elif op_node.name() in self.dequantize_type:
...@@ -132,7 +121,7 @@ class TransformForMkldnnPass(object): ...@@ -132,7 +121,7 @@ class TransformForMkldnnPass(object):
# Convert int8 range weights to fp32 range weights # Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name) weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide( w_fp32 = np.divide(
np.multiply(weight, 127), self.max_range[output_name]) np.multiply(weight, self.s8_max), self.max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape) w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32) self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs, input_var_node = graph._find_node_by_name(op_node.inputs,
...@@ -140,8 +129,8 @@ class TransformForMkldnnPass(object): ...@@ -140,8 +129,8 @@ class TransformForMkldnnPass(object):
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name) weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of conv2d # Set fake_dequantize_abs_max's output as new output of conv2d
output_var_node = graph._find_node_by_name( output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
graph.all_var_nodes(), self.conv_new_output[output_name]) self.new_output[output_name])
attrs = { attrs = {
name: op_node.op().attr(name) name: op_node.op().attr(name)
for name in op_node.op().attr_names() for name in op_node.op().attr_names()
...@@ -157,7 +146,7 @@ class TransformForMkldnnPass(object): ...@@ -157,7 +146,7 @@ class TransformForMkldnnPass(object):
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d # Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
scale_in = self.s8_max / self.InScale[output_name] scale_in = self.s8_max / self.InScale[output_name]
scale_w = [] scale_w = []
scale_w.append(self.max_range[output_name] / self.s8_max) scale_w = [self.max_range[output_name] / self.s8_max]
conv_op_node.set_attr("Scale_weights", scale_w) conv_op_node.set_attr("Scale_weights", scale_w)
conv_op_node.set_attr("Scale_in", scale_in) conv_op_node.set_attr("Scale_in", scale_in)
...@@ -169,6 +158,50 @@ class TransformForMkldnnPass(object): ...@@ -169,6 +158,50 @@ class TransformForMkldnnPass(object):
graph.link_to(conv_op_node, output_var_node) graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
def _transform_to_mul_mkldnn(self, graph, op_node):
# For MKL-DNN INT8 mul, input Y should be the weights
weight_name = op_node.input("Y")[0]
output_name = op_node.output("Out")[0]
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, self.s8_max), self.max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of mul
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
self.new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
mul_op_node = graph.create_op_node(
op_type='mul',
attrs=attrs,
inputs={'X': input_var_node,
'Y': weight_var_node},
outputs={'Out': output_var_node})
# Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales
scale_in = self.s8_max / self.InScale[output_name]
scale_w = []
scale_w = [self.max_range[output_name] / self.s8_max]
mul_op_node.set_attr("scale_y", scale_w)
mul_op_node.set_attr("scale_x", scale_in)
mul_op_node.set_attr("scale_out", 1.0)
mul_op_node.set_attr("use_mkldnn", 1)
mul_op_node.set_attr("force_fp32_output", 1)
graph.link_to(input_var_node, mul_op_node)
graph.link_to(weight_var_node, mul_op_node)
graph.link_to(mul_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
def _transform_to_quantize_mkldnn(self, graph, op_node): def _transform_to_quantize_mkldnn(self, graph, op_node):
""" """
Transform fake_quantize_xx op to quantize mkldnn op in the graph. Transform fake_quantize_xx op to quantize mkldnn op in the graph.
...@@ -177,9 +210,6 @@ class TransformForMkldnnPass(object): ...@@ -177,9 +210,6 @@ class TransformForMkldnnPass(object):
op_node.input("X")[0]) op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs, output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0]) op_node.output("Out")[0])
if output_var_node.id() in self.mul_input_id:
return
else:
scale_in = self.s8_max / self._load_param( scale_in = self.s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0] self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
...@@ -199,9 +229,6 @@ class TransformForMkldnnPass(object): ...@@ -199,9 +229,6 @@ class TransformForMkldnnPass(object):
def _remove_fake_dequantize_op(self, graph, op_node): def _remove_fake_dequantize_op(self, graph, op_node):
input_var_node = graph._find_node_by_name(op_node.inputs, input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0]) op_node.input("X")[0])
if input_var_node.id() in self.mul_output_id:
return
else:
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
def _load_param(self, scope, param_name): def _load_param(self, scope, param_name):
......
...@@ -55,9 +55,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -55,9 +55,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
self.quantizable_op_and_inputs = { self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'], 'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'], 'depthwise_conv2d': ['Input', 'Filter'],
# Mul int8 op is under internal test 'mul': ['X', 'Y']
# TODO Update this when mul op is merged
#'mul': ['X', 'Y']
} }
def check_program(self, program): def check_program(self, program):
...@@ -162,11 +160,15 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -162,11 +160,15 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
activation_quant_type + '_' + weight_quant_type, activation_quant_type + '_' + weight_quant_type,
marked_nodes) marked_nodes)
mkldnn_program = test_graph.to_program() mkldnn_program = test_graph.to_program()
w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Check the transformation weights of conv2d and mul
conv_w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
mul_w_mkldnn = np.array(scope.find_var('fc_0.w_0').get_tensor())
# Check if weights are still integer # Check if weights are still integer
self.assertFalse(self.isinteger(np.sum(w_mkldnn))) self.assertFalse(self.isinteger(np.sum(conv_w_mkldnn)))
self.assertFalse(self.isinteger(np.sum(mul_w_mkldnn)))
# Check if the conv2d output is rightly linked to fake_dequantize's # Check if the conv2d output and mul output are correctly linked to fake_dequantize's
# output # output
self.check_program(mkldnn_program) self.check_program(mkldnn_program)
if not for_ci: if not for_ci:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册