depthwise_conv.cu 36.7 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
A
Abhinav Arora 已提交
16
#include <vector>
17 18 19 20 21 22 23
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
Y
Yi Wang 已提交
24
#include "paddle/fluid/operators/math/depthwise_conv.h"
25
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
26
#include "paddle/fluid/platform/cuda_primitives.h"
Z
zlx 已提交
27 28 29 30 31

namespace paddle {
namespace operators {
namespace math {

32
template <typename T>
33 34 35
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
36 37 38 39 40

#ifdef __HIPCC__
  int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
  value = WarpReduce(temp_storage).Sum(value, block_size);
#else
41
  value = WarpReduce(temp_storage).Sum(value);
42 43
#endif

44
  if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
45 46
}

47 48 49 50 51 52 53 54
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
55 56
      const int dilate_height, const int dilate_width, T *const output_data,   \
      const DataLayout data_layout = DataLayout::kNCHW
57

58 59
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
60
template <typename T, bool fuse_relu_before_conv>
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
__device__ __inline__ void KernelDepthwiseConvNCHW(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int w_out = idx % output_width;
  const int h_out = (idx / output_width) % output_height;
  const int c_out = (idx / output_width / output_height) % output_channels;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  const T* weight = filter_data + c_out * filter_height * filter_width;
  T value = 0;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  int in_offset =
      ((batch * input_channels + c_in) * input_height) * input_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = in_offset + h_in * input_width + w_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
          value += weight[weight_offset] * max(0.0f, in_data);
        } else {
          value += weight[weight_offset] * in_data;
        }
101
      }
102 103 104 105 106 107 108 109
      weight_offset++;
    }
  }
  int index = batch * output_channels * output_height * output_width +
              c_out * output_height * output_width + h_out * output_width +
              w_out;
  output_data[index] = value;
}
110

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int c_out = idx % output_channels;
  const int w_out = (idx / output_channels) % output_width;
  const int h_out = (idx / output_channels / output_width) % output_height;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  const T* weight = filter_data + c_out * filter_height * filter_width;
  T value = 0;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = ((batch * input_height + h_in) * input_width + w_in) *
                         output_channels +
                     c_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
          value += weight[weight_offset] * max(0.0f, in_data);
        } else {
          value += weight[weight_offset] * in_data;
152
        }
Z
zlx 已提交
153
      }
154
      weight_offset++;
Z
zlx 已提交
155 156
    }
  }
157 158 159 160
  int index = batch * output_channels * output_height * output_width +
              h_out * output_width * output_channels + w_out * output_channels +
              c_out;
  output_data[index] = value;
Z
zlx 已提交
161
}
162

163
template <typename T, int c_filter, bool fuse_relu_before_conv>
164 165 166 167 168 169 170 171
__device__ __inline__ void KernelDepthwiseConvCFilter(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int kWeghtSize = c_filter * c_filter;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
172

173 174 175 176 177 178 179 180 181 182 183 184
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

185
      int in_offset;
186
      if (data_layout != DataLayout::kNHWC) {
187 188 189 190 191
        in_offset =
            ((batch * input_channels + c_in) * input_height) * input_width;
      } else {
        in_offset = batch * input_height * input_width * input_channels;
      }
192 193 194 195 196 197 198 199 200 201 202 203

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
204
            int offset;
205
            if (data_layout != DataLayout::kNHWC) {
206 207 208 209 210
              offset = in_offset + h_in * input_width + w_in;
            } else {
              offset = in_offset +
                       (h_in * input_width + w_in) * input_channels + c_in;
            }
211 212 213 214 215 216
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
                       max(0.0f, input_data[offset]);
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
217 218 219
          }
        }
      }
220
      int index;
221
      if (data_layout != DataLayout::kNHWC) {
222 223 224 225 226 227 228 229
        index = ((batch * gridDim.x + c_out) * output_height + h_out) *
                    output_width +
                w_out;
      } else {
        index = ((batch * output_height + h_out) * output_width + w_out) *
                    gridDim.x +
                c_out;
      }
230 231 232 233 234
      output_data[index] = value;
    }
  }
}

