From c3e19549184e16dd04a560fc18edb098f5fc5883 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Fri, 27 Dec 2019 16:23:43 +0800 Subject: [PATCH] make reverse op support negative axis (#21925) * make reverse op support negative axis --- paddle/fluid/operators/reverse_op.cc | 8 ++++++- paddle/fluid/operators/reverse_op.h | 6 ++++- .../fluid/tests/unittests/test_reverse_op.py | 24 +++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/reverse_op.cc b/paddle/fluid/operators/reverse_op.cc index 1c3b3d3c2f..f50450bc2a 100644 --- a/paddle/fluid/operators/reverse_op.cc +++ b/paddle/fluid/operators/reverse_op.cc @@ -31,7 +31,13 @@ class ReverseOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(!axis.empty(), "'axis' can not be empty."); for (int a : axis) { PADDLE_ENFORCE_LT(a, x_dims.size(), - "The axis must be less than input tensor's rank."); + paddle::platform::errors::OutOfRange( + "The axis must be less than input tensor's rank.")); + PADDLE_ENFORCE_GE( + a, -x_dims.size(), + paddle::platform::errors::OutOfRange( + "The axis must be greater than the negative number of " + "input tensor's rank.")); } ctx->SetOutputDim("Out", x_dims); } diff --git a/paddle/fluid/operators/reverse_op.h b/paddle/fluid/operators/reverse_op.h index 9063cd59bb..c5535a6dc2 100644 --- a/paddle/fluid/operators/reverse_op.h +++ b/paddle/fluid/operators/reverse_op.h @@ -28,7 +28,11 @@ struct ReverseFunctor { reverse_axis[i] = false; } for (int a : axis) { - reverse_axis[a] = true; + if (a >= 0) { + reverse_axis[a] = true; + } else { + reverse_axis[Rank + a] = true; + } } auto in_eigen = framework::EigenTensor::From(in); diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py index 09252f47d2..80f0562132 100644 --- a/python/paddle/fluid/tests/unittests/test_reverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -47,23 +47,47 @@ class TestCase0(TestReverseOp): self.axis = [1] +class TestCase0(TestReverseOp): + def initTestCase(self): + self.x = np.random.random((3, 40)).astype('float64') + self.axis = [-1] + + class TestCase1(TestReverseOp): def initTestCase(self): self.x = np.random.random((3, 40)).astype('float64') self.axis = [0, 1] +class TestCase0(TestReverseOp): + def initTestCase(self): + self.x = np.random.random((3, 40)).astype('float64') + self.axis = [0, -1] + + class TestCase2(TestReverseOp): def initTestCase(self): self.x = np.random.random((3, 4, 10)).astype('float64') self.axis = [0, 2] +class TestCase2(TestReverseOp): + def initTestCase(self): + self.x = np.random.random((3, 4, 10)).astype('float64') + self.axis = [0, -2] + + class TestCase3(TestReverseOp): def initTestCase(self): self.x = np.random.random((3, 4, 10)).astype('float64') self.axis = [1, 2] +class TestCase3(TestReverseOp): + def initTestCase(self): + self.x = np.random.random((3, 4, 10)).astype('float64') + self.axis = [-1, -2] + + if __name__ == '__main__': unittest.main() -- GitLab