未验证 提交 7b1e30fc 编写于 作者: F Feiyu Chan 提交者: GitHub

roll_op: support Tensor as input for shifts (#36727)

上级 236ed94d
......@@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel {
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
if (dims.size() != 0) {
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
platform::errors::InvalidArgument(
"When dims.size() != 0, dims.size() "
"should be equal to "
"shifts.size(). But received "
"dims.size() = %d, shifts.size() = %d",
dims.size(), shifts.size()));
} else {
PADDLE_ENFORCE_EQ(shifts.size(), 1,
platform::errors::InvalidArgument(
"When dims.size() == 0, shifts.size() "
"should be equal to 1, But received "
"shifts.size() = %d",
shifts.size()));
if (!ctx->HasInput("ShiftsTensor")) {
if (dims.size() != 0) {
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
platform::errors::InvalidArgument(
"When dims.size() != 0, dims.size() "
"should be equal to "
"shifts.size(). But received "
"dims.size() = %d, shifts.size() = %d",
dims.size(), shifts.size()));
} else {
PADDLE_ENFORCE_EQ(shifts.size(), 1,
platform::errors::InvalidArgument(
"When dims.size() == 0, shifts.size() "
"should be equal to 1, But received "
"shifts.size() = %d",
shifts.size()));
}
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
......@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"The number of places by which the elements "
"of the tensor are shifted.")
.SetDefault({});
AddInput("ShiftsTensor",
"The number of places by which the elements of the tensor "
"are shifted.")
.AsDispensable();
AddAttr<std::vector<int64_t>>(
"axis",
"Axis along which to roll. It must have the same size "
......@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override {
op->SetType("roll_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("ShiftsTensor")) {
op->SetInput("ShiftsTensor", this->Input("ShiftsTensor"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
......
......@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>();
......@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>();
......
......@@ -16,6 +16,8 @@
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
......@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec;
......@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec;
......
......@@ -122,6 +122,34 @@ class TestRollAPI(unittest.TestCase):
self.assertRaises(ValueError, test_axis_out_range)
def test_shifts_as_tensor_dygraph(self):
with fluid.dygraph.guard():
x = paddle.arange(9).reshape([3, 3])
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes).numpy()
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
self.assertTrue(np.allclose(out, expected_out))
def test_shifts_as_tensor_static(self):
with program_guard(Program(), Program()):
x = paddle.arange(9).reshape([3, 3]).astype('float32')
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes)
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))
if paddle.is_compiled_with_cuda():
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))
if __name__ == "__main__":
unittest.main()
......@@ -696,15 +696,24 @@ def roll(x, shifts, axis=None, name=None):
helper = LayerHelper("roll", **locals())
check_type(axis, 'axis', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='roll',
inputs={'X': x},
outputs={'Out': out},
attrs={'axis': axis,
'shifts': shifts})
if isinstance(shifts, Variable):
helper.append_op(
type='roll',
inputs={'X': x,
"ShiftsTensor": shifts},
outputs={'Out': out},
attrs={'axis': axis})
else:
check_type(shifts, 'shifts', (list, tuple), 'roll')
helper.append_op(
type='roll',
inputs={'X': x},
outputs={'Out': out},
attrs={'axis': axis,
'shifts': shifts})
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册