deformable_conv_op.h 22.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
//
// Copyright (c) 2017 Microsoft
// Licensed under The Apache-2.0 License [see LICENSE for details]
// \file deformable_psroi_pooling.cu
// \brief
// \author Yi Li, Guodong Zhang, Jifeng Dai

#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/deformable_conv_func.h"
29 30
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using CPUDeviceContext = platform::CPUDeviceContext;

template <typename T>
void ModulatedDeformableCol2imCPUKernel(
    const int num_kernels, const T* data_col, const T* data_offset,
    const T* data_mask, const int channels, const int height, const int width,
    const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
    const int stride_h, const int stride_w, const int dilation_h,
    const int dilation_w, const int channel_per_deformable_group,
    const int batch_size, const int deformable_group, const int height_col,
    const int width_col, T* grad_im) {
47
  for (int thread = 0; thread < num_kernels; thread++) {
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    const int j = (thread / width_col / height_col / batch_size) % kernel_w;
    const int i =
        (thread / width_col / height_col / batch_size / kernel_w) % kernel_h;
    const int c =
        thread / width_col / height_col / batch_size / kernel_w / kernel_h;

    const int deformable_group_index = c / channel_per_deformable_group;

    int w_out = thread % width_col;
    int h_out = (thread / width_col) % height_col;
    int b = (thread / width_col / height_col) % batch_size;
    int w_in = w_out * stride_w - pad_w;
    int h_in = h_out * stride_h - pad_h;

    const T* data_offset_ptr = data_offset +
                               (b * deformable_group + deformable_group_index) *
                                   2 * kernel_h * kernel_w * height_col *
                                   width_col;
    const T* data_mask_ptr = data_mask +
                             (b * deformable_group + deformable_group_index) *
                                 kernel_h * kernel_w * height_col * width_col;
    const int data_offset_h_ptr =
        ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
    const int data_offset_w_ptr =
        ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
    const int data_mask_hw_ptr =
        ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
    const T offset_h = data_offset_ptr[data_offset_h_ptr];
    const T offset_w = data_offset_ptr[data_offset_w_ptr];
    const T mask = data_mask_ptr[data_mask_hw_ptr];
    const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
    const T cur_inv_w_data = w_in + j * dilation_w + offset_w;

    const T cur_top_grad = data_col[thread] * mask;
    const int cur_h = static_cast<int>(cur_inv_h_data);
    const int cur_w = static_cast<int>(cur_inv_w_data);
    for (int dy = -2; dy <= 2; dy++) {
      for (int dx = -2; dx <= 2; dx++) {
        if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
            cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
            abs(cur_inv_w_data - (cur_w + dx)) < 1) {
          int cur_bottom_grad_pos =
              ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
          T weight =
              DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy,
                                    cur_w + dx, height, width);

          *(grad_im + cur_bottom_grad_pos) =
              *(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad;
        }
      }
    }
  }
}

template <typename T>
static inline void ModulatedDeformableCol2imCPU(
    const platform::CPUDeviceContext& ctx, const T* data_col,
    const T* data_offset, const T* data_mask,
    const std::vector<int64_t> im_shape, const std::vector<int64_t> col_shape,
    const std::vector<int64_t> kernel_shape, const std::vector<int> pad,
    const std::vector<int> stride, const std::vector<int> dilation,
    const int deformable_group, T* grad_im) {
  int channel_per_deformable_group = im_shape[0] / deformable_group;
  int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];

  ModulatedDeformableCol2imCPUKernel(
      num_kernels, data_col, data_offset, data_mask, im_shape[0], im_shape[1],
      im_shape[2], kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0],
      stride[1], dilation[0], dilation[1], channel_per_deformable_group,
      col_shape[1], deformable_group, col_shape[2], col_shape[3], grad_im);
}

