From 4d35112255bd053976d7d0d5fc2c4fa1d3d5b470 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 13 May 2020 15:54:37 +0800 Subject: [PATCH] [Fix bug] Init scale node in OutScaleForTrainingPass and enable test_quantization_scale_pass UT (#24393) * Init scale node in OutScaleForTrainingPass, test=develop * Enable test_quantization_scale, test=develop --- .../contrib/slim/quantization/quantization_pass.py | 10 ++++++++-- python/paddle/fluid/contrib/slim/tests/CMakeLists.txt | 3 --- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index cde41e687fa..a6ab2aa86d0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1170,6 +1170,14 @@ class OutScaleForTrainingPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=in_node.dtype()) + data_type = 'float64' if in_node.dtype() \ + == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node( + scale_node, + np.ones( + [1], dtype=data_type), + self._scope, + self._place) ins = {'X': in_node} outs = {'Out': out_node, 'OutScale': scale_node} if not self._is_test: @@ -1178,8 +1186,6 @@ class OutScaleForTrainingPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, var_dtype=in_node.dtype(), shape=[1]) - data_type = 'float64' if in_node.dtype( - ) == core.VarDesc.VarType.FP64 else 'float32' _init_var_node( state_in_node, np.ones( diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 095489bc736..93ca87ce0f4 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -127,9 +127,6 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) endif() -# Disable unittest for random error temporary -list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass) - if(LINUX AND WITH_MKLDNN) #### Image classification dataset: ImageNet (small) -- GitLab