unsqueeze_op.cc 16.2 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/unsqueeze_op.h"
16

17
#include <memory>
18 19
#include <string>
#include <vector>
20

21
#include "paddle/fluid/framework/infershape_utils.h"
22
#include "paddle/fluid/framework/op_registry.h"
23
#include "paddle/phi/infermeta/unary.h"
24 25 26 27

namespace paddle {
namespace operators {

28
class UnsqueezeOp : public framework::OperatorWithKernel {
29
 public:
30 31 32 33
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
34 35 36
                      platform::errors::InvalidArgument(
                          "Input(X) of "
                          "Unsqueeze operator should not be null."));
37
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
38 39 40
                      platform::errors::InvalidArgument(
                          "Output(Out) of "
                          "Unsqueeze operator should not be null."));
41

42 43
    const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
    const auto &x_dims = ctx->GetInputDim("X");
44
    // Validity Check: input tensor dims (<6).
45
    PADDLE_ENFORCE_LE(x_dims.size(), 6,
46 47 48 49
                      platform::errors::InvalidArgument(
                          "Invalid "
                          "dimensions, the rank of Input(X) "
                          "should be in the range of [1, 6] (Eigen limit)"));
50 51 52 53 54 55 56 57 58 59 60 61
    if (!axes.empty()) {
      auto out_dims = GetOutputShape(axes, x_dims);
      ctx->SetOutputDim("Out", out_dims);
      if (x_dims[0] == out_dims[0]) {
        // Only pass LoD when the first dimension of output and Input(X)
        // are the same.
        ctx->ShareLoD("X", "Out");
      }
    } else if (ctx->HasInputs("AxesTensorList")) {
      auto AxesTensorList = ctx->Inputs("AxesTensorList");
      int output_size = x_dims.size() + static_cast<int>(AxesTensorList.size());
      PADDLE_ENFORCE_LE(output_size, 6,
62 63
                        platform::errors::InvalidArgument(
                            "The output tensor's rank should be less than 6."));
64
      std::vector<int> vec_out_dims(output_size, -1);
65
      ctx->SetOutputDim("Out", phi::make_ddim(vec_out_dims));
66 67
    } else if (ctx->HasInput("AxesTensor")) {
      auto axes_dims = ctx->GetInputDim("AxesTensor");
68 69 70 71 72 73 74 75 76 77 78 79 80
      PADDLE_ENFORCE_EQ(axes_dims.size(), 1,
                        platform::errors::InvalidArgument(
                            "Input(AxesTensor)'s dimension of "
                            "Op(unsqueeze) must be 1. "
                            "But received AxesTensor's shape = [%s], "
                            "AxesTensor's dimension = %d.",
                            axes_dims, axes_dims.size()));
      PADDLE_ENFORCE_GE(
          axes_dims[0], 0,
          platform::errors::InvalidArgument(
              "Input(AxesTensor)'s shape must be known. But received "
              "AxesTensor's shape = [%s]",
              axes_dims));
81 82
      int output_size = x_dims.size() + static_cast<int>(axes_dims[0]);
      PADDLE_ENFORCE_LE(output_size, 6,
83 84
                        platform::errors::InvalidArgument(
                            "The output tensor's rank should be less than 6."));
85
      std::vector<int> vec_out_dims(output_size, -1);
86
      ctx->SetOutputDim("Out", phi::make_ddim(vec_out_dims));
87
    }
88 89
  }

90
  static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