template <typename T>
void ModulatedDeformableCol2imCoordCPUKernel(
    const int num_kernels, const T* data_col, const T* data_im,
    const T* data_offset, const T* data_mask, const int channels,
    const int height, const int width, const int kernel_h, const int kernel_w,
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int channel_per_deformable_group, const int batch_size,
    const int offset_channels, const int deformable_group, const int height_col,
    const int width_col, T* grad_offset, T* grad_mask) {
131
  for (int i = 0; i < num_kernels; i++) {
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    T val = 0, mval = 0;
    const int w = i % width_col;
    const int h = (i / width_col) % height_col;
    const int c = (i / width_col / height_col) % offset_channels;
    const int b = (i / width_col / height_col) / offset_channels;

    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
    const int col_step = kernel_h * kernel_w;
    int cnt = 0;
    const T* data_col_ptr = data_col +
                            deformable_group_index *
                                channel_per_deformable_group * batch_size *
                                width_col * height_col;
    const T* data_im_ptr = data_im +
                           (b * deformable_group + deformable_group_index) *
                               channel_per_deformable_group / kernel_h /
                               kernel_w * height * width;
    const T* data_offset_ptr = data_offset +
                               (b * deformable_group + deformable_group_index) *
                                   2 * kernel_h * kernel_w * height_col *
                                   width_col;
    const T* data_mask_ptr = data_mask +
                             (b * deformable_group + deformable_group_index) *
                                 kernel_h * kernel_w * height_col * width_col;

    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;

    for (int col_c = offset_c / 2; col_c < channel_per_deformable_group;
         col_c += col_step) {
      const int col_pos =
          (((col_c * batch_size + b) * height_col) + h) * width_col + w;
      const int bp_dir = offset_c % 2;

      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
      int i =
          (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
      int w_out = col_pos % width_col;
      int h_out = (col_pos / width_col) % height_col;
      int w_in = w_out * stride_w - pad_w;
      int h_in = h_out * stride_h - pad_h;
      const int data_offset_h_ptr =
          (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
      const int data_offset_w_ptr =
          (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
           w_out);
      const int data_mask_hw_ptr =
          (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
      const T offset_h = data_offset_ptr[data_offset_h_ptr];
      const T offset_w = data_offset_ptr[data_offset_w_ptr];
      const T mask = data_mask_ptr[data_mask_hw_ptr];
      T inv_h = h_in + i * dilation_h + offset_h;
      T inv_w = w_in + j * dilation_w + offset_w;
      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
        inv_h = inv_w = -2;
      } else {
        mval += data_col_ptr[col_pos] *
                DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width,
                                   height, width, inv_h, inv_w);
      }
      const T weight = DmcnGetCoordinateWeight(
          inv_h, inv_w, height, width, data_im_ptr + cnt * height * width,
          width, bp_dir);
      val += weight * data_col_ptr[col_pos] * mask;
      cnt += 1;
    }
    grad_offset[i] = val;
    if (offset_c % 2 == 0)
      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
                      kernel_w +
                  offset_c / 2) *
                     height_col +
                 h) *
                    width_col +
                w] = mval;
  }
}

template <typename T>
static inline void ModulatedDeformableCol2imCoordCPU(
    const platform::CPUDeviceContext& ctx, const T* data_col, const T* data_im,
    const T* data_offset, const T* data_mask,
    const std::vector<int64_t> im_shape, const std::vector<int64_t> col_shape,
    const std::vector<int64_t> kernel_shape, const std::vector<int> paddings,
    const std::vector<int> strides, const std::vector<int> dilations,
    const int deformable_groups, T* grad_offset, T* grad_mask) {
  int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] *
                    col_shape[2] * col_shape[3] * deformable_groups;
  int channel_per_deformable_group = col_shape[0] / deformable_groups;

  ModulatedDeformableCol2imCoordCPUKernel(
      num_kernels, data_col, data_im, data_offset, data_mask, im_shape[0],
      im_shape[1], im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0],
      paddings[1], strides[0], strides[1], dilations[0], dilations[1],
      channel_per_deformable_group, col_shape[1],
      2 * kernel_shape[2] * kernel_shape[3] * deformable_groups,
      deformable_groups, col_shape[2], col_shape[3], grad_offset, grad_mask);
}

template <typename T>
void ModulatedDeformableIm2colCPUKernel(
    const int num_kernels, const T* data_im, const T* data_offset,
    const T* data_mask, const int height, const int width, const int kernel_h,
    const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
    const int stride_w, const int dilation_h, const int dilation_w,
    const int channel_per_deformable_group, const int batch_size,
    const int num_channels, const int deformable_group, const int height_col,
    const int width_col, T* data_col) {
239
  for (int i = 0; i < num_kernels; i++) {
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
    const int w_col = i % width_col;
    const int h_col = (i / width_col) % height_col;
    const int b_col = (i / width_col) / height_col % batch_size;
    const int c_im = (i / width_col / height_col) / batch_size;
    const int c_col = c_im * kernel_h * kernel_w;

    const int deformable_group_index = c_im / channel_per_deformable_group;

    const int h_in = h_col * stride_h - pad_h;
    const int w_in = w_col * stride_w - pad_w;

    T* data_col_ptr =
        data_col +
        ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
    const T* data_im_ptr =
        data_im + (b_col * num_channels + c_im) * height * width;
    const T* data_offset_ptr =
        data_offset +
        (b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
            kernel_w * height_col * width_col;
    const T* data_mask_ptr =
        data_mask +
        (b_col * deformable_group + deformable_group_index) * kernel_h *
            kernel_w * height_col * width_col;

    for (int i = 0; i < kernel_h; ++i) {
      for (int j = 0; j < kernel_w; ++j) {
        const int data_offset_h_ptr =
            ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
        const int data_offset_w_ptr =
            ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
            w_col;
        const int data_mask_hw_ptr =
            ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;

        const T offset_h = data_offset_ptr[data_offset_h_ptr];
        const T offset_w = data_offset_ptr[data_offset_w_ptr];
        const T mask = data_mask_ptr[data_mask_hw_ptr];
        T val = static_cast<T>(0);
        const T h_im = h_in + i * dilation_h + offset_h;
        const T w_im = w_in + j * dilation_w + offset_w;
        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
          val =
              DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
        }
        *data_col_ptr = val * mask;
        data_col_ptr += batch_size * height_col * width_col;
      }
    }
  }
}

