squeeze_op.cc 15.3 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/squeeze_op.h"
L
Leo Chen 已提交
16

17
#include <memory>
18
#include <string>
19
#include <unordered_map>
20
#include <vector>
L
Leo Chen 已提交
21

22
#include "paddle/fluid/framework/infershape_utils.h"
Y
yuyang18 已提交
23
#include "paddle/fluid/framework/op_registry.h"
24
#include "paddle/phi/infermeta/unary.h"
25 26 27 28

namespace paddle {
namespace operators {

L
Leo Chen 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
                               const framework::DDim &in_dims,
                               bool is_runtime) {
  size_t num_squeeze_dims = squeeze_dims.size();
  std::vector<bool> should_squeeze(in_dims.size(), false);

  // Mark dimensions need to be squeezed.
  if (num_squeeze_dims == 0) {
    for (int i = 0; i < in_dims.size(); ++i) {
      if (in_dims[i] == 1) {
        should_squeeze[i] = true;
      }
    }
  } else {
    for (size_t i = 0; i < num_squeeze_dims; ++i) {
      int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims.size()
                                        : squeeze_dims[i];

      PADDLE_ENFORCE_GE(
48 49
          current,
          0,
L
Leo Chen 已提交
50 51 52
          platform::errors::InvalidArgument(
              "Each axis in Attr(axes) should be in the range of [%d, %d]"
              "But current axis is:%d, input tensor's shape = [%s].",
53 54 55 56
              -in_dims.size(),
              in_dims.size() - 1,
              current,
              in_dims));
L
Leo Chen 已提交
57
      PADDLE_ENFORCE_LT(
58 59
          current,
          in_dims.size(),
L
Leo Chen 已提交
60 61 62
          platform::errors::InvalidArgument(
              "Each axis in Attr(axes) should be in the range of [%d, %d]"
              "But current axis is:%d, input tensor's shape = [%s].",
63 64 65 66
              -in_dims.size(),
              in_dims.size() - 1,
              current,
              in_dims));
L
Leo Chen 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

      if (!should_squeeze[current]) {
        if (is_runtime) {
          // At run time, dim of 1 is allowed to squeeze
          if (in_dims[current] == 1) {
            should_squeeze[current] = true;
          }
        } else {
          // At compile time, dim of -1 or 1 is allowed to squeeze
          if (in_dims[current] == 1 || in_dims[current] == -1) {
            should_squeeze[current] = true;
          }
        }
      }
    }
  }
  // Make output dimensions
  std::vector<int64_t> output_shape;
  for (int i = 0; i < in_dims.size(); ++i) {
    if (!should_squeeze[i]) {
      output_shape.push_back(in_dims[i]);
    }
  }
90
  return phi::make_ddim(output_shape);
L
Leo Chen 已提交
91 92
}

93
class SqueezeOp : public framework::OperatorWithKernel {
94
 public:
95 96 97
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
98 99
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Squeeze");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Squeeze");
100

Y
yuyang18 已提交
101
    const auto &x_dims = ctx->GetInputDim("X");
102
    // Check input tensor dims (<6) Eigen limit.
103 104
    PADDLE_ENFORCE_LE(x_dims.size(),
                      6,
105 106 107 108
                      platform::errors::InvalidArgument(
                          "The dimensions of Input(X) "
                          "should be in the range of [1, 6] (Eigen limit)."
                          "But received X's dimensions = %d, X's shape=[%s].",
109 110
                          x_dims.size(),
                          x_dims));
111

Y
yuyang18 已提交
112
    const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
L
Leo Chen 已提交
113
    auto out_dims = GetOutputShape(axes, x_dims, false);
114
    ctx->SetOutputDim("Out", out_dims);
115 116 117 118 119
    if (x_dims[0] == out_dims[0]) {
      // Only pass LoD when the first dimension of output and Input(X)
      // are the same.
      ctx->ShareLoD("X", "Out");
    }
120 121
  }

122 123 124
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
125 126 127
    auto input_data_type =
        framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