91
                                        const framework::DDim &in_dims) {
92 93
    int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
    int cur_output_size = in_dims.size();
94 95 96
    std::vector<int64_t> output_shape(output_size, 0);

    // Validity Check: rank range.
97
    PADDLE_ENFORCE_LE(output_size, 6,
98 99
                      platform::errors::InvalidArgument(
                          "The output tensor's rank should be less than 6."));
100 101

    for (int axis : unsqz_dims) {
102
      int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
103
      // Vaildity Check: the axis bound
104 105 106 107
      PADDLE_ENFORCE_GE(
          cur, 0,
          platform::errors::InvalidArgument("The insert dimension value should "
                                            "not be less than 0"));
108 109 110 111
      PADDLE_ENFORCE_LE(cur, cur_output_size,
                        platform::errors::InvalidArgument(
                            "The insert dimension value shoud not be larger "
                            "than the dimension size of input tensor"));
112 113 114 115 116 117 118 119 120
      // Move old axis, and insert new axis
      for (int i = cur_output_size; i >= cur; --i) {
        if (output_shape[i] == 1) {
          // Move axis
          output_shape[i + 1] = 1;
          output_shape[i] = 0;
        }
      }
      output_shape[cur] = 1;
121
      // Add the output size.
122
      cur_output_size++;
123 124
    }

125
    // Make output shape
126 127
    for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
      if (output_shape[out_idx] == 0) {
128 129 130 131
        output_shape[out_idx] = in_dims[in_idx++];
      }
    }

132
    return phi::make_ddim(output_shape);
133
  }
134 135 136 137

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
138 139 140 141
    return framework::OpKernelType(
        framework::TransToProtoVarType(
            ctx.Input<framework::LoDTensor>("X")->type()),
        ctx.device_context());
142 143 144 145 146 147 148 149 150 151 152
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string &var_name, const framework::Tensor &tensor,
      const framework::OpKernelType &expected_kernel_type) const override {
    if (var_name == "AxesTensor" || var_name == "AxesTensorList") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
153 154 155 156 157 158
};

class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
159 160 161 162 163 164 165 166 167 168 169
    AddInput("AxesTensor",
             "(Tensor<int32>, optional). The dimensions to be inserted. "
             "If it exists, it will replace Attr(axes).")
        .AsDispensable();
    AddInput(
        "AxesTensorList",
        "(vector<Tensor<int32>>, optional). The dimensions to be inserted. "
        "If it exists, it will replace Attr(axes)."
        "The shape of the element in vector must be [1].")
        .AsDuplicable()
        .AsDispensable();
170 171
    AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
    AddAttr<std::vector<int>>("axes",
172
                              "(std::vector<int>). List of integers,"
173
                              " indicating the dimensions to be inserted")
174
        .SetDefault({})
175 176
        .AddCustomChecker([](const std::vector<int> &axes) {
          // Validity Check: axes dims (<6).
177
          PADDLE_ENFORCE_LT(static_cast<int>(axes.size()), 6,
178 179 180 181
                            platform::errors::InvalidArgument(
                                "Invalid "
                                "dimensions, dynamic dimensions should be "
                                "within [1, 6] dimensions (Eigen limit)."));
T
tianshuo78520a 已提交
182
          // Validity Check: the range of unsqueeze axis.
183
          for (int axis : axes) {
184
            PADDLE_ENFORCE_LT(axis, 6,
185 186 187 188
                              platform::errors::InvalidArgument(
                                  "Invalid "
                                  "dimensions, input axis should be"
                                  "within [1, 6] dimensions (Eigen limit)."));
189 190
          }
        });
191
    AddComment(R"DOC(
192 193
    Unsqueeze Operator.

194 195 196 197 198 199
    Insert single-dimensional entries to the shape of a tensor.
    Takes one required argument axes, a list of dimensions that will be inserted.
    Dimension indices in axes are as seen in the output tensor.

    For example:
      Given a tensor such that tensor with shape [3, 4, 5],
200
      then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
201 202 203 204
    )DOC");
  }
};

205
class UnsqueezeGradOp : public framework::OperatorWithKernel {
206
 public:
207 208 209
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
210
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
211
    ctx->ShareLoD("X", framework::GradVarName("X"));
212
  }
213 214 215 216 217 218 219 220 221 222 223 224 225 226

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
  }
};

template <typename T>
class UnsqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

