crop_tensor_op.cc 12.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/crop_tensor_op.h"
#include <memory>
#include <string>
#include <vector>

namespace paddle {
namespace operators {

using framework::Tensor;

class CropTensorOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      "Input(X) of Op(crop_tensor) should not be null.");
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      "Output(Out) of Op(crop_tensor) should not be null.");
34
    auto x_dim = ctx->GetInputDim("X");
35
    auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
36
    auto offsets = ctx->Attrs().Get<std::vector<int>>("offsets");
37 38 39 40 41 42 43 44 45 46
    if (ctx->HasInputs("ShapeTensor")) {
      // top prority shape
      auto inputs_name = ctx->Inputs("ShapeTensor");
      PADDLE_ENFORCE_GT(
          inputs_name.size(), 0,
          "Input(ShapeTensor)'size of Op(crop_tensor) can't be zero. "
          "Please check the Attr(shape)'s size of "
          "Op(fluid.layers.crop_tensor).");
      auto out_dims = std::vector<int>(inputs_name.size(), -1);
      for (size_t i = 0; i < shape.size(); ++i) {
47
        if (shape[i] > 0) {
48
          out_dims[i] = static_cast<int64_t>(shape[i]);
49 50 51 52
        } else {
          if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
            out_dims[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
          }
53 54 55 56 57 58
        }
      }
      ctx->SetOutputDim("Out", framework::make_ddim(out_dims));

      return;
    }
59

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
    if (ctx->HasInput("Shape")) {
      auto shape_dim = ctx->GetInputDim("Shape");
      PADDLE_ENFORCE_EQ(
          shape_dim.size(), 1,
          "Input(Shape)'s dimension size of Op(crop_tensor) must be 1. "
          "Please check the Attr(shape)'s dimension size of "
          "Op(fluid.layers.crop_tensor).");
      PADDLE_ENFORCE_EQ(shape_dim[0], x_dim.size(),
                        "Input(Shape)'s size of Op(crop_tensor) must be equal "
                        "to dimension size of input tensor. "
                        "Please check the Attr(shape)'s size of "
                        "Op(fluid.layers.crop_tensor).");
      if (ctx->IsRuntime()) {
        // If true, set the shape of Output(Out) according to Input(Shape) in
        // CropTensorKernel with ExecutionContext. Also check LoD in
        // CropTensorKernel.
        ctx->ShareLoD("X", /*->*/ "Out");
      } else {
        auto out_dims = std::vector<int>(shape_dim[0], -1);
        ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
      }
      return;
    }
    PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(),
                      "Attr(shape)'size of Op(crop_tensor) should be equal to "
                      "dimention size of input tensor.");
86
    std::vector<int64_t> out_shape(shape.size(), -1);
87
    for (size_t i = 0; i < shape.size(); ++i) {
88 89 90 91 92 93 94
      if (shape[i] > 0) {
        out_shape[i] = static_cast<int64_t>(shape[i]);
      } else {
        if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
          out_shape[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
        }
      }
95
    }
96
    ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
