depthwise_conv.cu 56.9 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
A
Abhinav Arora 已提交
16
#include <vector>
17 18 19 20 21 22 23
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
Y
Yi Wang 已提交
24
#include "paddle/fluid/operators/math/depthwise_conv.h"
25
#include "paddle/fluid/operators/math/math_function.h"
26
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
27
#include "paddle/fluid/platform/cuda_primitives.h"
Z
zlx 已提交
28 29 30 31 32

namespace paddle {
namespace operators {
namespace math {

33
template <typename T>
W
wangguanzhong 已提交
34
static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) {
35 36
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
W
wangguanzhong 已提交
37 38 39
  val = WarpReduce(temp_storage).Sum(val, warp_size);
  return val;
}
40

W
wangguanzhong 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
template <typename T>
__forceinline__ __device__ T BlockReduceSum(T val) {
  static __shared__ T shared[32];
  int thread_id = threadIdx.x + threadIdx.y * blockDim.x +
                  threadIdx.z * blockDim.x * blockDim.y;
  int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
  int lane = thread_id % warp_size;
  int wid = thread_id / warp_size;

  val = WarpReduceSum(val, warp_size);  // Each warp performs partial reduction

  if (lane == 0) shared[wid] = val;  // Write reduced value to shared memory
  __syncthreads();                   // Wait for all partial reductions

  // read from shared memory only if that warp existed
  int block_size = blockDim.x * blockDim.y * blockDim.z;
  if (thread_id < (block_size - 1) / warp_size + 1) {
    val = shared[lane];
  } else {
    val = static_cast<T>(0);
  }
62

W
wangguanzhong 已提交
63 64 65 66 67 68 69 70
  if (wid == 0) {
    val = WarpReduceSum(val, warp_size);  // Final reduce within first warp
  }
  __syncthreads();
  if (thread_id != 0) {
    val = static_cast<T>(0);
  }
  return val;
71 72
}

73 74 75 76 77 78 79 80
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
81
      const int dilate_height, const int dilate_width, T *const output_data
82

83 84
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
85
template <typename T, bool fuse_relu_before_conv>
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
__device__ __inline__ void KernelDepthwiseConvNCHW(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int w_out = idx % output_width;
  const int h_out = (idx / output_width) % output_height;
  const int c_out = (idx / output_width / output_height) % output_channels;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  const T* weight = filter_data + c_out * filter_height * filter_width;
  T value = 0;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  int in_offset =
      ((batch * input_channels + c_in) * input_height) * input_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = in_offset + h_in * input_width + w_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
          value += weight[weight_offset] * max(0.0f, in_data);
        } else {
          value += weight[weight_offset] * in_data;
        }
126
      }
127 128 129 130 131 132 133 134
      weight_offset++;
    }
  }
  int index = batch * output_channels * output_height * output_width +
              c_out * output_height * output_width + h_out * output_width +
              w_out;
  output_data[index] = value;
}
135

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int c_out = idx % output_channels;
  const int w_out = (idx / output_channels) % output_width;
  const int h_out = (idx / output_channels / output_width) % output_height;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  T value = 0;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = ((batch * input_height + h_in) * input_width + w_in) *
169
                         input_channels +
170 171
                     c_in;
        T in_data = input_data[offset];
172
        const T* weight = filter_data + weight_offset * output_channels + c_out;
173
        if (fuse_relu_before_conv) {
174
          value += weight[0] * max(0.0f, in_data);
175
        } else {
176
          value += weight[0] * in_data;
177
        }
Z
zlx 已提交
178
      }
179
      weight_offset++;
Z
zlx 已提交
180 181
    }
  }
182 183 184 185
  int index = batch * output_channels * output_height * output_width +
              h_out * output_width * output_channels + w_out * output_channels +
              c_out;
  output_data[index] = value;
Z
zlx 已提交
186
}
187

