From feee67cabd5f66614ec74e3326249c43d89d26eb Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Thu, 11 May 2023 22:58:13 +0800 Subject: [PATCH] Fix div error when dtype is int64 in static mode (#53705) (#53713) * Fix div error when dtype is int64 in static mode * Fix out dtype --- python/paddle/fluid/layers/math_op_patch.py | 10 +++++++++- test/dygraph_to_static/test_tensor_methods.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 01426a0c792..41299a624c1 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -409,11 +409,19 @@ def monkey_patch_variable(): self = other_var other_var = tmp + if ( + op_type == "divide" or op_type == "elementwise_div" + ) and self.dtype in _supported_int_dtype_: + self = astype(self, 'float32') + other_var = astype(other_var, 'float32') + # NOTE(zhiqiu): the output of compare operator should be bool. if method_name in compare_ops: out = create_new_tmp_var(current_block(self), dtype="bool") else: - out = create_new_tmp_var(current_block(self), dtype=lhs_dtype) + out = create_new_tmp_var( + current_block(self), dtype=safe_get_dtype(self) + ) axis = -1 if other_var.ndim > 0 and other_var.shape[0] == -1: diff --git a/test/dygraph_to_static/test_tensor_methods.py b/test/dygraph_to_static/test_tensor_methods.py index b1a512c08fd..fff5476167d 100644 --- a/test/dygraph_to_static/test_tensor_methods.py +++ b/test/dygraph_to_static/test_tensor_methods.py @@ -101,5 +101,24 @@ class TestTensorSize(unittest.TestCase): np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5) +@paddle.jit.to_static +def true_div(x, y): + z = x / y + return z + + +class TestTrueDiv(unittest.TestCase): + def _run(self, to_static): + paddle.jit.enable_to_static(to_static) + x = paddle.to_tensor([3], dtype='int64') + y = paddle.to_tensor([4], dtype='int64') + return true_div(x, y).numpy() + + def test_ture_div(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5) + + if __name__ == '__main__': unittest.main() -- GitLab