diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index e55de4508bcd46be5d0b4ae1766213eb688b50d3..22f2a0fc3550587e95c8a3c7b47e2a3474ea9bea 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -37,12 +37,38 @@ class UnsqueezeOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_LE(x_dims.size(), 6, "Invalid dimensions, the rank of Input(X) " "should be in the range of [1, 6] (Eigen limit)"); - auto out_dims = GetOutputShape(axes, x_dims); - ctx->SetOutputDim("Out", out_dims); - if (x_dims[0] == out_dims[0]) { - // Only pass LoD when the first dimension of output and Input(X) - // are the same. - ctx->ShareLoD("X", "Out"); + if (!axes.empty()) { + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + } else if (ctx->HasInputs("AxesTensorList")) { + auto AxesTensorList = ctx->Inputs("AxesTensorList"); + int output_size = x_dims.size() + static_cast(AxesTensorList.size()); + PADDLE_ENFORCE_LE(output_size, 6, + "The output tensor's rank should be less than 6."); + std::vector vec_out_dims(output_size, -1); + ctx->SetOutputDim("Out", framework::make_ddim(vec_out_dims)); + } else if (ctx->HasInput("AxesTensor")) { + auto axes_dims = ctx->GetInputDim("AxesTensor"); + PADDLE_ENFORCE_EQ( + axes_dims.size(), 1, + "Input(AxesTensor)'s dimension of Op(unsqueeze) must be 1. " + "But received AxesTensor's shape = [%s], " + "AxesTensor's dimension = %d.", + axes_dims, axes_dims.size()); + PADDLE_ENFORCE_GE(axes_dims[0], 0, + "Input(AxesTensor)'s shape must be known. But received " + "AxesTensor's shape = [%s]", + axes_dims); + int output_size = x_dims.size() + static_cast(axes_dims[0]); + PADDLE_ENFORCE_LE(output_size, 6, + "The output tensor's rank should be less than 6."); + std::vector vec_out_dims(output_size, -1); + ctx->SetOutputDim("Out", framework::make_ddim(vec_out_dims)); } } @@ -83,19 +109,46 @@ class UnsqueezeOp : public framework::OperatorWithKernel { return framework::make_ddim(output_shape); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "AxesTensor" || var_name == "AxesTensorList") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); + AddInput("AxesTensor", + "(Tensor, optional). The dimensions to be inserted. " + "If it exists, it will replace Attr(axes).") + .AsDispensable(); + AddInput( + "AxesTensorList", + "(vector>, optional). The dimensions to be inserted. " + "If it exists, it will replace Attr(axes)." + "The shape of the element in vector must be [1].") + .AsDuplicable() + .AsDispensable(); AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddAttr>("axes", "(std::vector). List of integers," " indicating the dimensions to be inserted") + .SetDefault({}) .AddCustomChecker([](const std::vector &axes) { - PADDLE_ENFORCE_EQ(!axes.empty(), true, - "Invalid axes, The unsqueeze axes is empty."); // Validity Check: axes dims (<6). PADDLE_ENFORCE_LT(static_cast(axes.size()), 6, "Invalid dimensions, dynamic dimensions should be " @@ -136,28 +189,12 @@ class UnsqueezeGradOp : public framework::OperatorWithKernel { // will be used in unsqueeze_grad, in this way, the framework can reuse // the memory of X immediately the unsqueeze2_op is finished. // Considering compatibility issues, we could not fix unsqueeze2_op -class Unsqueeze2Op : public framework::OperatorWithKernel { +class Unsqueeze2Op : public UnsqueezeOp { public: - using framework::OperatorWithKernel::OperatorWithKernel; + using UnsqueezeOp::UnsqueezeOp; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - "Input(X) of Unsqueeze operator should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - "Output(Out) of Unsqueeze operator should not be null."); - - const auto &axes = ctx->Attrs().Get>("axes"); + UnsqueezeOp::InferShape(ctx); const auto &x_dims = ctx->GetInputDim("X"); - // Validity Check: input tensor dims (<6). - PADDLE_ENFORCE_LE(x_dims.size(), 6, - "Invalid dimensions, the rank of Input(X) " - "should be in the range of [1, 6] (Eigen limit)"); - auto out_dims = UnsqueezeOp::GetOutputShape(axes, x_dims); - ctx->SetOutputDim("Out", out_dims); - if (x_dims[0] == out_dims[0]) { - // Only pass LoD when the first dimension of output and Input(X) - // are the same. - ctx->ShareLoD("X", "Out"); - } PADDLE_ENFORCE_EQ( ctx->HasOutput("XShape"), true, @@ -252,12 +289,11 @@ REGISTER_OP_CPU_KERNEL( ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel); REGISTER_OP_CPU_KERNEL( - unsqueeze2, - ops::Unsqueeze2Kernel, - ops::Unsqueeze2Kernel, - ops::Unsqueeze2Kernel, - ops::Unsqueeze2Kernel, - ops::Unsqueeze2Kernel); + unsqueeze2, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); REGISTER_OP_CPU_KERNEL( unsqueeze2_grad, ops::Unsqueeze2GradKernel, diff --git a/paddle/fluid/operators/unsqueeze_op.cu.cc b/paddle/fluid/operators/unsqueeze_op.cu.cc index fbdec5af94a570f430f9c50a16fe01b69a4f2d14..ffdd61170e5a67db18da6db1298d985fa82eb859 100644 --- a/paddle/fluid/operators/unsqueeze_op.cu.cc +++ b/paddle/fluid/operators/unsqueeze_op.cu.cc @@ -31,11 +31,11 @@ REGISTER_OP_CUDA_KERNEL( ops::UnsqueezeGradKernel); REGISTER_OP_CUDA_KERNEL( unsqueeze2, - ops::Unsqueeze2Kernel, - ops::Unsqueeze2Kernel, - ops::Unsqueeze2Kernel, - ops::Unsqueeze2Kernel, - ops::Unsqueeze2Kernel); + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); REGISTER_OP_CUDA_KERNEL( unsqueeze2_grad, ops::Unsqueeze2GradKernel, diff --git a/paddle/fluid/operators/unsqueeze_op.h b/paddle/fluid/operators/unsqueeze_op.h index 68f0cbe81223126c3f850a6e738c7b581910c69d..75c443da95cb816984cc68ea41871ec218e7a47c 100644 --- a/paddle/fluid/operators/unsqueeze_op.h +++ b/paddle/fluid/operators/unsqueeze_op.h @@ -23,17 +23,66 @@ limitations under the License. */ namespace paddle { namespace operators { +template +inline std::vector GetDataFromTensorList( + const std::vector &list_tensor) { + std::vector vec_new_data; + for (size_t i = 0; i < list_tensor.size(); ++i) { + auto tensor = list_tensor[i]; + PADDLE_ENFORCE_EQ( + tensor->dims(), framework::make_ddim({1}), + "ShapeError: If the element type is Tensor, " + "the element's shape must be [1]. But received the element's shape " + "is [%s]", + tensor->dims()); + if (platform::is_gpu_place(tensor->place())) { + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_new_data.push_back((*temp.data())); + } else { + vec_new_data.push_back((*tensor->data())); + } + } + return vec_new_data; +} +template +inline std::vector GetDataFromTensor(const framework::Tensor *x) { + auto *data = x->data(); + framework::Tensor cpu_attr_tensor; + if (platform::is_gpu_place(x->place())) { + TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor); + data = cpu_attr_tensor.data(); + } + auto vec_data = std::vector(data, data + x->numel()); + return vec_data; +} template class UnsqueezeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto &axes = context.Attr>("axes"); + auto axes = context.Attr>("axes"); auto *in = context.Input("X"); auto *out = context.Output("Out"); auto x_dims = in->dims(); - auto out_dims = GetOutputShape(axes, x_dims); + bool need_resize_out_dims = false; + if (axes.empty()) { + auto axes_tensor_list = + context.MultiInput("AxesTensorList"); + if (axes_tensor_list.size() > 0) { + axes = GetDataFromTensorList(axes_tensor_list); + } else if (context.HasInput("AxesTensor")) { + auto *axes_tensor = context.Input("AxesTensor"); + axes = GetDataFromTensor(axes_tensor); + } + need_resize_out_dims = true; + } + framework::DDim out_dims = out->dims(); + if (need_resize_out_dims) { + out_dims = GetOutputShape(axes, x_dims); + out->Resize(out_dims); + } out->mutable_data(context.GetPlace(), in->type()); framework::TensorCopy( *in, context.GetPlace(), @@ -95,27 +144,6 @@ class UnsqueezeGradKernel : public framework::OpKernel { } }; -template -class Unsqueeze2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *out = context.Output("Out"); - auto *in = context.Input("X"); - - auto &axes = context.Attr>("axes"); - - auto x_dims = in->dims(); - auto out_dims = - UnsqueezeKernel::GetOutputShape(axes, x_dims); - - out->mutable_data(context.GetPlace(), in->type()); - framework::TensorCopy( - *in, context.GetPlace(), - context.template device_context(), out); - out->Resize(out_dims); - } -}; - template class Unsqueeze2GradKernel : public framework::OpKernel { public: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 2e0c1044a6edb90f92f82eb16a89d16145f3021c..e09a49d8322d8c687ca538993794ac9a62e4eb6b 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8997,7 +8997,7 @@ def unsqueeze(input, axes, name=None): Args: input (Variable): The input Tensor to be unsqueezed. It is a N-D Tensor of data types float32, float64, int32. - axes (list): List of integers, indicating the dimensions to be inserted. + axes (int|list|tuple|Variable): Indicates the dimensions to be inserted. The data type is ``int32`` . If ``axes`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``axes`` is an Variable, it should be an 1-D Tensor . name (str|None): Name for this layer. Returns: @@ -9011,13 +9011,45 @@ def unsqueeze(input, axes, name=None): y = fluid.layers.unsqueeze(input=x, axes=[1]) """ - helper = LayerHelper("unsqueeze", **locals()) + if not isinstance(axes, (int, list, tuple, Variable)): + raise TypeError( + "The type of 'axes' in unsqueeze must be int, list, tuple or Variable, but " + "received %s." % (type(axes))) + helper = LayerHelper("unsqueeze2", **locals()) + inputs = {"X": input} + attrs = {} + + def _to_Variable_list(one_list): + Variable_list = [] + for ele in one_list: + if isinstance(ele, Variable): + ele.stop_gradient = True + Variable_list.append(ele) + else: + assert (isinstance(ele, int)) + temp_out = helper.create_variable_for_type_inference('int32') + fill_constant([1], 'int32', ele, force_cpu=True, out=temp_out) + Variable_list.append(temp_out) + return Variable_list + + if isinstance(axes, int): + axes = [axes] + if isinstance(axes, Variable): + axes.stop_gradient = True + inputs["AxesTensor"] = axes + elif isinstance(axes, (list, tuple)): + contain_var = not all(not isinstance(ele, Variable) for ele in axes) + if contain_var: + inputs["AxesTensorList"] = _to_Variable_list(axes) + else: + attrs["axes"] = axes + out = helper.create_variable_for_type_inference(dtype=input.dtype) x_shape = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( type="unsqueeze2", - inputs={"X": input}, - attrs={"axes": axes}, + inputs=inputs, + attrs=attrs, outputs={"Out": out, "XShape": x_shape}) diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py index 14dd2bb06f9a18d0b15a4aee4e9e6bfdf8c41206..c04fc47f2289d6949993e8fb48d886969af4654f 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np - +import paddle.fluid as fluid from op_test import OpTest @@ -79,5 +79,163 @@ class TestUnsqueezeOp4(TestUnsqueezeOp): self.new_shape = (3, 1, 1, 2, 5, 1) +# axes is a list(with tensor) +class TestUnsqueezeOp_AxesTensorList(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "unsqueeze2" + + axes_tensor_list = [] + for index, ele in enumerate(self.axes): + axes_tensor_list.append(("axes" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + "AxesTensorList": axes_tensor_list + } + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 5) + + def init_attrs(self): + self.attrs = {} + + +class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (-1, ) + self.new_shape = (3, 5, 1) + + +class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -1) + self.new_shape = (1, 3, 5, 1) + + +class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + +class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (3, 1, 1, 2, 5, 1) + + +# axes is a Tensor +class TestUnsqueezeOp_AxesTensor(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "unsqueeze2" + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + "AxesTensor": np.array(self.axes).astype("int32") + } + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 5) + + def init_attrs(self): + self.attrs = {} + + +class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (-1, ) + self.new_shape = (3, 5, 1) + + +class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -1) + self.new_shape = (1, 3, 5, 1) + + +class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + +class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (3, 1, 1, 2, 5, 1) + + +# test api +class TestUnsqueezeAPI(OpTest): + def test_api(self): + input = np.random.random([3, 2, 5]).astype("float32") + x = fluid.data(name='x', shape=[3, 2, 5], dtype="float32") + positive_3 = fluid.layers.fill_constant([1], "int32", 3) + axes_tensor = fluid.data(name='axes_tensor', shape=[3], dtype="int32") + + out_1 = fluid.layers.unsqueeze(x, axes=[3, 1, 1]) + out_2 = fluid.layers.unsqueeze(x, axes=[positive_3, 1, 1]) + out_3 = fluid.layers.unsqueeze(x, axes=axes_tensor) + out_4 = fluid.layers.unsqueeze(x, axes=3) + + exe = fluid.Executor(place=fluid.CPUPlace()) + res_1, res_2, res_3, res_4 = exe.run( + fluid.default_main_program(), + feed={ + "x": input, + "axes_tensor": np.array([3, 1, 1]).astype("int32") + }, + fetch_list=[out_1, out_2, out_3, out_4]) + + assert np.array_equal(res_1, input.reshape([3, 1, 1, 2, 5, 1])) + assert np.array_equal(res_2, input.reshape([3, 1, 1, 2, 5, 1])) + assert np.array_equal(res_3, input.reshape([3, 1, 1, 2, 5, 1])) + assert np.array_equal(res_4, input.reshape([3, 2, 5, 1])) + + def test_error(self): + def test_axes_type(): + x2 = fluid.data(name="x2", shape=[2, 25], dtype="int32") + fluid.layers.unsqueeze(x2, axes=2.1) + + self.assertRaises(TypeError, test_axes_type) + + if __name__ == "__main__": unittest.main()