未验证 提交 4d351122 编写于 作者: C cc 提交者: GitHub

[Fix bug] Init scale node in OutScaleForTrainingPass and enable...

[Fix bug] Init scale node in OutScaleForTrainingPass and enable test_quantization_scale_pass UT (#24393)

* Init scale node in OutScaleForTrainingPass, test=develop
* Enable test_quantization_scale, test=develop
上级 05c3bc3b
...@@ -1170,6 +1170,14 @@ class OutScaleForTrainingPass(object): ...@@ -1170,6 +1170,14 @@ class OutScaleForTrainingPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=in_node.dtype()) var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
scale_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
ins = {'X': in_node} ins = {'X': in_node}
outs = {'Out': out_node, 'OutScale': scale_node} outs = {'Out': out_node, 'OutScale': scale_node}
if not self._is_test: if not self._is_test:
...@@ -1178,8 +1186,6 @@ class OutScaleForTrainingPass(object): ...@@ -1178,8 +1186,6 @@ class OutScaleForTrainingPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(), var_dtype=in_node.dtype(),
shape=[1]) shape=[1])
data_type = 'float64' if in_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(
state_in_node, state_in_node,
np.ones( np.ones(
......
...@@ -127,9 +127,6 @@ if(WIN32) ...@@ -127,9 +127,6 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
endif() endif()
# Disable unittest for random error temporary
list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass)
if(LINUX AND WITH_MKLDNN) if(LINUX AND WITH_MKLDNN)
#### Image classification dataset: ImageNet (small) #### Image classification dataset: ImageNet (small)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册