From 7b1e30fcb30b7a30faa7bbabe50cfd304d27ca94 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 26 Oct 2021 17:58:35 +0800 Subject: [PATCH] roll_op: support Tensor as input for shifts (#36727) --- paddle/fluid/operators/roll_op.cc | 39 ++++++++++++------- paddle/fluid/operators/roll_op.cu | 20 ++++++++++ paddle/fluid/operators/roll_op.h | 17 ++++++++ .../fluid/tests/unittests/test_roll_op.py | 28 +++++++++++++ python/paddle/tensor/manipulation.py | 23 +++++++---- 5 files changed, 105 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index b6a8111592..b74dfc984a 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel { auto dims = ctx->Attrs().Get>("axis"); auto shifts = ctx->Attrs().Get>("shifts"); - if (dims.size() != 0) { - PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), - platform::errors::InvalidArgument( - "When dims.size() != 0, dims.size() " - "should be equal to " - "shifts.size(). But received " - "dims.size() = %d, shifts.size() = %d", - dims.size(), shifts.size())); - } else { - PADDLE_ENFORCE_EQ(shifts.size(), 1, - platform::errors::InvalidArgument( - "When dims.size() == 0, shifts.size() " - "should be equal to 1, But received " - "shifts.size() = %d", - shifts.size())); + if (!ctx->HasInput("ShiftsTensor")) { + if (dims.size() != 0) { + PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), + platform::errors::InvalidArgument( + "When dims.size() != 0, dims.size() " + "should be equal to " + "shifts.size(). But received " + "dims.size() = %d, shifts.size() = %d", + dims.size(), shifts.size())); + } else { + PADDLE_ENFORCE_EQ(shifts.size(), 1, + platform::errors::InvalidArgument( + "When dims.size() == 0, shifts.size() " + "should be equal to 1, But received " + "shifts.size() = %d", + shifts.size())); + } } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); @@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker { "The number of places by which the elements " "of the tensor are shifted.") .SetDefault({}); + AddInput("ShiftsTensor", + "The number of places by which the elements of the tensor " + "are shifted.") + .AsDispensable(); AddAttr>( "axis", "Axis along which to roll. It must have the same size " @@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType("roll_grad"); op->SetInput("X", this->Input("X")); + if (this->HasInput("ShiftsTensor")) { + op->SetInput("ShiftsTensor", this->Input("ShiftsTensor")); + } op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetAttrMap(this->Attrs()); diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index a170ce2fb1..d70bd58887 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -59,6 +59,16 @@ class RollKernel auto* in = context.Input("X"); auto* out = context.Output("Out"); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); auto* in_data = in->data(); @@ -134,6 +144,16 @@ class RollGradKernel auto* in = context.Input(framework::GradVarName("Out")); auto* out = context.Output(framework::GradVarName("X")); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); auto* in_data = in->data(); diff --git a/paddle/fluid/operators/roll_op.h b/paddle/fluid/operators/roll_op.h index e58ff521d8..affb5f226e 100644 --- a/paddle/fluid/operators/roll_op.h +++ b/paddle/fluid/operators/roll_op.h @@ -16,6 +16,8 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { @@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel { auto& input = input_var->Get(); auto* output = output_var->GetMutable(); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); std::vector out_vec; @@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel { auto& input = input_var->Get(); auto* output = output_var->GetMutable(); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); std::vector out_vec; diff --git a/python/paddle/fluid/tests/unittests/test_roll_op.py b/python/paddle/fluid/tests/unittests/test_roll_op.py index 99121d2953..bca7665b81 100644 --- a/python/paddle/fluid/tests/unittests/test_roll_op.py +++ b/python/paddle/fluid/tests/unittests/test_roll_op.py @@ -122,6 +122,34 @@ class TestRollAPI(unittest.TestCase): self.assertRaises(ValueError, test_axis_out_range) + def test_shifts_as_tensor_dygraph(self): + with fluid.dygraph.guard(): + x = paddle.arange(9).reshape([3, 3]) + shape = paddle.shape(x) + shifts = shape // 2 + axes = [0, 1] + out = paddle.roll(x, shifts=shifts, axis=axes).numpy() + expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]]) + self.assertTrue(np.allclose(out, expected_out)) + + def test_shifts_as_tensor_static(self): + with program_guard(Program(), Program()): + x = paddle.arange(9).reshape([3, 3]).astype('float32') + shape = paddle.shape(x) + shifts = shape // 2 + axes = [0, 1] + out = paddle.roll(x, shifts=shifts, axis=axes) + expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]]) + + exe = fluid.Executor(fluid.CPUPlace()) + [out_np] = exe.run(fetch_list=[out]) + self.assertTrue(np.allclose(out_np, expected_out)) + + if paddle.is_compiled_with_cuda(): + exe = fluid.Executor(fluid.CPUPlace()) + [out_np] = exe.run(fetch_list=[out]) + self.assertTrue(np.allclose(out_np, expected_out)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5f7588cb2a..9b9b2d9431 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -696,15 +696,24 @@ def roll(x, shifts, axis=None, name=None): helper = LayerHelper("roll", **locals()) check_type(axis, 'axis', (list, tuple), 'roll') - check_type(shifts, 'shifts', (list, tuple), 'roll') + out = helper.create_variable_for_type_inference(x.dtype) - helper.append_op( - type='roll', - inputs={'X': x}, - outputs={'Out': out}, - attrs={'axis': axis, - 'shifts': shifts}) + if isinstance(shifts, Variable): + helper.append_op( + type='roll', + inputs={'X': x, + "ShiftsTensor": shifts}, + outputs={'Out': out}, + attrs={'axis': axis}) + else: + check_type(shifts, 'shifts', (list, tuple), 'roll') + helper.append_op( + type='roll', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'axis': axis, + 'shifts': shifts}) return out -- GitLab