未验证 提交 cb22a5c7 编写于 作者: C caozhou 提交者: GitHub

support flip 0D (#49460)

上级 4458a1e5
......@@ -101,6 +101,9 @@ void FlipKernel(const Context& dev_ctx,
DenseTensor* out) {
const size_t total_dims = x.dims().size();
switch (total_dims) {
case 0:
LaunchFlipCudaKernel<T, Context, 0>(dev_ctx, x, axis, out);
break;
case 1:
LaunchFlipCudaKernel<T, Context, 1>(dev_ctx, x, axis, out);
break;
......
......@@ -454,6 +454,15 @@ class TestSundryAPI(unittest.TestCase):
paddle.disable_static()
self.x = paddle.rand([])
def test_flip(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.flip(x, axis=[])
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
def test_linear(self):
x = paddle.randn([3, 2])
w = paddle.full(shape=[2, 4], fill_value=0.5)
......@@ -753,6 +762,18 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.enable_static()
self.exe = paddle.static.Executor()
@prog_scope()
def test_flip(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.flip(x, axis=[])
paddle.static.append_backward(out)
program = paddle.static.default_main_program()
res1, res2 = self.exe.run(program, fetch_list=[x, out])
self.assertEqual(res1.shape, ())
self.assertEqual(res2.shape, ())
@prog_scope()
def test_pow_factor(self):
x = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册