unsqueeze_op.cc 16.7 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/unsqueeze_op.h"
16

17
#include <memory>
18 19
#include <string>
#include <vector>
20

21
#include "paddle/fluid/framework/infershape_utils.h"
22
#include "paddle/fluid/framework/op_registry.h"
23
#include "paddle/phi/infermeta/unary.h"
24 25 26 27

namespace paddle {
namespace operators {

28
class UnsqueezeOp : public framework::OperatorWithKernel {
29
 public:
30 31 32
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
33 34
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
                      true,
35 36 37
                      platform::errors::InvalidArgument(
                          "Input(X) of "
                          "Unsqueeze operator should not be null."));
38 39
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"),
                      true,
40 41 42
                      platform::errors::InvalidArgument(
                          "Output(Out) of "
                          "Unsqueeze operator should not be null."));
43

44 45
    const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
    const auto &x_dims = ctx->GetInputDim("X");
46
    // Validity Check: input tensor dims (<6).
47 48
    PADDLE_ENFORCE_LE(x_dims.size(),
                      6,
49 50 51 52
                      platform::errors::InvalidArgument(
                          "Invalid "
                          "dimensions, the rank of Input(X) "
                          "should be in the range of [1, 6] (Eigen limit)"));
53 54 55 56 57 58 59 60 61 62 63
    if (!axes.empty()) {
      auto out_dims = GetOutputShape(axes, x_dims);
      ctx->SetOutputDim("Out", out_dims);
      if (x_dims[0] == out_dims[0]) {
        // Only pass LoD when the first dimension of output and Input(X)
        // are the same.
        ctx->ShareLoD("X", "Out");
      }
    } else if (ctx->HasInputs("AxesTensorList")) {
      auto AxesTensorList = ctx->Inputs("AxesTensorList");
      int output_size = x_dims.size() + static_cast<int>(AxesTensorList.size());
64 65
      PADDLE_ENFORCE_LE(output_size,
                        6,
66 67
                        platform::errors::InvalidArgument(
                            "The output tensor's rank should be less than 6."));
68
      std::vector<int> vec_out_dims(output_size, -1);
69
      ctx->SetOutputDim("Out", phi::make_ddim(vec_out_dims));
70 71
    } else if (ctx->HasInput("AxesTensor")) {
      auto axes_dims = ctx->GetInputDim("AxesTensor");
72 73
      PADDLE_ENFORCE_EQ(axes_dims.size(),
                        1,
74 75 76 77 78
                        platform::errors::InvalidArgument(
                            "Input(AxesTensor)'s dimension of "
                            "Op(unsqueeze) must be 1. "
                            "But received AxesTensor's shape = [%s], "
                            "AxesTensor's dimension = %d.",
79 80
                            axes_dims,
                            axes_dims.size()));
81
      PADDLE_ENFORCE_GE(
82 83
          axes_dims[0],
          0,
84 85 86 87
          platform::errors::InvalidArgument(
              "Input(AxesTensor)'s shape must be known. But received "
              "AxesTensor's shape = [%s]",
              axes_dims));
88
      int output_size = x_dims.size() + static_cast<int>(axes_dims[0]);
89 90
      PADDLE_ENFORCE_LE(output_size,
                        6,
91 92
                        platform::errors::InvalidArgument(
                            "The output tensor's rank should be less than 6."));
93
      std::vector<int> vec_out_dims(output_size, -1);
94
      ctx->SetOutputDim("Out", phi::make_ddim(vec_out_dims));
95
    }
96 97
  }

98
  static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
