提交 a25be53c 编写于 作者: B bingyanghuang 提交者: Tao Luo

QAT int8 MKL-DNN transformation pass with MUL (#18322)

上级 667f88f9
......@@ -27,9 +27,9 @@ class TransformForMkldnnPass(object):
1. Convert int8 range weights with float32 data type, which are generated by
the QuantizationFreezePass, to float32 range weights with float32 data type
by using the corresponding scales. This conversion is because MKL-DNN INT8
conv2d kernel now only supports float32 weights input, will do weights
quantization inside the conv2d kernel.
2. Create the new conv2d op with the converted weights and link its output
conv2d kernel and mul kernel now only support float32 weights input, hence
weights quantization will happen inside the conv2d and mul INT8 kernel.
2. Create the new conv2d or mul op with the converted weights and link its output
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
_output" as true
3. Transform fake_quantize_xx op to quantize op
......@@ -73,13 +73,8 @@ class TransformForMkldnnPass(object):
self.InScale = {}
self.max_range = {}
self.conv_new_output = {}
self.new_output = {}
self.s8_max = 127
# Temporary code for keeping the mul op as fake quantization
#TODO Intel: Remove the following code when mul int8 mkldnn
# kernel enabled
self.mul_input_id = []
self.mul_output_id = []
def apply(self, graph):
"""
......@@ -97,7 +92,7 @@ class TransformForMkldnnPass(object):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
# Collect the InScales and max_range to calculate the new scales for MKL-DNN
# INT8 conv2d
# INT8 conv2d and mul
for op_node in ops:
if op_node.name() in self.dequantize_type:
input_name = op_node.input("X")[0]
......@@ -105,20 +100,14 @@ class TransformForMkldnnPass(object):
self.InScale[input_name] = self._load_param(self._scope,
scale_name)[0]
self.max_range[input_name] = op_node.op().attr("max_range")
self.conv_new_output[input_name] = op_node.output("Out")[0]
# Temporary graph transform on keeping the mul op
# TODO Intel: Remove following code
elif op_node.name() in ['mul']:
input_node = graph._find_node_by_name(op_node.inputs,
op_node.input('X')[0])
output_node = graph._find_node_by_name(op_node.outputs,
op_node.output('Out')[0])
self.mul_input_id.append(input_node.id())
self.mul_output_id.append(output_node.id())
self.new_output[input_name] = op_node.output("Out")[0]
for op_node in ops:
if op_node.name() in self._conv_ops:
self._transform_to_conv_mkldnn(graph, op_node)
if op_node.name() in self._quantizable_ops:
if op_node.name() in self._conv_ops:
self._transform_to_conv_mkldnn(graph, op_node)
else:
self._transform_to_mul_mkldnn(graph, op_node)
elif op_node.name() in self.quantize_type:
self._transform_to_quantize_mkldnn(graph, op_node)
elif op_node.name() in self.dequantize_type:
......@@ -132,7 +121,7 @@ class TransformForMkldnnPass(object):
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, 127), self.max_range[output_name])
np.multiply(weight, self.s8_max), self.max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
......@@ -140,8 +129,8 @@ class TransformForMkldnnPass(object):
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of conv2d
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), self.conv_new_output[output_name])
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
self.new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
......@@ -157,7 +146,7 @@ class TransformForMkldnnPass(object):
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
scale_in = self.s8_max / self.InScale[output_name]
scale_w = []
scale_w.append(self.max_range[output_name] / self.s8_max)
scale_w = [self.max_range[output_name] / self.s8_max]
conv_op_node.set_attr("Scale_weights", scale_w)
conv_op_node.set_attr("Scale_in", scale_in)
......@@ -169,6 +158,50 @@ class TransformForMkldnnPass(object):
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
def _transform_to_mul_mkldnn(self, graph, op_node):
# For MKL-DNN INT8 mul, input Y should be the weights
weight_name = op_node.input("Y")[0]
output_name = op_node.output("Out")[0]
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, self.s8_max), self.max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of mul
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
self.new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
mul_op_node = graph.create_op_node(
op_type='mul',
attrs=attrs,
inputs={'X': input_var_node,
'Y': weight_var_node},
outputs={'Out': output_var_node})
# Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales
scale_in = self.s8_max / self.InScale[output_name]
scale_w = []
scale_w = [self.max_range[output_name] / self.s8_max]
mul_op_node.set_attr("scale_y", scale_w)
mul_op_node.set_attr("scale_x", scale_in)
mul_op_node.set_attr("scale_out", 1.0)
mul_op_node.set_attr("use_mkldnn", 1)
mul_op_node.set_attr("force_fp32_output", 1)
graph.link_to(input_var_node, mul_op_node)
graph.link_to(weight_var_node, mul_op_node)
graph.link_to(mul_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
def _transform_to_quantize_mkldnn(self, graph, op_node):
"""
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
......@@ -177,32 +210,26 @@ class TransformForMkldnnPass(object):
op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0])
if output_var_node.id() in self.mul_input_id:
return
else:
scale_in = self.s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
scale_in = self.s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
def _remove_fake_dequantize_op(self, graph, op_node):
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
if input_var_node.id() in self.mul_output_id:
return
else:
graph.safe_remove_nodes(op_node)
graph.safe_remove_nodes(op_node)
def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
......
......@@ -55,9 +55,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
# Mul int8 op is under internal test
# TODO Update this when mul op is merged
#'mul': ['X', 'Y']
'mul': ['X', 'Y']
}
def check_program(self, program):
......@@ -162,11 +160,15 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
activation_quant_type + '_' + weight_quant_type,
marked_nodes)
mkldnn_program = test_graph.to_program()
w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Check the transformation weights of conv2d and mul
conv_w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
mul_w_mkldnn = np.array(scope.find_var('fc_0.w_0').get_tensor())
# Check if weights are still integer
self.assertFalse(self.isinteger(np.sum(w_mkldnn)))
self.assertFalse(self.isinteger(np.sum(conv_w_mkldnn)))
self.assertFalse(self.isinteger(np.sum(mul_w_mkldnn)))
# Check if the conv2d output is rightly linked to fake_dequantize's
# Check if the conv2d output and mul output are correctly linked to fake_dequantize's
# output
self.check_program(mkldnn_program)
if not for_ci:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册