deformable_conv_op.cc 15.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/operators/deformable_conv_op.h"
#include <memory>
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
#include "paddle/fluid/operators/conv_op.h"

namespace paddle {
namespace operators {
class DeformableConvOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("Input",
             "(Tensor) The input of deformable conv op. "
             "The shape of input is "
             "[N, channel_in, H, W]");
    AddInput("Offset",
             "(Tensor) The input offset. "
             "The shape of the offset is "
             "[N, deformable_groups * kernel_w * kernel_h * 2, H, W");
    AddInput("Mask",
             "(Tensor) The input mask. "
             "The shape of the mask is "
             "[N, deformable_groups * kernel_w * kernel_h, H, W].");
    AddInput("Filter",
             "(Tensor) The Input Filter "
             "The shape of the wight is "
             "[num_filters, channel_in, kernel_h, kernel_w.");
    AddOutput("Output",
              "(Tensor) The output. "
              "The shape of the output tensor is "
              "[N, num_filters, out_height, out_width]].");
    AddAttr<std::vector<int>>("strides",
                              "(vector<int> default:{1, 1}), the "
                              "strides(h_stride, w_stride) of "
                              "convolution operator.")
        .SetDefault({1, 1});
    AddAttr<std::vector<int>>("paddings",
                              "(vector<int> default:{0,0}), the "
                              "paddings(h_pad, w_pad) of "
                              "convolution operator. ")
        .SetDefault({0, 0});
    AddAttr<std::vector<int>>("dilations",
                              "(vector<int> default:{1, 1}), the "
                              "dilations(h_dilation, w_dilation) of "
                              "convolution operator.")
        .SetDefault({1, 1});
    AddAttr<int>(
        "groups",
        "(int default:1), the groups number of the convolution operator. "
        "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
        "when group=2, the first half of the filters is only connected to the "
        "first half of the input channels, while the second half of the "
        "filters "
        "is only connected to the second half of the input channels.")
        .SetDefault(1);
    AddAttr<int>("deformable_groups",
                 "(int default:1), the number of the deformable groups.")
        .SetDefault(1);
    AddAttr<int>("im2col_step",
                 "im2col maximum number of image per computation")
        .SetDefault(64);
    AddComment(R"DOC(
**Deformable Convolution Operator**

Compute 2-D deformable convolution on 4-D input.

Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:

$$
y(p) = \\sum_{k=1}^{K}{w_k * x(p + p_k + \\Delta p_k) * \\Delta m_k}
$$

Where $$\\Delta p_k$$ and $$\Delta m_k$$ are the learnable offset and modulation scalar for the k-th location, respectively.

Refer to 'Deformable ConvNets v2: More Deformable, Better Results
'<https://arxiv.org/abs/1811.11168v2>

Example:
  Input:
       Input shape: $(N, C_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
       Offset shape: $(N, 2 * deformable_groups, * H_f * W_f, H_{out}, W_{out})$
       Mask shape: $(N, deformable_groups * H_f * W_f, H_{out}, W_{out})$
  Output:
       Output shape: $(N, C_{out}, H_{out}, W_{out})$
                     where $H_{out}, W_{out}$ must be equal to $H_{in}, W_{in}$ respectively.
  Where
$$
       H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
       W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC");
  }
};

class DeformableConvOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext *ctx) const override {
112 113 114 115 116 117 118 119
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "deformable_conv");
    OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset",
                   "deformable_conv)");
    OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "deformable_conv");
    OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter",
                   "deformable_conv");
    OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output",
                   "deformable_conv");
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

    auto in_dims = ctx->GetInputDim("Input");
    auto filter_dims = ctx->GetInputDim("Filter");
    auto offset_dims = ctx->GetInputDim("Offset");
    auto mask_dims = ctx->GetInputDim("Mask");

    std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
    std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
    std::vector<int> dilations =
        ctx->Attrs().Get<std::vector<int>>("dilations");
    int groups = ctx->Attrs().Get<int>("groups");
    int deformable_groups = ctx->Attrs().Get<int>("deformable_groups");
    int im2col_step = ctx->Attrs().Get<int>("im2col_step");

    PADDLE_ENFORCE_EQ(
135 136 137 138 139 140
        in_dims.size(), 4,
        platform::errors::InvalidArgument(
            "Conv input should be 4-D tensor, get %u", in_dims.size()));
    PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(),
                      platform::errors::InvalidArgument(
                          "Conv input dimension and filter dimension should be "
141
                          "the same. The difference is [%d]: [%d]",
142
                          in_dims.size(), filter_dims.size()));