99
                                        const framework::DDim &in_dims) {
100 101
    int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
    int cur_output_size = in_dims.size();
102 103 104
    std::vector<int64_t> output_shape(output_size, 0);

    // Validity Check: rank range.
105 106
    PADDLE_ENFORCE_LE(output_size,
                      6,
107 108
                      platform::errors::InvalidArgument(
                          "The output tensor's rank should be less than 6."));
109 110

    for (int axis : unsqz_dims) {
111
      int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
112
      // Vaildity Check: the axis bound
113
      PADDLE_ENFORCE_GE(
114 115
          cur,
          0,
116 117
          platform::errors::InvalidArgument("The insert dimension value should "
                                            "not be less than 0"));
118 119
      PADDLE_ENFORCE_LE(cur,
                        cur_output_size,
120 121 122
                        platform::errors::InvalidArgument(
                            "The insert dimension value shoud not be larger "
                            "than the dimension size of input tensor"));
123 124 125 126 127 128 129 130 131
      // Move old axis, and insert new axis
      for (int i = cur_output_size; i >= cur; --i) {
        if (output_shape[i] == 1) {
          // Move axis
          output_shape[i + 1] = 1;
          output_shape[i] = 0;
        }
      }
      output_shape[cur] = 1;
132
      // Add the output size.
133
      cur_output_size++;
134 135
    }

136
    // Make output shape
137 138
    for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
      if (output_shape[out_idx] == 0) {
139 140 141 142
        output_shape[out_idx] = in_dims[in_idx++];
      }
    }

143
    return phi::make_ddim(output_shape);
144
  }
145 146 147 148

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
149 150 151 152
    return framework::OpKernelType(
        framework::TransToProtoVarType(
            ctx.Input<framework::LoDTensor>("X")->type()),
        ctx.device_context());
153 154 155
  }

  framework::OpKernelType GetKernelTypeForVar(
156 157
      const std::string &var_name,
      const framework::Tensor &tensor,
158 159 160 161
      const framework::OpKernelType &expected_kernel_type) const override {
    if (var_name == "AxesTensor" || var_name == "AxesTensorList") {
      return expected_kernel_type;
    }
162 163
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
164
  }
165 166 167 168 169 170
};

class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
171 172 173 174 175 176 177 178 179 180 181
    AddInput("AxesTensor",
             "(Tensor<int32>, optional). The dimensions to be inserted. "
             "If it exists, it will replace Attr(axes).")
        .AsDispensable();
    AddInput(
        "AxesTensorList",
        "(vector<Tensor<int32>>, optional). The dimensions to be inserted. "
        "If it exists, it will replace Attr(axes)."
        "The shape of the element in vector must be [1].")
        .AsDuplicable()
        .AsDispensable();
182 183
    AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
    AddAttr<std::vector<int>>("axes",
184
                              "(std::vector<int>). List of integers,"
185
                              " indicating the dimensions to be inserted")
186
        .SetDefault({})
187 188
        .AddCustomChecker([](const std::vector<int> &axes) {
          // Validity Check: axes dims (<6).
189 190
          PADDLE_ENFORCE_LT(static_cast<int>(axes.size()),
                            6,
191 192 193 194
                            platform::errors::InvalidArgument(
                                "Invalid "
                                "dimensions, dynamic dimensions should be "
                                "within [1, 6] dimensions (Eigen limit)."));
T
tianshuo78520a 已提交
195
          // Validity Check: the range of unsqueeze axis.
196
          for (int axis : axes) {
197 198
            PADDLE_ENFORCE_LT(axis,
                              6,
199 200 201 202
                              platform::errors::InvalidArgument(
                                  "Invalid "
                                  "dimensions, input axis should be"
                                  "within [1, 6] dimensions (Eigen limit)."));
203 204
          }
        });
205
    AddComment(R"DOC(
206 207
    Unsqueeze Operator.

208 209 210 211 212 213
    Insert single-dimensional entries to the shape of a tensor.
    Takes one required argument axes, a list of dimensions that will be inserted.
    Dimension indices in axes are as seen in the output tensor.

    For example:
      Given a tensor such that tensor with shape [3, 4, 5],
214
      then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
215 216 217 218
    )DOC");
  }
};

219
class UnsqueezeGradOp : public framework::OperatorWithKernel {
220
 public:
221 222 223
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
224
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
225
    ctx->ShareLoD("X", framework::GradVarName("X"));
226
  }
227 228 229 230 231 232 233 234 235 236 237 238 239 240

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
  }
};

template <typename T>
class UnsqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

241
  void Apply(GradOpPtr<T> grad_op) const override {
242 243 244 245 246 247
    grad_op->SetType("unsqueeze_grad");
    grad_op->SetInput("X", this->Input("X"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
  }
248
};
249

250 251 252 253 254 255 256 257 258 259 260 261 262
template <typename T>
class UnsqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("unsqueeze");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetAttrMap(this->Attrs());
  }
};

