slice_op.cc 19.6 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
16
#include <memory>
17
#include <string>
W
whs 已提交
18
#include <vector>
19

Z
zyfncg 已提交
20
#include "paddle/fluid/framework/op_registry.h"
X
xiaoguoguo626807 已提交
21 22 23
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
H
hong 已提交
24
#include "paddle/phi/kernels/funcs/slice_utils.h"
W
whs 已提交
25 26 27 28 29 30 31 32

namespace paddle {
namespace operators {

class SliceOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

33
  void InferShape(framework::InferShapeContext *ctx) const override {
34 35
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "slice");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "slice");
36

37
    // Case 1: Special treatment when input is a tensor array.
38 39 40
    auto x_var_type = ctx->GetInputsVarType("Input")[0];
    auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
    if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
41 42
      PADDLE_ENFORCE_EQ(axes.size(),
                        1,
43 44 45 46 47 48 49 50 51 52
                        platform::errors::InvalidArgument(
                            "The size of axes must be 1 when the Input of "
                            "SliceOp is LoDTensorArray, "
                            "but received %d.",
                            axes.size()));
      if (ctx->IsRuntime()) {
        // If the var type of input is LOD_TENSOR_ARRAY,
        // the output shape is determined by SliceKernel:Compute in runtime.
        return;
      } else {
L
liym27 已提交
53 54
        // NOTE(liym27): A better way is needed to get accurate dims of tensor
        // array.
55 56 57 58 59 60
        // The resulted dim of GetInputDim("Input") is the dim of the
        // last item written into TensorArray "Input". Maybe it's a bug to fix.
        ctx->SetOutputDim("Out", ctx->GetInputDim("Input"));
        return;
      }
    }
61 62

    // Case 2: input is a tensor.
W
whs 已提交
63
    auto in_dims = ctx->GetInputDim("Input");
64 65
    PADDLE_ENFORCE_LT(in_dims.size(),
                      7,
T
Thunderbrook 已提交
66 67
                      platform::errors::InvalidArgument(
                          "The rank of input should be less than 7."));
W
whs 已提交
68
    framework::DDim out_dims(in_dims);
69

W
whs 已提交
70 71
    auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
    auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
H
Hongyu Liu 已提交
72
    auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
73
    auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
74 75 76 77 78 79
    if (infer_flags.empty()) {
      // Initialize infer_flags with 1.
      // To be compatible with other op tests in which infer_flags is not set.
      infer_flags = std::vector<int>(axes.size(), 1);
    }

80 81 82 83
    // 2.1 Check attrs.
    auto starts_size = starts.size();
    auto ends_size = ends.size();

84
    if (ctx->HasInputs("StartsTensorList")) {
85
      starts_size = ctx->Inputs("StartsTensorList").size();
86 87
      PADDLE_ENFORCE_GT(starts_size,
                        0,
T
Thunderbrook 已提交
88 89
                        platform::errors::InvalidArgument(
                            "StartsTensorList size can't be zero"));
90 91
    }
    if (ctx->HasInputs("EndsTensorList")) {
92
      ends_size = ctx->Inputs("EndsTensorList").size();
93 94
      PADDLE_ENFORCE_GT(ends_size,
                        0,
95 96
                        platform::errors::InvalidArgument(
                            "EndsTensorList size can't be zero"));
97 98
    }

99
    if (!ctx->HasInput("StartsTensor")) {
100
      PADDLE_ENFORCE_EQ(
101 102
          starts_size,
          axes.size(),
T
Thunderbrook 已提交
103 104
          platform::errors::InvalidArgument(
              "The size of starts must be equal to the size of axes."));
105
    }
106
    if (!ctx->HasInput("EndsTensor")) {
T
Thunderbrook 已提交
107
      PADDLE_ENFORCE_EQ(
108 109
          ends_size,
          axes.size(),
T
Thunderbrook 已提交
110 111
          platform::errors::InvalidArgument(
              "The size of ends must be equal to the size of axes."));
112
    }
113 114 115 116 117
    for (auto &axis : axes) {
      if (axis < 0) {
        axis = std::max(0, axis + in_dims.size());
      }
    }
