im2sequence_op.cc 6.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/im2sequence_op.h"
S
sneaxiy 已提交
16
#include <memory>
17
#include <string>
18
#include <vector>
G
gongweibao 已提交
19 20 21 22

namespace paddle {
namespace operators {

23
class Im2SequenceOp : public framework::OperatorWithKernel {
G
gongweibao 已提交
24 25 26 27 28
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
29 30 31 32 33 34
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      platform::errors::NotFound(
                          "The input 'X' of Im2SequenceOp is not found."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      platform::errors::NotFound(
                          "The output 'Out' of Im2SequenceOp is not found."));
G
gongweibao 已提交
35
    auto in_dim = ctx->GetInputDim("X");
36

37 38 39 40 41 42
    PADDLE_ENFORCE_EQ(
        in_dim.size(), 4,
        platform::errors::InvalidArgument(
            "The dimesions size of input 'X' in Im2SequenceOp should be 4. But "
            "received dimesions size=[%d], dimesions=[%s].",
            in_dim.size(), in_dim));
L
liuwei1031 已提交
43
    auto img_channels = in_dim[1];
G
gongweibao 已提交
44

45 46 47
    auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
    auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
    auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
48 49 50 51 52 53
    if (!ctx->IsRuntime()) {
      // set lod level for compile-time
      framework::VarDesc* out_desc =
          boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
      out_desc->SetLoDLevel(1);
    }
54

W
whs 已提交
55 56
    ctx->SetOutputDim("Out",
                      {in_dim[0], img_channels * kernels[0] * kernels[1]});
G
gongweibao 已提交
57 58 59
  }
};

60
class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
G
gongweibao 已提交
61
 public:
Y
Yu Yang 已提交
62
  void Make() override {
W
wanghaoshuang 已提交
63
    AddInput("X",
W
wanghaoshuang 已提交
64
             "(Tensor) The input tensor has NCHW format."
W
wanghaoshuang 已提交
65 66 67 68
             "N: batch size"
             "C: channels"
             "H: height"
             "W: width");
69 70 71 72
    AddInput("Y",
             "(Tensor) The input tensor of image real size(H, W)."
             "2-D with shape [batchsize, 2]")
        .AsDispensable();
W
wanghaoshuang 已提交
73
    AddOutput("Out", "(LodTensor) The output data of im2sequence op,");
W
wanghaoshuang 已提交
74 75
    AddAttr<std::vector<int>>("kernels",
                              "(vector<int>), the "
W
wanghaoshuang 已提交
76 77 78 79 80
                              "kernels(kernel_height, kernel_width)");
    AddAttr<std::vector<int>>("strides",
                              "(vector<int> default:{1, 1}), the "
                              "strides(h_stride, w_stride)")
        .SetDefault({1, 1});
W
wanghaoshuang 已提交
81 82 83 84
    AddAttr<std::vector<int>>("paddings",
                              "(vector<int> default:{0, 0, 0, 0}), the "
                              "paddings(up_pad, left_pad, down_pad, right_pad)")
        .SetDefault({0, 0, 0, 0});
85 86 87 88 89 90 91
    AddAttr<std::vector<int>>("out_stride",
                              "the attribute is valid only when input(Y)"
                              "is not NULL.this attribute represents the"
                              "scaling of the pic through the CNN"
                              "(vector<int> dedault:{1,1}),the out_stride"
                              " (out_stride_height, out_stride_width)")
        .SetDefault({1, 1});
G
gongweibao 已提交
92
    AddComment(R"DOC(
W
wanghaoshuang 已提交
93 94 95 96
This op uses kernels to scan images and converts these images to sequences.
After expanding, The number of time steps are output_height * output_width
and the dimension of each time step is kernel_height * kernel_width * channels,
in which:
W
wanghaoshuang 已提交
97 98

output_height =
W
wanghaoshuang 已提交
99
    1 + (padding_height + padding_down + img_height - kernel_height + stride_height - 1) /
W
wanghaoshuang 已提交
100 101
            stride_height;
output_width =
W
wanghaoshuang 已提交
102
    1 + (padding_left + padding+right + img_width - kernel_width + stride_width - 1) /
W
wanghaoshuang 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            stride_width;

This op can be used after convolution neural network, and before recurrent neural network.

Given:

x = [[[[ 6.  2.  1.]
       [ 8.  3.  5.]
       [ 0.  2.  6.]]

      [[ 2.  4.  4.]
       [ 6.  3.  0.]
       [ 6.  4.  7.]]]

     [[[ 6.  7.  1.]
       [ 5.  7.  9.]
       [ 2.  4.  8.]]

      [[ 1.  2.  1.]
       [ 1.  3.  5.]
       [ 9.  0.  8.]]]]
x.dims = {2, 2, 3, 3}

And:

W
wanghaoshuang 已提交
128 129 130
kernels = [2, 2]
strides = [1, 1]
paddings = [0, 0, 0, 0]
W
wanghaoshuang 已提交
131 132 133 134 135 136 137 138 139 140 141

Then:

output.data = [[ 6.  2.  8.  3.  2.  4.  6.  3.]
               [ 2.  1.  3.  5.  4.  4.  3.  0.]
               [ 8.  3.  0.  2.  6.  3.  6.  4.]
               [ 3.  5.  2.  6.  3.  0.  4.  7.]
               [ 6.  7.  5.  7.  1.  2.  1.  3.]
               [ 7.  1.  7.  9.  2.  1.  3.  5.]
               [ 5.  7.  2.  4.  1.  3.  9.  0.]
               [ 7.  9.  4.  8.  3.  5.  0.  8.]]
142
output.dims = {8, 8}
W
wanghaoshuang 已提交
143 144
output.lod = [[0, 4, 8]]

G
gongweibao 已提交
145 146 147 148
)DOC");
  }
};

149
class Im2SequenceGradOp : public framework::OperatorWithKernel {
G
gongweibao 已提交
150 151 152 153
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
G
add gpu  
gongweibao 已提交
154
  void InferShape(framework::InferShapeContext* ctx) const override {
155 156 157 158 159 160 161
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      platform::errors::NotFound(
                          "The input 'X' of Im2SequenceGradOp is not found."));
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
                      platform::errors::NotFound(
                          "The input %s of Im2SequenceGradOp is not found.",
                          framework::GradVarName("Out")));
G
gongweibao 已提交
162
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
G
add gpu  
gongweibao 已提交
163
  }
G
gongweibao 已提交
164 165
};

H
hong 已提交
166 167
template <typename T>
class Im2SequenceGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
168
 public:
H
hong 已提交
169
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
170 171

 protected:
172
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
173
    op->SetType("im2sequence_grad");
H
hong 已提交
174 175 176 177
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
178 179 180
  }
};

G
gongweibao 已提交
181 182 183 184
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
185
REGISTER_OPERATOR(im2sequence, ops::Im2SequenceOp, ops::Im2SequenceOpMaker,
H
hong 已提交
186 187
                  ops::Im2SequenceGradMaker<paddle::framework::OpDesc>,
                  ops::Im2SequenceGradMaker<paddle::imperative::OpBase>);
188
REGISTER_OPERATOR(im2sequence_grad, ops::Im2SequenceGradOp);
G
gongweibao 已提交
189
REGISTER_OP_CPU_KERNEL(
190 191
    im2sequence,
    ops::Im2SequenceKernel<paddle::platform::CPUDeviceContext, float>);
G
gongweibao 已提交
192
REGISTER_OP_CPU_KERNEL(
193 194
    im2sequence_grad,
    ops::Im2SequenceGradKernel<paddle::platform::CPUDeviceContext, float>);