depthwise_conv.cu 34.5 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
A
Abhinav Arora 已提交
16
#include <vector>
17
#include "cub/cub.cuh"
Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/depthwise_conv.h"
19
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
20
#include "paddle/fluid/platform/cuda_primitives.h"
Z
zlx 已提交
21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

26
template <typename T>
27 28 29 30 31
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
  value = WarpReduce(temp_storage).Sum(value);
  if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
32 33
}

34 35 36 37 38 39 40 41
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
42 43
      const int dilate_height, const int dilate_width, T *const output_data,   \
      const DataLayout data_layout = DataLayout::kNCHW
44

45 46
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
47
template <typename T, bool fuse_relu_before_conv>
48
__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
49 50 51 52 53 54 55 56 57 58 59 60 61
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      const T* weight = filter_data + c_out * filter_height * filter_width;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + filter_height * dilate_height;
      const int w_in_end = w_in_start + filter_width * dilate_width;

62 63 64 65 66 67 68
      int in_offset;
      if (data_layout == DataLayout::kNCHW) {
        in_offset =
            ((batch * input_channels + c_in) * input_height) * input_width;
      } else {
        in_offset = batch * input_height * input_width * input_channels;
      }
69 70 71 72 73 74 75 76 77 78 79

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;
      int weight_offset = 0;

      for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
        for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
          if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
              w_in < w_end) {
80 81 82 83 84 85 86
            int offset;
            if (data_layout == DataLayout::kNCHW) {
              offset = in_offset + h_in * input_width + w_in;
            } else {
              offset = in_offset +
                       (h_in * input_width + w_in) * input_channels + c_in;
            }
87 88 89 90 91
            if (fuse_relu_before_conv) {
              value += weight[weight_offset] * max(0.0f, input_data[offset]);
            } else {
              value += weight[weight_offset] * input_data[offset];
            }
92 93 94
          }
          weight_offset++;
        }
Z
zlx 已提交
95
      }
96 97 98 99 100 101 102 103 104 105
      int index;
      if (data_layout == DataLayout::kNCHW) {
        index = ((batch * gridDim.x + c_out) * output_height + h_out) *
                    output_width +
                w_out;
      } else {
        index = ((batch * output_height + h_out) * output_width + w_out) *
                    gridDim.x +
                c_out;
      }
106
      output_data[index] = value;
Z
zlx 已提交
107 108 109
    }
  }
}
110

111
template <typename T, int c_filter, bool fuse_relu_before_conv>
112 113 114 115 116 117 118 119
__device__ __inline__ void KernelDepthwiseConvCFilter(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int kWeghtSize = c_filter * c_filter;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
120

121 122 123 124 125 126 127 128 129 130 131 132
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

133 134 135 136 137 138 139
      int in_offset;
      if (data_layout == DataLayout::kNCHW) {
        in_offset =
            ((batch * input_channels + c_in) * input_height) * input_width;
      } else {
        in_offset = batch * input_height * input_width * input_channels;
      }
140 141 142 143 144 145 146 147 148 149 150 151

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
152 153 154 155 156 157 158
            int offset;
            if (data_layout == DataLayout::kNCHW) {
              offset = in_offset + h_in * input_width + w_in;
            } else {
              offset = in_offset +
                       (h_in * input_width + w_in) * input_channels + c_in;
            }
159 160 161 162 163 164
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
                       max(0.0f, input_data[offset]);
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
165 166 167
          }
        }
      }
168 169 170 171 172 173 174 175 176 177
      int index;
      if (data_layout == DataLayout::kNCHW) {
        index = ((batch * gridDim.x + c_out) * output_height + h_out) *
                    output_width +
                w_out;
      } else {
        index = ((batch * output_height + h_out) * output_width + w_out) *
                    gridDim.x +
                c_out;
      }
178 179 180 181 182
      output_data[index] = value;
    }
  }
}

183 184
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
185 186 187
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
  if (c_filter_multiplier == 0) {
    if (c_filter == -1)
188
      KernelDepthwiseConv<T, fuse_relu_before_conv>(
189 190 191 192
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          filter_multiplier, filter_height, filter_width, stride_height,
          stride_width, padding_height, padding_width, dilate_height,
193
          dilate_width, output_data, data_layout);
194
    else
195
      KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
196 197 198 199
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          filter_multiplier, filter_height, filter_width, stride_height,
          stride_width, padding_height, padding_width, dilate_height,
200
          dilate_width, output_data, data_layout);