227
  void Apply(GradOpPtr<T> grad_op) const override {
228 229 230 231 232 233
    grad_op->SetType("unsqueeze_grad");
    grad_op->SetInput("X", this->Input("X"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
  }
234
};
235

236 237 238 239 240 241 242 243 244 245 246 247 248
template <typename T>
class UnsqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("unsqueeze");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetAttrMap(this->Attrs());
  }
};

249 250 251 252 253
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
// unsqueeze, the XShape is used to carry the shape and lod of X which
// will be used in unsqueeze_grad, in this way, the framework can reuse
// the memory of X immediately the unsqueeze2_op is finished.
// Considering compatibility issues, we could not fix unsqueeze2_op
254
class Unsqueeze2Op : public UnsqueezeOp {
255
 public:
256
  using UnsqueezeOp::UnsqueezeOp;
257 258 259 260 261 262 263 264 265
};

class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
 public:
  void Make() override {
    UnsqueezeOpMaker::Make();
    AddOutput("XShape",
              "XShape is just used to store the shape and lod of X, which will "
              "be used in UnsqueezeGradOp.")
266 267
        .AsIntermediate()
        .AsExtra();
268 269 270
  }
};

H
hong 已提交
271 272
template <typename T>
class Unsqueeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
273
 public:
H
hong 已提交
274
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
275

276
  void Apply(GradOpPtr<T> grad_op) const override {
277
    grad_op->SetType("unsqueeze2_grad");
H
hong 已提交
278 279 280 281
    grad_op->SetInput("XShape", this->Output("XShape"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
282 283 284
  }
};

285
class Unsqueeze2GradOp : public framework::OperatorWithKernel {
286
 public:
287 288
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext *context) const override {
289 290 291
    PADDLE_ENFORCE_EQ(
        context->HasInput("XShape"), true,
        platform::errors::InvalidArgument("Input(XShape) shouldn't be null."));
292
    PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true,
293 294
                      platform::errors::InvalidArgument(
                          "Input(Out@GRAD) shouldn't be null."));
295
    auto xshape_dims = context->GetInputDim("XShape");
296
    auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
297 298 299 300
    context->SetOutputDim(framework::GradVarName("X"), x_dims);
    context->ShareLoD("XShape", framework::GradVarName("X"));
  }

301 302 303
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
304 305 306
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
307 308
  }
};
309

310 311 312 313 314 315 316 317 318 319 320 321 322 323
template <typename T>
class Unsqueeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("unsqueeze2");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetOutput("XShape", this->Input("XShape"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

324 325 326 327
DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
328
DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
329 330 331
}  // namespace operators
}  // namespace paddle

332 333 334
DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2, Unsqueeze2InferShapeFunctor,
                            PD_INFER_META(phi::UnsqueezeInferMeta));

335
namespace ops = paddle::operators;
336 337 338
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
                  ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>,
                  ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>);
339

340
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
341 342
                  ops::UnsqueezeDoubleGradOpMaker<paddle::framework::OpDesc>,
                  ops::UnsqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
343
                  ops::UnsqueezeGradOpNoNeedBufferVarInferer);
344 345

REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
H
hong 已提交
346 347
                  ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
                  ops::Unsqueeze2GradOpMaker<paddle::imperative::OpBase>,
348 349
                  Unsqueeze2InferShapeFunctor, ops::UnsqueezeInplaceInferer);

350
REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
351 352
                  ops::Unsqueeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
                  ops::Unsqueeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
353
                  ops::UnsqueezeGradInplaceInferer);
354 355 356 357

REGISTER_OP_CPU_KERNEL(
    unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
358
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
359
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
360
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int16_t>,
361
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
362
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
363 364 365 366
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
                         paddle::platform::complex<float>>,
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
367 368 369
                         paddle::platform::complex<double>>,
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
                         paddle::platform::bfloat16>);
370 371 372 373
REGISTER_OP_CPU_KERNEL(
    unsqueeze_grad,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
374
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
375
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
376
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int16_t>,
377
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
378
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
379 380 381 382
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
                             paddle::platform::complex<float>>,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
383 384 385
                             paddle::platform::complex<double>>,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
                             paddle::platform::bfloat16>);