diff --git a/paddle/fluid/operators/reverse_op.cc b/paddle/fluid/operators/reverse_op.cc index 5393e277a6a29e103f2073ef8c84be55deab4613..4e99be4e521ac2355848f50bfb271be8c464868c 100644 --- a/paddle/fluid/operators/reverse_op.cc +++ b/paddle/fluid/operators/reverse_op.cc @@ -26,6 +26,15 @@ namespace operators { class ReverseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ReverseOpVarTypeInference : public framework::VarTypeInference { @@ -42,7 +51,8 @@ class ReverseOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The LoDTensor to be flipped."); AddOutput("Out", "The LoDTensor after flipping."); AddAttr>( - "axis", "The axises that along which order of elements is reversed."); + "axis", "The axises that along which order of elements is reversed.") + .SupportTensor(); AddComment(R"DOC( Reverse Operator. diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 9dd54e7fb44a5cc3ccbe47eb8166ad599681d747..f71674ec91b16ed6495072113a4034b35172fc9d 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2161,7 +2161,7 @@ backward: reshape_grad - api : reverse - args : (Tensor x, int[] axis) + args : (Tensor x, IntArray axis) output : Tensor infer_meta : func : ReverseInferMeta @@ -2170,7 +2170,7 @@ backward : reverse_grad - api : reverse_array - args : (Tensor[] x, int[] axis) + args : (Tensor[] x, IntArray axis) output : Tensor[]{x.size()} infer_meta : func : ReverseArrayInferMeta diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 176491491b90c6c1f7c5bfaadfaf41d114b2c8ac..26884f260f75de5e1c92283fc44fee395071fde0 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1963,8 +1963,8 @@ inplace : (out_grad -> x_grad) - backward_api : reverse_array_grad - forward : reverse_array (Tensor[] x, int[] axis) -> Tensor[](out) - args : (Tensor[] out_grad, int[] axis) + forward : reverse_array (Tensor[] x, IntArray axis) -> Tensor[](out) + args : (Tensor[] out_grad, IntArray axis) output : Tensor[](x_grad){out_grad.size()} infer_meta : func : ReverseArrayInferMeta @@ -1972,8 +1972,8 @@ func : reverse - backward_api : reverse_grad - forward : reverse (Tensor x, int[] axis) -> Tensor(out) - args : (Tensor out_grad, int[] axis) + forward : reverse (Tensor x, IntArray axis) -> Tensor(out) + args : (Tensor out_grad, IntArray axis) output : Tensor(x_grad) infer_meta : func : ReverseInferMeta diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8eba84a24e0ede70f538025799ade5e6eb90b036..76142c4eea1af4b0521635f00c60509c3c1daf4a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2744,13 +2744,22 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, } void ReverseInferMeta(const MetaTensor& x, - const std::vector& axis, - MetaTensor* out) { - PADDLE_ENFORCE_NE(axis.empty(), + const IntArray& axis, + MetaTensor* out, + MetaConfig config) { + // NOTE(Aurelius84): In Reverse Op, output TensorMeta is always same + // as input, so we only verify axis when it is not from Tensor or in + // runtime. + if (!config.is_runtime && axis.FromTensor()) { + out->share_meta(x); + return; + } + auto& axis_data = axis.GetData(); + PADDLE_ENFORCE_NE(axis_data.empty(), true, phi::errors::InvalidArgument("'axis' can not be empty.")); const auto& x_dims = x.dims(); - for (int a : axis) { + for (int a : axis_data) { PADDLE_ENFORCE_LT(a, x_dims.size(), phi::errors::OutOfRange( @@ -2771,22 +2780,27 @@ void ReverseInferMeta(const MetaTensor& x, } void ReverseArrayInferMeta(const std::vector& x, - const std::vector& axis, - std::vector out) { + const IntArray& axis, + std::vector out, + MetaConfig config) { + if (!config.is_runtime && axis.FromTensor()) { + return; + } + auto& axis_data = axis.GetData(); PADDLE_ENFORCE_EQ( - axis.size(), + axis_data.size(), 1, phi::errors::InvalidArgument( "The size of axis must be 1 when the Input(X) is LoDTensorArray, " "but received %d.", - axis.size())); + axis_data.size())); PADDLE_ENFORCE_EQ( - axis[0], + axis_data[0], 0, phi::errors::InvalidArgument("The value of axis should be 1 when " "the Input(X) is LoDTensorArray, " "but received %d.", - axis[0])); + axis_data[0])); } void RollInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 1479bcf64930ae3d17d547b621d31c9c2c4150c1..736360e7400656add0efbb6719a3dd808e0f2a6b 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -388,12 +388,14 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, MetaConfig config = MetaConfig()); void ReverseInferMeta(const MetaTensor& x, - const std::vector& axis, - MetaTensor* out); + const IntArray& axis, + MetaTensor* out, + MetaConfig config = MetaConfig()); void ReverseArrayInferMeta(const std::vector& x, - const std::vector& axis, - std::vector out); + const IntArray& axis, + std::vector out, + MetaConfig config = MetaConfig()); void RollInferMeta(const MetaTensor& x, const IntArray& shifts, diff --git a/paddle/phi/kernels/impl/reverse_kernel_impl.h b/paddle/phi/kernels/impl/reverse_kernel_impl.h index 16ee333f83fa911c92eca6eb38f921d345f6da54..0580ab7db01e8c1336824c2bf46750d01abb8b1d 100644 --- a/paddle/phi/kernels/impl/reverse_kernel_impl.h +++ b/paddle/phi/kernels/impl/reverse_kernel_impl.h @@ -25,12 +25,13 @@ struct ReverseFunctor { void operator()(const Context& dev_ctx, const DenseTensor& in, DenseTensor* out, - const std::vector& axis) { + const IntArray& axis) { + auto& axis_data = axis.GetData(); Eigen::DSizes reverse_axis; for (int i = 0; i < Rank; ++i) { reverse_axis[i] = false; } - for (int a : axis) { + for (int a : axis_data) { if (a >= 0) { reverse_axis[a] = true; } else { @@ -50,7 +51,7 @@ struct ReverseFunctor { template void ReverseKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& axis, + const IntArray& axis, DenseTensor* out) { dev_ctx.template Alloc(out); int rank = x.dims().size(); diff --git a/paddle/phi/kernels/reverse_kernel.cc b/paddle/phi/kernels/reverse_kernel.cc index d89e68e7389fd2e6f70dba3f72fa15ceb2b6d20c..b42923ac5dde47e036bab227f033988df1229521 100644 --- a/paddle/phi/kernels/reverse_kernel.cc +++ b/paddle/phi/kernels/reverse_kernel.cc @@ -23,7 +23,7 @@ namespace phi { template void ReverseArrayKernel(const Context& dev_ctx, const std::vector& x, - const std::vector& axis, + const IntArray& axis, std::vector out) { PADDLE_ENFORCE_EQ( x.size(), diff --git a/paddle/phi/kernels/reverse_kernel.h b/paddle/phi/kernels/reverse_kernel.h index 2b81f4018c25d896745637f032c25dbe5551ef26..1ccfa344d5c92733c883883dbcc701547ff9bdf4 100644 --- a/paddle/phi/kernels/reverse_kernel.h +++ b/paddle/phi/kernels/reverse_kernel.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -23,13 +24,13 @@ namespace phi { template void ReverseKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& axis, + const IntArray& axis, DenseTensor* out); template void ReverseArrayKernel(const Context& dev_ctx, const std::vector& x, - const std::vector& axis, + const IntArray& axis, std::vector out); } // namespace phi diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index f823f98a53610f459d24add1321f4e0783ad1dae..c073b003345c1f504b30b78d36ba7f3b1f179795 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1279,7 +1279,7 @@ def reverse(x, axis): check_variable_and_dtype(x, 'x', ('float32', 'float64', 'int32', 'int64', 'uint8'), 'reverse') - check_type(axis, 'axis', (int, tuple, list), 'reverse') + check_type(axis, 'axis', (int, tuple, list, Variable), 'reverse') if isinstance(axis, int): axis = [axis] if in_dygraph_mode(): diff --git a/python/paddle/fluid/tests/unittests/test_attribute_var.py b/python/paddle/fluid/tests/unittests/test_attribute_var.py index 5d0316edfa4e519b203854f423186deac51f2634..cabbfb826b53b6b464962fb74a126420b4075481 100644 --- a/python/paddle/fluid/tests/unittests/test_attribute_var.py +++ b/python/paddle/fluid/tests/unittests/test_attribute_var.py @@ -37,6 +37,9 @@ class UnittestBase(unittest.TestCase): self.shapes = None self.save_path = None + def path_prefix(self): + return type(self).__name__ + def infer_prog(self): config = paddle_infer.Config(self.save_path + '.pdmodel', self.save_path + '.pdiparams') @@ -44,15 +47,21 @@ class UnittestBase(unittest.TestCase): input_names = predictor.get_input_names() for i, shape in enumerate(self.shapes): input_handle = predictor.get_input_handle(input_names[i]) - fake_input = np.random.randn(*shape).astype("float32") + self.fake_input = np.random.randn(*shape).astype("float32") input_handle.reshape(shape) - input_handle.copy_from_cpu(fake_input) + input_handle.copy_from_cpu(self.fake_input) predictor.run() output_names = predictor.get_output_names() - output_handle = predictor.get_output_handle(output_names[0]) - output_data = output_handle.copy_to_cpu() + res = [] + for out_name in output_names: + output_handle = predictor.get_output_handle(out_name) + output_data = output_handle.copy_to_cpu() + res.append(output_data) + + if len(output_names) == 1: + res = res[0] - return output_data + return res class TestDropout(UnittestBase): diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py index 57b87bd896347d681bfdddaeae739a86e7d83294..60f2d0cb1ae719d90009baadf32cec42de5e88b8 100644 --- a/python/paddle/fluid/tests/unittests/test_reverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import unittest import numpy as np from op_test import OpTest @@ -21,6 +22,9 @@ import paddle import paddle.fluid as fluid from paddle.fluid import core +from paddle.fluid.framework import program_guard, Program +from test_attribute_var import UnittestBase + class TestReverseOp(OpTest): @@ -195,6 +199,130 @@ class TestReverseLoDTensorArray(unittest.TestCase): self.run_program(arr_len=3, axis=1) +class TestReverseAxisTensor(UnittestBase): + + def init_info(self): + self.shapes = [[2, 3, 4]] + self.save_path = os.path.join(self.temp_dir.name, self.path_prefix()) + + def test_static(self): + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) # [2,3,10] + + out = self.call_func(feat) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out)) + self.assertTrue(self.var_prefix() in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[feat, out]) + gt = res[0][::-1, :, ::-1] + np.testing.assert_allclose(res[1], gt) + + paddle.static.save_inference_model(self.save_path, [x], [feat, out], + exe) + # Test for Inference Predictor + infer_outs = self.infer_prog() + gt = infer_outs[0][::-1, :, ::-1] + np.testing.assert_allclose(infer_outs[1], gt) + + def path_prefix(self): + return 'reverse_tensor' + + def var_prefix(self): + return "Var[" + + def call_func(self, x): + # axes is a Variable + axes = paddle.assign([0, 2]) + out = paddle.fluid.layers.reverse(x, axes) + return out + + +class TestReverseAxisListTensor(TestReverseAxisTensor): + + def path_prefix(self): + return 'reverse_tensors' + + def var_prefix(self): + return "Vars[" + + def call_func(self, x): + # axes is a List[Variable] + axes = [paddle.assign([0]), paddle.assign([2])] + out = paddle.fluid.layers.reverse(x, axes) + return out + + +class TestAReverseEagerAPI(UnittestBase): + + def test_api(self): + paddle.disable_static() + x = paddle.randn([4, 10]) + y = paddle.randn([4, 10]) + + out = paddle._C_ops.final_state_reverse_array([x, y], [0]) + np.testing.assert_allclose(x.numpy(), out[1].numpy()) + np.testing.assert_allclose(y.numpy(), out[0].numpy()) + + paddle.enable_static() + + +class TestReverseTensorArrayAxisTensor(UnittestBase): + + def init_info(self): + self.shapes = [[2, 3, 4]] + self.save_path = os.path.join(self.temp_dir.name, + 'reverse_tensor_array') + + def test_static(self): + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(4, 2) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) # [2,3,10] + # tensor_array.shape: [[2,3,10], [2,3,10]] + tensor_array = paddle.fluid.layers.create_array(dtype='float32') + idx0 = paddle.full(shape=[1], fill_value=0, dtype="int64") + val0 = paddle.randn([2, 3, 2]) + paddle.fluid.layers.array_write(val0, idx0, tensor_array) + idx1 = paddle.full(shape=[1], fill_value=1, dtype="int64") + paddle.fluid.layers.array_write(feat, idx1, tensor_array) + # axes is a Variable + axes = paddle.assign([0]) + # tensor_array.shape: [[2,3,10], [2,3,10]] + reverse_array = paddle.fluid.layers.reverse(tensor_array, axes) + + out, _ = paddle.fluid.layers.tensor_array_to_tensor(reverse_array, + axis=0) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out)) + self.assertTrue("Var[" in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[val0, feat, out]) + np.testing.assert_allclose(res[1], res[-1][0:2]) + np.testing.assert_allclose(res[0], res[-1][2:4]) + + paddle.static.save_inference_model(self.save_path, [x], + [val0, feat, out], exe) + # Test for Inference Predictor + infer_outs = self.infer_prog() + np.testing.assert_allclose(infer_outs[1], infer_outs[-1][0:2]) + np.testing.assert_allclose(infer_outs[0], infer_outs[-1][2:4]) + + if __name__ == '__main__': paddle.enable_static() unittest.main()