118 119
    phi::funcs::CheckAndUpdateSliceAttrs<int>(
        in_dims, axes, &starts, &ends, nullptr, &infer_flags);
H
Hongyu Liu 已提交
120

121 122
    auto slice_dims = phi::funcs::GetSliceDims<int>(
        in_dims, axes, starts, ends, nullptr, &infer_flags);
123
    if (ctx->IsRuntime()) {
124 125
      out_dims = phi::funcs::GetDecreasedDims<int>(
          slice_dims, decrease_axis, &infer_flags);
126
    } else {
H
hong 已提交
127 128
      out_dims =
          phi::funcs::GetDecreasedDims<int>(slice_dims, decrease_axis, nullptr);
H
Hongyu Liu 已提交
129
    }
130

W
whs 已提交
131
    ctx->SetOutputDim("Out", out_dims);
132
    if (axes.size() > 0 && axes[0] != 0) {
J
jerrywgz 已提交
133 134
      ctx->ShareLoD("Input", /*->*/ "Out");
    }
W
whs 已提交
135 136 137
  }

 protected:
138
  phi::KernelKey GetExpectedKernelType(
139
      const framework::ExecutionContext &ctx) const override {
140
    auto *in_var = ctx.InputVar("Input");
141 142
    if (in_var->IsType<phi::DenseTensor>()) {
      auto &in_tensor = in_var->Get<phi::DenseTensor>();
143
      PADDLE_ENFORCE_EQ(
144 145
          in_tensor.IsInitialized(),
          true,
146 147
          platform::errors::InvalidArgument(
              "The tensor Input (Input) of Slice op is not initialized."));
148 149
      // NOTE: cuda pinned tensor need to copy its data to target place
      if (platform::is_cuda_pinned_place(in_tensor.place())) {
150 151
        return phi::KernelKey(framework::TransToProtoVarType(in_tensor.dtype()),
                              ctx.GetPlace());
152
      }
153 154 155 156

#ifdef PADDLE_WITH_MKLDNN
      auto input_data_type =
          framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input");
157 158 159 160
      auto vec_dims = phi::vectorize(in_tensor.dims());
      bool all_zero_dims = std::all_of(
          vec_dims.cbegin(), vec_dims.cend(), [](int64_t i) { return i == 0; });
      if (!all_zero_dims && this->CanMKLDNNBeUsed(ctx, input_data_type)) {
161 162 163 164
        // OneDNN uses blocking format, which cannot be always supported with
        // reorders, because if blocked dimension is not divisible by 8 or
        // 16(depending on which blocking format is used) submemory cannot be
        // created, so in that scenario a fallback is needed
165 166
        if (ctx.Input<phi::DenseTensor>("Input")
                ->mem_desc()
167 168 169 170 171
                .data.format_desc.blocking.inner_nblks == 0) {
          return phi::KernelKey(phi::Backend::ONEDNN,
                                phi::DataLayout::ONEDNN,
                                phi::TransToPhiDataType(input_data_type));
        }
172 173 174
      }
#endif

175 176
      return phi::KernelKey(framework::TransToProtoVarType(in_tensor.dtype()),
                            in_tensor.place());
177
    }
178 179
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
                          ctx.GetPlace());
180
  }
181

182
  phi::KernelKey GetKernelTypeForVar(
183
      const std::string &var_name,
184
      const phi::DenseTensor &tensor,
185
      const phi::KernelKey &expected_kernel_type) const override {
186
    if (var_name == "StartsTensor" || var_name == "EndsTensor") {
187 188 189
      return phi::KernelKey(phi::Backend::ALL_BACKEND,
                            expected_kernel_type.layout(),
                            expected_kernel_type.dtype());
190 191
    }
    if (var_name == "StartsTensorList" || var_name == "EndsTensorList") {
192 193 194
      return phi::KernelKey(phi::Backend::ALL_BACKEND,
                            expected_kernel_type.layout(),
                            expected_kernel_type.dtype());
195
    }
196 197
    return phi::KernelKey(
        tensor.place(), tensor.layout(), expected_kernel_type.dtype());
W
whs 已提交
198 199 200
  }
};