143 144 145 146 147 148
    PADDLE_ENFORCE_EQ(in_dims.size() - strides.size(), 2U,
                      platform::errors::InvalidArgument(
                          "Conv input dimension and strides "
                          "dimension should be consistent. But received input "
                          "dimension:[%d], strides dimension:[%d]",
                          in_dims.size(), strides.size()));
149
    PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
150 151
                      platform::errors::InvalidArgument(
                          "Conv paddings dimension and Conv strides dimension "
152
                          "should be the same. The difference is [%d]: [%d]",
153
                          paddings.size(), strides.size()));
154

155 156 157 158
    PADDLE_ENFORCE_EQ(
        in_dims[1], filter_dims[1] * groups,
        platform::errors::InvalidArgument(
            "The number of input channels should be equal to filter "
159
            "channels * groups. The difference is [%d]: [%d]",
160
            in_dims[1], filter_dims[1] * groups));
161 162
    PADDLE_ENFORCE_EQ(
        filter_dims[0] % groups, 0,
163
        platform::errors::InvalidArgument(
164 165 166
            "The number of output channels should be divided by groups. But "
            "received output channels:[%d], groups:[%d]",
            filter_dims[0], groups));
167 168 169 170
    PADDLE_ENFORCE_EQ(
        filter_dims[0] % deformable_groups, 0,
        platform::errors::InvalidArgument(
            "The number of output channels should be "
171
            "divided by deformable groups. The difference is [%d]: [%d]",
172
            filter_dims[0] % groups, 0));
173 174 175 176

    if (in_dims[0] > im2col_step) {
      PADDLE_ENFORCE_EQ(
          in_dims[0] % im2col_step, 0U,
177
          platform::errors::InvalidArgument(
178 179 180
              "Input batchsize must be smaller than or divide im2col_step. But "
              "received Input batchsize:[%d], im2col_step:[%d]",
              in_dims[0], im2col_step));
181 182 183
    }

    for (size_t i = 0; i < strides.size(); ++i) {
184 185
      PADDLE_ENFORCE_GT(strides[i], 0U, platform::errors::InvalidArgument(
                                            "stride %d size incorrect", i));
186 187
    }
    for (size_t i = 0; i < dilations.size(); ++i) {
188 189
      PADDLE_ENFORCE_GT(dilations[i], 0U, platform::errors::InvalidArgument(
                                              "dilation %d size incorrect", i));
190 191 192 193
    }

    std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
    for (size_t i = 0; i < strides.size(); ++i) {
194 195 196 197 198 199 200 201
      if ((!ctx->IsRuntime()) &&
          (in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
        output_shape.push_back(-1);
      } else {
        output_shape.push_back(ConvOutputSize(in_dims[i + 2],
                                              filter_dims[i + 2], dilations[i],
                                              paddings[i], strides[i]));
      }
202
    }
203

204 205 206
    PADDLE_ENFORCE_EQ(
        output_shape[1] % deformable_groups, 0U,
        platform::errors::InvalidArgument(
207 208 209
            "output num_filter must divide deformable group size. But received "
            "output num_filter:[%d], deformable group size:[%d]",
            output_shape[1], deformable_groups));
210 211 212

    if (ctx->IsRuntime()) {
      PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
213 214
                        platform::errors::InvalidArgument(
                            "output height must equal to offset map height. "
215
                            "The difference is [%d]: [%d]",
216
                            output_shape[2], offset_dims[2]));
217
      PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
218 219
                        platform::errors::InvalidArgument(
                            "output width must equal to offset map width. The "
220
                            "difference is [%d]: [%d]",
221
                            output_shape[3], offset_dims[3]));
