diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index d3ce543320ef4130a54a8f4351e96b595be4d55c..e2502e7f5d44795f729ceabf553a3038f0ae131c 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1440,11 +1440,18 @@ class OutScaleForTrainingPass(object): [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: continue - scale_node = graph.create_persistable_node( - name=self._scale_name(in_node.name()), - var_type=core.VarDesc.VarType.LOD_TENSOR, - shape=[1], - var_dtype=in_node.dtype()) + try: + graph._find_node_by_name( + graph.all_var_nodes(), + self._scale_name(in_node.name())) + continue + except: + scale_node = graph.create_persistable_node( + name=self._scale_name(in_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=in_node.dtype()) + data_type = 'float64' if in_node.dtype() \ == core.VarDesc.VarType.FP64 else 'float32' _init_var_node(scale_node, np.ones([1], dtype=data_type),