block_expand_op.cc 5.3 KB
Newer Older
G
gongweibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/block_expand_op.h"

namespace paddle {
namespace operators {

class BlockExpandOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
G
gongweibao 已提交
26
    printf("op infershape\n");
G
gongweibao 已提交
27
    using namespace framework;
G
gongweibao 已提交
28
    PADDLE_ENFORCE(ctx->HasInput("X"),
G
gongweibao 已提交
29 30
                   "Input of BlockExpandOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
G
gongweibao 已提交
31
                   "Output of BlockExpandOp op should not be null.");
G
gongweibao 已提交
32

G
gongweibao 已提交
33
    auto in_dim = ctx->GetInputDim("X");
G
gongweibao 已提交
34 35 36
    PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input format  must be NCHW.");
    PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1.");

G
gongweibao 已提交
37
    printf("op infershape2\n");
G
gongweibao 已提交
38 39 40 41 42 43
    int block_height = ctx->Attrs().Get<int>("blockHeight");
    int block_width = ctx->Attrs().Get<int>("blockWidth");
    int stride_height = ctx->Attrs().Get<int>("strideHeight");
    int stride_width = ctx->Attrs().Get<int>("strideWidth");
    int padding_height = ctx->Attrs().Get<int>("paddingHeight");
    int padding_width = ctx->Attrs().Get<int>("paddingWidth");
G
gongweibao 已提交
44 45 46

    int N = in_dim[0];
    int C = in_dim[1];
G
gongweibao 已提交
47 48
    int img_height = in_dim[2];
    int img_width = in_dim[3];
G
gongweibao 已提交
49

G
gongweibao 已提交
50 51
    int output_height = 0;
    int output_width = 0;
G
gongweibao 已提交
52

G
gongweibao 已提交
53 54 55 56
    get_blockexpand_output_shape(img_height, img_width, block_height,
                                 block_width, stride_height, stride_width,
                                 padding_height, padding_width, output_height,
                                 output_width);
G
gongweibao 已提交
57

G
gongweibao 已提交
58
    // The result of im2col is [output_height, output_width,
G
gongweibao 已提交
59 60
    // inputChannels, filterHeight, filterWidth], and it is easy to
    // reshape into [seqLength, stepSize], where seqLength is equal
G
gongweibao 已提交
61
    // output_height * output_width, stepSize is equal
G
gongweibao 已提交
62
    // input_channels * blockHeight * blockWidth
G
gongweibao 已提交
63 64
    printf("N:%d, o_h:%d o_w:%d C:%d b_h:%d b_w:%d\n", N, output_height,
           output_width, C, block_height, block_width);
G
gongweibao 已提交
65
    ctx->SetOutputDim(
G
gongweibao 已提交
66
        "Out", {N, output_height, output_width, C, block_height, block_width});
G
gongweibao 已提交
67 68

    // ctx->ShareLoD("X", /*->*/ "Out");
G
gongweibao 已提交
69 70 71 72 73 74 75 76
  }
};

class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  BlockExpandOpMaker(framework::OpProto* proto,
                     framework::OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
G
gongweibao 已提交
77 78 79 80 81 82 83
    AddInput("X", R"DOC(
(Tensor)The input tensor has NCHW format.
    N: batch size
    C: channels
    H: height
    W: width
)DOC");
G
gongweibao 已提交
84
    printf("opmakeer\n");
G
gongweibao 已提交
85 86 87 88 89 90 91
    AddOutput("Out", "(LodTensor)The output data of block_expand op,");
    AddAttr<int>("blockHeight", "(int)height of block.");
    AddAttr<int>("blockWidth", "(int)width of block.");
    AddAttr<int>("strideHeight", "(int)height of stride.");
    AddAttr<int>("strideWidth", "(int)width of stride.");
    AddAttr<int>("paddingHeight", "(int)height of padding.");
    AddAttr<int>("paddingWidth", "(int)width of padding.");
G
gongweibao 已提交
92 93
    AddComment(R"DOC(
Expand feature map to minibatch matrix.
G
gongweibao 已提交
94
- matirx height is: output_height * output_width
G
gongweibao 已提交
95
- matrix width is: blockHeight * blockWidth * channels
G
gongweibao 已提交
96

G
gongweibao 已提交
97 98
output_height = 
    1 + (2 * paddingHeight + img_height - blockHeight + strideHeight - 1) /
G
gongweibao 已提交
99
            strideHeight;
G
gongweibao 已提交
100 101
output_width = 
    1 + (2 * paddingWidth + img_width - blockWidth + strideWidth - 1) /
G
gongweibao 已提交
102
            strideWidth;
G
gongweibao 已提交
103 104

The expand method is the same with ExpandConvLayer, but saved the transposed
G
gongweibao 已提交
105
value. After expanding, The number of time steps are output_height * output_width
G
gongweibao 已提交
106 107
and the dimension of each time step is blockHeight * blockWidth * channels.
This layer can be used after convolution neural network, and before recurrent neural network.
G
gongweibao 已提交
108 109 110 111 112 113 114 115 116
)DOC");
  }
};

class BlockExpandGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
G
add gpu  
gongweibao 已提交
117 118 119 120 121 122 123 124 125 126 127 128
  void InferShape(framework::InferShapeContext* ctx) const override {
    using namespace framework;
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output of BlockExpandOp op should not be null.");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null");

    auto in_dim = ctx->GetInputDim("X");

    ctx->SetOutputDim(GradVarName("Out"), in_dim);
  }
G
gongweibao 已提交
129 130 131 132 133 134 135
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker,
G
add gpu  
gongweibao 已提交
136
            block_expand_grad, ops::BlockExpandGradOp);
G
gongweibao 已提交
137
REGISTER_OP_CPU_KERNEL(
G
gongweibao 已提交
138
    block_expand, ops::BlockExpandKernel<paddle::platform::CPUPlace, float>);
G
gongweibao 已提交
139 140 141
REGISTER_OP_CPU_KERNEL(
    block_expand_grad,
    ops::BlockExpandGradKernel<paddle::platform::CPUPlace, float>);