188
template <typename T, int c_filter, bool fuse_relu_before_conv>
189
__device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
190
    ARG_DEFINE_KernelDepthwiseConv) {
191 192
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
193 194 195 196
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
197

198 199 200 201 202 203 204 205 206 207 208 209
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

210 211
      int in_offset =
          ((batch * input_channels + c_in) * input_height) * input_width;
212 213 214 215 216 217 218 219 220 221 222 223

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
224 225 226 227
            int offset = in_offset + h_in * input_width + w_in;
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
                       max(0.0f, input_data[offset]);
228
            } else {
229
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
230
            }
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
          }
        }
      }
      int index =
          ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
          w_out;
      output_data[index] = value;
    }
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int batch = blockIdx.z;
  int h_out = blockIdx.x * dilate_height + blockIdx.y;
  if (h_out >= output_height) {
    return;
  }
  int in_offset = batch * input_height * input_width * input_channels;
  int out_offset =
      (batch * output_height + h_out) * output_width * output_channels;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];

  for (int c_out = threadIdx.x; c_out < output_channels; c_out += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; i++) {
      const T* weight = filter_data + i * output_channels + c_out;
      r_weight[i] = weight[0];
    }
    const int c_in = c_out / filter_multiplier;
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_out = i_wi * dilate_width + i_dw;
      if (w_out >= output_width) {
        continue;
      }
      T value = 0;
      const int w_in_start = -padding_width + w_out * stride_width;
      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
            int offset =
                in_offset + (h_in * input_width + w_in) * input_channels + c_in;
281 282 283 284 285 286
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
                       max(0.0f, input_data[offset]);
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
287 288 289
          }
        }
      }
290
      int index = out_offset + w_out * output_channels + c_out;
291 292 293 294 295
      output_data[index] = value;
    }
  }
}

296
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
297
          DataLayout data_layout, bool fuse_relu_before_conv>
298
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
299 300 301 302 303 304 305 306 307
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter == -1) {
308
    if (data_layout != DataLayout::kNHWC) {
309
      KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(
310 311
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
312 313
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
314
          output_data);
315 316
    } else {
      KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(
317 318
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
319 320
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
321
          output_data);
322 323
    }
  } else {
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvCFilterNCHW<T, c_filter, fuse_relu_before_conv>(
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
          output_data);
    } else {
      KernelDepthwiseConvCFilterNHWC<T, c_filter, fuse_relu_before_conv>(
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
          output_data);
    }
339
  }
340 341
}

Z
zlx 已提交
342
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
343
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
344 345 346 347 348
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
349 350 351 352
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
353
      T *const input_grad_data
354

355
template <typename T, bool fuse_relu_before_conv>
356
__device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
357
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
358 359
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;
360 361 362 363 364 365 366 367 368 369 370
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int c_out_start = c_in * filter_multiplier;
      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;
      int h_out_end = h_in + padding_height;
      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;
      int w_out_end = w_in + padding_width;

      T value = 0;
371 372 373
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
374

375 376 377 378 379 380
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
381 382 383 384 385 386 387 388 389 390 391 392 393 394

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
              int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int batch = blockIdx.z;
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int w_in = threadIdx.y; w_in < input_width; w_in += blockDim.y) {
      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;
      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

      T value = 0;
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        int weight_offset = filter_height * filter_width;
        for (int h_out = h_out_start, h_f = 0; h_f < filter_height;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < filter_width;
               w_out += dilate_width, w_f++) {
            weight_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              int filter_offset = weight_offset * output_channels + c_out;
456 457 458 459
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
460 461
        }
      }
462
      input_grad_data[index] = value;
Z
zlx 已提交
463 464 465 466
    }
  }
}

467 468
template <typename T, int c_filter, int c_filter_multiplier,
          bool fuse_relu_before_conv>