template <typename T>
static inline void ModulatedDeformableIm2colCPU(
    const platform::CPUDeviceContext& ctx, const T* data_im,
    const T* data_offset, const T* data_mask,
    const std::vector<int64_t> im_shape, const std::vector<int64_t> col_shape,
    const std::vector<int64_t> filter_shape, const std::vector<int> paddings,
    const std::vector<int> strides, const std::vector<int> dilations,
    const int deformable_groups, T* data_col) {
  int channel_per_deformable_group = im_shape[0] / deformable_groups;
  int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];

  // get outputs of im2col with offset by bilinear interpolation
  ModulatedDeformableIm2colCPUKernel(
      num_kernels, data_im, data_offset, data_mask, im_shape[1], im_shape[2],
      filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0],
      strides[1], dilations[0], dilations[1], channel_per_deformable_group,
      col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3],
      data_col);
}

template <typename T>
void FilterGradAddupCPUKernel(const int nthreads, const int n, const int height,
                              const int width, const T* dweight_3d,
                              T* filter_grad) {
316
  for (int i = 0; i < nthreads; i++) {
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
    filter_grad[i] = filter_grad[i] + dweight_3d[i];
  }
}

template <typename T>
class DeformableConvGradCPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const Tensor* output_grad =
        ctx.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor* offset_grad = ctx.Output<Tensor>(framework::GradVarName("Offset"));
    Tensor* mask_grad = ctx.Output<Tensor>(framework::GradVarName("Mask"));

    const Tensor* input = ctx.Input<Tensor>("Input");
    Tensor offset = *ctx.Input<Tensor>("Offset");
    Tensor mask = *ctx.Input<Tensor>("Mask");
    Tensor filter = *ctx.Input<Tensor>("Filter");
    if (!input_grad && !filter_grad && !offset_grad && !mask_grad) return;

    int groups = ctx.Attr<int>("groups");
    int deformable_groups = ctx.Attr<int>("deformable_groups");
    int im2col_step = ctx.Attr<int>("im2col_step");
    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");

    auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
    const int batch_size = static_cast<int>(input->dims()[0]);

    framework::DDim input_shape =
349 350 351 352
        phi::slice_ddim(input->dims(), 1, input->dims().size());
    std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape);
    std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
    std::vector<int64_t> output_shape_vec(phi::vectorize(output_grad->dims()));
353 354 355 356 357 358 359 360

    std::vector<int64_t> col_buffer_shape_vec(filter_shape_vec.size());
    col_buffer_shape_vec[0] =
        input->dims()[1] * filter.dims()[2] * filter.dims()[3];
    col_buffer_shape_vec[1] = im2col_step;
    for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) {
      col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2];
    }
361
    framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec));
362 363 364
    std::vector<int64_t> output_buffer_shape_vec(1);
    output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] *
                                 output_shape_vec[2] * output_shape_vec[3];
365
    framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec));
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
    Tensor col_buffer;
    Tensor output_buffer;
    col_buffer = ctx.AllocateTmpTensor<T, CPUDeviceContext>(col_shape, dev_ctx);
    output_buffer =
        ctx.AllocateTmpTensor<T, CPUDeviceContext>(output_shape, dev_ctx);

    output_buffer.ShareDataWith(*output_grad);

    int64_t M =
        input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3];
    int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3];
    int64_t K = output_shape_vec[1] / groups;

    framework::DDim weight_3d_shape = {groups, K, M};
    framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K,
                                         N};
    framework::DDim col_buffer_3d_shape = {groups, M, N};
    framework::DDim filter_grad_shape = {groups, K, M};

    Tensor weight_3d;
    weight_3d.ShareDataWith(filter).Resize(weight_3d_shape);
    Tensor out_grad_4d;
    out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape);
    Tensor col_buffer_3d;
    col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape);

