未验证 提交 7b1e30fc 编写于 作者: F Feiyu Chan 提交者: GitHub

roll_op: support Tensor as input for shifts (#36727)

上级 236ed94d
...@@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel { ...@@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel {
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis"); auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts"); auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
if (dims.size() != 0) { if (!ctx->HasInput("ShiftsTensor")) {
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), if (dims.size() != 0) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
"When dims.size() != 0, dims.size() " platform::errors::InvalidArgument(
"should be equal to " "When dims.size() != 0, dims.size() "
"shifts.size(). But received " "should be equal to "
"dims.size() = %d, shifts.size() = %d", "shifts.size(). But received "
dims.size(), shifts.size())); "dims.size() = %d, shifts.size() = %d",
} else { dims.size(), shifts.size()));
PADDLE_ENFORCE_EQ(shifts.size(), 1, } else {
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(shifts.size(), 1,
"When dims.size() == 0, shifts.size() " platform::errors::InvalidArgument(
"should be equal to 1, But received " "When dims.size() == 0, shifts.size() "
"shifts.size() = %d", "should be equal to 1, But received "
shifts.size())); "shifts.size() = %d",
shifts.size()));
}
} }
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
...@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"The number of places by which the elements " "The number of places by which the elements "
"of the tensor are shifted.") "of the tensor are shifted.")
.SetDefault({}); .SetDefault({});
AddInput("ShiftsTensor",
"The number of places by which the elements of the tensor "
"are shifted.")
.AsDispensable();
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"axis", "axis",
"Axis along which to roll. It must have the same size " "Axis along which to roll. It must have the same size "
...@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override { void Apply(GradOpPtr<T> op) const override {
op->SetType("roll_grad"); op->SetType("roll_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
if (this->HasInput("ShiftsTensor")) {
op->SetInput("ShiftsTensor", this->Input("ShiftsTensor"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
......
...@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T> ...@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>(); auto* in_data = in->data<T>();
...@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T> ...@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* out = context.Output<LoDTensor>(framework::GradVarName("X")); auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>(); auto* in_data = in->data<T>();
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> { ...@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>(); auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>(); auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec; std::vector<T> out_vec;
...@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> { ...@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>(); auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>(); auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec; std::vector<T> out_vec;
......
...@@ -122,6 +122,34 @@ class TestRollAPI(unittest.TestCase): ...@@ -122,6 +122,34 @@ class TestRollAPI(unittest.TestCase):
self.assertRaises(ValueError, test_axis_out_range) self.assertRaises(ValueError, test_axis_out_range)
def test_shifts_as_tensor_dygraph(self):
with fluid.dygraph.guard():
x = paddle.arange(9).reshape([3, 3])
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes).numpy()
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
self.assertTrue(np.allclose(out, expected_out))
def test_shifts_as_tensor_static(self):
with program_guard(Program(), Program()):
x = paddle.arange(9).reshape([3, 3]).astype('float32')
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes)
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))
if paddle.is_compiled_with_cuda():
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -696,15 +696,24 @@ def roll(x, shifts, axis=None, name=None): ...@@ -696,15 +696,24 @@ def roll(x, shifts, axis=None, name=None):
helper = LayerHelper("roll", **locals()) helper = LayerHelper("roll", **locals())
check_type(axis, 'axis', (list, tuple), 'roll') check_type(axis, 'axis', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( if isinstance(shifts, Variable):
type='roll', helper.append_op(
inputs={'X': x}, type='roll',
outputs={'Out': out}, inputs={'X': x,
attrs={'axis': axis, "ShiftsTensor": shifts},
'shifts': shifts}) outputs={'Out': out},
attrs={'axis': axis})
else:
check_type(shifts, 'shifts', (list, tuple), 'roll')
helper.append_op(
type='roll',
inputs={'X': x},
outputs={'Out': out},
attrs={'axis': axis,
'shifts': shifts})
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册