未验证 提交 d856f876 编写于 作者: W whs 提交者: GitHub

Set attribute "with_quant_attr" into quantized operators (#35583)

上级 2c922d63
...@@ -360,6 +360,7 @@ class ImperativePTQ(object): ...@@ -360,6 +360,7 @@ class ImperativePTQ(object):
in_var_name) in_var_name)
op._set_attr(argname + str(index) + "_threshold", op._set_attr(argname + str(index) + "_threshold",
in_threshold) in_threshold)
op._set_attr("with_quant_attr", True)
else: else:
for out_var_name in utils._get_op_output_var_names( for out_var_name in utils._get_op_output_var_names(
previous_op): previous_op):
...@@ -376,6 +377,7 @@ class ImperativePTQ(object): ...@@ -376,6 +377,7 @@ class ImperativePTQ(object):
op, in_var_name) op, in_var_name)
attr_name = argname + str(index) + "_threshold" attr_name = argname + str(index) + "_threshold"
op._set_attr(attr_name, threshold) op._set_attr(attr_name, threshold)
op._set_attr("with_quant_attr", True)
def _clean_up(self, program): def _clean_up(self, program):
""" """
...@@ -394,6 +396,7 @@ class ImperativePTQ(object): ...@@ -394,6 +396,7 @@ class ImperativePTQ(object):
op._remove_attr(old_attr_name) op._remove_attr(old_attr_name)
next_op._remove_attr(old_attr_name) next_op._remove_attr(old_attr_name)
next_op._set_attr(new_attr_name, threshold) next_op._set_attr(new_attr_name, threshold)
next_op._set_attr("with_quant_attr", True)
for op in utils.program_all_ops(program): for op in utils.program_all_ops(program):
if "quantize_dequantize" in op.type: if "quantize_dequantize" in op.type:
......
...@@ -548,6 +548,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -548,6 +548,7 @@ class ImperativeQuantizeOutputs(object):
op, in_var_name) op, in_var_name)
op._set_attr(argname + str(index) + "_threshold", op._set_attr(argname + str(index) + "_threshold",
in_scale) in_scale)
op._set_attr("with_quant_attr", True)
def _gather_output_scale(): def _gather_output_scale():
target_ops = [] target_ops = []
...@@ -574,6 +575,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -574,6 +575,7 @@ class ImperativeQuantizeOutputs(object):
previous_op._set_attr( previous_op._set_attr(
argname + str(index) + "_threshold", out_scale) argname + str(index) + "_threshold", out_scale)
previous_op._set_attr("out_threshold", out_scale) previous_op._set_attr("out_threshold", out_scale)
previous_op._set_attr("with_quant_attr", True)
for next_op in next_ops: for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name) next_op._rename_input(out_var_name, in_var_name)
...@@ -589,6 +591,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -589,6 +591,7 @@ class ImperativeQuantizeOutputs(object):
for op in block.ops: for op in block.ops:
if self._is_skip_quant_op(block, op): if self._is_skip_quant_op(block, op):
op._set_attr("skip_quant", True) op._set_attr("skip_quant", True)
op._set_attr("with_quant_attr", True)
def _is_skip_quant_op(self, block, in_op): def _is_skip_quant_op(self, block, in_op):
""" """
......
...@@ -705,6 +705,7 @@ class PostTrainingQuantization(object): ...@@ -705,6 +705,7 @@ class PostTrainingQuantization(object):
self._quantized_var_min[var_name]) self._quantized_var_min[var_name])
op._set_attr(var_name + ".max", op._set_attr(var_name + ".max",
self._quantized_var_max[var_name]) self._quantized_var_max[var_name])
op._set_attr("with_quant_attr", True)
def _collect_activation_abs_min_max(self): def _collect_activation_abs_min_max(self):
''' '''
...@@ -849,6 +850,7 @@ class PostTrainingQuantization(object): ...@@ -849,6 +850,7 @@ class PostTrainingQuantization(object):
"The output ({}) of {} node does not have threshold.".format( "The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type) out_var_name, op_node.type)
op_node._set_attr(out_info_name, threshold_map[var_name]) op_node._set_attr(out_info_name, threshold_map[var_name])
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type: if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type) op._set_attr("quantization_type", quantized_type)
...@@ -921,6 +923,7 @@ class PostTrainingQuantization(object): ...@@ -921,6 +923,7 @@ class PostTrainingQuantization(object):
op._set_attr(argname + str(index) + "_threshold", threshold) op._set_attr(argname + str(index) + "_threshold", threshold)
op._set_attr("quantization_type", quantization_type) op._set_attr("quantization_type", quantization_type)
op._set_attr("bit_length", self._weight_bits) op._set_attr("bit_length", self._weight_bits)
op._set_attr("with_quant_attr", True)
def _get_hist_scaling_factor(self, hist, hist_edges): def _get_hist_scaling_factor(self, hist, hist_edges):
''' '''
...@@ -1184,6 +1187,7 @@ class WeightQuantization(object): ...@@ -1184,6 +1187,7 @@ class WeightQuantization(object):
op._set_attr('quantization_type', 'post_weight_abs_max') op._set_attr('quantization_type', 'post_weight_abs_max')
op._set_attr('quantize_weight_bits', weight_bits) op._set_attr('quantize_weight_bits', weight_bits)
op._set_attr(var_name + "_quant_scale", [scale]) # Save as list op._set_attr(var_name + "_quant_scale", [scale]) # Save as list
op._set_attr("with_quant_attr", True)
def _weight_channel_wise_abs_max_quantization( def _weight_channel_wise_abs_max_quantization(
self, scope, place, weight_bits, op, var_name, for_test): self, scope, place, weight_bits, op, var_name, for_test):
...@@ -1225,6 +1229,7 @@ class WeightQuantization(object): ...@@ -1225,6 +1229,7 @@ class WeightQuantization(object):
op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max') op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max')
op._set_attr('quantize_weight_bits', weight_bits) op._set_attr('quantize_weight_bits', weight_bits)
op._set_attr(var_name + "_quant_scale", scales) op._set_attr(var_name + "_quant_scale", scales)
op._set_attr("with_quant_attr", True)
def _conv_channel_wise_quantization(self, weight_data, quantize_range, def _conv_channel_wise_quantization(self, weight_data, quantize_range,
save_weight_dtype): save_weight_dtype):
......
...@@ -442,9 +442,11 @@ class QuantizationTransformPass(object): ...@@ -442,9 +442,11 @@ class QuantizationTransformPass(object):
if user_skipped: if user_skipped:
op_node.op()._set_attr("skip_quant", True) op_node.op()._set_attr("skip_quant", True)
op_node.op()._set_attr("with_quant_attr", True)
def _transform_forward(graph, op): def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight") op.op()._set_attr("quantization_type", "qat_with_weight")
op.op()._set_attr("with_quant_attr", True)
inputs = op.inputs inputs = op.inputs
for var_node in inputs: for var_node in inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
...@@ -1760,6 +1762,7 @@ class OutScaleForInferencePass(object): ...@@ -1760,6 +1762,7 @@ class OutScaleForInferencePass(object):
var_name + " is not the output of the op" var_name + " is not the output of the op"
op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \ op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \
+ "_threshold", float(scale_value)) + "_threshold", float(scale_value))
op_node.op()._set_attr("with_quant_attr", True)
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
...@@ -1875,6 +1878,7 @@ class AddQuantDequantPass(object): ...@@ -1875,6 +1878,7 @@ class AddQuantDequantPass(object):
op_node.op()._set_attr("quantization_type", op_node.op()._set_attr("quantization_type",
"qat_without_weight") "qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits) op_node.op()._set_attr("activation_bits", self._quant_bits)
op_node.op()._set_attr("with_quant_attr", True)
arg_names = _get_op_input_var_names(op_node) arg_names = _get_op_input_var_names(op_node)
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name) in_node = graph._find_node_by_name(op_node.inputs, arg_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册