235 236
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
237
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
238 239 240 241 242 243 244 245 246 247 248
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter == -1) {
    if (data_layout == DataLayout::kNCHW) {
      KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(
249 250
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
251 252
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
253
          output_data, data_layout);
254 255
    } else {
      KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(
256 257
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
258 259
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
260
          output_data, data_layout);
261 262 263 264 265 266 267 268
    }
  } else {
    KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
        input_data, filter_data, batch_size, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        final_filter_multiplier, filter_height, filter_width, h_stride,
        w_stride, padding_height, padding_width, dilate_height, dilate_width,
        output_data, data_layout);
269
  }
270 271
}

Z
zlx 已提交
272
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
273
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
274 275 276 277 278
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
279 280 281 282
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
283 284
      T *const input_grad_data,                                                \
      const DataLayout data_layout = DataLayout::kNCHW
285

286
template <typename T, bool fuse_relu_before_conv>
287
__device__ __inline__ void KernelDepthwiseConvInputGrad(
288
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      const int c_out_start = c_in * filter_multiplier;

      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;

      int h_out_end = h_in + padding_height;

      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

      int w_out_end = w_in + padding_width;

      T value = 0;
307
      int index;
308
      if (data_layout != DataLayout::kNHWC) {
309 310 311 312 313 314 315 316 317
        index =
            ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
            w_in;
      } else {
        index =
            ((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
            c_in;
      }

318 319 320 321 322 323
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
324 325 326 327 328 329 330 331 332 333 334 335 336 337

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
338
              int output_grad_offset;
339
              if (data_layout != DataLayout::kNHWC) {
340 341 342 343 344 345 346 347 348 349 350 351
                output_grad_offset =
                    ((batch * output_channels + c_out) * output_height +
                     s_h_out) *
                        output_width +
                    s_w_out;
              } else {
                output_grad_offset =
                    ((batch * output_height + s_h_out) * output_width +
                     s_w_out) *
                        output_channels +
                    c_out;
              }
352 353 354 355
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
356 357
        }
      }
358
      input_grad_data[index] = value;
Z
zlx 已提交
359 360 361 362
    }
  }
}

363 364
template <typename T, int c_filter, int c_filter_multiplier,
          bool fuse_relu_before_conv>
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
__device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;

      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

      T value = 0;
390
      int index;
391
      if (data_layout != DataLayout::kNHWC) {
392 393 394 395 396 397 398 399
        index =
            ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
            w_in;
      } else {
        index =
            ((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
            c_in;
      }
400 401 402 403 404 405
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
406 407 408 409 410 411 412 413 414 415 416 417

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
418
              int output_grad_offset;
419
              if (data_layout != DataLayout::kNHWC) {
420 421 422 423 424 425 426 427 428 429 430 431
                output_grad_offset =
                    ((batch * output_channels + c_out) * output_height +
                     s_h_out) *
                        output_width +
                    s_w_out;
              } else {
                output_grad_offset =
                    ((batch * output_height + s_h_out) * output_width +
                     s_w_out) *
                        output_channels +
                    c_out;
              }
432 433 434 435 436 437 438 439 440 441 442 443
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

444 445
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
446
__global__ void KernelDepthwiseConvInputGradSp(
447
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
448
  if (c_filter_multiplier == 0)
449 450
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
451 452 453
        output_height, output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
454
        dilate_width, input_grad_data, data_layout);
455
  else if (c_filter == -1)
456 457
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
458 459 460
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
461
        input_grad_data, data_layout);
462
  else
463 464 465
    KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier,
                                        fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
466 467 468
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
469
        input_grad_data, data_layout);
470 471
}

472
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
473
template <typename T, bool fuse_relu_before_conv>
474 475 476 477 478 479 480
__device__ __inline__ void KernelDepthwiseConvFilterGrad(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
481 482
    const int dilate_width, T* filter_grad_data,
    const DataLayout data_layout = DataLayout::kNCHW) {
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
  T s = 0;

  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
502 503 504
#define gaid_nhwc(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C))
        int input_id;
505
        if (data_layout != DataLayout::kNHWC) {
506 507 508 509 510 511 512 513 514 515 516 517 518
          input_id = ((bid * (gridDim.z / filter_multiplier) +
                       kernel_id / filter_multiplier) *
                          input_height +
                      image_hk) *
                         input_width +
                     image_wk;
          if (fuse_relu_before_conv) {
            s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
                 max(0.0f, input_data[input_id]);
          } else {
            s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
                 input_data[input_id];
          }
519
        } else {
520 521 522 523 524 525 526 527 528 529 530
          input_id =
              ((bid * input_height + image_hk) * input_width + image_wk) *
                  (gridDim.z / filter_multiplier) +
              kernel_id / filter_multiplier;
          if (fuse_relu_before_conv) {
            s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
                 max(0.0f, input_data[input_id]);
          } else {
            s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
                 input_data[input_id];
          }
531
        }
532 533

#undef gaid
534
      }