128
    // #ifdef PADDLE_WITH_MKLDNN
129 130 131 132 133
    //    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
    //      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
    //                                     framework::DataLayout::kMKLDNN,
    //                                     framework::LibraryType::kMKLDNN);
    //    }
134
    // #endif
135
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
136
  }
137 138
};

139
class SqueezeGradOp : public framework::OperatorWithKernel {
Y
yuyang18 已提交
140
 public:
141 142 143 144 145 146 147 148 149 150 151
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *context) const override {
    context->SetOutputDim(framework::GradVarName("X"),
                          context->GetInputDim("X"));
    context->ShareLoD("X", framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
152 153 154
    auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));

155
    // #ifdef PADDLE_WITH_MKLDNN
156 157 158 159 160
    //    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
    //      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
    //                                     framework::DataLayout::kMKLDNN,
    //                                     framework::LibraryType::kMKLDNN);
    //    }
161
    // #endif
162
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
Y
yuyang18 已提交
163 164 165
  }
};

166 167 168
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
169 170
    AddInput("X", "(Tensor). The input tensor of squeeze operator.");
    AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
171
    AddAttr<std::vector<int>>("axes",
172
                              "(std::vector<int>). List of integers,"
173
                              " indicating the dimensions to squeeze.")
174
        .SetDefault({});
175 176
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
177 178
        .SetDefault(false)
        .AsExtra();
179 180 181 182
    AddAttr<std::string>(
        "mkldnn_data_type",
        "(string, default \"float32\"). Data type of mkldnn kernel")
        .SetDefault("float32")
183 184
        .InEnum({"float32", "bfloat16"})
        .AsExtra();
185
    AddComment(R"DOC(
Y
yuyang18 已提交
186
        Squeeze Operator.
187 188 189 190

        Remove single-dimensional entries from the shape of a tensor.
        Takes a parameter axes with a list of axes to squeeze.
        If axes is not provided, all the single dimensions will be removed from the shape.
191
        If an axis is selected with shape entry not equal to one, an error is raised.
192

Y
yuyang18 已提交
193 194
        Examples:
        Case 1:
195
          Given
Y
yuyang18 已提交
196 197 198 199 200 201 202 203 204
            X.shape = (1, 3, 1, 5)
          and
            axes = [0]
          we get:
            Out.shape = (3, 1, 5)

        Case 2:
          Given
            X.shape = (1, 3, 1, 5)
205
          and
206
            axes = []
Y
yuyang18 已提交
207 208
          we get:
            Out.shape = (3, 5)
209 210 211 212
    )DOC");
  }
};

213
class Squeeze2Op : public framework::OperatorWithKernel {
214
 public:
215
  using framework::OperatorWithKernel::OperatorWithKernel;
216 217 218 219 220
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    auto input_data_type =
        framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

221
    // #ifdef PADDLE_WITH_MKLDNN
222 223 224 225 226
    //    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
    //      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
    //                                     framework::DataLayout::kMKLDNN,
    //                                     framework::LibraryType::kMKLDNN);
    //    }
227
    // #endif
228 229
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
Y
yuyang18 已提交
230
};
231

232 233 234 235 236
template <typename T>
class SqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

237
  void Apply(GradOpPtr<T> grad_op) const override {
238 239 240 241 242 243 244 245
    grad_op->SetType("squeeze_grad");
    grad_op->SetInput("X", this->Input("X"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

246
class Squeeze2GradOp : public framework::OperatorWithKernel {
Y
yuyang18 已提交
247
 public:
248 249 250
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *context) const override {
251 252 253 254 255
    OP_INOUT_CHECK(
        context->HasInput("XShape"), "Input", "XShape", "Squeeze2Grad");
    OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")),
                   "Input",
                   framework::GradVarName("Out"),
256
                   "Squeeze2Grad");
257
    auto xshape_dims = context->GetInputDim("XShape");
258
    auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
259 260 261 262 263 264 265
    context->SetOutputDim(framework::GradVarName("X"), x_dims);
    context->ShareLoD("XShape", framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
266 267 268
    auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));

269
    // #ifdef PADDLE_WITH_MKLDNN
270 271 272 273 274
    //    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
    //      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
    //                                     framework::DataLayout::kMKLDNN,
    //                                     framework::LibraryType::kMKLDNN);
    //    }
275
    // #endif
276
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
277 278 279
  }
};

