From bd89a273080028bdc761c3bbadab0f51a657e7a6 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Tue, 17 Sep 2019 14:37:14 +0800 Subject: [PATCH] add tensor support for argument shape in reshape op; (#19268) add support parameter inference when argument shape is a list containing integer and tensor variable; test=develop fix reshape op according to reviews: 1. improve or message; 2. improve test of test_api. test=develop,test=document_preview fix reshape op: Add error message in nn.py, test=develop add stop_gradient=True when attr(shape) is tensor Variable. change examples in API reshape. test=develop,test=document_preview --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/reshape_op.cc | 80 ++++++---- python/paddle/fluid/layers/nn.py | 138 +++++++++++------- .../fluid/tests/unittests/test_reshape_op.py | 117 +++++++++++++-- 4 files changed, 245 insertions(+), 92 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index c382a559b12..7c8c31abbaa 100755 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -177,7 +177,7 @@ paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ecb75c1b00c4c76c98b482f633b7a10c')) paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range'], varargs=None, keywords=None, defaults=(False,)), ('document', 'ec4115591be842868c86b2e5334245c6')) paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', '98e7927f09ee2270535b29f048e481ec')) -paddle.fluid.layers.reshape (ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', '6196c9ec3075ca5a9c058ea1f8492256')) +paddle.fluid.layers.reshape (ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', 'ca73fdc4551c5765c92eb00f24874289')) paddle.fluid.layers.squeeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ebbac07662a6e22e8e299ced880c7775')) paddle.fluid.layers.unsqueeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b9bd3129d36a70e7c4385df51ff71c62')) paddle.fluid.layers.lod_reset (ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)), ('document', '74498d37dd622ac472cb36887fce09ea')) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 6341fa935ec..0059921c046 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -50,29 +50,56 @@ class ReshapeOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ReshapeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ReshapeOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of ReshapeOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of ReshapeOp should not be null."); if (ctx->HasInputs("ShapeTensor")) { // top prority shape - auto inputs_name = ctx->Inputs("ShapeTensor"); - PADDLE_ENFORCE(inputs_name.size() > 0, "shape tensor size can't be zero"); - auto out_dims = std::vector(inputs_name.size(), -1); - ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + auto ShapeTensor = ctx->Inputs("ShapeTensor"); + PADDLE_ENFORCE_GT(ShapeTensor.size(), 0, + "The size of Input(ShapeTensor) can't be zero"); + auto infer_shape = ctx->Attrs().Get>("shape"); + const int64_t copy_dim_val = 0; + auto in_dims = ctx->GetInputDim("X"); + for (size_t i = 0; i < infer_shape.size(); ++i) { + if (infer_shape[i] == copy_dim_val) { + PADDLE_ENFORCE_LT( + static_cast(i), in_dims.size(), + "The dimension of data to copy from input must be less " + "than the dimension of input."); + infer_shape[i] = in_dims[i]; + } + } + auto infer_out_dims = framework::make_ddim(infer_shape); + ctx->SetOutputDim("Out", infer_out_dims); + return; + } + const std::vector &shape = ctx->Attrs().Get>("shape"); + if (ctx->HasInput("Shape") && shape.empty()) { + auto shape_dims = ctx->GetInputDim("Shape"); + int num_ele = 1; + for (int i = 0; i < shape_dims.size(); ++i) { + num_ele *= shape_dims[i]; + } + auto vec_dims = std::vector(num_ele, -1); + auto out_dims = framework::make_ddim(vec_dims); + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); return; } - if (ctx->HasInput("Shape") && ctx->IsRuntime()) { + + if (ctx->HasInput("Shape") && !shape.empty() && ctx->IsRuntime()) { // If true, set the shape of Output(Out) according to Input(Shape) in // ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel. ctx->ShareLoD("X", /*->*/ "Out"); return; } - const std::vector &shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE(!shape.empty(), - "The shape information must be set by Attr(shape)."); + + PADDLE_ENFORCE_EQ(!shape.empty(), true, + "The shape information must be set by Attr(shape)."); auto x_dims = ctx->GetInputDim("X"); auto out_dims = ValidateShape(shape, x_dims); ctx->SetOutputDim("Out", out_dims); @@ -99,18 +126,18 @@ class ReshapeOp : public framework::OperatorWithKernel { int unk_dim_idx = -1; for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] == unk_dim_val) { - PADDLE_ENFORCE( - unk_dim_idx == -1, + PADDLE_ENFORCE_EQ( + unk_dim_idx, -1, "Only one input dimension of Attr(shape) can be unknown."); unk_dim_idx = i; } else if (shape[i] == copy_dim_val) { - PADDLE_ENFORCE( - static_cast(i) < in_dims.size(), + PADDLE_ENFORCE_LT( + static_cast(i), in_dims.size(), "The index of dimension to copy from input shape must be less " "than the size of input shape."); } else { - PADDLE_ENFORCE( - shape[i] > 0, + PADDLE_ENFORCE_GT( + shape[i], 0, "Each input dimension of Attr(shape) must not be negtive except " "one unknown dimension."); } @@ -231,9 +258,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + "Input(Out@GRAD) shouldn't be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } @@ -314,8 +341,8 @@ class Reshape2Op : public ReshapeOp { : ReshapeOp(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasOutput("XShape"), - "Output(XShape) of ReshapeOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true, + "Output(XShape) of ReshapeOp should not be null."); const auto &x_dims = ctx->GetInputDim("X"); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; @@ -365,9 +392,10 @@ class Reshape2GradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("XShape"), "Input(XShape) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("XShape"), true, + "Input(XShape) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + "Input(Out@GRAD) shouldn't be null."); auto xshape_dims = ctx->GetInputDim("XShape"); auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); ctx->SetOutputDim(framework::GradVarName("X"), x_dims); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 805faac339f..cf925d71977 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7025,9 +7025,9 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): Gives a new shape to the input Tensor without changing its data. The target shape can be given by :attr:`shape` or :attr:`actual_shape`. - :attr:`shape` is a list of integer while :attr:`actual_shape` is a tensor + :attr:`shape` is a list of integer or tensor variable while :attr:`actual_shape` is a tensor variable. :attr:`actual_shape` has a higher priority than :attr:`shape` - if it is provided, while :attr:`shape` still should be set correctly to + if it is provided and it only contains integer, while :attr:`shape` still should be set correctly to gurantee shape inference in compile-time. Some tricks exist when specifying the target shape. @@ -7059,15 +7059,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): besides -1, 0 means the actual dimension value is going to be copied from the corresponding dimension of x. + **Warning:** the parameter :attr:`actual_shape` will be deprecated in the future and only use :attr:`shape` instead. + Args: x(variable): The input tensor. - shape(list): The new shape. At most one dimension of the new shape can - be -1. + shape(list|tuple|Variable): The new shape. At most one dimension of the new shape can + be -1. If :attr:`shape` is a list or tuple, it can contain Variable or not and + the shape of Variable must be [1]. + actual_shape(variable): An optional input. If provided, reshape according to this given shape rather than :attr:`shape` specifying shape. That is to say :attr:`actual_shape` has a higher priority - than :attr:`shape`. + than :attr:`shape(list|tuple)` but not :attr:`shape(Variable)`. \ + This argument :attr:`actual_shape` will be removed in a future version. \ + Instructions for updating: :attr:`actual_shape` is deprecated, + only use :attr:`shape` instead. act (str): The non-linear activation to be applied to the reshaped tensor variable. inplace(bool): If ``inplace`` is `True`, the input and output of ``layers.reshape`` @@ -7089,64 +7096,89 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): .. code-block:: python import paddle.fluid as fluid - data = fluid.layers.data( - name='data', shape=[2, 4, 6], dtype='float32') - reshaped = fluid.layers.reshape( - x=data, shape=[-1, 0, 3, 2], inplace=True) + + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + data_1 = fluid.layers.data( + name='data_1', shape=[2, 4, 6], dtype='float32') + reshaped_1 = fluid.layers.reshape( + x=data_1, shape=[-1, 0, 3, 2], inplace=True) + + # example 2: + # attr shape is a list which contains tensor Variable. + data_2 = fluid.layers.fill_constant([2,25], "int32", 3) + dim = fluid.layers.fill_constant([1], "int32", 5) + reshaped_2 = fluid.layers.reshape(data_2, shape=[dim, 10]) """ - if not (isinstance(shape, list) or isinstance(shape, tuple)): - raise ValueError("Input shape must be a python list or tuple.") + if not isinstance(shape, (list, tuple, Variable)): + raise TypeError( + "Input shape must be an Variable or python list or tuple.") - inputs = {"X": x} - if isinstance(actual_shape, Variable): - inputs["Shape"] = actual_shape - elif actual_shape is not None: - raise TypeError("actual_shape should either be Variable or None") - - # Validate the shape - unk_dim_idx = -1 - contain_var = False - for dim_idx, dim_size in enumerate(shape): - if isinstance(dim_size, Variable): - contain_var = True - continue - - if dim_size == -1: - assert unk_dim_idx == -1, ( - "Only one dimension in shape can be unknown.") - unk_dim_idx = dim_idx - elif dim_size == 0: - assert dim_idx < len(x.shape), ( - "The indice of 0s in shape can not exceed Rank(X).") - else: - assert dim_size > 0, ( - "Each dimension size given in shape must not be negtive " - "except one unknown dimension.") + if not isinstance(actual_shape, Variable) and (actual_shape is not None): + raise TypeError("actual_shape should either be Variable or None.") helper = LayerHelper("reshape2", **locals()) + inputs = {"X": x} + attrs = {} + + def contain_var(one_list): + for ele in one_list: + if isinstance(ele, Variable): + return True + return False + + def get_new_shape_tensor(list_shape): + new_shape_tensor = [] + for dim in list_shape: + if isinstance(dim, Variable): + dim.stop_gradient = True + new_shape_tensor.append(dim) + else: + assert (isinstance(dim, int)) + temp_out = helper.create_variable_for_type_inference('int32') + fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out) + new_shape_tensor.append(temp_out) + return new_shape_tensor + + def get_attr_shape(list_shape): + unk_dim_idx = -1 + attrs_shape = [] + for dim_idx, dim_size in enumerate(list_shape): + if isinstance(dim_size, Variable): + attrs_shape.append(-1) + else: + attrs_shape.append(dim_size) + if dim_size == -1: + assert unk_dim_idx == -1, ( + "Only one dimension in shape can be unknown.") + unk_dim_idx = dim_idx + elif dim_size == 0: + assert dim_idx < len(x.shape), ( + "The indice of 0s in shape can not exceed Rank(X).") + else: + assert dim_size > 0, ( + "Each dimension size given in shape must not be negtive " + "except one unknown dimension.") + return attrs_shape + if in_dygraph_mode(): inputs = {'X': x} attrs = {'shape': shape} else: - if contain_var: - new_shape_tensor = [] - for dim in shape: - if isinstance(dim, Variable): - dim.stop_gradient = True - new_shape_tensor.append(dim) - else: - assert (isinstance(dim, int)) - temp_out = helper.create_variable_for_type_inference( - 'int32') - fill_constant( - [1], 'int32', dim, force_cpu=True, out=temp_out) - new_shape_tensor.append(temp_out) - inputs['ShapeTensor'] = new_shape_tensor - attrs = {} + if isinstance(shape, Variable): + shape.stop_gradient = True + inputs["Shape"] = shape + elif isinstance(shape, (list, tuple)): + assert len(shape) > 0, ( + "The size of argument(shape) can't be zero.") + attrs["shape"] = get_attr_shape(shape) + if contain_var(shape): + inputs['ShapeTensor'] = get_new_shape_tensor(shape) + elif isinstance(actual_shape, Variable): + actual_shape.stop_gradient = True + inputs["Shape"] = actual_shape - else: - attrs = {'shape': shape} out = x if inplace else helper.create_variable_for_type_inference( dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 3221985c442..beaffd055c1 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -18,8 +18,10 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid as fluid +# situation 1: have shape( list, no tensor), no actual shape(Tensor) class TestReshapeOp(OpTest): def setUp(self): self.init_data() @@ -58,24 +60,28 @@ class TestReshapeOpDimInfer2(TestReshapeOp): self.infered_shape = (2, 2, 3, -1) +# situation 2: have shape(list, no tensor), have actual shape(Tensor) class TestReshapeOpWithInputShape(OpTest): def setUp(self): - ori_shape = (6, 5) - new_shape = (0, -1, 5) - actual_shape = (2, 3, 5) - + self.init_data() self.op_type = "reshape2" + self.inputs = { - "X": np.random.random(ori_shape).astype("float32"), + "X": np.random.random(self.ori_shape).astype("float32"), "Shape": np.array( - actual_shape, dtype="int32") + self.actual_shape, dtype="int32") } - self.attrs = {"shape": new_shape} + self.attrs = {"shape": self.new_shape} self.outputs = { - "Out": self.inputs["X"].reshape(actual_shape), - 'XShape': np.random.random(ori_shape).astype("float32") + "Out": self.inputs["X"].reshape(self.actual_shape), + 'XShape': np.random.random(self.ori_shape).astype("float32") } + def init_data(self): + self.ori_shape = (6, 5) + self.new_shape = (0, -1, 5) + self.actual_shape = (2, 3, 5) + def test_check_output(self): self.check_output(no_check_set=['XShape']) @@ -83,7 +89,8 @@ class TestReshapeOpWithInputShape(OpTest): self.check_grad(["X"], "Out") -class TestReshapeOp_attr_tensor(OpTest): +# Situation 3: have shape(list, have tensor), no actual shape(Tensor) +class TestReshapeOp_attr_ShapeTensor(OpTest): def setUp(self): self.init_data() self.op_type = "reshape2" @@ -97,6 +104,52 @@ class TestReshapeOp_attr_tensor(OpTest): "X": np.random.random(self.ori_shape).astype("float32"), 'ShapeTensor': shape_tensor } + self.attrs = {'shape': self.shape} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + 'XShape': np.random.random(self.ori_shape).astype("float32") + } + + def init_data(self): + self.ori_shape = (2, 25) + self.new_shape = (5, 10) + self.infered_shape = (5, 10) + self.shape = (-1, -1) + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReshapeOpDimInfer1_attr_ShapeTensor(TestReshapeOp_attr_ShapeTensor): + def init_data(self): + self.ori_shape = (5, 10) + self.new_shape = (5, -1, 5) + self.infered_shape = (5, -1, 5) + self.shape = (5, -1, -1) + + +class TestReshapeOpDimInfer2_attr_ShapeTensor(TestReshapeOp_attr_ShapeTensor): + def init_data(self): + self.ori_shape = (2, 2, 6) + self.new_shape = (2, 0, 3, -1) + self.infered_shape = (2, 2, 3, -1) + self.shape = (2, 0, 3, -1) + + +# Situation 4: have shape(Tensor), no actual shape(Tensor) +class TestReshapeOp_attr_OnlyShape(OpTest): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + "Shape": np.array( + self.new_shape, dtype="int32") + } self.attrs = {} self.outputs = { "Out": self.inputs["X"].reshape(self.infered_shape), @@ -115,18 +168,58 @@ class TestReshapeOp_attr_tensor(OpTest): self.check_grad(["X"], "Out") -class TestReshapeOpDimInfer1_attr_tensor(TestReshapeOp_attr_tensor): +class TestReshapeOpDimInfer1_attr_OnlyShape(TestReshapeOp_attr_OnlyShape): def init_data(self): self.ori_shape = (5, 10) self.new_shape = (5, -1, 5) self.infered_shape = (5, -1, 5) + self.shape = (5, -1, -1) -class TestReshapeOpDimInfer2_attr_tensor(TestReshapeOp_attr_tensor): +class TestReshapeOpDimInfer2_attr_OnlyShape(TestReshapeOp_attr_OnlyShape): def init_data(self): self.ori_shape = (2, 2, 6) self.new_shape = (2, 0, 3, -1) self.infered_shape = (2, 2, 3, -1) + self.shape = (2, 0, 3, -1) + + +# Test python API +class TestReshapeAPI(OpTest): + # situation 1: have shape( list, no tensor), no actual shape(Tensor) + def test_1(self): + input = np.random.random([2, 25]).astype("float32") + shape = [2, 5, 5] + positive_five = fluid.layers.fill_constant([1], "int32", 5) + x = fluid.layers.data( + name="x", shape=[2, 25], append_batch_size=False, dtype="float32") + + actual_shape = fluid.layers.data( + name="shape", + shape=[1, 3], + append_batch_size=False, + dtype="float32") + + # situation 1: have shape( list, no tensor), no actual shape(Tensor) + out_1 = fluid.layers.reshape(x, shape) + # situation 2: have shape(list, no tensor), have actual shape(Tensor) + out_2 = fluid.layers.reshape(x, shape=shape, actual_shape=actual_shape) + # Situation 3: have shape(list, have tensor), no actual shape(Tensor) + out_3 = fluid.layers.reshape(x, shape=[positive_five, 10]) + # Situation 4: have shape(Tensor), no actual shape(Tensor) + out_4 = fluid.layers.reshape(x, shape=actual_shape) + + exe = fluid.Executor(place=fluid.CPUPlace()) + res_1, res_2, res_3, res_4 = exe.run( + fluid.default_main_program(), + feed={"x": input, + "shape": np.array([2, 5, 5]).astype("int32")}, + fetch_list=[out_1, out_2, out_3, out_4]) + + assert np.array_equal(res_1, input.reshape(shape)) + assert np.array_equal(res_2, input.reshape(shape)) + assert np.array_equal(res_3, input.reshape([5, 10])) + assert np.array_equal(res_4, input.reshape(shape)) if __name__ == "__main__": -- GitLab