Z
zlx 已提交
535 536
    }
  }
537
  CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
538 539
}

540
template <typename T, int c_filter_multiplier, bool fuse_relu_before_conv>
541 542 543 544 545 546 547
__global__ void KernelDepthwiseConvFilterGradSp(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
548 549
    const int dilate_width, T* filter_grad_data,
    const DataLayout data_layout = DataLayout::kNCHW) {
550
  if (c_filter_multiplier == 0)
551
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
552 553 554 555
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
556
        dilate_width, filter_grad_data, data_layout);
557
  else
558
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
559 560 561 562
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
563
        dilate_width, filter_grad_data, data_layout);
Z
zlx 已提交
564 565 566 567 568 569 570
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
571 572 573
template <class T, bool fuse_relu_before_conv>
class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
                           fuse_relu_before_conv> {
Z
zlx 已提交
574 575 576
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
X
xzl 已提交
577 578
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
579
                  const std::vector<int>& paddings,
580 581
                  const std::vector<int>& dilations, framework::Tensor* output,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
582
    const int batch_size = input.dims()[0];
583
    const int input_channels =
584
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
585
    const int input_height =
586
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
587
    const int input_width =
588
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
589
    const int output_channels =
590
        (data_layout != DataLayout::kNHWC ? output->dims()[1]
591 592
                                          : output->dims()[3]);
    const int output_height =
593
        (data_layout != DataLayout::kNHWC ? output->dims()[2]
594 595
                                          : output->dims()[1]);
    const int output_width =
596
        (data_layout != DataLayout::kNHWC ? output->dims()[3]
597
                                          : output->dims()[2]);
598 599
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
600 601 602 603
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
604 605
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
606 607 608 609 610

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

611
    int thread = 512;
612 613 614 615
    if (output_width > 1024 && output_width <= 2048)
      thread = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      thread = output_width;
616 617 618 619
    int blocks = std::min(std::max(thread / output_width, 1), output_height);
    dim3 threads(std::min(output_width, thread), blocks, 1);
    dim3 grid(output_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;
620 621 622 623 624

    int nums_output =
        batch_size * output_channels * output_height * output_width;
    int block_size = 512;

625
#define check_case(c_filter_multiplier, c_stride, c_filter)                  \
626 627
  if (c_filter_multiplier == 0 ||                                            \
      filter_multiplier == c_filter_multiplier &&                            \
628 629 630
          stride_height == stride_width && stride_height == c_stride &&      \
          (ksize_height == ksize_width && ksize_height == c_filter ||        \
           c_filter == -1)) {                                                \
631 632 633 634 635
    if (c_filter == -1) {                                                    \
      threads.x = block_size;                                                \
      grid.x = (nums_output + block_size - 1) / block_size;                  \
      threads.y = threads.z = grid.y = grid.z = 1;                           \
    }                                                                        \
636 637 638
    KernelDepthwiseConvSp<                                                   \
        T, c_filter_multiplier, c_stride, c_filter,                          \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
639 640 641 642
        input_data, filter_data, batch_size, output_channels, output_height, \
        output_width, input_channels, input_height, input_width,             \
        filter_multiplier, ksize_height, ksize_width, stride_height,         \
        stride_width, padding_height, padding_width, dilate_height,          \
643
        dilate_width, output_data, data_layout);                             \
644 645
    return;                                                                  \
  }
646 647 648 649 650 651
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
652 653 654 655 656 657
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
658 659 660
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
661
#undef check_case
Z
zlx 已提交
662 663 664
  }
};

