diff --git a/python/paddle/static/quantization/quantization_pass.py b/python/paddle/static/quantization/quantization_pass.py index b70386656912199a5f6e1b0b2cfc3d72050fb774..cc76669c661c43b0463559a596e6f306d81524e3 100644 --- a/python/paddle/static/quantization/quantization_pass.py +++ b/python/paddle/static/quantization/quantization_pass.py @@ -3150,7 +3150,7 @@ class QuantWeightPass: self._save_int_weight = save_int_weight assert self._scope is not None, "scope must not be None." assert self._place is not None, "place must not be None." - self._quantized_ops = {} + self._quantized_ops = set() def apply(self, graph): assert isinstance( @@ -3189,6 +3189,7 @@ class QuantWeightPass: quant_axis = _op.op().attr("quant_axis") bits_length = _op.op().attr("bit_length") if x_node.name() not in self._quantized_ops: + self._quantized_ops.add(x_node.name()) quantized_param_v = utils.quant_tensor( param_v.copy(), scale_v, @@ -3211,30 +3212,10 @@ class QuantWeightPass: quantized_param_v = quantized_param_v.astype( save_weight_dtype ) - quant_weight_node = graph.create_persistable_node( - name=self._quantized_var_name(x_node.name()), - var_type=core.VarDesc.VarType.LOD_TENSOR, - shape=x_node.shape(), - var_dtype=core.VarDesc.VarType.INT8, - ) - _init_var_node( - quant_weight_node, - quantized_param_v, - self._scope, - self._place, - ) - self._quantized_ops[x_node.name()] = quant_weight_node + self._restore_var(x_node.name(), quantized_param_v) for next_op_node in out_node.outputs: - if ( - self._quantized_ops[x_node.name()].node - in graph.graph.nodes() - ): - graph.update_input_link( - out_node, - self._quantized_ops[x_node.name()], - next_op_node, - ) + graph.update_input_link(out_node, x_node, next_op_node) graph.safe_remove_nodes(_op) self._remove_unused_var_nodes(graph) @@ -3260,11 +3241,9 @@ class QuantWeightPass: def _load_var(self, name): return np.array(self._scope.find_var(name).get_tensor()) - def _quantized_var_name(self, var_name): - """ - Return quantized variable name for the input `var_name`. - """ - return "%s.quantized" % (var_name) + def _restore_var(self, name, array): + tensor = self._scope.find_var(name).get_tensor() + tensor.set(array, self._place) class AddQuantDequantForInferencePass: