depthwise_conv.cu 29.4 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
A
Abhinav Arora 已提交
16
#include <vector>
17
#include "cub/cub.cuh"
Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/depthwise_conv.h"
19
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
20
#include "paddle/fluid/platform/cuda_primitives.h"
Z
zlx 已提交
21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

26
template <typename T>
27 28 29 30 31
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
  value = WarpReduce(temp_storage).Sum(value);
  if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
32 33
}

34 35 36 37 38 39 40 41 42 43
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width, T *const output_data

44 45
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
46
template <typename T, bool fuse_relu_before_conv>
47
__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      const T* weight = filter_data + c_out * filter_height * filter_width;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + filter_height * dilate_height;
      const int w_in_end = w_in_start + filter_width * dilate_width;

      const int in_offset =
          ((batch * input_channels + c_in) * input_height) * input_width;

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;
      int weight_offset = 0;

      for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
        for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
          if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
              w_in < w_end) {
            const int offset = in_offset + h_in * input_width + w_in;
75 76 77 78 79
            if (fuse_relu_before_conv) {
              value += weight[weight_offset] * max(0.0f, input_data[offset]);
            } else {
              value += weight[weight_offset] * input_data[offset];
            }
80 81 82
          }
          weight_offset++;
        }
Z
zlx 已提交
83
      }
84 85 86 87
      int index =
          ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
          w_out;
      output_data[index] = value;
Z
zlx 已提交
88 89 90
    }
  }
}
91

92
template <typename T, int c_filter, bool fuse_relu_before_conv>
93 94 95 96 97 98 99 100
__device__ __inline__ void KernelDepthwiseConvCFilter(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int kWeghtSize = c_filter * c_filter;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
101

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

      const int in_offset =
          ((batch * input_channels + c_in) * input_height) * input_width;

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
            const int offset = in_offset + h_in * input_width + w_in;
129 130 131 132 133 134
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
                       max(0.0f, input_data[offset]);
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
135 136 137 138 139 140 141 142 143 144 145
          }
        }
      }
      int index =
          ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
          w_out;
      output_data[index] = value;
    }
  }
}

146 147
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
148 149 150
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
  if (c_filter_multiplier == 0) {
    if (c_filter == -1)
151
      KernelDepthwiseConv<T, fuse_relu_before_conv>(
152 153 154 155 156 157
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          filter_multiplier, filter_height, filter_width, stride_height,
          stride_width, padding_height, padding_width, dilate_height,
          dilate_width, output_data);
    else
158
      KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
159 160 161 162 163 164 165
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          filter_multiplier, filter_height, filter_width, stride_height,
          stride_width, padding_height, padding_width, dilate_height,
          dilate_width, output_data);
  } else {
    if (c_filter == -1)
166 167 168 169 170 171
      KernelDepthwiseConv<T, fuse_relu_before_conv>(
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
          padding_height, padding_width, dilate_height, dilate_width,
          output_data);
172
    else
173
      KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
174 175 176 177 178 179
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
          padding_height, padding_width, dilate_height, dilate_width,
          output_data);
  }
180 181
}

Z
zlx 已提交
182
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
183
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
184 185 186 187 188
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
189 190 191 192 193 194
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
      T *const input_grad_data

195
template <typename T, bool fuse_relu_before_conv>
196
__device__ __inline__ void KernelDepthwiseConvInputGrad(
197
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      const int c_out_start = c_in * filter_multiplier;

      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;

      int h_out_end = h_in + padding_height;

      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

      int w_out_end = w_in + padding_width;

      T value = 0;
216 217 218 219 220 221 222 223 224
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              const int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
248 249
        }
      }
250
      input_grad_data[index] = value;
Z
zlx 已提交
251 252 253 254
    }
  }
}

255 256
template <typename T, int c_filter, int c_filter_multiplier,
          bool fuse_relu_before_conv>
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
__device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;

      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

      T value = 0;
282 283 284 285 286 287 288 289 290
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              const int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

320 321
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
322
__global__ void KernelDepthwiseConvInputGradSp(
323
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
324
  if (c_filter_multiplier == 0)
325 326
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
327 328 329 330
        output_height, output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
        dilate_width, input_grad_data);
331
  else if (c_filter == -1)
332 333
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
334 335 336 337
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
        input_grad_data);
338
  else
339 340 341
    KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier,
                                        fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
342 343 344 345
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
        input_grad_data);
346 347
}

348
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
349
template <typename T, bool fuse_relu_before_conv>
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
__device__ __inline__ void KernelDepthwiseConvFilterGrad(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* filter_grad_data) {
  T s = 0;

  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
377 378 379 380 381 382 383 384 385 386 387 388 389
        int input_id = ((bid * (gridDim.z / filter_multiplier) +
                         kernel_id / filter_multiplier) *
                            input_height +
                        image_hk) *
                           input_width +
                       image_wk;
        if (fuse_relu_before_conv) {
          s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
               max(0.0f, input_data[input_id]);
        } else {
          s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
               input_data[input_id];
        }
390 391

#undef gaid
392
      }
Z
zlx 已提交
393 394
    }
  }
395
  CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
396 397
}

398
template <typename T, int c_filter_multiplier, bool fuse_relu_before_conv>
399 400 401 402 403 404 405 406 407
__global__ void KernelDepthwiseConvFilterGradSp(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* filter_grad_data) {
  if (c_filter_multiplier == 0)
408
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
409 410 411 412 413 414
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
        dilate_width, filter_grad_data);
  else
415
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
416 417 418 419 420
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
        dilate_width, filter_grad_data);
