From 270ba5709774b61bc98f18b6e278b6faf5e90ccc Mon Sep 17 00:00:00 2001 From: handiz <35895648+ZhangHandi@users.noreply.github.com> Date: Thu, 14 Jul 2022 10:34:59 +0800 Subject: [PATCH] fix acc diff problem caused by pr #44116 (#44311) --- .../slim/quantization/quantization_pass.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index d3ce543320..e2502e7f5d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1440,11 +1440,18 @@ class OutScaleForTrainingPass(object): [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: continue - scale_node = graph.create_persistable_node( - name=self._scale_name(in_node.name()), - var_type=core.VarDesc.VarType.LOD_TENSOR, - shape=[1], - var_dtype=in_node.dtype()) + try: + graph._find_node_by_name( + graph.all_var_nodes(), + self._scale_name(in_node.name())) + continue + except: + scale_node = graph.create_persistable_node( + name=self._scale_name(in_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=in_node.dtype()) + data_type = 'float64' if in_node.dtype() \ == core.VarDesc.VarType.FP64 else 'float32' _init_var_node(scale_node, np.ones([1], dtype=data_type), -- GitLab