提交 7d315faa 编写于 作者: D DesmonDay

fix bug

上级 e6254e35
......@@ -1147,42 +1147,6 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_array_equal(out3_2, np.asarray(1))
@prog_scope()
<<<<<<< HEAD
def test_sort(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.sort(x1, axis=-1)
paddle.static.append_backward(out1)
x2 = paddle.rand([])
x2.stop_gradient = False
out2 = paddle.sort(x2, axis=0)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, out2])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
@prog_scope()
def test_argsort(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.argsort(x1, axis=-1)
paddle.static.append_backward(out1)
x2 = paddle.rand([])
x2.stop_gradient = False
out2 = paddle.argsort(x2, axis=0)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, out2])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
=======
def test_reshape_list(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
......@@ -1253,7 +1217,42 @@ class TestSundryAPIStatic(unittest.TestCase):
res1, res2 = self.exe.run(program, fetch_list=[x, out])
self.assertEqual(res1.shape, ())
self.assertEqual(res2.shape, ())
>>>>>>> c123dd1e4032efdbfff0bf0c35a58155f2d6e1d9
@prog_scope()
def test_sort(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.sort(x1, axis=-1)
paddle.static.append_backward(out1)
x2 = paddle.rand([])
x2.stop_gradient = False
out2 = paddle.sort(x2, axis=0)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, out2])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
@prog_scope()
def test_argsort(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.argsort(x1, axis=-1)
paddle.static.append_backward(out1)
x2 = paddle.rand([])
x2.stop_gradient = False
out2 = paddle.argsort(x2, axis=0)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, out2])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册