未验证 提交 d856f876 编写于 作者: W whs 提交者: GitHub

Set attribute "with_quant_attr" into quantized operators (#35583)

上级 2c922d63
......@@ -360,6 +360,7 @@ class ImperativePTQ(object):
in_var_name)
op._set_attr(argname + str(index) + "_threshold",
in_threshold)
op._set_attr("with_quant_attr", True)
else:
for out_var_name in utils._get_op_output_var_names(
previous_op):
......@@ -376,6 +377,7 @@ class ImperativePTQ(object):
op, in_var_name)
attr_name = argname + str(index) + "_threshold"
op._set_attr(attr_name, threshold)
op._set_attr("with_quant_attr", True)
def _clean_up(self, program):
"""
......@@ -394,6 +396,7 @@ class ImperativePTQ(object):
op._remove_attr(old_attr_name)
next_op._remove_attr(old_attr_name)
next_op._set_attr(new_attr_name, threshold)
next_op._set_attr("with_quant_attr", True)
for op in utils.program_all_ops(program):
if "quantize_dequantize" in op.type:
......
......@@ -548,6 +548,7 @@ class ImperativeQuantizeOutputs(object):
op, in_var_name)
op._set_attr(argname + str(index) + "_threshold",
in_scale)
op._set_attr("with_quant_attr", True)
def _gather_output_scale():
target_ops = []
......@@ -574,6 +575,7 @@ class ImperativeQuantizeOutputs(object):
previous_op._set_attr(
argname + str(index) + "_threshold", out_scale)
previous_op._set_attr("out_threshold", out_scale)
previous_op._set_attr("with_quant_attr", True)
for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name)
......@@ -589,6 +591,7 @@ class ImperativeQuantizeOutputs(object):
for op in block.ops:
if self._is_skip_quant_op(block, op):
op._set_attr("skip_quant", True)
op._set_attr("with_quant_attr", True)
def _is_skip_quant_op(self, block, in_op):
"""
......
......@@ -705,6 +705,7 @@ class PostTrainingQuantization(object):
self._quantized_var_min[var_name])
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
op._set_attr("with_quant_attr", True)
def _collect_activation_abs_min_max(self):
'''
......@@ -849,6 +850,7 @@ class PostTrainingQuantization(object):
"The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type)
op_node._set_attr(out_info_name, threshold_map[var_name])
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
......@@ -921,6 +923,7 @@ class PostTrainingQuantization(object):
op._set_attr(argname + str(index) + "_threshold", threshold)
op._set_attr("quantization_type", quantization_type)
op._set_attr("bit_length", self._weight_bits)
op._set_attr("with_quant_attr", True)
def _get_hist_scaling_factor(self, hist, hist_edges):
'''
......@@ -1184,6 +1187,7 @@ class WeightQuantization(object):
op._set_attr('quantization_type', 'post_weight_abs_max')
op._set_attr('quantize_weight_bits', weight_bits)
op._set_attr(var_name + "_quant_scale", [scale]) # Save as list
op._set_attr("with_quant_attr", True)
def _weight_channel_wise_abs_max_quantization(
self, scope, place, weight_bits, op, var_name, for_test):
......@@ -1225,6 +1229,7 @@ class WeightQuantization(object):
op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max')
op._set_attr('quantize_weight_bits', weight_bits)
op._set_attr(var_name + "_quant_scale", scales)
op._set_attr("with_quant_attr", True)
def _conv_channel_wise_quantization(self, weight_data, quantize_range,
save_weight_dtype):
......
......@@ -442,9 +442,11 @@ class QuantizationTransformPass(object):
if user_skipped:
op_node.op()._set_attr("skip_quant", True)
op_node.op()._set_attr("with_quant_attr", True)
def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
op.op()._set_attr("with_quant_attr", True)
inputs = op.inputs
for var_node in inputs:
if var_node.name() not in op.input_arg_names():
......@@ -1760,6 +1762,7 @@ class OutScaleForInferencePass(object):
var_name + " is not the output of the op"
op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \
+ "_threshold", float(scale_value))
op_node.op()._set_attr("with_quant_attr", True)
graph.resolve_hazard()
return graph
......@@ -1875,6 +1878,7 @@ class AddQuantDequantPass(object):
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits)
op_node.op()._set_attr("with_quant_attr", True)
arg_names = _get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册