469
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW(
470
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
471 472
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

      T value = 0;
490 491 492
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
493 494 495 496 497 498
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
499 500 501 502 503 504 505 506 507 508 509 510

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
511 512 513 514 515
              int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
516 517 518 519 520 521 522 523 524 525 526 527
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

528
template <typename T, int c_filter, int c_filter_multiplier,
529
          bool fuse_relu_before_conv>
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
  const int batch = blockIdx.z;
  const int wi_size = (input_width + dilate_width - 1) / dilate_width;
  const int h_out_start =
      h_in - (c_filter - 1) * dilate_height + padding_height;

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
      int c_out = c_in * c_filter_multiplier + c_i;
      for (int i = 0; i < c_filter * c_filter; i++)
        r_weight[i + c_i * c_filter * c_filter] =
            filter_data[(c_filter * c_filter - i - 1) * output_channels +
                        c_out];
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_in = i_wi * dilate_width + i_dw;
      if (w_in >= input_width) {
        continue;
      }
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

      T value = 0;
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
        int c_out = c_in * c_filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          DataLayout data_layout, bool fuse_relu_before_conv>
600
__global__ void KernelDepthwiseConvInputGradSp(
601
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }

  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvInputGradNCHW<T, fuse_relu_before_conv>(
          input_data, output_grad_data, filter_data, batch_size,
          output_channels, output_height, output_width, input_channels,
          input_height, input_width, final_filter_multiplier, filter_height,
          filter_width, h_stride, w_stride, padding_height, padding_width,
          dilate_height, dilate_width, input_grad_data);
    } else {
      KernelDepthwiseConvInputGradNHWC<T, fuse_relu_before_conv>(
          input_data, output_grad_data, filter_data, batch_size,
          output_channels, output_height, output_width, input_channels,
          input_height, input_width, final_filter_multiplier, filter_height,
          filter_width, h_stride, w_stride, padding_height, padding_width,
          dilate_height, dilate_width, input_grad_data);
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvInputGradCFilterNCHW<T, c_filter, c_filter_multiplier,
                                              fuse_relu_before_conv>(
          input_data, output_grad_data, filter_data, batch_size,
          output_channels, output_height, output_width, input_channels,
          input_height, input_width, c_filter_multiplier, filter_height,
          filter_width, c_stride, c_stride, padding_height, padding_width,
          dilate_height, dilate_width, input_grad_data);
    } else {
      KernelDepthwiseConvInputGradCFilterNHWC<T, c_filter, c_filter_multiplier,
                                              fuse_relu_before_conv>(
          input_data, output_grad_data, filter_data, batch_size,
          output_channels, output_height, output_width, input_channels,
          input_height, input_width, c_filter_multiplier, filter_height,
          filter_width, c_stride, c_stride, padding_height, padding_width,
          dilate_height, dilate_width, input_grad_data);
    }
  }
646 647
}

648
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
649
template <typename T, bool fuse_relu_before_conv>
650
__device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
651 652 653 654 655 656
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
657
    const int dilate_width, T* filter_grad_data) {
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
  T s = 0;
  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
676 677 678 679 680 681 682 683 684
        int input_id = ((bid * (gridDim.z / filter_multiplier) +
                         kernel_id / filter_multiplier) *
                            input_height +
                        image_hk) *
                           input_width +
                       image_wk;
        if (fuse_relu_before_conv) {
          s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
               max(0.0f, input_data[input_id]);
685
        } else {
686 687 688 689 690 691 692
          s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
               input_data[input_id];
        }
#undef gaid
      }
    }
  }
W
wangguanzhong 已提交
693 694 695

  T val = BlockReduceSum(s);
  platform::CudaAtomicAdd(&filter_grad_data[gbid], val);
