From 1b7de8945593ea4cb493fd081adc2b24604a53d8 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 26 Sep 2019 07:38:21 +0800 Subject: [PATCH] fix math_op_path.py when integers, test=develop (#20008) --- python/paddle/fluid/layers/math_op_patch.py | 6 +++++- .../fluid/tests/unittests/test_math_op_patch.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 9564eff73f2..bd56762f451 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -143,7 +143,11 @@ def monkey_patch_variable(): reverse=False, scalar_method=None): def __impl__(self, other_var): - if scalar_method is not None: + # FIXME(zjl): elementwise_div between integers cannot be converted to scale, + # which may lose accuracy. This is a hot fix for release 1.6. + if scalar_method is not None and not ( + op_type == 'elementwise_div' and + self.dtype in _supported_int_dtype_): if isinstance(other_var, float): if self.dtype in _supported_int_dtype_: assert other_var == int(other_var), \ diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index f6cdb17def9..c90640b65af 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -186,6 +186,20 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[c]) self.assertTrue(numpy.allclose(a_np - b_np, c_np)) + @prog_scope() + def test_integer_div(self): + a = fluid.layers.data(name="a", shape=[1], dtype='int64') + b = a / 7 + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.array([3, 4, 10, 14, 9, 18]).astype('int64') + b_np, = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + + b_np_actual = (a_np / 7).astype('int64') + self.assertTrue(numpy.array_equal(b_np, b_np_actual)) + if __name__ == '__main__': unittest.main() -- GitLab