strided_slice_op.cc 9.7 KB
Newer Older
W
wangchaochaohu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include <memory>
17
#include <string>
W
wangchaochaohu 已提交
18
#include <vector>
19 20 21

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
22
#include "paddle/fluid/operators/slice_op.h"
23 24 25
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
W
wangchaochaohu 已提交
26 27 28 29 30 31 32 33 34 35 36 37

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class StridedSliceOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  framework::OpKernelType GetExpectedKernelType(
38
      const framework::ExecutionContext &ctx) const override {
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    auto *in_var = ctx.InputVar("Input");
    auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
    if (is_in_var_array) {
      auto &tensor_array = in_var->Get<framework::LoDTensorArray>();
      for (auto &tensor : tensor_array) {
        if (!platform::is_cuda_pinned_place(tensor.place())) {
          PADDLE_ENFORCE_EQ(
              platform::is_same_place(tensor.place(),
                                      ctx.device_context().GetPlace()),
              true,
              platform::errors::InvalidArgument(
                  "Place of context is %s. Place of input tensor is %s. They "
                  "are should be same, but reveived different place.",
                  string::to_string(ctx.device_context().GetPlace()),
                  string::to_string(tensor.place())));
        }
      }
      return framework::OpKernelType(
          OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
          ctx.device_context());
    }
60 61 62
    // NOTE: cuda pinned tensor need to copy its data to target place
    auto in_tensor = ctx.Input<Tensor>("Input");
    if (platform::is_cuda_pinned_place(in_tensor->place())) {
63 64 65
      return framework::OpKernelType(
          framework::TransToProtoVarType(in_tensor->dtype()),
          ctx.device_context());
66
    }
67 68
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
69
        in_tensor->place());
W
wangchaochaohu 已提交
70
  }
71 72 73 74 75 76 77 78 79 80 81 82 83 84
  framework::OpKernelType GetKernelTypeForVar(
      const std::string &var_name, const Tensor &tensor,
      const framework::OpKernelType &expected_kernel_type) const override {
    if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
        var_name == "StridesTensor") {
      return expected_kernel_type;
    }
    if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
        var_name == "StridesTensorList") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
W
wangchaochaohu 已提交
85 86
};

87 88 89 90 91 92 93 94
class StridedSliceOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SetOutputType("Out", ctx->GetInputType("Input"));
    ctx->SetOutputDataType("Out", ctx->GetInputDataType("Input"));
  }
};

W
wangchaochaohu 已提交
95 96 97 98
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("Input", "Tensor of data to extract slices from.");
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    AddOutput("Out", "Strided Sliced data tensor.");

    AddInput("StartsTensor",
             "(Tensor<int32>, optional) If provided, slice will use this."
             "It has the highest priority of StartsTensor, StartsTensorList "
             "and attr(starts).")
        .AsDispensable();
    AddInput("EndsTensor",
             "(Tensor<int32>, optional) If provided, slice will use this."
             "It has the highest priority of EndsTensor, EndsTensorList and "
             "attr(ends).")
        .AsDispensable();
    AddInput(
        "StridesTensor",
        "(Tensor<int32>, optional) If provided, slice will use this."
        "It has the highest priority of StridesTensor, StridesTensorList and "
        "attr(ends).")
        .AsDispensable();
    AddInput(
        "StartsTensorList",
        "(vector<Tensor<int32>>, optional) If provided, slice will use this."
        "The shape of the tensor in vector MUST BE [1]."
        "It has higher priority compare with attr(starts).")
        .AsDuplicable()
        .AsDispensable();
    AddInput(
        "EndsTensorList",
        "(vector<Tensor<int32>>, optional) If provided, slice will use this."
        "The shape of the tensor in vector MUST BE [1]."
        "It has higher priority compare with attr(ends).")
        .AsDuplicable()
        .AsDispensable();
    AddInput(
        "StridesTensorList",
        "(vector<Tensor<int32>>, optional) If provided, slice will use this."
        "The shape of the tensor in vector MUST BE [1]."
        "It has higher priority compare with attr(strides).")
        .AsDuplicable()
        .AsDispensable();
W
wangchaochaohu 已提交
138
    AddAttr<std::vector<int>>(
139
        "axes", "(list<int>) Axes that `starts` and `ends` apply to.");