696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* filter_grad_data) {
  int bid = blockIdx.z;
  int image_h = blockIdx.y;
  int kernel_iw = blockIdx.x % filter_width;
  int kernel_ih = blockIdx.x / filter_width;
  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
    T s = 0;
    int gbid =
        ((kernel_id * filter_height) + kernel_ih) * filter_width + kernel_iw;
    for (int image_w = threadIdx.y; image_w < output_width;
         image_w += blockDim.y) {
      int kernel_h = kernel_ih * dilate_height - padding_height;
      int kernel_w = kernel_iw * dilate_width - padding_width;

      int image_hk = image_h * stride_height + kernel_h;
      int image_wk = image_w * stride_width + kernel_w;
      if (image_hk < 0 || image_hk >= input_height) continue;
      if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * output_channels + (C))
      int input_id =
          ((bid * input_height + image_hk) * input_width + image_wk) *
              input_channels +
          kernel_id / filter_multiplier;
      if (fuse_relu_before_conv) {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
             max(0.0f, input_data[input_id]);
      } else {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
             input_data[input_id];
      }
#undef gaid
    }
    platform::CudaAtomicAdd(&filter_grad_data[gbid], s);
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* filter_grad_data) {
  const int bid = blockIdx.z;
  int image_h = blockIdx.x * dilate_height + blockIdx.y;
  if (image_h >= output_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;

  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; ++i) {
      r_weight[i] = 0;
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int image_w = i_wi * dilate_width + i_dw;
      if (image_w >= output_width) {
        continue;
      }
      for (int kernel_ih = 0; kernel_ih < c_filter; ++kernel_ih) {
        for (int kernel_iw = 0; kernel_iw < c_filter; ++kernel_iw) {
          int kernel_h = kernel_ih * dilate_height - padding_height;
          int kernel_w = kernel_iw * dilate_width - padding_width;
          int image_hk = image_h * stride_height + kernel_h;
          int image_wk = image_w * stride_width + kernel_w;
          if (image_hk < 0 || image_hk >= input_height) continue;
          if (image_wk < 0 || image_wk >= input_width) continue;
          int input_id =
783
              ((bid * input_height + image_hk) * input_width + image_wk) *
784
                  input_channels +
785
              kernel_id / filter_multiplier;
786 787 788 789 790
          int output_id =
              ((bid * output_height + image_h) * output_width + image_w) *
                  output_channels +
              kernel_id;
          T s = 0;
791
          if (fuse_relu_before_conv) {
792
            s = output_grad_data[output_id] * max(0.0f, input_data[input_id]);
793
          } else {
794
            s = output_grad_data[output_id] * input_data[input_id];
795
          }
796
          r_weight[kernel_ih * c_filter + kernel_iw] += s;
797
        }
798
      }
Z
zlx 已提交
799
    }
800 801 802 803
    for (int i = 0; i < c_filter * c_filter; ++i) {
      T* weight = filter_grad_data + i * output_channels + kernel_id;
      platform::CudaAtomicAdd(&weight[0], r_weight[i]);
    }
Z
zlx 已提交
804
  }
805 806
}

807 808
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          DataLayout data_layout, bool fuse_relu_before_conv>
809 810 811 812 813 814 815
__global__ void KernelDepthwiseConvFilterGradSp(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
    const int dilate_width, T* filter_grad_data) {
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
          output_grad_data, input_data, num, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
          filter_grad_data);
    } else {
      KernelDepthwiseConvFilterGradNHWC<T, fuse_relu_before_conv>(
          output_grad_data, input_data, num, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
          filter_grad_data);
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
          output_grad_data, input_data, num, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
          filter_grad_data);
    } else {
      KernelDepthwiseConvFilterGradCFilterNHWC<T, c_filter,
                                               fuse_relu_before_conv>(
          output_grad_data, input_data, num, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
          filter_grad_data);
    }
  }