222 223 224 225 226 227

      PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
                        platform::errors::InvalidArgument(
                            "offset filter must divide deformable group size. "
                            "But received [%d]: [%d]",
                            offset_dims[1], filter_dims[2] * filter_dims[3]));
228 229 230 231
      PADDLE_ENFORCE_EQ(
          offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
          deformable_groups,
          platform::errors::InvalidArgument(
232 233 234 235
              "offset filter must divide deformable group size. But received "
              "[%d]: [%d]",
              offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
              deformable_groups));
236
      PADDLE_ENFORCE_EQ(output_shape[2], mask_dims[2],
237 238
                        platform::errors::InvalidArgument(
                            "output height must equal to mask map height. The "
239
                            "difference is [%d] vs [%d]",
240
                            output_shape[2], mask_dims[2]));
241
      PADDLE_ENFORCE_EQ(output_shape[3], mask_dims[3],
242 243
                        platform::errors::InvalidArgument(
                            "output width must equal to mask map width. The "
244
                            "difference is [%d] vs [%d]",
245
                            output_shape[3], mask_dims[3]));
246 247

      PADDLE_ENFORCE_EQ(mask_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
248
                        platform::errors::InvalidArgument(
249 250 251
                            "mask filter must divide deformable group size. "
                            "But received [%d]: [%d]",
                            mask_dims[1], filter_dims[2] * filter_dims[3]));
252 253
      PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]),
                        deformable_groups,
254
                        platform::errors::InvalidArgument(
255 256 257 258
                            "mask filter must divide deformable group size. "
                            "But received [%d]: [%d]",
                            mask_dims[1] / (filter_dims[2] * filter_dims[3]),
                            deformable_groups));
259 260
    }

261
    ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
262 263 264 265 266
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
267 268 269
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        ctx.device_context());
270 271 272
  }
};

H
hong 已提交
273 274
template <typename T>
class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> {
275
 public:
H
hong 已提交
276
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
277 278

 protected:
279
  void Apply(GradOpPtr<T> op) const override {
280
    op->SetType("deformable_conv_grad");
H
hong 已提交
281 282 283 284 285 286 287 288 289 290 291 292
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput("Offset", this->Input("Offset"));
    op->SetInput("Mask", this->Input("Mask"));
    op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));

    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
    op->SetOutput(framework::GradVarName("Offset"), this->InputGrad("Offset"));
    op->SetOutput(framework::GradVarName("Mask"), this->InputGrad("Mask"));

    op->SetAttrMap(this->Attrs());
293 294 295 296 297 298 299 300 301 302 303 304 305
  }
};

class DeformableConvGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    auto in_dims = ctx->GetInputDim("Input");
    auto filter_dims = ctx->GetInputDim("Filter");
    auto offset_dims = ctx->GetInputDim("Offset");
    auto mask_dims = ctx->GetInputDim("Mask");

306 307
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Output")), "Input",
                   "Output@Grad", "deformable_conv_grad");
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
    if (ctx->HasOutput(framework::GradVarName("Input"))) {
      ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
    }
    if (ctx->HasOutput(framework::GradVarName("Filter"))) {
      ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
    }
    if (ctx->HasOutput(framework::GradVarName("Offset"))) {
      ctx->SetOutputDim(framework::GradVarName("Offset"), offset_dims);
    }
    if (ctx->HasOutput(framework::GradVarName("Mask"))) {
      ctx->SetOutputDim(framework::GradVarName("Mask"), mask_dims);
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
325 326 327
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        ctx.device_context());
328 329 330 331 332 333 334 335
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
                  ops::DeformableConvOpMaker,
H
hong 已提交
336 337 338
                  ops::DeformableConvGradOpMaker<paddle::framework::OpDesc>,
                  ops::DeformableConvGradOpMaker<paddle::imperative::OpBase>);

339
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
340 341 342 343

REGISTER_OP_CPU_KERNEL(deformable_conv_grad,
                       ops::DeformableConvGradCPUKernel<float>,
                       ops::DeformableConvGradCPUKernel<double>);