squeeze_op.cc 15.9 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/squeeze_op.h"
L
Leo Chen 已提交
16

17
#include <memory>
18
#include <string>
19
#include <unordered_map>
20
#include <vector>
L
Leo Chen 已提交
21

22
#include "paddle/fluid/framework/infershape_utils.h"
Y
yuyang18 已提交
23
#include "paddle/fluid/framework/op_registry.h"
24
#include "paddle/phi/infermeta/unary.h"
25 26 27 28

namespace paddle {
namespace operators {

L
Leo Chen 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
                               const framework::DDim &in_dims,
                               bool is_runtime) {
  size_t num_squeeze_dims = squeeze_dims.size();
  std::vector<bool> should_squeeze(in_dims.size(), false);

  // Mark dimensions need to be squeezed.
  if (num_squeeze_dims == 0) {
    for (int i = 0; i < in_dims.size(); ++i) {
      if (in_dims[i] == 1) {
        should_squeeze[i] = true;
      }
    }
  } else {
    for (size_t i = 0; i < num_squeeze_dims; ++i) {
      int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims.size()
                                        : squeeze_dims[i];

      PADDLE_ENFORCE_GE(
48 49
          current,
          0,
L
Leo Chen 已提交
50 51 52
          platform::errors::InvalidArgument(
              "Each axis in Attr(axes) should be in the range of [%d, %d]"
              "But current axis is:%d, input tensor's shape = [%s].",
53 54 55 56
              -in_dims.size(),
              in_dims.size() - 1,
              current,
              in_dims));
L
Leo Chen 已提交
57
      PADDLE_ENFORCE_LT(
58 59
          current,
          in_dims.size(),
L
Leo Chen 已提交
60 61 62
          platform::errors::InvalidArgument(
              "Each axis in Attr(axes) should be in the range of [%d, %d]"
              "But current axis is:%d, input tensor's shape = [%s].",
63 64 65 66
              -in_dims.size(),
              in_dims.size() - 1,
              current,
              in_dims));
L
Leo Chen 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

      if (!should_squeeze[current]) {
        if (is_runtime) {
          // At run time, dim of 1 is allowed to squeeze
          if (in_dims[current] == 1) {
            should_squeeze[current] = true;
          }
        } else {
          // At compile time, dim of -1 or 1 is allowed to squeeze
          if (in_dims[current] == 1 || in_dims[current] == -1) {
            should_squeeze[current] = true;
          }
        }
      }
    }
  }
  // Make output dimensions
  std::vector<int64_t> output_shape;
  for (int i = 0; i < in_dims.size(); ++i) {
    if (!should_squeeze[i]) {
      output_shape.push_back(in_dims[i]);
    }
  }
90
  return phi::make_ddim(output_shape);
L
Leo Chen 已提交
91 92
}

93
class SqueezeOp : public framework::OperatorWithKernel {
94
 public:
95 96 97
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
98 99
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Squeeze");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Squeeze");
100

Y
yuyang18 已提交
101
    const auto &x_dims = ctx->GetInputDim("X");
102
    // Check input tensor dims (<6) Eigen limit.
103 104
    PADDLE_ENFORCE_LE(x_dims.size(),
                      6,
105 106 107 108
                      platform::errors::InvalidArgument(
                          "The dimensions of Input(X) "
                          "should be in the range of [1, 6] (Eigen limit)."
                          "But received X's dimensions = %d, X's shape=[%s].",
109 110
                          x_dims.size(),
                          x_dims));
111

Y
yuyang18 已提交
112
    const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
L
Leo Chen 已提交
113
    auto out_dims = GetOutputShape(axes, x_dims, false);
114
    ctx->SetOutputDim("Out", out_dims);
115 116 117 118 119
    if (x_dims[0] == out_dims[0]) {
      // Only pass LoD when the first dimension of output and Input(X)
      // are the same.
      ctx->ShareLoD("X", "Out");
    }
120 121
  }

122 123 124
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
125 126 127
    auto input_data_type =
        framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

128
    // #ifdef PADDLE_WITH_MKLDNN
129 130 131 132 133
    //    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
    //      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
    //                                     framework::DataLayout::kMKLDNN,
    //                                     framework::LibraryType::kMKLDNN);
    //    }
134
    // #endif
135
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
136
  }
137 138
};

139
class SqueezeGradOp : public framework::OperatorWithKernel {
Y
yuyang18 已提交
140
 public:
141 142 143 144 145 146 147 148 149 150 151
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *context) const override {
    context->SetOutputDim(framework::GradVarName("X"),
                          context->GetInputDim("X"));
    context->ShareLoD("X", framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
152 153 154
    auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));

155
    // #ifdef PADDLE_WITH_MKLDNN