280 281 282 283 284 285 286 287 288 289 290 291 292
template <typename T>
class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("squeeze");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetAttrMap(this->Attrs());
  }
};

293 294 295 296 297 298 299 300 301 302 303 304
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
// the XShape is used to carry the shape and lod of X which will be used in
// squeeze_grad, in this way, the framework can reuse the memory of X
// immediately the squeeze2_op is finished.
// Considering compatibility issues, we could not fix squeeze2_op
class Squeeze2OpMaker : public SqueezeOpMaker {
 public:
  void Make() override {
    SqueezeOpMaker::Make();
    AddOutput("XShape",
              "XShape is just used to store the shape and lod of X, which will "
              "be used in SqueezeGradOp.")
C
ceci3 已提交
305 306
        .AsIntermediate()
        .AsExtra();
307 308 309
  }
};

H
hong 已提交
310 311
template <typename T>
class Squeeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
312
 public:
H
hong 已提交
313
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
314

315
  void Apply(GradOpPtr<T> grad_op) const override {
316
    grad_op->SetType("squeeze2_grad");
H
hong 已提交
317 318 319 320
    grad_op->SetInput("XShape", this->Output("XShape"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
321 322 323
  }
};

324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
template <typename T>
class Squeeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("squeeze2");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetOutput("XShape", this->Input("XShape"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

DECLARE_INPLACE_OP_INFERER(SqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SqueezeGradInplaceInferer,
340 341
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
342
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X");
343 344 345 346
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
347

348 349
DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
                            SqueezeInferShapeFunctor,
350 351
                            PD_INFER_META(phi::SqueezeInferMeta));

352 353 354
REGISTER_OPERATOR(squeeze,
                  ops::SqueezeOp,
                  ops::SqueezeOpMaker,
355 356
                  ops::SqueezeGradOpMaker<paddle::framework::OpDesc>,
                  ops::SqueezeGradOpMaker<paddle::imperative::OpBase>);
357 358
REGISTER_OPERATOR(squeeze_grad,
                  ops::SqueezeGradOp,
359 360
                  ops::SqueezeDoubleGradOpMaker<paddle::framework::OpDesc>,
                  ops::SqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
361
                  ops::SqueezeGradNoNeedBufferVarsInferer);
362

363 364 365
REGISTER_OPERATOR(squeeze2,
                  ops::Squeeze2Op,
                  ops::Squeeze2OpMaker,
H
hong 已提交
366 367
                  ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>,
                  ops::Squeeze2GradOpMaker<paddle::imperative::OpBase>,
368 369 370 371
                  ops::SqueezeInplaceInferer,
                  SqueezeInferShapeFunctor);
REGISTER_OPERATOR(squeeze2_grad,
                  ops::Squeeze2GradOp,
372 373 374
                  ops::Squeeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
                  ops::Squeeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
                  ops::SqueezeGradInplaceInferer);
375 376

REGISTER_OP_CPU_KERNEL(
377 378
    squeeze,
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext, float>,
379
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext, double>,
380
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
381
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>,
382
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
383
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
384 385 386 387
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
                       paddle::platform::complex<float>>,
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
388 389 390
                       paddle::platform::complex<double>>,
    ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
                       paddle::platform::bfloat16>);
391 392 393 394
REGISTER_OP_CPU_KERNEL(
    squeeze_grad,
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
395
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
396
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
397
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
398
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
399 400 401 402
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
                           paddle::platform::complex<float>>,
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
403 404 405
                           paddle::platform::complex<double>>,
    ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
                           paddle::platform::bfloat16>);