未验证 提交 5d3766ff 编写于 作者: Y yaoxuefeng 提交者: GitHub

modify flip test=develop (#25312)

According to paddle 2.0 standard
1, change flip api attr name 'dim' to 'axis'.
2, support empty axis
3, change example code to imperative mode.
上级 f8eccb0b
......@@ -36,46 +36,52 @@ class FlipOp : public framework::OperatorWithKernel {
platform::errors::NotFound(
"Output(Out) of FlipOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto flip_dims = ctx->Attrs().Get<std::vector<int>>("dims");
auto flip_dims = ctx->Attrs().Get<std::vector<int>>("axis");
size_t flip_dims_size = flip_dims.size();
// check if dims axis within range
auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end());
PADDLE_ENFORCE_LT(*min_max_d.first, x_dims.size(),
platform::errors::InvalidArgument(
"min(dims) should be less than the input tensor X's "
"dimensions of FlipOp. But received min(dims) = %d, "
"X's dimensions = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(
*min_max_d.first, x_dims.size() * -1,
platform::errors::InvalidArgument(
"min(dims) should be greater than or equal to the input tensor X's "
"dimensions of FlipOp times -1. But received min(dims) = %d, X's "
"dimensions = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size() * -1, x_dims));
PADDLE_ENFORCE_LT(*min_max_d.second, x_dims.size(),
platform::errors::InvalidArgument(
"max(dims) should be less than the input tensor X's "
"dimensions of FlipOp. But received max(dims) = %d, "
"X's dimensions = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(
*min_max_d.second, x_dims.size() * -1,
platform::errors::InvalidArgument(
"max(dims) should be greater than or equal to the input tensor X's "
"dimensions of FlipOp times -1. But received max(dims) = %d, X's "
"dimensions = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size() * -1, x_dims));
// check duplicates in dims
flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()),
flip_dims.end());
PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size,
platform::errors::InvalidArgument(
"dims has duplicates, original flip dims size=%d, "
"but unique flip dims size=%d.)",
flip_dims_size, flip_dims.size()));
if (flip_dims_size > 0) {
// check if dims axis within range
auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end());
PADDLE_ENFORCE_LT(
*min_max_d.first, x_dims.size(),
platform::errors::InvalidArgument(
"min(axes) should be less than the input tensor X's "
"axes of FlipOp. But received min(axes) = %d, "
"X's axes = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(*min_max_d.first, x_dims.size() * -1,
platform::errors::InvalidArgument(
"min(axes) should be greater than or equal to the "
"input tensor X's "
"axes of FlipOp times -1. But received "
"min(axes) = %d, X's "
"axes = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size() * -1, x_dims));
PADDLE_ENFORCE_LT(
*min_max_d.second, x_dims.size(),
platform::errors::InvalidArgument(
"max(axes) should be less than the input tensor X's "
"axes of FlipOp. But received max(axes) = %d, "
"X's axes = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(*min_max_d.second, x_dims.size() * -1,
platform::errors::InvalidArgument(
"max(axes) should be greater than or equal to the "
"input tensor X's "
"axes of FlipOp times -1. But received "
"max(axes) = %d, X's "
"axes = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size() * -1, x_dims));
// check duplicates in dims
flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()),
flip_dims.end());
PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size,
platform::errors::InvalidArgument(
"axes has duplicates, original flip axes size=%d, "
"but unique flip axes size=%d.)",
flip_dims_size, flip_dims.size()));
}
VLOG(3) << "flip operator x.shape=" << x_dims;
......@@ -104,10 +110,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "(Tensor), The input tensor of flip op.");
AddOutput("Out", "(Tensor), The output tensor of flip op.");
AddAttr<std::vector<int>>("dims", "The axes to flip on.");
AddAttr<std::vector<int>>("axis", "The axes to flip on.");
AddComment(R"DOC(
Flip Operator.
Reverse the order of a n-D tensor along given axis in dims.
Reverse the order of a n-D tensor along given axis in axes.
)DOC");
}
};
......
......@@ -81,7 +81,7 @@ class FlipKernel<platform::CUDADeviceContext, T>
Tensor* out = ctx.Output<Tensor>("Out");
auto* in_data = x->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto flip_dims = ctx.template Attr<std::vector<int>>("dims");
auto flip_dims = ctx.template Attr<std::vector<int>>("axis");
const int flip_dims_size = static_cast<int>(flip_dims.size());
auto x_dims = x->dims();
......
......@@ -41,7 +41,7 @@ class FlipKernel<platform::CPUDeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
auto flip_dims = ctx.template Attr<std::vector<int>>("dims");
auto flip_dims = ctx.template Attr<std::vector<int>>("axis");
auto x_dims = x->dims();
const int total_dims = x_dims.size();
......
......@@ -108,7 +108,7 @@ from .tensor.manipulation import flatten #DEFINE_ALIAS
from .tensor.manipulation import gather #DEFINE_ALIAS
from .tensor.manipulation import gather_nd #DEFINE_ALIAS
from .tensor.manipulation import reshape #DEFINE_ALIAS
from .tensor.manipulation import reverse #DEFINE_ALIAS
from .tensor.manipulation import flip as reverse #DEFINE_ALIAS
from .tensor.manipulation import scatter #DEFINE_ALIAS
from .tensor.manipulation import scatter_nd_add #DEFINE_ALIAS
from .tensor.manipulation import scatter_nd #DEFINE_ALIAS
......
......@@ -30,9 +30,9 @@ class TestFlipOp_API(unittest.TestCase):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
dims = [0]
axis = [0]
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.flip(input, dims)
output = paddle.flip(input, axis)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
......@@ -68,7 +68,7 @@ class TestFlipOp(OpTest):
self.outputs = {'Out': self.calc_ref_res()}
def init_attrs(self):
self.attrs = {"dims": self.dims}
self.attrs = {"axis": self.axis}
def test_check_output(self):
self.check_output()
......@@ -78,11 +78,11 @@ class TestFlipOp(OpTest):
def init_test_case(self):
self.in_shape = (6, 4, 2, 3)
self.dims = [0, 1]
self.axis = [0, 1]
def calc_ref_res(self):
res = self.inputs['X']
for axis in self.dims:
for axis in self.axis:
res = np.flip(res, axis)
return res
......@@ -90,25 +90,37 @@ class TestFlipOp(OpTest):
class TestFlipOpAxis1(TestFlipOp):
def init_test_case(self):
self.in_shape = (2, 4, 4)
self.dims = [0]
self.axis = [0]
class TestFlipOpAxis2(TestFlipOp):
def init_test_case(self):
self.in_shape = (4, 4, 6, 3)
self.dims = [0, 2]
self.axis = [0, 2]
class TestFlipOpAxis3(TestFlipOp):
def init_test_case(self):
self.in_shape = (4, 3, 1)
self.dims = [0, 1, 2]
self.axis = [0, 1, 2]
class TestFlipOpAxis4(TestFlipOp):
def init_test_case(self):
self.in_shape = (6, 4, 2, 2)
self.dims = [0, 1, 2, 3]
self.axis = [0, 1, 2, 3]
class TestFlipOpEmptyAxis(TestFlipOp):
def init_test_case(self):
self.in_shape = (6, 4, 2, 2)
self.axis = []
class TestFlipOpNegAxis(TestFlipOp):
def init_test_case(self):
self.in_shape = (6, 4, 2, 2)
self.axis = [-1]
if __name__ == "__main__":
......
......@@ -81,7 +81,7 @@ from .manipulation import flatten #DEFINE_ALIAS
from .manipulation import gather #DEFINE_ALIAS
from .manipulation import gather_nd #DEFINE_ALIAS
from .manipulation import reshape #DEFINE_ALIAS
from .manipulation import reverse #DEFINE_ALIAS
from .manipulation import flip as reverse #DEFINE_ALIAS
from .manipulation import scatter #DEFINE_ALIAS
from .manipulation import scatter_nd_add #DEFINE_ALIAS
from .manipulation import scatter_nd #DEFINE_ALIAS
......
......@@ -28,7 +28,6 @@ from ..fluid.layers import expand #DEFINE_ALIAS
from ..fluid.layers import expand_as #DEFINE_ALIAS
from ..fluid.layers import flatten #DEFINE_ALIAS
from ..fluid.layers import reshape #DEFINE_ALIAS
from ..fluid.layers import reverse #DEFINE_ALIAS
from ..fluid.layers import scatter #DEFINE_ALIAS
from ..fluid.layers import slice #DEFINE_ALIAS
from ..fluid.layers import strided_slice #DEFINE_ALIAS
......@@ -51,46 +50,47 @@ __all__ = [
]
def flip(input, dims, name=None):
def flip(x, axis, name=None):
"""
:alias_main: paddle.flip
:alias: paddle.flip,paddle.tensor.flip,paddle.tensor.manipulation.flip
Reverse the order of a n-D tensor along given axis in dims.
Reverse the order of a n-D tensor along given axis in axis.
Args:
input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
x (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x
should be float32, float64, int32, int64, bool.
dims (list): The axis to flip on.
axis (list): The axis(axes) to flip on. Negative indices for indexing from the end are accepted.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input.
Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input x.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="x", shape=[-1, 2, 2], dtype='float32')
output = paddle.flip(input, dims=[0, 1])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
img = np.arange(12).reshape((3,2,2)).astype(np.float32)
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]]
paddle.enable_imperative()
image_shape=(3, 2, 2)
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape)
x = x.astype('float32')
img = paddle.imperative.to_variable(x)
out = paddle.flip(img, [0,1])
print(out) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]]
"""
helper = LayerHelper("flip", **locals())
check_type(input, 'X', (Variable), 'flip')
dtype = helper.input_dtype()
check_type(x, 'X', (Variable), 'flip')
dtype = helper.input_dtype('x')
check_dtype(dtype, 'X',
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
'flip')
check_type(dims, 'dims', (list, tuple), 'flip')
assert len(dims) > 0, 'len(dims) must be greater than 0.'
check_type(axis, 'axis', (list, tuple), 'flip')
if name is None:
out = helper.create_variable_for_type_inference(dtype)
else:
......@@ -98,12 +98,15 @@ def flip(input, dims, name=None):
helper.append_op(
type="flip",
inputs={"X": input},
inputs={"X": x},
outputs={"Out": out},
attrs={"dims": dims})
attrs={"axis": axis})
return out
reverse = flip #DEFINE_ALIAS
def roll(x, shifts, axis=None, name=None):
"""
:alias_main: paddle.roll
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册