diff --git a/paddle/fluid/operators/reverse_op.cc b/paddle/fluid/operators/reverse_op.cc index 1f0184197472369f911dd4ec10f19f0a8d637672..8b2b9f464b407ba27333e354854a70a233986853 100644 --- a/paddle/fluid/operators/reverse_op.cc +++ b/paddle/fluid/operators/reverse_op.cc @@ -24,14 +24,31 @@ class ReverseOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::InvalidArgument("Input(X) should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument("Output(Out) should not be null")); - const auto& x_dims = ctx->GetInputDim("X"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Reverse"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Reverse"); + + auto x_var_type = ctx->GetInputsVarType("X")[0]; const auto& axis = ctx->Attrs().Get>("axis"); + if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { + PADDLE_ENFORCE_EQ( + axis.size(), 1, + platform::errors::InvalidArgument( + "The size of axis must be 1 when the Input(X) is LoDTensorArray, " + "but received %d.", + axis.size())); + PADDLE_ENFORCE_EQ(axis[0], 0, platform::errors::InvalidArgument( + "The value of axis should be 1 when " + "the Input(X) is LoDTensorArray, " + "but received %d.", + axis[0])); + // In runtime, shape is determined by RunImpl. + if (!ctx->IsRuntime()) { + const auto& x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", x_dims); + } + return; + } + const auto& x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_NE(axis.empty(), true, platform::errors::InvalidArgument( "'axis' can not be empty.")); for (int a : axis) { @@ -51,6 +68,14 @@ class ReverseOp : public framework::OperatorWithKernel { } }; +class ReverseOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override { + ctx->SetOutputType("Out", ctx->GetInputType("X")); + ctx->SetOutputDataType("Out", ctx->GetInputDataType("X")); + } +}; + class ReverseOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -111,8 +136,9 @@ class ReverseGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(reverse, ops::ReverseOp, ops::ReverseOpMaker, ops::ReverseGradMaker, - ops::ReverseGradMaker); -REGISTER_OPERATOR(reverse_grad, ops::ReverseOp); + ops::ReverseGradMaker, + ops::ReverseOpVarTypeInference); +REGISTER_OPERATOR(reverse_grad, ops::ReverseOp, ops::ReverseOpVarTypeInference); REGISTER_OP_CPU_KERNEL( reverse, ops::ReverseKernel, ops::ReverseKernel, diff --git a/paddle/fluid/operators/reverse_op.h b/paddle/fluid/operators/reverse_op.h index 24489e618dba41241648d3f5c844b8fda6bcb54a..2813f7a4864a9ee84cefd8c824ee6f277b192dec 100644 --- a/paddle/fluid/operators/reverse_op.h +++ b/paddle/fluid/operators/reverse_op.h @@ -47,10 +47,30 @@ template class ReverseKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto* x_var = context.InputVar("X"); + const auto& axis = context.Attr>("axis"); + if (x_var->IsType()) { + auto& x_array = x_var->Get(); + auto* out_array = context.Output("Out"); + + out_array->resize(x_array.size()); + for (size_t offset = 0; offset < x_array.size(); offset++) { + auto& x_tensor = x_array.at(offset); + PADDLE_ENFORCE_GT( + x_tensor.memory_size(), 0, + platform::errors::PreconditionNotMet( + "The input LoDTensorArray X[%d] holds no memory.", offset)); + auto out_offset = x_array.size() - offset - 1; + auto* out_tensor = &out_array->at(out_offset); + + out_tensor->set_lod(x_tensor.lod()); + TensorCopy(x_tensor, context.GetPlace(), out_tensor); + } + return; + } auto* x = context.Input("X"); auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); - const auto& axis = context.Attr>("axis"); int rank = x->dims().size(); auto& dev_ctx = context.template device_context(); diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index e81af19b6fec2d05593f2035b0a5d3059b2c9702..a7f10584b73f99e65f3a39a94c77e0b5980e2b0c 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1094,11 +1094,37 @@ def reverse(x, axis): The OP reverses the tensor :attr:`x` along the given :attr:`axis`. + .. code-block:: text + + Case 1: + + Given a LoDTensor: + x = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + axis = [0, 1] + + Then: + output = [[8, 7, 6], [5, 4, 3], [2, 1, 0]] + + Case 2: + + Given a LoDTensorArray: + x = {[[0, 1], [2, 3]], + [[4, 5, 6]], + [[7],[8], [9]]} + axis = 0 + + Then: + output = {[[7],[8], [9]], + [[4, 5, 6]], + [[0, 1], [2, 3]]} + Parameters: - x (Variable): A tensor to be reversed, its data type supports bool, float32, float64, int32, int64 and uint8. + x (Variable): A tensor or LoDTensorArray to be reversed, its data type supports bool, float32, float64, int32, int64 and uint8. + If input is a LoDTensorArray, returns a new reversed LoDTensorArray without changing the internal order of each inner tensor. axis (int|tuple|list): A dimension or a set of dimensions of :attr:`x` to reverse. Must be in the range [-rank( :attr:`x` ), rank( :attr:`x` )). If it is a tuple or a list, reversing - will be apply on each axis in the tuple or list. + will be apply on each axis in the tuple or list. If input is a LoDTensorArray, the value of axis shall be 0, or a + list [0] or tuple (0, ) with shape [1]. Returns: Variable: The reversed tensor with the same shape and data type as :attr:`x`. @@ -1111,6 +1137,16 @@ def reverse(x, axis): data = fluid.layers.assign(np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype='float32')) # [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]] result1 = fluid.layers.reverse(data, 0) # [[6., 7., 8.], [3., 4., 5.], [0., 1., 2.]] result2 = fluid.layers.reverse(data, [0, 1]) # [[8., 7., 6.], [5., 4., 3.], [2., 1., 0.]] + + # example of LoDTensorArray + data1 = fluid.layers.assign(np.array([[0, 1, 2]], dtype='float32')) + data2 = fluid.layers.assign(np.array([[3, 4, 5]], dtype='float32')) + tensor_array = fluid.layers.create_array(dtype='float32') + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + fluid.layers.array_write(data1, i, tensor_array) + fluid.layers.array_write(data2, i+1, tensor_array) + + reversed_tensor_array = fluid.layers.reverse(tensor_array, 0) # {[[3, 4, 5]], [[0, 1, 2]]} """ check_variable_and_dtype( x, 'x', ('float32', 'float64', 'int32', 'int64', 'uint8'), 'reverse') diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py index 21d15b05715f70104dbbd89a75acc484a6a941a1..5aaf0b85d504feb233aaf146250cd8ca7c6648e1 100644 --- a/python/paddle/fluid/tests/unittests/test_reverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -110,5 +110,76 @@ class TestCase4(unittest.TestCase): self.assertRaises(core.EnforceNotMet, _run_program) +class TestReverseLoDTensorArray(unittest.TestCase): + def setUp(self): + self.shapes = [[5, 25], [5, 20], [5, 5]] + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.exe = fluid.Executor(self.place) + + def run_program(self, arr_len, axis=0): + main_program = fluid.Program() + + with fluid.program_guard(main_program): + inputs, inputs_data = [], [] + for i in range(arr_len): + x = fluid.data("x%s" % i, self.shapes[i], dtype='float32') + x.stop_gradient = False + inputs.append(x) + inputs_data.append( + np.random.random(self.shapes[i]).astype('float32')) + + tensor_array = fluid.layers.create_array(dtype='float32') + for i in range(arr_len): + idx = fluid.layers.array_length(tensor_array) + fluid.layers.array_write(inputs[i], idx, tensor_array) + + reverse_array = fluid.layers.reverse(tensor_array, axis=axis) + output, _ = fluid.layers.tensor_array_to_tensor(reverse_array) + loss = fluid.layers.reduce_sum(output) + fluid.backward.append_backward(loss) + input_grads = list( + map(main_program.global_block().var, + [x.name + "@GRAD" for x in inputs])) + + feed_dict = dict(zip([x.name for x in inputs], inputs_data)) + res = self.exe.run(main_program, + feed=feed_dict, + fetch_list=input_grads + [output.name]) + + return np.hstack(inputs_data[::-1]), res + + def test_case1(self): + gt, res = self.run_program(arr_len=3) + self.check_output(gt, res) + # test with tuple type of axis + gt, res = self.run_program(arr_len=3, axis=(0, )) + self.check_output(gt, res) + + def test_case2(self): + gt, res = self.run_program(arr_len=1) + self.check_output(gt, res) + # test with list type of axis + gt, res = self.run_program(arr_len=1, axis=[0]) + self.check_output(gt, res) + + def check_output(self, gt, res): + arr_len = len(res) - 1 + reversed_array = res[-1] + # check output + self.assertTrue(np.array_equal(gt, reversed_array)) + # check grad + for i in range(arr_len): + self.assertTrue(np.array_equal(res[i], np.ones_like(res[i]))) + + def test_raise_error(self): + # The len(axis) should be 1 is input(X) is LoDTensorArray + with self.assertRaises(Exception): + self.run_program(arr_len=3, axis=[0, 1]) + # The value of axis should be 0 is input(X) is LoDTensorArray + with self.assertRaises(Exception): + self.run_program(arr_len=3, axis=1) + + if __name__ == '__main__': unittest.main()