未验证 提交 7c73910e 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

Support 0D for equal tensor with scalar (#50857)

上级 2dec64d0
......@@ -978,7 +978,18 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([]))
def test_pow_factor(self):
def test_equal_scalar(self):
x = paddle.rand([])
out = paddle.equal(x, 2.0)
self.assertEqual(out.shape, [])
self.assertEqual(out, False)
x1 = paddle.full([], 2.0)
out1 = paddle.equal(x1, 2.0)
self.assertEqual(out1.shape, [])
self.assertEqual(out1, True)
def test_pow_scalar(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.pow(x, 2.0)
......@@ -2235,7 +2246,17 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ())
@prog_scope()
def test_pow_factor(self):
def test_equal_scalar(self):
x = paddle.rand([])
out = paddle.equal(x, 2.0)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], False)
@prog_scope()
def test_pow_scalar(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.pow(x, 2.0)
......
......@@ -457,7 +457,7 @@ def equal(x, y, name=None):
)
)
if not isinstance(y, Variable):
y = full(shape=[1], dtype=x.dtype, fill_value=y)
y = full(shape=[], dtype=x.dtype, fill_value=y)
if in_dygraph_mode():
return _C_ops.equal(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册