未验证 提交 c42d662e 编写于 作者: Y yaoxuefeng 提交者: GitHub

modify roll test=develop (#25321)

上级 bdc2c2db
...@@ -33,7 +33,7 @@ class RollOp : public framework::OperatorWithKernel { ...@@ -33,7 +33,7 @@ class RollOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Output(Out) of RollOp should not be null.")); "Output(Out) of RollOp should not be null."));
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("dims"); auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts"); auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
...@@ -92,7 +92,7 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -92,7 +92,7 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"of the tensor are shifted.") "of the tensor are shifted.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"dims", "axis",
"Axis along which to roll. It must have the same size " "Axis along which to roll. It must have the same size "
"with shifts.") "with shifts.")
.SetDefault({}); .SetDefault({});
......
...@@ -82,7 +82,7 @@ class RollKernel : public framework::OpKernel<T> { ...@@ -82,7 +82,7 @@ class RollKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>(); auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>(); auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec; std::vector<T> out_vec;
TensorToVector(input, context.device_context(), &out_vec); TensorToVector(input, context.device_context(), &out_vec);
...@@ -94,8 +94,8 @@ class RollKernel : public framework::OpKernel<T> { ...@@ -94,8 +94,8 @@ class RollKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true, dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true,
platform::errors::OutOfRange( platform::errors::OutOfRange(
"Attr(dims[%d]) is out of range, It's expected " "Attr(axis[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dims[%d]) = %d.", "to be in range of [-%d, %d]. But received Attr(axis[%d]) = %d.",
i, input_dim.size(), input_dim.size() - 1, i, dims[i])); i, input_dim.size(), input_dim.size() - 1, i, dims[i]));
shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]); shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]);
} }
...@@ -114,7 +114,7 @@ class RollGradKernel : public framework::OpKernel<T> { ...@@ -114,7 +114,7 @@ class RollGradKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>(); auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>(); auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec; std::vector<T> out_vec;
TensorToVector(input, context.device_context(), &out_vec); TensorToVector(input, context.device_context(), &out_vec);
......
...@@ -28,17 +28,17 @@ class TestRollOp(OpTest): ...@@ -28,17 +28,17 @@ class TestRollOp(OpTest):
self.op_type = "roll" self.op_type = "roll"
self.init_dtype_type() self.init_dtype_type()
self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)} self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)}
self.attrs = {'shifts': self.shifts, 'dims': self.dims} self.attrs = {'shifts': self.shifts, 'axis': self.axis}
self.outputs = { self.outputs = {
'Out': np.roll(self.inputs['X'], self.attrs['shifts'], 'Out': np.roll(self.inputs['X'], self.attrs['shifts'],
self.attrs['dims']) self.attrs['axis'])
} }
def init_dtype_type(self): def init_dtype_type(self):
self.dtype = np.float64 self.dtype = np.float64
self.x_shape = (100, 4, 5) self.x_shape = (100, 4, 5)
self.shifts = [101, -1] self.shifts = [101, -1]
self.dims = [0, -2] self.axis = [0, -2]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -52,7 +52,7 @@ class TestRollOpCase2(TestRollOp): ...@@ -52,7 +52,7 @@ class TestRollOpCase2(TestRollOp):
self.dtype = np.float32 self.dtype = np.float32
self.x_shape = (100, 10, 5) self.x_shape = (100, 10, 5)
self.shifts = [8, -1] self.shifts = [8, -1]
self.dims = [-1, -2] self.axis = [-1, -2]
class TestRollAPI(unittest.TestCase): class TestRollAPI(unittest.TestCase):
...@@ -78,7 +78,7 @@ class TestRollAPI(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestRollAPI(unittest.TestCase):
# case 2: # case 2:
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3]) x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1, dims=0) z = paddle.roll(x, shifts=1, axis=0)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x}, res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name], fetch_list=[z.name],
...@@ -101,12 +101,26 @@ class TestRollAPI(unittest.TestCase): ...@@ -101,12 +101,26 @@ class TestRollAPI(unittest.TestCase):
# case 2: # case 2:
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x) x = fluid.dygraph.to_variable(self.data_x)
z = paddle.roll(x, shifts=1, dims=0) z = paddle.roll(x, shifts=1, axis=0)
np_z = z.numpy() np_z = z.numpy()
expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0], expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]) [4.0, 5.0, 6.0]])
self.assertTrue(np.allclose(expect_out, np_z)) self.assertTrue(np.allclose(expect_out, np_z))
def test_roll_op_false(self):
self.input_data()
def test_axis_out_range():
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1, axis=10)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
return_numpy=False)
self.assertRaises(ValueError, test_axis_out_range)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -104,23 +104,24 @@ def flip(input, dims, name=None): ...@@ -104,23 +104,24 @@ def flip(input, dims, name=None):
return out return out
def roll(input, shifts, dims=None): def roll(x, shifts, axis=None, name=None):
""" """
:alias_main: paddle.roll :alias_main: paddle.roll
:alias: paddle.roll,paddle.tensor.roll,paddle.tensor.manipulation.roll :alias: paddle.roll,paddle.tensor.roll,paddle.tensor.manipulation.roll
Roll the `input` tensor along the given dimension(s). Elements that are shifted beyond Roll the `x` tensor along the given axis(axes). With specific 'shifts', Elements that
the last position are re-introduced at the first position. If a dimension is not specified, roll beyond the last position are re-introduced at the first according to 'shifts'.
If a axis is not specified,
the tensor will be flattened before rolling and then restored to the original shape. the tensor will be flattened before rolling and then restored to the original shape.
Args: Args:
input (Variable): The input tensor variable. x (Variable): The x tensor variable as input.
shifts (int|list|tuple): The number of places by which the elements shifts (int|list|tuple): The number of places by which the elements
of the `input` tensor are shifted. of the `x` tensor are shifted.
dims (int|list|tuple|None): Dimentions along which to roll. axis (int|list|tuple|None): axis(axes) along which to roll.
Returns: Returns:
Variable: A Tensor with same data type as `input`. Variable: A Tensor with same data type as `x`.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -131,48 +132,56 @@ def roll(input, shifts, dims=None): ...@@ -131,48 +132,56 @@ def roll(input, shifts, dims=None):
data = np.array([[1.0, 2.0, 3.0], data = np.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0], [4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]]) [7.0, 8.0, 9.0]])
with fluid.dygraph.guard(): paddle.enable_imperative()
x = fluid.dygraph.to_variable(data) x = paddle.imperative.to_variable(data)
out_z1 = paddle.roll(x, shifts=1) out_z1 = paddle.roll(x, shifts=1)
print(out_z1.numpy()) print(out_z1.numpy())
#[[9. 1. 2.] #[[9. 1. 2.]
# [3. 4. 5.] # [3. 4. 5.]
# [6. 7. 8.]] # [6. 7. 8.]]
out_z2 = paddle.roll(x, shifts=1, dims=0) out_z2 = paddle.roll(x, shifts=1, axis=0)
print(out_z2.numpy()) print(out_z2.numpy())
#[[7. 8. 9.] #[[7. 8. 9.]
# [1. 2. 3.] # [1. 2. 3.]
# [4. 5. 6.]] # [4. 5. 6.]]
""" """
helper = LayerHelper("roll", **locals()) helper = LayerHelper("roll", **locals())
origin_shape = input.shape origin_shape = x.shape
if type(shifts) == int: if type(shifts) == int:
shifts = [shifts] shifts = [shifts]
if type(dims) == int: if type(axis) == int:
dims = [dims] axis = [axis]
if dims: len_origin_shape = len(origin_shape)
check_type(dims, 'dims', (list, tuple), 'roll') if axis:
for i in range(len(axis)):
if axis[i] >= len_origin_shape or axis[i] < -len_origin_shape:
raise ValueError(
"axis is out of range, it should be in range [{}, {}), but received {}".
format(-len_origin_shape, len_origin_shape, axis))
if axis:
check_type(axis, 'axis', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll') check_type(shifts, 'shifts', (list, tuple), 'roll')
if in_dygraph_mode(): if in_dygraph_mode():
if dims is None: if axis is None:
input = core.ops.reshape(input, 'shape', [-1, 1]) x = core.ops.reshape(x, 'shape', [-1, 1])
dims = [0] axis = [0]
out = core.ops.roll(input, 'dims', dims, 'shifts', shifts) out = core.ops.roll(x, 'axis', axis, 'shifts', shifts)
return core.ops.reshape(out, 'shape', origin_shape) return core.ops.reshape(out, 'shape', origin_shape)
out = helper.create_variable_for_type_inference(input.dtype) out = helper.create_variable_for_type_inference(x.dtype)
if dims is None: if axis is None:
input = reshape(input, shape=[-1, 1]) x = reshape(x, shape=[-1, 1])
dims = [0] axis = [0]
helper.append_op( helper.append_op(
type='roll', type='roll',
inputs={'X': input}, inputs={'X': x},
outputs={'Out': out}, outputs={'Out': out},
attrs={'dims': dims, attrs={'axis': axis,
'shifts': shifts}) 'shifts': shifts})
out = reshape(out, shape=origin_shape, inplace=True) out = reshape(out, shape=origin_shape, inplace=True)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册