/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/unsqueeze_op.h" #include #include #include #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { class UnsqueezeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, platform::errors::InvalidArgument( "Input(X) of " "Unsqueeze operator should not be null.")); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, platform::errors::InvalidArgument( "Output(Out) of " "Unsqueeze operator should not be null.")); const auto &axes = ctx->Attrs().Get>("axes"); const auto &x_dims = ctx->GetInputDim("X"); // Validity Check: input tensor dims (<6). PADDLE_ENFORCE_LE(x_dims.size(), 6, platform::errors::InvalidArgument( "Invalid " "dimensions, the rank of Input(X) " "should be in the range of [1, 6] (Eigen limit)")); if (!axes.empty()) { auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); if (x_dims[0] == out_dims[0]) { // Only pass LoD when the first dimension of output and Input(X) // are the same. ctx->ShareLoD("X", "Out"); } } else if (ctx->HasInputs("AxesTensorList")) { auto AxesTensorList = ctx->Inputs("AxesTensorList"); int output_size = x_dims.size() + static_cast(AxesTensorList.size()); PADDLE_ENFORCE_LE(output_size, 6, platform::errors::InvalidArgument( "The output tensor's rank should be less than 6.")); std::vector vec_out_dims(output_size, -1); ctx->SetOutputDim("Out", framework::make_ddim(vec_out_dims)); } else if (ctx->HasInput("AxesTensor")) { auto axes_dims = ctx->GetInputDim("AxesTensor"); PADDLE_ENFORCE_EQ(axes_dims.size(), 1, platform::errors::InvalidArgument( "Input(AxesTensor)'s dimension of " "Op(unsqueeze) must be 1. " "But received AxesTensor's shape = [%s], " "AxesTensor's dimension = %d.", axes_dims, axes_dims.size())); PADDLE_ENFORCE_GE( axes_dims[0], 0, platform::errors::InvalidArgument( "Input(AxesTensor)'s shape must be known. But received " "AxesTensor's shape = [%s]", axes_dims)); int output_size = x_dims.size() + static_cast(axes_dims[0]); PADDLE_ENFORCE_LE(output_size, 6, platform::errors::InvalidArgument( "The output tensor's rank should be less than 6.")); std::vector vec_out_dims(output_size, -1); ctx->SetOutputDim("Out", framework::make_ddim(vec_out_dims)); } } static framework::DDim GetOutputShape(const std::vector unsqz_dims, const framework::DDim &in_dims) { int output_size = in_dims.size() + static_cast(unsqz_dims.size()); int cur_output_size = in_dims.size(); std::vector output_shape(output_size, 0); // Validity Check: rank range. PADDLE_ENFORCE_LE(output_size, 6, platform::errors::InvalidArgument( "The output tensor's rank should be less than 6.")); for (int axis : unsqz_dims) { int cur = axis < 0 ? axis + cur_output_size + 1 : axis; // Vaildity Check: the axis bound PADDLE_ENFORCE_GE(cur, 0, platform::errors::InvalidArgument( "The insert dimension value should " "not be less than 0")); PADDLE_ENFORCE_LE(cur, cur_output_size, platform::errors::InvalidArgument( "The insert dimension value shoud not be larger " "than the dimension size of input tensor")); // Move old axis, and insert new axis for (int i = cur_output_size; i >= cur; --i) { if (output_shape[i] == 1) { // Move axis output_shape[i + 1] = 1; output_shape[i] = 0; } } output_shape[cur] = 1; // Add the output size. cur_output_size++; } // Make output shape for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { if (output_shape[out_idx] == 0) { output_shape[out_idx] = in_dims[in_idx++]; } } return framework::make_ddim(output_shape); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType(ctx.Input("X")->type(), ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { if (var_name == "AxesTensor" || var_name == "AxesTensorList") { return expected_kernel_type; } return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } }; class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); AddInput("AxesTensor", "(Tensor, optional). The dimensions to be inserted. " "If it exists, it will replace Attr(axes).") .AsDispensable(); AddInput( "AxesTensorList", "(vector>, optional). The dimensions to be inserted. " "If it exists, it will replace Attr(axes)." "The shape of the element in vector must be [1].") .AsDuplicable() .AsDispensable(); AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddAttr>("axes", "(std::vector). List of integers," " indicating the dimensions to be inserted") .SetDefault({}) .AddCustomChecker([](const std::vector &axes) { // Validity Check: axes dims (<6). PADDLE_ENFORCE_LT(static_cast(axes.size()), 6, platform::errors::InvalidArgument( "Invalid " "dimensions, dynamic dimensions should be " "within [1, 6] dimensions (Eigen limit).")); // Validity Check: the range of unsqueeze axis. for (int axis : axes) { PADDLE_ENFORCE_LT(axis, 6, platform::errors::InvalidArgument( "Invalid " "dimensions, input axis should be" "within [1, 6] dimensions (Eigen limit).")); } }); AddComment(R"DOC( Unsqueeze Operator. Insert single-dimensional entries to the shape of a tensor. Takes one required argument axes, a list of dimensions that will be inserted. Dimension indices in axes are as seen in the output tensor. For example: Given a tensor such that tensor with shape [3, 4, 5], then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1] )DOC"); } }; class UnsqueezeGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareLoD("X", framework::GradVarName("X")); } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")), ctx.device_context()); } }; template class UnsqueezeGradOpMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; void Apply(GradOpPtr grad_op) const override { grad_op->SetType("unsqueeze_grad"); grad_op->SetInput("X", this->Input("X")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetAttrMap(this->Attrs()); } }; // FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on // unsqueeze, the XShape is used to carry the shape and lod of X which // will be used in unsqueeze_grad, in this way, the framework can reuse // the memory of X immediately the unsqueeze2_op is finished. // Considering compatibility issues, we could not fix unsqueeze2_op class Unsqueeze2Op : public UnsqueezeOp { public: using UnsqueezeOp::UnsqueezeOp; void InferShape(framework::InferShapeContext *ctx) const override { UnsqueezeOp::InferShape(ctx); const auto &x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ( ctx->HasOutput("XShape"), true, platform::errors::InvalidArgument("Output(XShape) of Unsqueeze " "operator should not be null.")); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < x_dims.size(); ++i) { xshape_dims[i + 1] = x_dims[i]; } ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); ctx->ShareLoD("X", /*->*/ "XShape"); } }; class Unsqueeze2OpMaker : public UnsqueezeOpMaker { public: void Make() override { UnsqueezeOpMaker::Make(); AddOutput("XShape", "XShape is just used to store the shape and lod of X, which will " "be used in UnsqueezeGradOp.") .AsIntermediate(); } }; template class Unsqueeze2GradOpMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; void Apply(GradOpPtr grad_op) const override { grad_op->SetType("unsqueeze2_grad"); grad_op->SetInput("XShape", this->Output("XShape")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetAttrMap(this->Attrs()); } }; class Unsqueeze2GradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *context) const override { PADDLE_ENFORCE_EQ( context->HasInput("XShape"), true, platform::errors::InvalidArgument("Input(XShape) shouldn't be null.")); PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true, platform::errors::InvalidArgument( "Input(Out@GRAD) shouldn't be null.")); auto xshape_dims = context->GetInputDim("XShape"); auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); context->SetOutputDim(framework::GradVarName("X"), x_dims); context->ShareLoD("XShape", framework::GradVarName("X")); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")), ctx.device_context()); } }; DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer, {framework::GradVarName("Out"), framework::GradVarName("X")}); DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X"); } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, ops::UnsqueezeGradOpMaker, ops::UnsqueezeGradOpMaker); REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, ops::UnsqueezeGradOpNoNeedBufferVarInferer); REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker, ops::Unsqueeze2GradOpMaker, ops::Unsqueeze2GradOpMaker, ops::UnsqueezeInplaceInferer); REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp, ops::UnsqueezeGradInplaceInferer); REGISTER_OP_CPU_KERNEL( unsqueeze, ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel); REGISTER_OP_CPU_KERNEL( unsqueeze_grad, ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel); REGISTER_OP_CPU_KERNEL( unsqueeze2, ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel); REGISTER_OP_CPU_KERNEL( unsqueeze2_grad, ops::Unsqueeze2GradKernel, ops::Unsqueeze2GradKernel, ops::Unsqueeze2GradKernel, ops::Unsqueeze2GradKernel, ops::Unsqueeze2GradKernel);