Z
zlx 已提交
859 860 861 862 863 864 865
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
866 867 868
template <class T, bool fuse_relu_before_conv>
class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
                           fuse_relu_before_conv> {
Z
zlx 已提交
869 870 871
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
X
xzl 已提交
872 873
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
874
                  const std::vector<int>& paddings,
875 876
                  const std::vector<int>& dilations, framework::Tensor* output,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
877
    const int batch_size = input.dims()[0];
878
    const int input_channels =
879
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
880
    const int input_height =
881
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
882
    const int input_width =
883
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
884
    const int output_channels =
885
        (data_layout != DataLayout::kNHWC ? output->dims()[1]
886 887
                                          : output->dims()[3]);
    const int output_height =
888
        (data_layout != DataLayout::kNHWC ? output->dims()[2]
889 890
                                          : output->dims()[1]);
    const int output_width =
891
        (data_layout != DataLayout::kNHWC ? output->dims()[3]
892
                                          : output->dims()[2]);
893 894
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
895 896 897 898
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
899 900
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
901 902 903 904 905

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

906 907 908 909 910 911 912 913 914 915 916 917
    framework::Tensor filter_hwc;
    if (data_layout == DataLayout::kNHWC) {
      framework::DDim filter_hwc_dims({filter.dims()[2], filter.dims()[3],
                                       filter.dims()[0], filter.dims()[1]});
      filter_hwc.Resize(filter_hwc_dims);
      filter_hwc.mutable_data<T>(context.GetPlace());
      std::vector<int> perm_axis({2, 3, 0, 1});
      math::TransposeNormal<platform::CUDADeviceContext, T> trans;
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

918
    int thread = 512;
919 920 921
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
922

923 924 925 926 927 928 929 930 931 932 933 934
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048)
        thread = (output_width - 1) / 2 + 1;
      else if (output_width > 512 && output_width <= 1024)
        thread = output_width;
#ifdef __HIPCC__
      thread = std::min(thread, 256);
#endif
      blocks = std::min(std::max(thread / output_width, 1), output_height);
      threads = dim3(std::min(output_width, thread), blocks, 1);
      grid = dim3(output_channels, batch_size, 1);
    } else {
935
#ifdef __HIPCC__
936
      thread = std::min(thread, 256);
937
#endif
938 939 940 941 942 943 944
      blocks = std::min(
          std::max(thread / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(output_channels, thread), blocks, 1);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
                  dilate_height, batch_size);
    }
945
    int filter_multiplier = output_channels / input_channels;
946 947
    int nums_output =
        batch_size * output_channels * output_height * output_width;
948 949 950
#ifdef __HIPCC__
    int block_size = 256;
#else
951
    int block_size = 512;
952
#endif
953
    int grid_size = (nums_output + block_size - 1) / block_size;
954

955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985
#define check_case(c_filter_multiplier, c_stride, c_filter)                    \
  if (c_filter_multiplier == 0 ||                                              \
      filter_multiplier == c_filter_multiplier &&                              \
          stride_height == stride_width && stride_height == c_stride &&        \
          (ksize_height == ksize_width && ksize_height == c_filter ||          \
           c_filter == -1)) {                                                  \
    if (c_filter == -1) {                                                      \
      threads.x = block_size;                                                  \
      grid.x = grid_size;                                                      \
      threads.y = threads.z = grid.y = grid.z = 1;                             \
    }                                                                          \
    if (data_layout != DataLayout::kNHWC) {                                    \
      KernelDepthwiseConvSp<                                                   \
          T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNCHW,       \
          fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
          input_data, filter_data, batch_size, output_channels, output_height, \
          output_width, input_channels, input_height, input_width,             \
          filter_multiplier, ksize_height, ksize_width, stride_height,         \
          stride_width, padding_height, padding_width, dilate_height,          \
          dilate_width, output_data);                                          \
    } else {                                                                   \
      KernelDepthwiseConvSp<                                                   \
          T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNHWC,       \
          fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
          input_data, filter_data, batch_size, output_channels, output_height, \
          output_width, input_channels, input_height, input_width,             \
          filter_multiplier, ksize_height, ksize_width, stride_height,         \
          stride_width, padding_height, padding_width, dilate_height,          \
          dilate_width, output_data);                                          \
    }                                                                          \
    return;                                                                    \
986
  }
987 988 989 990 991 992
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
993 994 995 996 997 998
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
999 1000 1001
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1002
#undef check_case
Z
zlx 已提交
1003 1004 1005
  }
};

1006 1007 1008
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
                                    fuse_relu_before_conv> {
Z
zlx 已提交
1009 1010 1011
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
1012 1013
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
1014 1015
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1016
                  const std::vector<int>& dilations,
1017 1018
                  framework::Tensor* input_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1019
    const int batch_size = input.dims()[0];
1020
    const int input_channels =
1021
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1022
    const int input_height =
1023
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1024
    const int input_width =
1025
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1026
    const int output_channels =
1027
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1028 1029
                                          : output_grad.dims()[3]);
    const int output_height =
1030
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1031 1032
                                          : output_grad.dims()[1]);
    const int output_width =
1033
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1034
                                          : output_grad.dims()[2]);