201 202 203 204 205 206
class SliceOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    auto x_name = "Input";
    auto out_name = "Out";
    auto decrease_axis = ctx->GetAttr("decrease_axis");
R
Ruibiao Chen 已提交
207 208
    auto not_decrease =
        paddle::get<std::vector<int>>(decrease_axis).size() == 0;
209
    if (not_decrease) {
210
      // The default type of out is phi::DenseTensor.
211
      // However, if no axis is decreased and the type of input is not
212
      // phi::DenseTensor, the type of out should be the same as input.
213 214 215 216 217 218 219 220
      // For example, input is a LoDTensorArray and no axis is decreased, the
      // output should be a LoDTensorArray.
      ctx->SetOutputType(out_name, ctx->GetInputType(x_name));
      ctx->SetOutputDataType(out_name, ctx->GetInputDataType(x_name));
    }
  }
};

W
whs 已提交
221 222 223
class SliceOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
    AddInput("Input", "(Tensor) Tensor of data to extract slices from.");
    AddInput("StartsTensor",
             "(Tensor<int32>, optional) If provided, slice will use this."
             "It has the highest priority of StartsTensor, StartsTensorList "
             "and attr(starts).")
        .AsDispensable();
    AddInput("EndsTensor",
             "(Tensor<int32>, optional) If provided, slice will use this."
             "It has the highest priority of EndsTensor, EndsTensorList and "
             "attr(ends).")
        .AsDispensable();
    AddInput(
        "StartsTensorList",
        "(vector<Tensor<int32>>, optional) If provided, slice will use this."
        "The shape of the tensor in vector MUST BE [1]."
        "It has higher priority compare with attr(starts).")
        .AsDuplicable()
        .AsDispensable();
    AddInput(
        "EndsTensorList",
        "(vector<Tensor<int32>>, optional) If provided, slice will use this."
        "The shape of the tensor in vector MUST BE [1]."
        "It has higher priority compare with attr(ends).")
        .AsDuplicable()
        .AsDispensable();
W
whs 已提交
249 250 251 252 253 254 255
    AddOutput("Out", "Sliced data tensor.");
    AddAttr<std::vector<int>>(
        "axes",
        "(list<int>) Axes that `starts` and `ends` apply to. It's optional."
        "If not present, will be treated as [0, 1, ..., len(`starts`) - 1].");
    AddAttr<std::vector<int>>(
        "starts",
256 257 258 259 260
        "(list<int>) Starting indices of corresponding axis in `axes`")
        .SetDefault({});
    AddAttr<std::vector<int>>(
        "ends", "(list<int>) Ending indices of corresponding axis in `axes`.")
        .SetDefault({});
W
whs 已提交
261
    AddAttr<std::vector<int>>(
262 263
        "infer_flags", "(list<int>) Flags of inferring dims in attributes.")
        .SetDefault({});
H
Hongyu Liu 已提交
264 265
    AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
        .SetDefault({});
W
whs 已提交
266 267 268 269 270
    AddComment(R"DOC(
Slice Operator.

Produces a slice of the input tensor along multiple axes. Similar to numpy:
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
271
Slice uses `axes`, `starts` and `ends` attributes to specify the start and
W
whs 已提交
272
end dimension for each axis in the list of axes, it uses this information
273 274
to slice the input data tensor. If a negative value is passed for any of
the start or end indices, it represents number of elements before the end
W
whs 已提交
275
of that dimension. If the value passed to start or end is larger than
276 277
the n (the number of elements in this dimension), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended
278
to pass in INT_MAX. The size of axes must be equal to starts\' and ends\'.
279 280
Following examples will explain how slice works:

281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
.. code-block:: text

    Case1:
        Given:
            data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
            axes = [0, 1]
            starts = [1, 0]
            ends = [2, 3]
        Then:
            result = [ [5, 6, 7], ]

    Case2:
        Given:
            data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
            starts = [0, 1]
            ends = [-1, 1000]
        Then:
            result = [ [2, 3, 4], ]
W
whs 已提交
299 300 301 302
)DOC");
  }
};

303 304 305 306
class SliceOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

