未验证 提交 7c73910e 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

Support 0D for equal tensor with scalar (#50857)

上级 2dec64d0
...@@ -978,7 +978,18 @@ class TestSundryAPI(unittest.TestCase): ...@@ -978,7 +978,18 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, [0]) self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([])) np.testing.assert_array_equal(out.numpy(), np.array([]))
def test_pow_factor(self): def test_equal_scalar(self):
x = paddle.rand([])
out = paddle.equal(x, 2.0)
self.assertEqual(out.shape, [])
self.assertEqual(out, False)
x1 = paddle.full([], 2.0)
out1 = paddle.equal(x1, 2.0)
self.assertEqual(out1.shape, [])
self.assertEqual(out1, True)
def test_pow_scalar(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = paddle.pow(x, 2.0) out = paddle.pow(x, 2.0)
...@@ -2235,7 +2246,17 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2235,7 +2246,17 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
@prog_scope() @prog_scope()
def test_pow_factor(self): def test_equal_scalar(self):
x = paddle.rand([])
out = paddle.equal(x, 2.0)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], False)
@prog_scope()
def test_pow_scalar(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = paddle.pow(x, 2.0) out = paddle.pow(x, 2.0)
......
...@@ -457,7 +457,7 @@ def equal(x, y, name=None): ...@@ -457,7 +457,7 @@ def equal(x, y, name=None):
) )
) )
if not isinstance(y, Variable): if not isinstance(y, Variable):
y = full(shape=[1], dtype=x.dtype, fill_value=y) y = full(shape=[], dtype=x.dtype, fill_value=y)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.equal(x, y) return _C_ops.equal(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册