W
wangchaochaohu 已提交
140
    AddAttr<std::vector<int>>(
141 142
        "starts", "(list<int>) Start indices for the strided slice start.")
        .SetDefault({});
W
wangchaochaohu 已提交
143
    AddAttr<std::vector<int>>("ends",
144 145
                              "(list<int>) End indices the tensor slice end")
        .SetDefault({});
W
wangchaochaohu 已提交
146
    AddAttr<std::vector<int>>(
147 148 149 150 151
        "strides", "(list<int> Stride step from the start to the end)")
        .SetDefault({});
    AddAttr<std::vector<int>>(
        "infer_flags", "(list<int>) Flags of inferring dims in attributes.")
        .SetDefault({});
152 153
    AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
        .SetDefault({});
W
wangchaochaohu 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    AddComment(R"DOC(
Strided Slice Operator.
Instead of calling this op directly most users will want to use the
NumPy-style slicing syntax.
For Example:
data = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='int64')
y = fluid.layers.strided_slice(data, [0, 1], [1,0], [2, 3], [1, 1])
)DOC");
  }
};

class StridedSliceOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  framework::OpKernelType GetExpectedKernelType(
170
      const framework::ExecutionContext &ctx) const override {
171 172 173
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
W
wangchaochaohu 已提交
174
  }
175 176 177
  framework::OpKernelType GetKernelTypeForVar(
      const std::string &var_name, const Tensor &tensor,
      const framework::OpKernelType &expected_kernel_type) const override {
178 179
    if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
        var_name == "StridesTensor") {
180 181
      return expected_kernel_type;
    }
182 183
    if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
        var_name == "StridesTensorList") {
184 185 186 187 188
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
W
wangchaochaohu 已提交
189 190
};

H
hong 已提交
191 192
template <typename T>
class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> {
W
wangchaochaohu 已提交
193
 public:
H
hong 已提交
194
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
W
wangchaochaohu 已提交
195 196

 protected:
197
  void Apply(GradOpPtr<T> bind) const override {
H
hong 已提交
198 199 200 201 202 203 204 205 206 207
    bind->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    bind->SetInput("Input", this->Input("Input"));
    bind->SetInput("StartsTensor", this->Input("StartsTensor"));
    bind->SetInput("EndsTensor", this->Input("EndsTensor"));
    bind->SetInput("StridesTensor", this->Input("StridesTensor"));
    bind->SetInput("StartsTensorList", this->Input("StartsTensorList"));
    bind->SetInput("EndsTensorList", this->Input("EndsTensorList"));
    bind->SetInput("StridesTensorList", this->Input("StridesTensorList"));
    bind->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    bind->SetAttrMap(this->Attrs());
W
wangchaochaohu 已提交
208 209 210
    bind->SetType("strided_slice_grad");
  }
};
211 212 213 214 215 216 217 218 219 220
class StridedSliceGradOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SetOutputType(framework::GradVarName("Input"),
                       ctx->GetInputType(framework::GradVarName("Out")));
    ctx->SetOutputDataType(
        framework::GradVarName("Input"),
        ctx->GetInputDataType(framework::GradVarName("Out")));
  }
};
W
wangchaochaohu 已提交
221

222
DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
Z
Zeng Jinle 已提交
223
                                    "Input");
W
wangchaochaohu 已提交
224 225 226 227 228

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
229 230

DECLARE_INFER_SHAPE_FUNCTOR(strided_slice, StridedSliceInferShape,
231
                            PD_INFER_META(phi::StridedSliceRawInferMeta));
232

W
wangchaochaohu 已提交
233
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
H
hong 已提交
234
                  ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
235
                  ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
236 237 238 239
                  ops::StridedSliceOpVarTypeInference, StridedSliceInferShape);

DECLARE_INFER_SHAPE_FUNCTOR(strided_slice_grad, StridedSliceGradInferShape,
                            PD_INFER_META(phi::GeneralUnaryGradInferMeta));
240

W
wangchaochaohu 已提交
241
REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
242
                  ops::StridedSliceOpGradNoNeedBufferVarsInferer,
243 244
                  ops::StridedSliceGradOpVarTypeInference,
                  StridedSliceGradInferShape);