307
  void InferShape(framework::InferShapeContext *ctx) const override {
T
Thunderbrook 已提交
308
    PADDLE_ENFORCE_EQ(
309 310
        ctx->HasInput("Input"),
        true,
T
Thunderbrook 已提交
311
        platform::errors::InvalidArgument("Input should not be null"));
312 313
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")),
                      true,
T
Thunderbrook 已提交
314 315
                      platform::errors::InvalidArgument(
                          "Input(Out@GRAD) should not be null"));
316 317 318 319 320 321 322 323
    auto x_var_type = ctx->GetInputsVarType("Input")[0];
    if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
      // If the var type of input is LOD_TENSOR_ARRAY,
      // the output shape is determined by SliceGradKernel:Compute in runtime.
      if (ctx->IsRuntime()) {
        return;
      }
    }
324 325 326 327 328 329
    auto x_dims = ctx->GetInputDim("Input");
    auto x_grad_name = framework::GradVarName("Input");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }
  }
330

331
  phi::KernelKey GetExpectedKernelType(
332
      const framework::ExecutionContext &ctx) const override {
333 334 335 336 337 338 339 340 341
    auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));

#ifdef PADDLE_WITH_MKLDNN
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
      // OneDNN uses blocking format, which cannot be always supported with
      // reorders, because if blocked dimension is not divisible by 8 or
      // 16(depending on which blocking format is used) submemory cannot be
      // created, so in that scenario a fallback is needed
342 343
      if (ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"))
              ->mem_desc()
344 345 346 347 348
              .data.format_desc.blocking.inner_nblks == 0) {
        return phi::KernelKey(phi::Backend::ONEDNN,
                              phi::DataLayout::ONEDNN,
                              phi::TransToPhiDataType(input_data_type));
      }
349 350
    }
#endif
351
    return phi::KernelKey(input_data_type, ctx.GetPlace());
352
  }
353

354
  phi::KernelKey GetKernelTypeForVar(
355
      const std::string &var_name,
356
      const phi::DenseTensor &tensor,
357
      const phi::KernelKey &expected_kernel_type) const override {
358
    if (var_name == "StartsTensor" || var_name == "EndsTensor") {
359 360 361
      return phi::KernelKey(phi::Backend::ALL_BACKEND,
                            expected_kernel_type.layout(),
                            expected_kernel_type.dtype());
362 363
    }
    if (var_name == "StartsTensorList" || var_name == "EndsTensorList") {
364 365 366
      return phi::KernelKey(phi::Backend::ALL_BACKEND,
                            expected_kernel_type.layout(),
                            expected_kernel_type.dtype());
367
    }
368 369
    return phi::KernelKey(
        tensor.place(), tensor.layout(), expected_kernel_type.dtype());
370
  }
371 372
};

373 374 375 376 377 378 379
class SliceOpGradVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    auto x = "Input";
    auto d_out = framework::GradVarName("Out");
    auto out = framework::GradVarName("Input");
    // The types of grad_input and input should always be the same.
380 381
    // The default type of out is phi::DenseTensor, but the type of input can be
    // phi::DenseTensor or phi::DenseTensorArray,
382 383 384 385 386 387
    // so set the type of both to be the same.
    ctx->SetOutputType(out, ctx->GetInputType(x));
    ctx->SetOutputDataType(out, ctx->GetInputDataType(d_out));
  }
};

H
hong 已提交
388 389
template <typename T>
class SliceOpGradMaker : public framework::SingleGradOpMaker<T> {
390
 public:
H
hong 已提交
391
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
392 393

 protected:
394
  void Apply(GradOpPtr<T> bind) const override {
H
hong 已提交
395
    bind->SetInput("Input", this->Input("Input"));
H
hong 已提交
396 397 398 399 400 401 402 403 404 405 406 407
    if (this->HasInput("StartsTensor")) {
      bind->SetInput("StartsTensor", this->Input("StartsTensor"));
    }
    if (this->HasInput("EndsTensor")) {
      bind->SetInput("EndsTensor", this->Input("EndsTensor"));
    }
    if (this->HasInput("StartsTensorList")) {
      bind->SetInput("StartsTensorList", this->Input("StartsTensorList"));
    }
    if (this->HasInput("EndsTensorList")) {
      bind->SetInput("EndsTensorList", this->Input("EndsTensorList"));
    }
H
hong 已提交
408 409 410
    bind->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    bind->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    bind->SetAttrMap(this->Attrs());
411 412 413 414
    bind->SetType("slice_grad");
  }
};

