diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c01d9bc38459952f9f422fc8c1ac5b0fbcb3c4d0..bc07bbe67e3dafc08aeb62cd75629966c216ce6e 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) +op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) if (WITH_GPU) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2a15fdf572e0de30f9949dda5020e130b0c5585 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -0,0 +1,191 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class UnsqueezeOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnsqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UnsqueezeOp should not be null."); + + const auto &axes = ctx->Attrs().Get>("axes"); + const auto &x_dims = ctx->GetInputDim("X"); + // Validity Check: input tensor dims (<6). + PADDLE_ENFORCE(x_dims.size() <= 6, + "Invalid dimensions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)"); + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + } + + static framework::DDim GetOutputShape(const std::vector unsqz_dims, + const framework::DDim &in_dims) { + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validity Check: rank range. + PADDLE_ENFORCE(output_size <= 6, + "The output tensor's rank should be less than 6."); + + for (int axis : unsqz_dims) { + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; + // Vaildity Check: the axis bound + PADDLE_ENFORCE( + cur >= 0 && cur <= cur_output_size, + "The unsqueeze dims must be within range of current rank."); + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + output_shape[cur] = 1; + // Add the output size. + cur_output_size++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { + output_shape[out_idx] = in_dims[in_idx++]; + } + } + + return framework::make_ddim(output_shape); + } +}; + +class UnsqueezeOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &axes = Attr>("axes"); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(out_dims); + attrs["inplace"] = Attr("inplace"); + // Invoke Reshape op. + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, + {{"Out", {Output("Out")}}}, attrs); + reshape_op->Run(scope, place); + } +}; + +class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); + AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); + AddAttr>("axes", + "(std::vector). List of integers," + " indicating the dimensions to be inserted") + .AddCustomChecker([](const std::vector &axes) { + PADDLE_ENFORCE(!axes.empty(), + "Invalid axes, The unsqueeze axes is empty."); + // Validity Check: axes dims (<6). + PADDLE_ENFORCE(static_cast(axes.size()) < 6, + "Invalid dimensions, dynamic dimensions should be " + "within [1, 6] dimensions (Eigen limit)."); + // Validity Check: the range of unsqueeze aixs. + for (int axis : axes) { + PADDLE_ENFORCE(axis < 6, + "Invalid dimensions, input axis should be" + " within [1, 6] dimensions (Eigen limit)."); + } + }); + AddAttr( + "inplace", + "(default: false) Unsqueeze the source tensor's shape without " + "memory copy. When Attr(inplace) is set true, the output " + "tensor shares memory with Input(X), otherwise, a new output " + "tensor is created, and its data are copied from Input(x).") + .SetDefault(false); + AddComment(R"DOC( + Unsqueeze Operator. + + Insert single-dimensional entries to the shape of a tensor. + Takes one required argument axes, a list of dimensions that will be inserted. + Dimension indices in axes are as seen in the output tensor. + + For example: + Given a tensor such that tensor with shape [3, 4, 5], + then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1] + )DOC"); + } +}; + +class UnsqueezeGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", framework::GradVarName("X")); + } +}; + +class UnsqueezeGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dx_name = Output(framework::GradVarName("X")); + auto dout_name = Input(framework::GradVarName("Out")); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(x_dims); + attrs["inplace"] = Attr("inplace"); + + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, + attrs); + reshape_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle + +// Tell linker to use reshape op. +USE_OP(reshape); + +namespace ops = paddle::operators; +REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, + ops::UnsqueezeOpInferShape, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, + ops::UnsqueezeGradInferShape); diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4aa0a40b5eb494f6027e800ca6b466bbe1c302 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -0,0 +1,111 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestUnsqueezeOp(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": False} + + +# Correct: Single input index. +class TestUnsqueezeOp1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (-1, ) + self.new_shape = (3, 5, 1) + + +# Correct: Mixed input axis. +class TestUnsqueezeOp2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -1) + self.new_shape = (1, 3, 5, 1) + + +# Correct: There is duplicated axis. +class TestUnsqueezeOp3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + +# Correct: Reversed axes. +class TestUnsqueezeOp4(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (3, 1, 1, 2, 5, 1) + + +# Correct: Inplace. +class TestUnsqueezeOpInplace1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, 2) + self.new_shape = (1, 3, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. There is mins index. +class TestUnsqueezeOpInplace2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -2) + self.new_shape = (1, 3, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. There is duplicated axis. +class TestUnsqueezeOpInplace3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +if __name__ == "__main__": + unittest.main()