201 202
  } else {
    if (c_filter == -1)
203 204 205 206 207
      KernelDepthwiseConv<T, fuse_relu_before_conv>(
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
          padding_height, padding_width, dilate_height, dilate_width,
208
          output_data, data_layout);
209
    else
210
      KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
211 212 213 214
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
          padding_height, padding_width, dilate_height, dilate_width,
215
          output_data, data_layout);
216
  }
217 218
}

Z
zlx 已提交
219
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
220
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
221 222 223 224 225
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
226 227 228 229
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
230 231
      T *const input_grad_data,                                                \
      const DataLayout data_layout = DataLayout::kNCHW
232

233
template <typename T, bool fuse_relu_before_conv>
234
__device__ __inline__ void KernelDepthwiseConvInputGrad(
235
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      const int c_out_start = c_in * filter_multiplier;

      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;

      int h_out_end = h_in + padding_height;

      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

      int w_out_end = w_in + padding_width;

      T value = 0;
254 255 256 257 258 259 260 261 262 263 264
      int index;
      if (data_layout == DataLayout::kNCHW) {
        index =
            ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
            w_in;
      } else {
        index =
            ((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
            c_in;
      }

265 266 267 268 269 270
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
271 272 273 274 275 276 277 278 279 280 281 282 283 284

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
285 286 287 288 289 290 291 292 293 294 295 296 297 298
              int output_grad_offset;
              if (data_layout == DataLayout::kNCHW) {
                output_grad_offset =
                    ((batch * output_channels + c_out) * output_height +
                     s_h_out) *
                        output_width +
                    s_w_out;
              } else {
                output_grad_offset =
                    ((batch * output_height + s_h_out) * output_width +
                     s_w_out) *
                        output_channels +
                    c_out;
              }
299 300 301 302
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
303 304
        }
      }
305
      input_grad_data[index] = value;
Z
zlx 已提交
306 307 308 309
    }
  }
}

310 311
template <typename T, int c_filter, int c_filter_multiplier,
          bool fuse_relu_before_conv>
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
__device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;

      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

      T value = 0;
337 338 339 340 341 342 343 344 345 346
      int index;
      if (data_layout == DataLayout::kNCHW) {
        index =
            ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
            w_in;
      } else {
        index =
            ((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
            c_in;
      }
347 348 349 350 351 352
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
353 354 355 356 357 358 359 360 361 362 363 364

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
365 366 367 368 369 370 371 372 373 374 375 376 377 378
              int output_grad_offset;
              if (data_layout == DataLayout::kNCHW) {
                output_grad_offset =
                    ((batch * output_channels + c_out) * output_height +
                     s_h_out) *
                        output_width +
                    s_w_out;
              } else {
                output_grad_offset =
                    ((batch * output_height + s_h_out) * output_width +
                     s_w_out) *
                        output_channels +
                    c_out;
              }
379 380 381 382 383 384 385 386 387 388 389 390
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

391 392
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
393
__global__ void KernelDepthwiseConvInputGradSp(
394
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
395
  if (c_filter_multiplier == 0)
396 397
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
398 399 400
        output_height, output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
401
        dilate_width, input_grad_data, data_layout);
402
  else if (c_filter == -1)
403 404
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
405 406 407
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
408
        input_grad_data, data_layout);
409
  else
410 411 412
    KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier,
                                        fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
413 414 415
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
416
        input_grad_data, data_layout);
417 418
}

419
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
420
template <typename T, bool fuse_relu_before_conv>
421 422 423 424 425 426 427
__device__ __inline__ void KernelDepthwiseConvFilterGrad(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
428 429
    const int dilate_width, T* filter_grad_data,
    const DataLayout data_layout = DataLayout::kNCHW) {
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
  T s = 0;

  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
#define gaid_nhwc(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C))
        int input_id;
        if (data_layout == DataLayout::kNCHW) {
          input_id = ((bid * (gridDim.z / filter_multiplier) +
                       kernel_id / filter_multiplier) *
                          input_height +
                      image_hk) *
                         input_width +
                     image_wk;
          if (fuse_relu_before_conv) {
            s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
                 max(0.0f, input_data[input_id]);
          } else {
            s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
                 input_data[input_id];
          }
466
        } else {
467 468 469 470 471 472 473 474 475 476 477
          input_id =
              ((bid * input_height + image_hk) * input_width + image_wk) *
                  (gridDim.z / filter_multiplier) +
              kernel_id / filter_multiplier;
          if (fuse_relu_before_conv) {
            s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
                 max(0.0f, input_data[input_id]);
          } else {
            s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
                 input_data[input_id];
          }
478
        }
479 480

#undef gaid
481
      }
