diff --git a/paddle/fluid/operators/reverse_op.cc b/paddle/fluid/operators/reverse_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a20f7d231fa9ea313581ac0629a87fa5f4a88ce5 --- /dev/null +++ b/paddle/fluid/operators/reverse_op.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reverse_op.h" +#include + +namespace paddle { +namespace operators { + +class ReverseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); + const auto& x_dims = ctx->GetInputDim("X"); + const auto& axis = ctx->Attrs().Get>("axis"); + PADDLE_ENFORCE(!axis.empty(), "'axis' can not be empty."); + for (int a : axis) { + PADDLE_ENFORCE_LT(a, x_dims.size(), + "The axis must be less than input tensor's rank."); + } + ctx->SetOutputDim("Out", x_dims); + } +}; + +class ReverseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The LoDTensor to be flipped."); + AddOutput("Out", "The LoDTensor after flipping."); + AddAttr>( + "axis", "The axises that along which order of elements is reversed."); + AddComment(R"DOC( + Reverse Operator. + + Reverse the order of elements in the input LoDTensor along given axises. + + Case 1: + Given + X = [[1, 2, 3, 4, 5] + [6, 7, 8, 9, 10] + [11, 12, 13, 14, 15]], + and + axis = [0], + we get: + Out = [[11, 12, 13, 14, 15] + [6, 7, 8, 9, 10] + [1, 2, 3, 4, 5]]. + + Case 2: + Given + X = [[[1, 2, 3, 4] + [5, 6, 7, 8]] + [[9, 10, 11, 12] + [13, 14, 15, 16]]], + and + axis = [0, 2], + we get: + Out = [[[12, 11, 10, 9] + [16, 15, 14, 13]] + [[4, 3, 2, 1] + [8, 7, 6, 5]]], + )DOC"); + } +}; + +class ReverseGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto* grad_op = new framework::OpDesc(); + grad_op->SetType("reverse"); + grad_op->SetInput("X", OutputGrad("Out")); + grad_op->SetOutput("Out", InputGrad("X")); + grad_op->SetAttr("axis", GetAttr("axis")); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(reverse, ops::ReverseOp, ops::ReverseOpMaker, + ops::ReverseGradMaker); +REGISTER_OPERATOR(reverse_grad, ops::ReverseOp); +REGISTER_OP_CPU_KERNEL( + reverse, ops::ReverseKernel, + ops::ReverseKernel, + ops::ReverseKernel, + ops::ReverseKernel, + ops::ReverseKernel, + ops::ReverseKernel) diff --git a/paddle/fluid/operators/reverse_op.cu b/paddle/fluid/operators/reverse_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..635c41529b38f2dd287b00ed2e5659e11f619e78 --- /dev/null +++ b/paddle/fluid/operators/reverse_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reverse_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + reverse, ops::ReverseKernel, + ops::ReverseKernel, + ops::ReverseKernel, + ops::ReverseKernel, + ops::ReverseKernel, + ops::ReverseKernel) diff --git a/paddle/fluid/operators/reverse_op.h b/paddle/fluid/operators/reverse_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9063cd59bba5c6307b55a500455908a5fd278390 --- /dev/null +++ b/paddle/fluid/operators/reverse_op.h @@ -0,0 +1,87 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +template +struct ReverseFunctor { + void operator()(const DeviceContext& context, const framework::LoDTensor& in, + framework::LoDTensor* out, const std::vector& axis) { + Eigen::array reverse_axis; + for (int i = 0; i < Rank; ++i) { + reverse_axis[i] = false; + } + for (int a : axis) { + reverse_axis[a] = true; + } + + auto in_eigen = framework::EigenTensor::From(in); + auto out_eigen = framework::EigenTensor::From(*out); + auto* dev = context.eigen_device(); + + out_eigen.device(*dev) = in_eigen.reverse(reverse_axis); + } +}; + +template +class ReverseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + const auto& axis = context.Attr>("axis"); + int rank = x->dims().size(); + auto& dev_ctx = context.template device_context(); + + switch (rank) { + case 1: + ReverseFunctor functor1; + functor1(dev_ctx, *x, out, axis); + break; + case 2: + ReverseFunctor functor2; + functor2(dev_ctx, *x, out, axis); + break; + case 3: + ReverseFunctor functor3; + functor3(dev_ctx, *x, out, axis); + break; + case 4: + ReverseFunctor functor4; + functor4(dev_ctx, *x, out, axis); + break; + case 5: + ReverseFunctor functor5; + functor5(dev_ctx, *x, out, axis); + break; + case 6: + ReverseFunctor functor6; + functor6(dev_ctx, *x, out, axis); + break; + default: + PADDLE_THROW( + "Reserve operator doesn't supports tensors whose ranks are greater " + "than 6."); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index be34cc81a5d5ca0e781e5984b6c3eeaa4e25eb90..75d3bf879703a1db1108eae45d879164e0024156 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -363,6 +363,40 @@ def zeros(shape, dtype, force_cpu=False): return fill_constant(value=0.0, **locals()) +def reverse(x, axis): + """ + **reverse** + + This function reverse the input 'x' along given axises. + + Args: + x(Vairbale): the input to be reversed. + axis(int|tuple|list): Axis that along which order of elements + is reversed. If it is a tuple or a list, reversing + will be apply on each axis in the tuple or list. + + Returns: + Variable: The reversed tensor. + + Examples: + .. code-block:: python + + out = fluid.layers.reverse(x=in, axis=0) + # or: + out = fluid.layers.reverse(x=in, axis=[0,1]) + """ + if isinstance(axis, int): + axis = [axis] + helper = LayerHelper("reverse", **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='reverse', + inputs={'Input': x}, + outputs={'Out': [out]}, + attrs={'axis': axis}) + return out + + def save(x, file_path, overwrite=True): """ Saves a variable as a file. diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f845575a02869f08299d76b5600074598ca27f6c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -0,0 +1,67 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestReverseOp(OpTest): + def initTestCase(self): + self.x = np.random.random((3, 4)).astype('float32') + self.axis = [0] + + def setUp(self): + self.initTestCase() + self.op_type = "reverse" + self.inputs = {"X": self.x} + self.attrs = {'axis': self.axis} + out = self.x + for a in self.axis: + out = np.flip(out, axis=a) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestCase0(TestReverseOp): + def initTestCase(self): + self.x = np.random.random((3, 4)).astype('float32') + self.axis = [1] + + +class TestCase1(TestReverseOp): + def initTestCase(self): + self.x = np.random.random((3, 4)).astype('float32') + self.axis = [0, 1] + + +class TestCase2(TestReverseOp): + def initTestCase(self): + self.x = np.random.random((3, 4, 5)).astype('float32') + self.axis = [0, 2] + + +class TestCase3(TestReverseOp): + def initTestCase(self): + self.x = np.random.random((3, 4, 5)).astype('float32') + self.axis = [1, 2] + + +if __name__ == '__main__': + unittest.main()