strided_slice_op.cc 9.8 KB
Newer Older
W
wangchaochaohu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include <memory>
17
#include <string>
W
wangchaochaohu 已提交
18
#include <vector>
19 20 21 22 23 24

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
W
wangchaochaohu 已提交
25 26 27 28

namespace paddle {
namespace operators {

29
using Tensor = phi::DenseTensor;
W
wangchaochaohu 已提交
30 31 32 33 34 35 36

class StridedSliceOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  framework::OpKernelType GetExpectedKernelType(
37
      const framework::ExecutionContext &ctx) const override {
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    auto *in_var = ctx.InputVar("Input");
    auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
    if (is_in_var_array) {
      auto &tensor_array = in_var->Get<framework::LoDTensorArray>();
      for (auto &tensor : tensor_array) {
        if (!platform::is_cuda_pinned_place(tensor.place())) {
          PADDLE_ENFORCE_EQ(
              platform::is_same_place(tensor.place(),
                                      ctx.device_context().GetPlace()),
              true,
              platform::errors::InvalidArgument(
                  "Place of context is %s. Place of input tensor is %s. They "
                  "are should be same, but reveived different place.",
                  string::to_string(ctx.device_context().GetPlace()),
                  string::to_string(tensor.place())));
        }
      }
      return framework::OpKernelType(
          OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
          ctx.device_context());
    }
59
    // NOTE: cuda pinned tensor need to copy its data to target place
60
    auto in_tensor = ctx.Input<phi::DenseTensor>("Input");
61
    if (platform::is_cuda_pinned_place(in_tensor->place())) {
62 63 64
      return framework::OpKernelType(
          framework::TransToProtoVarType(in_tensor->dtype()),
          ctx.device_context());
65
    }
66 67
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
68
        in_tensor->place());
W
wangchaochaohu 已提交
69
  }
70
  framework::OpKernelType GetKernelTypeForVar(
71 72
      const std::string &var_name,
      const Tensor &tensor,
73 74 75 76 77 78 79 80 81
      const framework::OpKernelType &expected_kernel_type) const override {
    if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
        var_name == "StridesTensor") {
      return expected_kernel_type;
    }
    if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
        var_name == "StridesTensorList") {
      return expected_kernel_type;
    }
82 83
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
84
  }
W
wangchaochaohu 已提交
85 86
};

87 88 89 90 91 92 93 94
class StridedSliceOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SetOutputType("Out", ctx->GetInputType("Input"));
    ctx->SetOutputDataType("Out", ctx->GetInputDataType("Input"));
  }
};

W
wangchaochaohu 已提交
95 96 97 98
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("Input", "Tensor of data to extract slices from.");
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    AddOutput("Out", "Strided Sliced data tensor.");

    AddInput("StartsTensor",
             "(Tensor<int32>, optional) If provided, slice will use this."
             "It has the highest priority of StartsTensor, StartsTensorList "
             "and attr(starts).")
        .AsDispensable();
    AddInput("EndsTensor",
             "(Tensor<int32>, optional) If provided, slice will use this."
             "It has the highest priority of EndsTensor, EndsTensorList and "
             "attr(ends).")
        .AsDispensable();
    AddInput(
        "StridesTensor",
        "(Tensor<int32>, optional) If provided, slice will use this."
        "It has the highest priority of StridesTensor, StridesTensorList and "
        "attr(ends).")
        .AsDispensable();
    AddInput(
        "StartsTensorList",
        "(vector<Tensor<int32>>, optional) If provided, slice will use this."
        "The shape of the tensor in vector MUST BE [1]."
        "It has higher priority compare with attr(starts).")
        .AsDuplicable()
        .AsDispensable();
    AddInput(
        "EndsTensorList",
        "(vector<Tensor<int32>>, optional) If provided, slice will use this."
        "The shape of the tensor in vector MUST BE [1]."
        "It has higher priority compare with attr(ends).")
        .AsDuplicable()
        .AsDispensable();
    AddInput(
        "StridesTensorList",
        "(vector<Tensor<int32>>, optional) If provided, slice will use this."
        "The shape of the tensor in vector MUST BE [1]."
        "It has higher priority compare with attr(strides).")
        .AsDuplicable()
        .AsDispensable();
W
wangchaochaohu 已提交
138
    AddAttr<std::vector<int>>(
139
        "axes", "(list<int>) Axes that `starts` and `ends` apply to.");
