未验证 提交 7b1e30fc 编写于 作者: F Feiyu Chan 提交者: GitHub

roll_op: support Tensor as input for shifts (#36727)

上级 236ed94d
...@@ -40,6 +40,7 @@ class RollOp : public framework::OperatorWithKernel { ...@@ -40,6 +40,7 @@ class RollOp : public framework::OperatorWithKernel {
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis"); auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts"); auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
if (!ctx->HasInput("ShiftsTensor")) {
if (dims.size() != 0) { if (dims.size() != 0) {
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -56,6 +57,7 @@ class RollOp : public framework::OperatorWithKernel { ...@@ -56,6 +57,7 @@ class RollOp : public framework::OperatorWithKernel {
"shifts.size() = %d", "shifts.size() = %d",
shifts.size())); shifts.size()));
} }
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
auto type = ctx->GetInputsVarType("X")[0]; auto type = ctx->GetInputsVarType("X")[0];
...@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"The number of places by which the elements " "The number of places by which the elements "
"of the tensor are shifted.") "of the tensor are shifted.")
.SetDefault({}); .SetDefault({});
AddInput("ShiftsTensor",
"The number of places by which the elements of the tensor "
"are shifted.")
.AsDispensable();
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"axis", "axis",
"Axis along which to roll. It must have the same size " "Axis along which to roll. It must have the same size "
...@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override { void Apply(GradOpPtr<T> op) const override {
op->SetType("roll_grad"); op->SetType("roll_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
if (this->HasInput("ShiftsTensor")) {
op->SetInput("ShiftsTensor", this->Input("ShiftsTensor"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
......
...@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T> ...@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>(); auto* in_data = in->data<T>();
...@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T> ...@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* out = context.Output<LoDTensor>(framework::GradVarName("X")); auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>(); auto* in_data = in->data<T>();
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> { ...@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>(); auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>(); auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec; std::vector<T> out_vec;
...@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> { ...@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>(); auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>(); auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec; std::vector<T> out_vec;
......
...@@ -122,6 +122,34 @@ class TestRollAPI(unittest.TestCase): ...@@ -122,6 +122,34 @@ class TestRollAPI(unittest.TestCase):
self.assertRaises(ValueError, test_axis_out_range) self.assertRaises(ValueError, test_axis_out_range)
def test_shifts_as_tensor_dygraph(self):
with fluid.dygraph.guard():
x = paddle.arange(9).reshape([3, 3])
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes).numpy()
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
self.assertTrue(np.allclose(out, expected_out))
def test_shifts_as_tensor_static(self):
with program_guard(Program(), Program()):
x = paddle.arange(9).reshape([3, 3]).astype('float32')
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes)
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))
if paddle.is_compiled_with_cuda():
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -696,9 +696,18 @@ def roll(x, shifts, axis=None, name=None): ...@@ -696,9 +696,18 @@ def roll(x, shifts, axis=None, name=None):
helper = LayerHelper("roll", **locals()) helper = LayerHelper("roll", **locals())
check_type(axis, 'axis', (list, tuple), 'roll') check_type(axis, 'axis', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
if isinstance(shifts, Variable):
helper.append_op(
type='roll',
inputs={'X': x,
"ShiftsTensor": shifts},
outputs={'Out': out},
attrs={'axis': axis})
else:
check_type(shifts, 'shifts', (list, tuple), 'roll')
helper.append_op( helper.append_op(
type='roll', type='roll',
inputs={'X': x}, inputs={'X': x},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册