1035 1036 1037
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
1038 1039 1040
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1041 1042
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1043

1044
    const T* input_data = input.data<T>();
1045
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
1046 1047 1048
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060
    framework::Tensor filter_hwc;
    if (data_layout == DataLayout::kNHWC) {
      framework::DDim filter_hwc_dims({filter.dims()[2], filter.dims()[3],
                                       filter.dims()[0], filter.dims()[1]});
      filter_hwc.Resize(filter_hwc_dims);
      filter_hwc.mutable_data<T>(context.GetPlace());
      std::vector<int> perm_axis({2, 3, 0, 1});
      math::TransposeNormal<platform::CUDADeviceContext, T> trans;
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1061
    int thread = 512;
1062 1063 1064
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1065

1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082
    if (data_layout != DataLayout::kNHWC) {
      if (input_width > 1024 && input_width <= 2048) {
        thread = (input_width - 1) / 2 + 1;
      } else if (input_width > 512 && input_width <= 1024) {
        thread = input_width;
      }
      blocks = std::min(std::max(thread / input_width, 1), input_height);
      threads = dim3(std::min(input_width, thread), blocks, 1);
      grid = dim3(input_channels, batch_size, 1);
    } else {
      blocks = std::min(
          std::max(thread / input_channels, 1),
          ((input_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(input_channels, thread), blocks, 1);
      grid = dim3((input_height + dilate_height - 1) / dilate_height,
                  dilate_height, batch_size);
    }
1083 1084
    int filter_multiplier = output_channels / input_channels;

1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110
#define check_case(c_filter_multiplier, c_stride, c_filter)               \
  if (c_filter_multiplier == 0 ||                                         \
      filter_multiplier == c_filter_multiplier &&                         \
          stride_height == stride_width && stride_height == c_stride &&   \
          (ksize_height == ksize_width && ksize_height == c_filter ||     \
           c_filter == -1)) {                                             \
    if (data_layout != DataLayout::kNHWC) {                               \
      KernelDepthwiseConvInputGradSp<                                     \
          T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNCHW,  \
          fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
          input_data, output_grad_data, filter_data, batch_size,          \
          output_channels, output_height, output_width, input_channels,   \
          input_height, input_width, filter_multiplier, ksize_height,     \
          ksize_width, stride_height, stride_width, padding_height,       \
          padding_width, dilate_height, dilate_width, input_grad_data);   \
    } else {                                                              \
      KernelDepthwiseConvInputGradSp<                                     \
          T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNHWC,  \
          fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
          input_data, output_grad_data, filter_data, batch_size,          \
          output_channels, output_height, output_width, input_channels,   \
          input_height, input_width, filter_multiplier, ksize_height,     \
          ksize_width, stride_height, stride_width, padding_height,       \
          padding_width, dilate_height, dilate_width, input_grad_data);   \
    }                                                                     \
    return;                                                               \
1111
  }
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1127
#undef check_case
Z
zlx 已提交
1128 1129 1130
  }
};

1131 1132 1133
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
                                     fuse_relu_before_conv> {
Z
zlx 已提交
1134 1135 1136
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
1137
                  const framework::Tensor& output_grad,
X
xzl 已提交
1138 1139
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1140
                  const std::vector<int>& dilations,
1141 1142
                  framework::Tensor* filter_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1143
    const int batch_size = input.dims()[0];
1144
    const int input_channels =
1145
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1146
    const int input_height =
1147
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1148
    const int input_width =
1149
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1150
    const int output_channels =
1151
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1152 1153
                                          : output_grad.dims()[3]);
    const int output_height =
1154
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1155 1156
                                          : output_grad.dims()[1]);
    const int output_width =
1157
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1158
                                          : output_grad.dims()[2]);
1159 1160
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
1161 1162 1163 1164
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1165 1166
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1167 1168 1169

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
1170
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
1171

1172
    int block_size = 512;