97 98 99 100
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
101 102 103
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string &var_name, const Tensor &tensor,
      const framework::OpKernelType &expected_kernel_type) const override {
    if (var_name == "ShapeTensor" || var_name == "OffsetsTensor" ||
        var_name == "Shape" || var_name == "Offsets") {
      return expected_kernel_type;
    }

    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

class CropTensorOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "The input of pad op. "
             "The input should be a k-D tensor(k > 0 and k < 7).");
    AddInput("Shape",
             "The input used to describe shape of output, which is a "
             "1-D vector whose size equals to the rank of input 'X'. The "
             "elements data type must be int. It has a higher priority than "
             "the shape attribute")
        .AsDispensable();
    AddInput("Offsets",
             "The input used to describe offsets in runtime, which is a "
             "1-D vector whose size equals to the rank of input 'X'. The "
             "elements data type must be int. It has a higher priority than "
             "the offsets attribute")
        .AsDispensable();
    AddInput("ShapeTensor",
             "(vector<Tensor<int32>>, optional). If provided, crop_tensor will "
             "use this. The shape of the tensor in vector MUST BE [1]. "
             "It has the highest priority compare with Input(Shape) and "
             "attr(shape).")
        .AsDuplicable()
        .AsDispensable();
    AddInput("OffsetsTensor",
             "(vector<Tensor<int32>>, optional). If provided, crop_tensor will "
             "use this. The shape of the tensor in vector MUST BE [1]. "
             "It has the highest priority compare with Input(Offsets) and "
             "attr(offsets).")
        .AsDuplicable()
        .AsDispensable();
    AddOutput("Out",
              "The output of crop_tensor op, "
              "which is of the same dimensions as X.");
    AddAttr<std::vector<int>>("offsets",
                              "A list<int> describing offsets to be cropped. "
                              "The size of offsets list should be the same as "
                              "the dimension size of input X.")
        .SetDefault(std::vector<int>());
    AddAttr<std::vector<int>>("shape",
                              "A list<int> describing the shape of output. "
                              "The size of shape list should be the same as "
                              "the dimension size of input X.")
        .SetDefault(std::vector<int>());
    AddComment(R"DOC(
CropTensor Operator.

Crop input into output, as specified by offsets and shape.

There are three ways to set the offsets:
1. Input 'OffsetsTensor: It is a tensor list. It should be set as a list that 
                         contains tensor variable in python configure script. 
                         This way is suitable for dynamic offsets.
2. Input 'Offsets': It is a variable and can be output of other operators. 
                    This way is suitable for dynamic offsets.
3. Attribute 'offsets': It will be set in python configure script. This way 
                        is suitable for fixed offsets.

You CANNOT use these three ways at the same time. An exception will be raised 
if input 'OffsetsTensor' or 'Offset' is configured and meanwhile the attribute 'offsets' is 
not empty.

There are three ways to set shape:
1. Input 'ShapeTensor': It is a tensor list. It should be set as a list that contains
                        tensor variable in python configure script. This way is suitable 
                        for dynamic shape.
2. Input 'Shape': It is a Variable and can be output of other operators. This way is suitable 
                  for dynamic shape.
2. Attribute 'shape': crop input X into the shape described by a list<int>. The size of shape 
                      list should be the same as the dimension size of input X. This way is 
                      suitable for fixed shape.

The input should be a k-D tensor(k > 0 and k < 7). As an example:

Case 1:
Given

    X = [[0, 1, 2, 0, 0]
         [0, 3, 4, 0, 0]
         [0, 0, 0, 0, 0]],

and

    offsets = [0, 1],

and

    shape = [2, 2],

we get:

    Out = [[1, 2],
           [3, 4]].


Case 2:
Given

    X = [[0, 1, 2, 5, 0]
         [0, 3, 4, 6, 0]
         [0, 0, 0, 0, 0]],

and offsets is a list that contains tensor variable,
in runtime offses_var' s value is 1.

    offsets = [0, offsets_var],

and shape is a list that contains tensor variable,
in runtime dim's value is 2.

    shape = [dim, 3]

we get:

    Out = [[1, 2, 5],
           [3, 4, 6]].
)DOC");
  }
};

class CropTensorOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      "Input(X) of Op(crop_tensor) should not be null.");
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
                      "Input(Out@GRAD) of Op(crop_tensor) should not be null.");
    auto x_dims = ctx->GetInputDim("X");
    auto x_grad_name = framework::GradVarName("X");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
258 259 260
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string &var_name, const Tensor &tensor,
      const framework::OpKernelType &expected_kernel_type) const override {
    if (var_name == "ShapeTensor" || var_name == "OffsetsTensor" ||
        var_name == "Shape" || var_name == "Offsets") {
      return expected_kernel_type;
    }

    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

H
hong 已提交
276 277
template <typename T>
class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
278
 public:
H
hong 已提交
279
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
280 281

 protected:
H
hong 已提交
282 283
  std::unique_ptr<T> Apply() const override {
    std::unique_ptr<T> op(new T());
284
    op->SetType("crop_tensor_grad");
H
hong 已提交
285 286 287 288
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetInput("X", this->Input("X"));
    if (this->HasInput("OffsetsTensor")) {
      op->SetInput("OffsetsTensor", this->Input("OffsetsTensor"));
289
    }
H
hong 已提交
290 291
    if (this->HasInput("Offsets")) {
      op->SetInput("Offsets", this->Input("Offsets"));
292
    }
H
hong 已提交
293 294
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
295 296 297 298 299 300 301 302 303
    return op;
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(crop_tensor, ops::CropTensorOp, ops::CropTensorOpMaker,
H
hong 已提交
304 305
                  ops::CropTensorGradOpMaker<paddle::framework::OpDesc>,
                  ops::CropTensorGradOpMaker<paddle::imperative::OpBase>);
306 307 308 309
REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad);
REGISTER_OP_CPU_KERNEL(
    crop_tensor,
    ops::CropTensorKernel<paddle::platform::CPUDeviceContext, float>,
310 311 312
    ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>,
    ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int>,
    ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int64_t>);
313 314 315
REGISTER_OP_CPU_KERNEL(
    crop_tensor_grad,
    ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, float>,
316 317 318
    ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>,
    ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int>,
    ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int64_t>);