263 264 265 266 267
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
// unsqueeze, the XShape is used to carry the shape and lod of X which
// will be used in unsqueeze_grad, in this way, the framework can reuse
// the memory of X immediately the unsqueeze2_op is finished.
// Considering compatibility issues, we could not fix unsqueeze2_op
268
class Unsqueeze2Op : public UnsqueezeOp {
269
 public:
270
  using UnsqueezeOp::UnsqueezeOp;
271 272 273 274 275 276 277 278 279
};

class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
 public:
  void Make() override {
    UnsqueezeOpMaker::Make();
    AddOutput("XShape",
              "XShape is just used to store the shape and lod of X, which will "
              "be used in UnsqueezeGradOp.")
280 281
        .AsIntermediate()
        .AsExtra();
282 283 284
  }
};

H
hong 已提交
285 286
template <typename T>
class Unsqueeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
287
 public:
H
hong 已提交
288
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
289

290
  void Apply(GradOpPtr<T> grad_op) const override {
291
    grad_op->SetType("unsqueeze2_grad");
H
hong 已提交
292 293 294 295
    grad_op->SetInput("XShape", this->Output("XShape"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
296 297 298
  }
};

299
class Unsqueeze2GradOp : public framework::OperatorWithKernel {
300
 public:
301 302
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext *context) const override {
303
    PADDLE_ENFORCE_EQ(
304 305
        context->HasInput("XShape"),
        true,
306
        platform::errors::InvalidArgument("Input(XShape) shouldn't be null."));
307 308
    PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")),
                      true,
309 310
                      platform::errors::InvalidArgument(
                          "Input(Out@GRAD) shouldn't be null."));
311
    auto xshape_dims = context->GetInputDim("XShape");
312
    auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
313 314 315 316
    context->SetOutputDim(framework::GradVarName("X"), x_dims);
    context->ShareLoD("XShape", framework::GradVarName("X"));
  }

317 318 319
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
320 321 322
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
323 324
  }
};
325

326 327 328 329 330 331 332 333 334 335 336 337 338 339
template <typename T>
class Unsqueeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("unsqueeze2");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetOutput("XShape", this->Input("XShape"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

340 341 342 343
DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
344
DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
345 346 347
}  // namespace operators
}  // namespace paddle

348 349
DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2,
                            Unsqueeze2InferShapeFunctor,
350 351
                            PD_INFER_META(phi::UnsqueezeInferMeta));

352
namespace ops = paddle::operators;
353 354 355
REGISTER_OPERATOR(unsqueeze,
                  ops::UnsqueezeOp,
                  ops::UnsqueezeOpMaker,
356 357
                  ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>,
                  ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>);
358

359 360
REGISTER_OPERATOR(unsqueeze_grad,
                  ops::UnsqueezeGradOp,
361 362
                  ops::UnsqueezeDoubleGradOpMaker<paddle::framework::OpDesc>,
                  ops::UnsqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
363
                  ops::UnsqueezeGradOpNoNeedBufferVarInferer);
364

365 366 367
REGISTER_OPERATOR(unsqueeze2,
                  ops::Unsqueeze2Op,
                  ops::Unsqueeze2OpMaker,
H
hong 已提交
368 369
                  ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
                  ops::Unsqueeze2GradOpMaker<paddle::imperative::OpBase>,
370 371
                  Unsqueeze2InferShapeFunctor,
                  ops::UnsqueezeInplaceInferer);
372

373 374
REGISTER_OPERATOR(unsqueeze2_grad,
                  ops::Unsqueeze2GradOp,
375 376
                  ops::Unsqueeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
                  ops::Unsqueeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
377
                  ops::UnsqueezeGradInplaceInferer);
378 379

REGISTER_OP_CPU_KERNEL(
380 381
    unsqueeze,
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
382
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
383
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
384
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
385
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int16_t>,
386
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
387
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
388 389 390 391
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
                         paddle::platform::complex<float>>,
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
392 393 394
                         paddle::platform::complex<double>>,
    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
                         paddle::platform::bfloat16>);
395 396 397 398
REGISTER_OP_CPU_KERNEL(
    unsqueeze_grad,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
399
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
400
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
401
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int16_t>,
402
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
403
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
404 405 406 407
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
                             paddle::platform::complex<float>>,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
408 409 410
                             paddle::platform::complex<double>>,
    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
                             paddle::platform::bfloat16>);