未验证 提交 c3e19549 编写于 作者: M mapingshuo 提交者: GitHub

make reverse op support negative axis (#21925)

* make reverse op support negative axis
上级 03479469
......@@ -31,7 +31,13 @@ class ReverseOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(!axis.empty(), "'axis' can not be empty.");
for (int a : axis) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
"The axis must be less than input tensor's rank.");
paddle::platform::errors::OutOfRange(
"The axis must be less than input tensor's rank."));
PADDLE_ENFORCE_GE(
a, -x_dims.size(),
paddle::platform::errors::OutOfRange(
"The axis must be greater than the negative number of "
"input tensor's rank."));
}
ctx->SetOutputDim("Out", x_dims);
}
......
......@@ -28,7 +28,11 @@ struct ReverseFunctor {
reverse_axis[i] = false;
}
for (int a : axis) {
if (a >= 0) {
reverse_axis[a] = true;
} else {
reverse_axis[Rank + a] = true;
}
}
auto in_eigen = framework::EigenTensor<T, Rank>::From(in);
......
......@@ -47,23 +47,47 @@ class TestCase0(TestReverseOp):
self.axis = [1]
class TestCase0(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 40)).astype('float64')
self.axis = [-1]
class TestCase1(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 40)).astype('float64')
self.axis = [0, 1]
class TestCase0(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 40)).astype('float64')
self.axis = [0, -1]
class TestCase2(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 4, 10)).astype('float64')
self.axis = [0, 2]
class TestCase2(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 4, 10)).astype('float64')
self.axis = [0, -2]
class TestCase3(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 4, 10)).astype('float64')
self.axis = [1, 2]
class TestCase3(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 4, 10)).astype('float64')
self.axis = [-1, -2]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册