diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index f3719e8f438f6365414a1e91192a863fd451209d..9750bc87b001e034cb65463101ba57fbbc105eca 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -19,6 +19,29 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + +inline std::vector get_new_shape( + const std::vector &list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}), + "shape of dim tensor should be [1]"); + if (platform::is_gpu_place(tensor->place())) { + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + + vec_new_shape.push_back(static_cast(*temp.data())); + } else { + vec_new_shape.push_back(static_cast(*tensor->data())); + } + } + + return vec_new_shape; +} + class ReshapeOp : public framework::OperatorWithKernel { public: ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, @@ -32,17 +55,24 @@ class ReshapeOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of ReshapeOp should not be null."); - const std::vector &shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE(!shape.empty(), - "The shape information must be set by Attr(shape)."); + if (ctx->HasInputs("ShapeTensor")) { + // top prority shape + auto inputs_name = ctx->Inputs("ShapeTensor"); + PADDLE_ENFORCE(inputs_name.size() > 0, "shape tensor size can't be zero"); + auto out_dims = std::vector(inputs_name.size(), -1); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + return; + } if (ctx->HasInput("Shape") && ctx->IsRuntime()) { // If true, set the shape of Output(Out) according to Input(Shape) in // ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel. ctx->ShareLoD("X", /*->*/ "Out"); return; } - + const std::vector &shape = ctx->Attrs().Get>("shape"); + PADDLE_ENFORCE(!shape.empty(), + "The shape information must be set by Attr(shape)."); auto x_dims = ctx->GetInputDim("X"); auto out_dims = ValidateShape(shape, x_dims); ctx->SetOutputDim("Out", out_dims); @@ -114,6 +144,16 @@ class ReshapeOp : public framework::OperatorWithKernel { return framework::OpKernelType(ctx.Input("X")->type(), ctx.device_context()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "ShapeTensor") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { @@ -126,9 +166,18 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { "the shape attribute, while the shape attribute still should be " "set correctly to gurantee shape inference in compile time.") .AsDispensable(); + AddInput( + "ShapeTensor", + "(vector>, optional). If provided, reshape will use this" + "The shape of the tensor in vector MUST BE [1]" + "it has the highest priority compare with Input(Shape) and " + "attr(shape).") + .AsDuplicable() + .AsDispensable(); AddOutput("Out", "(Tensor). The output tensor of reshape operator."); AddAttr>( - "shape", "(std::vector) Target shape of reshape operator."); + "shape", "(std::vector) Target shape of reshape operator.") + .SetDefault({}); AddComment(R"DOC( Reshape Operator. @@ -202,24 +251,35 @@ class ReshapeKernel { auto *out = ctx.Output("Out"); auto *in = ctx.Input("X"); - auto *shape_tensor = ctx.HasInput("Shape") - ? ctx.Input("Shape") - : nullptr; - framework::DDim out_dims = out->dims(); - if (shape_tensor) { - auto *shape_data = shape_tensor->data(); - framework::Tensor cpu_shape_tensor; - if (platform::is_gpu_place(shape_tensor->place())) { - TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); - shape_data = cpu_shape_tensor.data(); + auto list_new_shape_tensor = + ctx.MultiInput("ShapeTensor"); + if (list_new_shape_tensor.size() > 0) { + // have shape tensor + auto new_shape = get_new_shape(list_new_shape_tensor); + out_dims = ReshapeOp::ValidateShape(new_shape, in->dims()); + + } else { + auto *shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; + + if (shape_tensor) { + auto *shape_data = shape_tensor->data(); + framework::Tensor cpu_shape_tensor; + if (platform::is_gpu_place(shape_tensor->place())) { + TensorCopySync(*shape_tensor, platform::CPUPlace(), + &cpu_shape_tensor); + shape_data = cpu_shape_tensor.data(); + } + auto shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + out_dims = ReshapeOp::ValidateShape(shape, in->dims()); } - auto shape = - std::vector(shape_data, shape_data + shape_tensor->numel()); - out_dims = ReshapeOp::ValidateShape(shape, in->dims()); } + out->Resize(out_dims); out->mutable_data(ctx.GetPlace(), in->type()); framework::TensorCopy( *in, ctx.GetPlace(), @@ -288,6 +348,7 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker { auto *grad_op = new framework::OpDesc(); grad_op->SetType("reshape2_grad"); grad_op->SetInput("XShape", Output("XShape")); + grad_op->SetInput("ShapeTensor", Input("ShapeTensor")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); grad_op->SetAttrMap(Attrs()); @@ -320,6 +381,16 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ctx.Input(framework::GradVarName("Out"))->type(), ctx.device_context()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "ShapeTensor") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class ReshapeOpInplaceInToOut : public framework::InplaceOpInference { diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 48a5f5af72efbe473cf1cbddcba8f896d2d8107b..6676b4a381a4632d5a3cd3039eefbc5322a25d70 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6703,6 +6703,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): if not (isinstance(shape, list) or isinstance(shape, tuple)): raise ValueError("Input shape must be a python list or tuple.") + inputs = {"X": x} if isinstance(actual_shape, Variable): inputs["Shape"] = actual_shape @@ -6711,7 +6712,12 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): # Validate the shape unk_dim_idx = -1 + contain_var = False for dim_idx, dim_size in enumerate(shape): + if isinstance(dim_size, Variable): + contain_var = True + continue + if dim_size == -1: assert unk_dim_idx == -1, ( "Only one dimension in shape can be unknown.") @@ -6725,13 +6731,35 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): "except one unknown dimension.") helper = LayerHelper("reshape2", **locals()) + if in_dygraph_mode(): + inputs = {'X': x} + attrs = {'shape': shape} + else: + if contain_var: + new_shape_tensor = [] + for dim in shape: + if isinstance(dim, Variable): + dim.stop_gradient = True + new_shape_tensor.append(dim) + else: + assert (isinstance(dim, int)) + temp_out = helper.create_variable_for_type_inference( + 'int32') + fill_constant( + [1], 'int32', dim, force_cpu=True, out=temp_out) + new_shape_tensor.append(temp_out) + inputs['ShapeTensor'] = new_shape_tensor + attrs = {} + + else: + attrs = {'shape': shape} out = x if inplace else helper.create_variable_for_type_inference( dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( type="reshape2", inputs=inputs, - attrs={"shape": shape}, + attrs=attrs, outputs={"Out": out, "XShape": x_shape}) diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 7691221a5511ce83176ff07c231873f66ca371ed..3221985c442e5a09db468df616f203400af52371 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -37,6 +37,7 @@ class TestReshapeOp(OpTest): self.infered_shape = (5, 10) def test_check_output(self): + self.check_output(no_check_set=['XShape']) def test_check_grad(self): @@ -82,5 +83,51 @@ class TestReshapeOpWithInputShape(OpTest): self.check_grad(["X"], "Out") +class TestReshapeOp_attr_tensor(OpTest): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + + shape_tensor = [] + for index, ele in enumerate(self.new_shape): + shape_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + 'ShapeTensor': shape_tensor + } + self.attrs = {} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + 'XShape': np.random.random(self.ori_shape).astype("float32") + } + + def init_data(self): + self.ori_shape = (2, 25) + self.new_shape = (5, 10) + self.infered_shape = (5, 10) + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReshapeOpDimInfer1_attr_tensor(TestReshapeOp_attr_tensor): + def init_data(self): + self.ori_shape = (5, 10) + self.new_shape = (5, -1, 5) + self.infered_shape = (5, -1, 5) + + +class TestReshapeOpDimInfer2_attr_tensor(TestReshapeOp_attr_tensor): + def init_data(self): + self.ori_shape = (2, 2, 6) + self.new_shape = (2, 0, 3, -1) + self.infered_shape = (2, 2, 3, -1) + + if __name__ == "__main__": unittest.main()