W
wangchaochaohu 已提交
140
    AddAttr<std::vector<int>>(
141 142
        "starts", "(list<int>) Start indices for the strided slice start.")
        .SetDefault({});
W
wangchaochaohu 已提交
143
    AddAttr<std::vector<int>>("ends",
144 145
                              "(list<int>) End indices the tensor slice end")
        .SetDefault({});
W
wangchaochaohu 已提交
146
    AddAttr<std::vector<int>>(
147 148 149 150 151
        "strides", "(list<int> Stride step from the start to the end)")
        .SetDefault({});
    AddAttr<std::vector<int>>(
        "infer_flags", "(list<int>) Flags of inferring dims in attributes.")
        .SetDefault({});
152 153
    AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
        .SetDefault({});
W
wangchaochaohu 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    AddComment(R"DOC(
Strided Slice Operator.
Instead of calling this op directly most users will want to use the
NumPy-style slicing syntax.
For Example:
data = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='int64')
y = fluid.layers.strided_slice(data, [0, 1], [1,0], [2, 3], [1, 1])
)DOC");
  }
};

class StridedSliceOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  framework::OpKernelType GetExpectedKernelType(
170
      const framework::ExecutionContext &ctx) const override {
171 172 173
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
W
wangchaochaohu 已提交
174
  }
175
  framework::OpKernelType GetKernelTypeForVar(
176 177
      const std::string &var_name,
      const Tensor &tensor,
178
      const framework::OpKernelType &expected_kernel_type) const override {
179 180
    if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
        var_name == "StridesTensor") {
181 182
      return expected_kernel_type;
    }
183 184
    if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
        var_name == "StridesTensorList") {
185 186
      return expected_kernel_type;
    }
187 188
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
189
  }
W
wangchaochaohu 已提交
190 191
};

H
hong 已提交
192 193
template <typename T>
class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> {
W
wangchaochaohu 已提交
194
 public:
H
hong 已提交
195
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
W
wangchaochaohu 已提交
196 197

 protected:
198
  void Apply(GradOpPtr<T> bind) const override {
H
hong 已提交
199 200 201 202 203 204 205 206 207 208
    bind->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    bind->SetInput("Input", this->Input("Input"));
    bind->SetInput("StartsTensor", this->Input("StartsTensor"));
    bind->SetInput("EndsTensor", this->Input("EndsTensor"));
    bind->SetInput("StridesTensor", this->Input("StridesTensor"));
    bind->SetInput("StartsTensorList", this->Input("StartsTensorList"));
    bind->SetInput("EndsTensorList", this->Input("EndsTensorList"));
    bind->SetInput("StridesTensorList", this->Input("StridesTensorList"));
    bind->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    bind->SetAttrMap(this->Attrs());
W
wangchaochaohu 已提交
209 210 211
    bind->SetType("strided_slice_grad");
  }
};
212 213 214 215 216 217 218 219 220 221
class StridedSliceGradOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SetOutputType(framework::GradVarName("Input"),
                       ctx->GetInputType(framework::GradVarName("Out")));
    ctx->SetOutputDataType(
        framework::GradVarName("Input"),
        ctx->GetInputDataType(framework::GradVarName("Out")));
  }
};
W
wangchaochaohu 已提交
222

223
DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
Z
Zeng Jinle 已提交
224
                                    "Input");
W
wangchaochaohu 已提交
225 226 227 228 229

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
230

231 232
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice,
                            StridedSliceInferShape,
233
                            PD_INFER_META(phi::StridedSliceRawInferMeta));
234

235 236 237
REGISTER_OPERATOR(strided_slice,
                  ops::StridedSliceOp,
                  ops::StridedSliceOpMaker,
H
hong 已提交
238
                  ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
239
                  ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
240 241
                  ops::StridedSliceOpVarTypeInference,
                  StridedSliceInferShape);
242

243 244
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice_grad,
                            StridedSliceGradInferShape,
245
                            PD_INFER_META(phi::GeneralUnaryGradInferMeta));
246

247 248
REGISTER_OPERATOR(strided_slice_grad,
                  ops::StridedSliceOpGrad,
249
                  ops::StridedSliceOpGradNoNeedBufferVarsInferer,
250 251
                  ops::StridedSliceGradOpVarTypeInference,
                  StridedSliceGradInferShape);