156 157 158 159 160
    //    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
    //      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
    //                                     framework::DataLayout::kMKLDNN,
    //                                     framework::LibraryType::kMKLDNN);
    //    }
161
    // #endif
162
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
Y
yuyang18 已提交
163 164 165
  }
};

166 167 168
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
169 170
    AddInput("X", "(Tensor). The input tensor of squeeze operator.");
    AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
171
    AddAttr<std::vector<int>>("axes",
172
                              "(std::vector<int>). List of integers,"
173
                              " indicating the dimensions to squeeze.")
174 175
        .SetDefault({})
        .SupportTensor();
176 177
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
178 179
        .SetDefault(false)
        .AsExtra();
180 181 182 183
    AddAttr<std::string>(
        "mkldnn_data_type",
        "(string, default \"float32\"). Data type of mkldnn kernel")
        .SetDefault("float32")
184 185
        .InEnum({"float32", "bfloat16"})
        .AsExtra();
186
    AddComment(R"DOC(
Y
yuyang18 已提交
187
        Squeeze Operator.
188 189 190 191

        Remove single-dimensional entries from the shape of a tensor.
        Takes a parameter axes with a list of axes to squeeze.
        If axes is not provided, all the single dimensions will be removed from the shape.
192
        If an axis is selected with shape entry not equal to one, an error is raised.
193

Y
yuyang18 已提交
194 195
        Examples:
        Case 1:
196
          Given
Y
yuyang18 已提交
197 198 199 200 201 202 203 204 205
            X.shape = (1, 3, 1, 5)
          and
            axes = [0]
          we get:
            Out.shape = (3, 1, 5)

        Case 2:
          Given
            X.shape = (1, 3, 1, 5)
206
          and
207
            axes = []
Y
yuyang18 已提交
208 209
          we get:
            Out.shape = (3, 5)
210 211 212 213
    )DOC");
  }
};

214
class Squeeze2Op : public framework::OperatorWithKernel {
215
 public:
216
  using framework::OperatorWithKernel::OperatorWithKernel;
217 218 219 220 221
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    auto input_data_type =
        framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

222
    // #ifdef PADDLE_WITH_MKLDNN
223 224 225 226 227
    //    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
    //      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
    //                                     framework::DataLayout::kMKLDNN,
    //                                     framework::LibraryType::kMKLDNN);
    //    }
228
    // #endif
229 230
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
Y
yuyang18 已提交
231
};
232

233 234 235 236 237
template <typename T>
class SqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

238
  void Apply(GradOpPtr<T> grad_op) const override {
239 240 241 242 243 244 245 246
    grad_op->SetType("squeeze_grad");
    grad_op->SetInput("X", this->Input("X"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

247
class Squeeze2GradOp : public framework::OperatorWithKernel {
Y
yuyang18 已提交
248
 public:
249 250 251
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *context) const override {
252 253 254 255 256
    OP_INOUT_CHECK(
        context->HasInput("XShape"), "Input", "XShape", "Squeeze2Grad");
    OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")),
                   "Input",
                   framework::GradVarName("Out"),
257
                   "Squeeze2Grad");
258
    auto xshape_dims = context->GetInputDim("XShape");
259
    auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
260 261 262 263 264 265 266
    context->SetOutputDim(framework::GradVarName("X"), x_dims);
    context->ShareLoD("XShape", framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
267 268 269
    auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));

270
    // #ifdef PADDLE_WITH_MKLDNN
271 272 273 274 275
    //    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
    //      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
    //                                     framework::DataLayout::kMKLDNN,
    //                                     framework::LibraryType::kMKLDNN);
    //    }
276
    // #endif
277
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
278 279 280
  }
};

281 282 283 284 285 286 287 288 289 290 291 292 293
template <typename T>
class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("squeeze");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetAttrMap(this->Attrs());
  }
};

294 295 296 297 298
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
// the XShape is used to carry the shape and lod of X which will be used in
// squeeze_grad, in this way, the framework can reuse the memory of X
// immediately the squeeze2_op is finished.
// Considering compatibility issues, we could not fix squeeze2_op
299
class Squeeze2OpMaker : public framework::OpProtoAndCheckerMaker {
300 301
 public:
  void Make() override {
302 303
    AddInput("X", "(Tensor). The input tensor of squeeze operator.");
    AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
304 305 306
    AddOutput("XShape",
              "XShape is just used to store the shape and lod of X, which will "
              "be used in SqueezeGradOp.")
C
ceci3 已提交
307 308
        .AsIntermediate()
        .AsExtra();
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
    AddAttr<std::vector<int>>("axes",
                              "(std::vector<int>). List of integers,"
                              " indicating the dimensions to squeeze.")
        .SetDefault({})
        .SupportTensor();
    AddComment(R"DOC(
        Squeeze2 Operator.

        Remove single-dimensional entries from the shape of a tensor.
        Takes a parameter axes with a list of axes to squeeze.
        If axes is not provided, all the single dimensions will be removed from the shape.
        If an axis is selected with shape entry not equal to one, an error is raised.

        Examples:
        Case 1:
          Given
            X.shape = (1, 3, 1, 5)
          and
            axes = [0]
          we get:
            Out.shape = (3, 1, 5)

        Case 2:
          Given
            X.shape = (1, 3, 1, 5)
          and
            axes = []
          we get:
            Out.shape = (3, 5)
    )DOC");
339 340 341
  }
};