Z
zlx 已提交
482 483
    }
  }
484
  CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
485 486
}

487
template <typename T, int c_filter_multiplier, bool fuse_relu_before_conv>
488 489 490 491 492 493 494
__global__ void KernelDepthwiseConvFilterGradSp(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
495 496
    const int dilate_width, T* filter_grad_data,
    const DataLayout data_layout = DataLayout::kNCHW) {
497
  if (c_filter_multiplier == 0)
498
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
499 500 501 502
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
503
        dilate_width, filter_grad_data, data_layout);
504
  else
505
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
506 507 508 509
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
510
        dilate_width, filter_grad_data, data_layout);
Z
zlx 已提交
511 512 513 514 515 516 517
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
518 519 520
template <class T, bool fuse_relu_before_conv>
class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
                           fuse_relu_before_conv> {
Z
zlx 已提交
521 522 523
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
X
xzl 已提交
524 525
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
526
                  const std::vector<int>& paddings,
527 528
                  const std::vector<int>& dilations, framework::Tensor* output,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
529
    const int batch_size = input.dims()[0];
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
    const int input_channels =
        (data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
    const int input_height =
        (data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
    const int input_width =
        (data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
    const int output_channels =
        (data_layout == DataLayout::kNCHW ? output->dims()[1]
                                          : output->dims()[3]);
    const int output_height =
        (data_layout == DataLayout::kNCHW ? output->dims()[2]
                                          : output->dims()[1]);
    const int output_width =
        (data_layout == DataLayout::kNCHW ? output->dims()[3]
                                          : output->dims()[2]);
545 546
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
547 548 549 550
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
551 552
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
553 554 555 556 557

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

558
    int thread = 512;
559 560 561 562
    if (output_width > 1024 && output_width <= 2048)
      thread = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      thread = output_width;
563 564 565 566
    int blocks = std::min(std::max(thread / output_width, 1), output_height);
    dim3 threads(std::min(output_width, thread), blocks, 1);
    dim3 grid(output_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;
567
#define check_case(c_filter_multiplier, c_stride, c_filter)                  \
568 569
  if (c_filter_multiplier == 0 ||                                            \
      filter_multiplier == c_filter_multiplier &&                            \
570 571 572
          stride_height == stride_width && stride_height == c_stride &&      \
          (ksize_height == ksize_width && ksize_height == c_filter ||        \
           c_filter == -1)) {                                                \
573 574 575
    KernelDepthwiseConvSp<                                                   \
        T, c_filter_multiplier, c_stride, c_filter,                          \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
576 577 578 579
        input_data, filter_data, batch_size, output_channels, output_height, \
        output_width, input_channels, input_height, input_width,             \
        filter_multiplier, ksize_height, ksize_width, stride_height,         \
        stride_width, padding_height, padding_width, dilate_height,          \
580
        dilate_width, output_data, data_layout);                             \
581 582
    return;                                                                  \
  }
583 584 585 586 587 588
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
589 590 591 592 593 594
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
595 596 597
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
598
#undef check_case
Z
zlx 已提交
599 600 601
  }
};

602 603 604
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
                                    fuse_relu_before_conv> {
Z
zlx 已提交
605 606 607
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
608 609
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
610 611
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
612
                  const std::vector<int>& dilations,
613 614
                  framework::Tensor* input_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
615
    const int batch_size = input.dims()[0];
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
    const int input_channels =
        (data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
    const int input_height =
        (data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
    const int input_width =
        (data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
    const int output_channels =
        (data_layout == DataLayout::kNCHW ? output_grad.dims()[1]
                                          : output_grad.dims()[3]);
    const int output_height =
        (data_layout == DataLayout::kNCHW ? output_grad.dims()[2]
                                          : output_grad.dims()[1]);
    const int output_width =
        (data_layout == DataLayout::kNCHW ? output_grad.dims()[3]
                                          : output_grad.dims()[2]);
631 632 633
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
634 635 636
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
637 638
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
639

640
    const T* input_data = input.data<T>();
641
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
642 643 644
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

645
    int thread = 512;
646 647 648 649
    if (input_width > 1024 && input_width <= 2048)
      thread = (input_width - 1) / 2 + 1;
    else if (input_width > 512 && input_width <= 1024)
      thread = input_width;
650 651 652 653 654
    int blocks = std::min(std::max(thread / input_width, 1), input_height);
    dim3 threads(std::min(input_width, thread), blocks, 1);
    dim3 grid(input_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;

655
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
656 657
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
658 659 660
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
661
    KernelDepthwiseConvInputGradSp<                                     \
662 663 664 665 666 667
        T, c_filter_multiplier, c_stride, c_filter,                     \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
        input_data, output_grad_data, filter_data, batch_size,          \
        output_channels, output_height, output_width, input_channels,   \
        input_height, input_width, filter_multiplier, ksize_height,     \
        ksize_width, stride_height, stride_width, padding_height,       \
668 669
        padding_width, dilate_height, dilate_width, input_grad_data,    \
        data_layout);                                                   \
670 671
    return;                                                             \
  }
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
687
#undef check_case
Z
zlx 已提交
688 689 690
  }
};

691 692 693
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
                                     fuse_relu_before_conv> {
Z
zlx 已提交
694 695 696
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
697
                  const framework::Tensor& output_grad,
X
xzl 已提交
698 699
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
700
                  const std::vector<int>& dilations,
701 702
                  framework::Tensor* filter_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
703
    const int batch_size = input.dims()[0];
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718
    const int input_channels =
        (data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
    const int input_height =
        (data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
    const int input_width =
        (data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
    const int output_channels =
        (data_layout == DataLayout::kNCHW ? output_grad.dims()[1]
                                          : output_grad.dims()[3]);
    const int output_height =
        (data_layout == DataLayout::kNCHW ? output_grad.dims()[2]
                                          : output_grad.dims()[1]);
    const int output_width =
        (data_layout == DataLayout::kNCHW ? output_grad.dims()[3]
                                          : output_grad.dims()[2]);
719 720
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
721 722 723 724
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
725 726
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
727 728 729

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
730
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
731

732
    int block_size = 512;
733 734 735 736
    if (output_width > 1024 && output_width <= 2048)
      block_size = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      block_size = output_width;
737 738 739 740 741 742 743 744 745
    int crop_output_height =
        std::min(std::max(block_size / output_width, 1), output_height);
    dim3 grid(ksize_width, ksize_height, output_channels);
    dim3 threads(std::min(output_width, block_size), crop_output_height, 1);
    int filter_multiplier = output_channels / input_channels;

#define check_case(c_filter_multiplier)                                       \
  if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \
    KernelDepthwiseConvFilterGradSp<                                          \
746 747
        T, c_filter_multiplier,                                               \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(       \
748 749 750 751
        output_grad_data, input_data, batch_size, output_channels,            \
        output_height, output_width, input_channels, input_height,            \
        input_width, filter_multiplier, ksize_height, ksize_width,            \
        stride_height, stride_width, padding_height, padding_width,           \
752
        dilate_height, dilate_width, filter_grad_data, data_layout);          \
753 754 755 756 757
    return;                                                                   \
  }
    check_case(1);
    check_case(0);
#undef check_case
Z
zlx 已提交
758 759 760
  }
};

761 762
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, false>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, false>;
Z
zlx 已提交
763

764 765
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             false>;
Z
zlx 已提交
766
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
767 768 769 770 771 772 773 774 775 776 777 778
                                             double, false>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float, false>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              double, false>;

template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, true>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, true>;

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             true>;
Z
zlx 已提交
779
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
780
                                             double, true>;
781 782

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
783
                                              float, true>;
Z
zlx 已提交
784
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
785
                                              double, true>;
Z
zlx 已提交
786 787 788 789

}  // namespace math
}  // namespace operators
}  // namespace paddle