Z
zlx 已提交
421 422 423 424 425 426 427
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
428 429 430
template <class T, bool fuse_relu_before_conv>
class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
                           fuse_relu_before_conv> {
Z
zlx 已提交
431 432 433
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
X
xzl 已提交
434 435
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
436 437 438
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
                  framework::Tensor* output) {
Z
zlx 已提交
439 440 441 442 443 444 445
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
446 447
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
448 449 450 451
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
452 453
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
454 455 456 457 458

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

459
    int thread = 512;
460 461 462 463
    if (output_width > 1024 && output_width <= 2048)
      thread = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      thread = output_width;
464 465 466 467
    int blocks = std::min(std::max(thread / output_width, 1), output_height);
    dim3 threads(std::min(output_width, thread), blocks, 1);
    dim3 grid(output_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;
468
#define check_case(c_filter_multiplier, c_stride, c_filter)                  \
469 470
  if (c_filter_multiplier == 0 ||                                            \
      filter_multiplier == c_filter_multiplier &&                            \
471 472 473
          stride_height == stride_width && stride_height == c_stride &&      \
          (ksize_height == ksize_width && ksize_height == c_filter ||        \
           c_filter == -1)) {                                                \
474 475 476
    KernelDepthwiseConvSp<                                                   \
        T, c_filter_multiplier, c_stride, c_filter,                          \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
477 478 479 480 481 482 483
        input_data, filter_data, batch_size, output_channels, output_height, \
        output_width, input_channels, input_height, input_width,             \
        filter_multiplier, ksize_height, ksize_width, stride_height,         \
        stride_width, padding_height, padding_width, dilate_height,          \
        dilate_width, output_data);                                          \
    return;                                                                  \
  }
484 485 486 487 488 489
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
490 491 492 493 494 495
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
496 497 498
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
499
#undef check_case
Z
zlx 已提交
500 501 502
  }
};

503 504 505
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
                                    fuse_relu_before_conv> {
Z
zlx 已提交
506 507 508
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
509 510
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
511 512
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
513
                  const std::vector<int>& dilations,
514
                  framework::Tensor* input_grad) {
Z
zlx 已提交
515 516 517 518
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
519 520 521 522 523 524
    const int output_channels = output_grad.dims()[1];
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
525 526 527
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
528 529
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
530

531
    const T* input_data = input.data<T>();
532
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
533 534 535
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

536
    int thread = 512;
537 538 539 540
    if (input_width > 1024 && input_width <= 2048)
      thread = (input_width - 1) / 2 + 1;
    else if (input_width > 512 && input_width <= 1024)
      thread = input_width;
541 542 543 544 545
    int blocks = std::min(std::max(thread / input_width, 1), input_height);
    dim3 threads(std::min(input_width, thread), blocks, 1);
    dim3 grid(input_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;

546
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
547 548
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
549 550 551
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
552
    KernelDepthwiseConvInputGradSp<                                     \
553 554 555 556 557 558 559
        T, c_filter_multiplier, c_stride, c_filter,                     \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
        input_data, output_grad_data, filter_data, batch_size,          \
        output_channels, output_height, output_width, input_channels,   \
        input_height, input_width, filter_multiplier, ksize_height,     \
        ksize_width, stride_height, stride_width, padding_height,       \
        padding_width, dilate_height, dilate_width, input_grad_data);   \
560 561
    return;                                                             \
  }
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
577
#undef check_case
Z
zlx 已提交
578 579 580
  }
};

581 582 583
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
                                     fuse_relu_before_conv> {
Z
zlx 已提交
584 585 586
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
587
                  const framework::Tensor& output_grad,
X
xzl 已提交
588 589
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
590
                  const std::vector<int>& dilations,
591
                  framework::Tensor* filter_grad) {
Z
zlx 已提交
592 593 594 595
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
596 597 598 599 600
    const int output_channels = output_grad.dims()[1];
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
601 602 603 604
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
605 606
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
607 608 609

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
610
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
611

612
    int block_size = 512;
613 614 615 616
    if (output_width > 1024 && output_width <= 2048)
      block_size = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      block_size = output_width;
617 618 619 620 621 622 623 624 625
    int crop_output_height =
        std::min(std::max(block_size / output_width, 1), output_height);
    dim3 grid(ksize_width, ksize_height, output_channels);
    dim3 threads(std::min(output_width, block_size), crop_output_height, 1);
    int filter_multiplier = output_channels / input_channels;

#define check_case(c_filter_multiplier)                                       \
  if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \
    KernelDepthwiseConvFilterGradSp<                                          \
626 627
        T, c_filter_multiplier,                                               \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(       \
628 629 630 631 632 633 634 635 636 637
        output_grad_data, input_data, batch_size, output_channels,            \
        output_height, output_width, input_channels, input_height,            \
        input_width, filter_multiplier, ksize_height, ksize_width,            \
        stride_height, stride_width, padding_height, padding_width,           \
        dilate_height, dilate_width, filter_grad_data);                       \
    return;                                                                   \
  }
    check_case(1);
    check_case(0);
#undef check_case
Z
zlx 已提交
638 639 640
  }
};

641 642
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, false>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, false>;
Z
zlx 已提交
643

644 645
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             false>;
Z
zlx 已提交
646
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
647 648 649 650 651 652 653 654 655 656 657 658
                                             double, false>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float, false>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              double, false>;

template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, true>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, true>;

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             true>;
Z
zlx 已提交
659
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
660
                                             double, true>;
661 662

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
663
                                              float, true>;
Z
zlx 已提交
664
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
665
                                              double, true>;
Z
zlx 已提交
666 667 668 669

}  // namespace math
}  // namespace operators
}  // namespace paddle