H
hong 已提交
342 343
template <typename T>
class Squeeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
344
 public:
H
hong 已提交
345
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
346

347
  void Apply(GradOpPtr<T> grad_op) const override {
348
    grad_op->SetType("squeeze2_grad");
H
hong 已提交
349 350 351 352
    grad_op->SetInput("XShape", this->Output("XShape"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
353 354 355
  }
};

356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
template <typename T>
class Squeeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("squeeze2");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetOutput("XShape", this->Input("XShape"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

DECLARE_INPLACE_OP_INFERER(SqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SqueezeGradInplaceInferer,
372 373
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
374
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X");
375 376 377 378
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
379

380 381
DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
                            SqueezeInferShapeFunctor,
382
                            PD_INFER_META(phi::SqueezeWithXShapeInferMeta));
383

384 385 386
REGISTER_OPERATOR(squeeze,
                  ops::SqueezeOp,
                  ops::SqueezeOpMaker,
387 388
                  ops::SqueezeGradOpMaker<paddle::framework::OpDesc>,
                  ops::SqueezeGradOpMaker<paddle::imperative::OpBase>);
389 390
REGISTER_OPERATOR(squeeze_grad,
                  ops::SqueezeGradOp,
391 392
                  ops::SqueezeDoubleGradOpMaker<paddle::framework::OpDesc>,
                  ops::SqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
393
                  ops::SqueezeGradNoNeedBufferVarsInferer);
394

395 396 397
REGISTER_OPERATOR(squeeze2,
                  ops::Squeeze2Op,
                  ops::Squeeze2OpMaker,
H
hong 已提交
398 399
                  ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>,
                  ops::Squeeze2GradOpMaker<paddle::imperative::OpBase>,
400 401 402 403
                  ops::SqueezeInplaceInferer,
                  SqueezeInferShapeFunctor);
REGISTER_OPERATOR(squeeze2_grad,
                  ops::Squeeze2GradOp,
404 405 406
                  ops::Squeeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
                  ops::Squeeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
                  ops::SqueezeGradInplaceInferer);
407 408

REGISTER_OP_CPU_KERNEL(
409
    squeeze,
L
Leo Chen 已提交
410 411 412 413 414 415 416 417 418 419
    ops::SqueezeKernel<phi::CPUContext, float>,
    ops::SqueezeKernel<phi::CPUContext, double>,
    ops::SqueezeKernel<phi::CPUContext, bool>,
    ops::SqueezeKernel<phi::CPUContext, int>,
    ops::SqueezeKernel<phi::CPUContext, uint8_t>,
    ops::SqueezeKernel<phi::CPUContext, int8_t>,
    ops::SqueezeKernel<phi::CPUContext, int64_t>,
    ops::SqueezeKernel<phi::CPUContext, paddle::platform::complex<float>>,
    ops::SqueezeKernel<phi::CPUContext, paddle::platform::complex<double>>,
    ops::SqueezeKernel<phi::CPUContext, paddle::platform::bfloat16>);
420 421
REGISTER_OP_CPU_KERNEL(
    squeeze_grad,
L
Leo Chen 已提交
422 423 424 425 426 427 428 429 430 431
    ops::SqueezeGradKernel<phi::CPUContext, float>,
    ops::SqueezeGradKernel<phi::CPUContext, double>,
    ops::SqueezeGradKernel<phi::CPUContext, bool>,
    ops::SqueezeGradKernel<phi::CPUContext, int>,
    ops::SqueezeGradKernel<phi::CPUContext, uint8_t>,
    ops::SqueezeGradKernel<phi::CPUContext, int8_t>,
    ops::SqueezeGradKernel<phi::CPUContext, int64_t>,
    ops::SqueezeGradKernel<phi::CPUContext, paddle::platform::complex<float>>,
    ops::SqueezeGradKernel<phi::CPUContext, paddle::platform::complex<double>>,
    ops::SqueezeGradKernel<phi::CPUContext, paddle::platform::bfloat16>);