1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192
    int blocks;
    dim3 threads;
    dim3 grid;
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048) {
        block_size = (output_width - 1) / 2 + 1;
      } else if (output_width > 512 && output_width <= 1024) {
        block_size = output_width;
      }
      blocks = std::min(std::max(block_size / output_width, 1), output_height);
      grid = dim3(ksize_width, ksize_height, output_channels);
      threads = dim3(std::min(output_width, block_size), blocks, 1);
    } else {
      blocks = std::min(
          std::max(block_size / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
                  dilate_height, batch_size);
      threads = dim3(std::min(output_channels, block_size), blocks, 1);
    }
1193 1194
    int filter_multiplier = output_channels / input_channels;

1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247
#define check_case(c_filter_multiplier, c_stride, c_filter)                    \
  if (c_filter_multiplier == 0 ||                                              \
      filter_multiplier == c_filter_multiplier &&                              \
          stride_height == stride_width && stride_height == c_stride &&        \
          (ksize_height == ksize_width && ksize_height == c_filter ||          \
           c_filter == -1)) {                                                  \
    if (data_layout != DataLayout::kNHWC) {                                    \
      KernelDepthwiseConvFilterGradSp<                                         \
          T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNCHW,       \
          fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
          output_grad_data, input_data, batch_size, output_channels,           \
          output_height, output_width, input_channels, input_height,           \
          input_width, filter_multiplier, ksize_height, ksize_width,           \
          stride_height, stride_width, padding_height, padding_width,          \
          dilate_height, dilate_width, filter_grad_data);                      \
    } else {                                                                   \
      framework::Tensor filter_grad_hwc;                                       \
      if (c_filter != -1) {                                                    \
        framework::DDim filter_grad_hwc_dims(                                  \
            {filter_grad->dims()[2], filter_grad->dims()[3],                   \
             filter_grad->dims()[0], filter_grad->dims()[1]});                 \
        filter_grad_hwc.Resize(filter_grad_hwc_dims);                          \
        filter_grad_hwc.mutable_data<T>(context.GetPlace());                   \
        math::SetConstant<platform::CUDADeviceContext, T> set_zero;            \
        set_zero(context, &filter_grad_hwc, static_cast<T>(0));                \
        filter_grad_data = filter_grad_hwc.data<T>();                          \
      } else {                                                                 \
        block_size = 512;                                                      \
        if (output_channels > 1024 && output_channels <= 2048) {               \
          block_size = (output_channels - 1) / 2 + 1;                          \
        } else if (output_channels > 512 && output_channels <= 1024) {         \
          block_size = output_channels;                                        \
        }                                                                      \
        blocks =                                                               \
            std::min(std::max(block_size / output_channels, 1), output_width); \
        grid = dim3(ksize_width * ksize_height, output_height, batch_size);    \
        threads = dim3(std::min(output_channels, block_size), blocks, 1);      \
      }                                                                        \
      KernelDepthwiseConvFilterGradSp<                                         \
          T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNHWC,       \
          fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
          output_grad_data, input_data, batch_size, output_channels,           \
          output_height, output_width, input_channels, input_height,           \
          input_width, filter_multiplier, ksize_height, ksize_width,           \
          stride_height, stride_width, padding_height, padding_width,          \
          dilate_height, dilate_width, filter_grad_data);                      \
      if (c_filter != -1) {                                                    \
        std::vector<int> perm_axis({2, 3, 0, 1});                              \
        math::TransposeNormal<platform::CUDADeviceContext, T> trans;           \
        trans(context, filter_grad_hwc, filter_grad, perm_axis);               \
      }                                                                        \
    }                                                                          \
    return;                                                                    \
1248
  }
1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
1262
#undef check_case
Z
zlx 已提交
1263 1264 1265
  }
};

1266 1267
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, false>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, false>;
Z
zlx 已提交
1268

1269 1270
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             false>;
Z
zlx 已提交
1271
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283
                                             double, false>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float, false>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              double, false>;

template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, true>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, true>;

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             true>;
Z
zlx 已提交
1284
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
1285
                                             double, true>;
1286 1287

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
1288
                                              float, true>;
Z
zlx 已提交
1289
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
1290
                                              double, true>;
Z
zlx 已提交
1291 1292 1293 1294

}  // namespace math
}  // namespace operators
}  // namespace paddle