665 666 667
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
                                    fuse_relu_before_conv> {
Z
zlx 已提交
668 669 670
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
671 672
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
673 674
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
675
                  const std::vector<int>& dilations,
676 677
                  framework::Tensor* input_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
678
    const int batch_size = input.dims()[0];
679
    const int input_channels =
680
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
681
    const int input_height =
682
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
683
    const int input_width =
684
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
685
    const int output_channels =
686
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
687 688
                                          : output_grad.dims()[3]);
    const int output_height =
689
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
690 691
                                          : output_grad.dims()[1]);
    const int output_width =
692
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
693
                                          : output_grad.dims()[2]);
694 695 696
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
697 698 699
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
700 701
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
702

703
    const T* input_data = input.data<T>();
704
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
705 706 707
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

708
    int thread = 512;
709 710 711 712
    if (input_width > 1024 && input_width <= 2048)
      thread = (input_width - 1) / 2 + 1;
    else if (input_width > 512 && input_width <= 1024)
      thread = input_width;
713 714 715 716 717
    int blocks = std::min(std::max(thread / input_width, 1), input_height);
    dim3 threads(std::min(input_width, thread), blocks, 1);
    dim3 grid(input_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;

718
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
719 720
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
721 722 723
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
724
    KernelDepthwiseConvInputGradSp<                                     \
725 726 727 728 729 730
        T, c_filter_multiplier, c_stride, c_filter,                     \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
        input_data, output_grad_data, filter_data, batch_size,          \
        output_channels, output_height, output_width, input_channels,   \
        input_height, input_width, filter_multiplier, ksize_height,     \
        ksize_width, stride_height, stride_width, padding_height,       \
731 732
        padding_width, dilate_height, dilate_width, input_grad_data,    \
        data_layout);                                                   \
733 734
    return;                                                             \
  }
735 736 737 738 739 740 741 742 743 744 745 746 747 748 749
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
750
#undef check_case
Z
zlx 已提交
751 752 753
  }
};

754 755 756
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
                                     fuse_relu_before_conv> {
Z
zlx 已提交
757 758 759
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
760
                  const framework::Tensor& output_grad,
X
xzl 已提交
761 762
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
763
                  const std::vector<int>& dilations,
764 765
                  framework::Tensor* filter_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
766
    const int batch_size = input.dims()[0];
767
    const int input_channels =
768
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
769
    const int input_height =
770
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
771
    const int input_width =
772
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
773
    const int output_channels =
774
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
775 776
                                          : output_grad.dims()[3]);
    const int output_height =
777
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
778 779
                                          : output_grad.dims()[1]);
    const int output_width =
780
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
781
                                          : output_grad.dims()[2]);
782 783
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
784 785 786 787
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
788 789
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
790 791 792

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
793
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
794

795
    int block_size = 512;
796 797 798 799
    if (output_width > 1024 && output_width <= 2048)
      block_size = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      block_size = output_width;
800 801 802 803 804 805 806 807 808
    int crop_output_height =
        std::min(std::max(block_size / output_width, 1), output_height);
    dim3 grid(ksize_width, ksize_height, output_channels);
    dim3 threads(std::min(output_width, block_size), crop_output_height, 1);
    int filter_multiplier = output_channels / input_channels;

#define check_case(c_filter_multiplier)                                       \
  if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \
    KernelDepthwiseConvFilterGradSp<                                          \
809 810
        T, c_filter_multiplier,                                               \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(       \
811 812 813 814
        output_grad_data, input_data, batch_size, output_channels,            \
        output_height, output_width, input_channels, input_height,            \
        input_width, filter_multiplier, ksize_height, ksize_width,            \
        stride_height, stride_width, padding_height, padding_width,           \
815
        dilate_height, dilate_width, filter_grad_data, data_layout);          \
816 817 818 819 820
    return;                                                                   \
  }
    check_case(1);
    check_case(0);
#undef check_case
Z
zlx 已提交
821 822 823
  }
};

824 825
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, false>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, false>;
Z
zlx 已提交
826

827 828
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             false>;
Z
zlx 已提交
829
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
830 831 832 833 834 835 836 837 838 839 840 841
                                             double, false>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float, false>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              double, false>;

template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, true>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, true>;

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             true>;
Z
zlx 已提交
842
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
843
                                             double, true>;
844 845

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
846
                                              float, true>;
Z
zlx 已提交
847
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
848
                                              double, true>;
Z
zlx 已提交
849 850 851 852

}  // namespace math
}  // namespace operators
}  // namespace paddle