392 393
    phi::funcs::SetConstant<CPUDeviceContext, T> set_zero;
    auto blas = phi::funcs::GetBlas<CPUDeviceContext, T>(dev_ctx);
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421

    col_buffer.mutable_data<T>(ctx.GetPlace());
    col_buffer_3d.mutable_data<T>(ctx.GetPlace());
    out_grad_4d.mutable_data<T>(ctx.GetPlace());

    int input_dim = input->numel() / input->dims()[0];
    int input_offset_dim = offset.numel() / offset.dims()[0];
    int input_mask_dim = mask.numel() / mask.dims()[0];

    if (filter_grad) {
      filter_grad->mutable_data<T>(ctx.GetPlace());
      filter_grad->Resize(filter_grad_shape);
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));
    }

    if (input_grad) {
      input_grad->mutable_data<T>(ctx.GetPlace());
      set_zero(dev_ctx, input_grad, static_cast<T>(0));
    }

    if (offset_grad && mask_grad) {
      offset_grad->mutable_data<T>(ctx.GetPlace());
      mask_grad->mutable_data<T>(ctx.GetPlace());
      set_zero(dev_ctx, offset_grad, static_cast<T>(0));
      set_zero(dev_ctx, mask_grad, static_cast<T>(0));
    }

    for (int i = 0; i < batch_size / im2col_step; ++i) {
422
      Tensor out_grad_3d = out_grad_4d.Slice(i, i + 1).Resize(
423
          phi::slice_ddim(out_grad_4d.dims(), 1, out_grad_4d.dims().size()));
424
      for (int g = 0; g < groups; ++g) {
425
        Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
426
            phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size()));
427
        Tensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize(
428
            phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size()));
429
        Tensor col_buffer_3d_slice =
430
            col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
                col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));

        blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0),
                    &col_buffer_3d_slice, T(0.0));
      }
      col_buffer.Resize(col_shape);

      T* col_buffer_ptr = col_buffer.data<T>();
      const T* input_ptr = input->data<T>();
      const T* offset_ptr = offset.data<T>();
      const T* mask_ptr = mask.data<T>();

      if (mask_grad && offset_grad) {
        T* offset_grad_ptr = offset_grad->data<T>();
        T* mask_grad_ptr = mask_grad->data<T>();
        // get grad of offset and mask
        ModulatedDeformableCol2imCoordCPU(
            ctx.template device_context<CPUDeviceContext>(), col_buffer_ptr,
            input_ptr + i * im2col_step * input_dim,
            offset_ptr + i * im2col_step * input_offset_dim,
            mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec,
            col_buffer_shape_vec, filter_shape_vec, paddings, strides,
            dilations, deformable_groups,
            offset_grad_ptr + i * im2col_step * input_offset_dim,
            mask_grad_ptr + i * im2col_step * input_mask_dim);
      }
      if (input_grad) {
        T* input_grad_ptr = input_grad->data<T>();
        // get grad of input
        ModulatedDeformableCol2imCPU(
            ctx.template device_context<CPUDeviceContext>(), col_buffer_ptr,
            offset_ptr + i * im2col_step * input_offset_dim,
            mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec,
            col_buffer_shape_vec, filter_shape_vec, paddings, strides,
            dilations, deformable_groups,
            input_grad_ptr + i * im2col_step * input_dim);
        input_grad->Resize(input->dims());
      }

      ModulatedDeformableIm2colCPU(
          ctx.template device_context<CPUDeviceContext>(),
          input_ptr + i * im2col_step * input_dim,
          offset_ptr + i * im2col_step * input_offset_dim,
          mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec,
          col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations,
          deformable_groups, col_buffer_ptr);

      col_buffer_3d.Resize(col_buffer_3d_shape);

      if (filter_grad) {
        Tensor dweight_3d;
        dweight_3d = ctx.AllocateTmpTensor<T, CPUDeviceContext>(
            filter_grad_shape, dev_ctx);
        for (int g = 0; g < groups; ++g) {
          Tensor out_grad_3d_slice =
486
              out_grad_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
487 488
                  out_grad_3d.dims(), 1, out_grad_3d.dims().size()));
          Tensor col_buffer_3d_slice =
489
              col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
490
                  col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
491
          Tensor dweight_3d_slice = dweight_3d.Slice(g, g + 1).Resize(
492
              phi::slice_ddim(dweight_3d.dims(), 1, dweight_3d.dims().size()));
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509

          blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true,
                      T(1.0), &dweight_3d_slice, T(0.0));
        }
        // update grad of weights
        FilterGradAddupCPUKernel(dweight_3d.numel(), groups, K, M,
                                 dweight_3d.data<T>(), filter_grad->data<T>());
      }
    }
    if (filter_grad) {
      filter_grad->Resize(filter.dims());
    }
  }
};

}  // namespace operators
}  // namespace paddle