未验证 提交 00ded2ea 编写于 作者: W WangZhen 提交者: GitHub

Fix div error when dtype is int64 in static mode (#53705)

* Fix div error when dtype is int64 in static mode

* Fix out dtype
上级 fb8ea98c
......@@ -409,11 +409,19 @@ def monkey_patch_variable():
self = other_var
other_var = tmp
if (
op_type == "divide" or op_type == "elementwise_div"
) and self.dtype in _supported_int_dtype_:
self = astype(self, 'float32')
other_var = astype(other_var, 'float32')
# NOTE(zhiqiu): the output of compare operator should be bool.
if method_name in compare_ops:
out = create_new_tmp_var(current_block(self), dtype="bool")
else:
out = create_new_tmp_var(current_block(self), dtype=lhs_dtype)
out = create_new_tmp_var(
current_block(self), dtype=safe_get_dtype(self)
)
axis = -1
if other_var.ndim > 0 and other_var.shape[0] == -1:
......
......@@ -101,5 +101,24 @@ class TestTensorSize(unittest.TestCase):
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5)
@paddle.jit.to_static
def true_div(x, y):
z = x / y
return z
class TestTrueDiv(unittest.TestCase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
x = paddle.to_tensor([3], dtype='int64')
y = paddle.to_tensor([4], dtype='int64')
return true_div(x, y).numpy()
def test_ture_div(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册