X
xiaoguoguo626807 已提交
415 416 417 418 419 420 421 422 423 424 425
class SliceCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
  using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

 public:
  void Apply() override {
    paddle::experimental::Tensor input = this->GetSingleForwardInput("Input");
    paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
    paddle::experimental::Tensor input_grad = this->GetSingleInputGrad("Input");

    auto dx_ptr = this->GetOutputPtr(&input_grad);
    std::string dx_name = this->GetOutputName(input_grad);
426 427 428 429 430
    auto axes = this->Attr<std::vector<int>>("axes");
    auto starts = this->Attr<std::vector<int>>("starts");
    auto ends = this->Attr<std::vector<int>>("ends");
    auto infer_flags = this->Attr<std::vector<int>>("infer_flags");
    auto decrease_axis = this->Attr<std::vector<int>>("decrease_axis");
X
xiaoguoguo626807 已提交
431
    VLOG(6) << "Runing slice_grad composite func";
432 433 434 435 436 437
    std::vector<int64_t> new_axes =
        std::vector<int64_t>(axes.begin(), axes.end());
    std::vector<int64_t> new_infer_flags =
        std::vector<int64_t>(infer_flags.begin(), infer_flags.end());
    std::vector<int64_t> new_decrease_axis =
        std::vector<int64_t>(decrease_axis.begin(), decrease_axis.end());
X
xiaoguoguo626807 已提交
438 439
    prim::slice_grad<prim::DescTensor>(input,
                                       out_grad,
440
                                       new_axes,
X
xiaoguoguo626807 已提交
441 442
                                       paddle::experimental::IntArray(starts),
                                       paddle::experimental::IntArray(ends),
443 444
                                       new_infer_flags,
                                       new_decrease_axis,
X
xiaoguoguo626807 已提交
445 446 447 448
                                       dx_ptr);
    this->RecoverOutputName(input_grad, dx_name);
  }
};
449 450 451 452 453 454
template <typename T>
class SliceDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
455
  void Apply(GradOpPtr<T> bind) const override {
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
    if (this->HasInput("StartsTensor")) {
      bind->SetInput("StartsTensor", this->Input("StartsTensor"));
    }
    if (this->HasInput("EndsTensor")) {
      bind->SetInput("EndsTensor", this->Input("EndsTensor"));
    }
    if (this->HasInput("StartsTensorList")) {
      bind->SetInput("StartsTensorList", this->Input("StartsTensorList"));
    }
    if (this->HasInput("EndsTensorList")) {
      bind->SetInput("EndsTensorList", this->Input("EndsTensorList"));
    }
    bind->SetInput("Input", this->OutputGrad(framework::GradVarName("Input")));
    bind->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    bind->SetAttrMap(this->Attrs());
    bind->SetType("slice");
  }
};

475
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SliceOpGradNoNeedBufferVarsInferer,
Z
Zeng Jinle 已提交
476
                                    "Input");
477

W
whs 已提交
478 479 480 481
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
482 483 484
REGISTER_OPERATOR(slice,
                  ops::SliceOp,
                  ops::SliceOpMaker,
H
hong 已提交
485
                  ops::SliceOpGradMaker<paddle::framework::OpDesc>,
486
                  ops::SliceOpGradMaker<paddle::imperative::OpBase>,
487
                  ops::SliceCompositeGradOpMaker,
488
                  ops::SliceOpVarTypeInference);
489 490
REGISTER_OPERATOR(slice_grad,
                  ops::SliceOpGrad,
491 492
                  ops::SliceDoubleOpGradMaker<paddle::framework::OpDesc>,
                  ops::SliceDoubleOpGradMaker<paddle::imperative::OpBase>,
493
                  ops::SliceOpGradNoNeedBufferVarsInferer,
494
                  ops::SliceOpGradVarTypeInference);