未验证 提交 270ba570 编写于 作者: H handiz 提交者: GitHub

fix acc diff problem caused by pr #44116 (#44311)

上级 cb44b694
...@@ -1440,11 +1440,18 @@ class OutScaleForTrainingPass(object): ...@@ -1440,11 +1440,18 @@ class OutScaleForTrainingPass(object):
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue continue
try:
graph._find_node_by_name(
graph.all_var_nodes(),
self._scale_name(in_node.name()))
continue
except:
scale_node = graph.create_persistable_node( scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()), name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=in_node.dtype()) var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \ data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32' == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type), _init_var_node(scale_node, np.ones([1], dtype=data_type),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册