depthwise_conv.cu 36.9 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
A
Abhinav Arora 已提交
16
#include <vector>
17 18 19 20 21 22 23
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
Y
Yi Wang 已提交
24
#include "paddle/fluid/operators/math/depthwise_conv.h"
25
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
26
#include "paddle/fluid/platform/cuda_primitives.h"
Z
zlx 已提交
27 28 29 30 31

namespace paddle {
namespace operators {
namespace math {

32
template <typename T>
33 34 35
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
36 37 38 39 40

#ifdef __HIPCC__
  int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
  value = WarpReduce(temp_storage).Sum(value, block_size);
#else
41
  value = WarpReduce(temp_storage).Sum(value);
42 43
#endif

44
  if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
45 46
}

47 48 49 50 51 52 53 54
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
55 56
      const int dilate_height, const int dilate_width, T *const output_data,   \
      const DataLayout data_layout = DataLayout::kNCHW
57

58 59
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
60
template <typename T, bool fuse_relu_before_conv>
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
__device__ __inline__ void KernelDepthwiseConvNCHW(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int w_out = idx % output_width;
  const int h_out = (idx / output_width) % output_height;
  const int c_out = (idx / output_width / output_height) % output_channels;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  const T* weight = filter_data + c_out * filter_height * filter_width;
  T value = 0;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  int in_offset =
      ((batch * input_channels + c_in) * input_height) * input_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = in_offset + h_in * input_width + w_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
          value += weight[weight_offset] * max(0.0f, in_data);
        } else {
          value += weight[weight_offset] * in_data;
        }
101
      }
102 103 104 105 106 107 108 109
      weight_offset++;
    }
  }
  int index = batch * output_channels * output_height * output_width +
              c_out * output_height * output_width + h_out * output_width +
              w_out;
  output_data[index] = value;
}
110

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int c_out = idx % output_channels;
  const int w_out = (idx / output_channels) % output_width;
  const int h_out = (idx / output_channels / output_width) % output_height;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  const T* weight = filter_data + c_out * filter_height * filter_width;
  T value = 0;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = ((batch * input_height + h_in) * input_width + w_in) *
                         output_channels +
                     c_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
          value += weight[weight_offset] * max(0.0f, in_data);
        } else {
          value += weight[weight_offset] * in_data;
152
        }
Z
zlx 已提交
153
      }
154
      weight_offset++;
Z
zlx 已提交
155 156
    }
  }
157 158 159 160
  int index = batch * output_channels * output_height * output_width +
              h_out * output_width * output_channels + w_out * output_channels +
              c_out;
  output_data[index] = value;
Z
zlx 已提交
161
}
162

163
template <typename T, int c_filter, bool fuse_relu_before_conv>
164 165 166 167 168 169 170 171
__device__ __inline__ void KernelDepthwiseConvCFilter(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int kWeghtSize = c_filter * c_filter;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
172

173 174 175 176 177 178 179 180 181 182 183 184
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

185
      int in_offset;
186
      if (data_layout != DataLayout::kNHWC) {
187 188 189 190 191
        in_offset =
            ((batch * input_channels + c_in) * input_height) * input_width;
      } else {
        in_offset = batch * input_height * input_width * input_channels;
      }
192 193 194 195 196 197 198 199 200 201 202 203

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
204
            int offset;
205
            if (data_layout != DataLayout::kNHWC) {
206 207 208 209 210
              offset = in_offset + h_in * input_width + w_in;
            } else {
              offset = in_offset +
                       (h_in * input_width + w_in) * input_channels + c_in;
            }
211 212 213 214 215 216
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
                       max(0.0f, input_data[offset]);
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
217 218 219
          }
        }
      }
220
      int index;
221
      if (data_layout != DataLayout::kNHWC) {
222 223 224 225 226 227 228 229
        index = ((batch * gridDim.x + c_out) * output_height + h_out) *
                    output_width +
                w_out;
      } else {
        index = ((batch * output_height + h_out) * output_width + w_out) *
                    gridDim.x +
                c_out;
      }
230 231 232 233 234
      output_data[index] = value;
    }
  }
}

235 236
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
237
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
238 239 240 241 242 243 244 245 246 247 248
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter == -1) {
    if (data_layout == DataLayout::kNCHW) {
      KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(
249 250
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
251 252
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
253
          output_data, data_layout);
254 255
    } else {
      KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(
256 257
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
258 259
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
260
          output_data, data_layout);
261 262 263 264 265 266 267 268
    }
  } else {
    KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
        input_data, filter_data, batch_size, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        final_filter_multiplier, filter_height, filter_width, h_stride,
        w_stride, padding_height, padding_width, dilate_height, dilate_width,
        output_data, data_layout);
269
  }
270 271
}

Z
zlx 已提交
272
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
273
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
274 275 276 277 278
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
279 280 281 282
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
283 284
      T *const input_grad_data,                                                \
      const DataLayout data_layout = DataLayout::kNCHW
285

286
template <typename T, bool fuse_relu_before_conv>
287
__device__ __inline__ void KernelDepthwiseConvInputGrad(
288
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      const int c_out_start = c_in * filter_multiplier;

      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;

      int h_out_end = h_in + padding_height;

      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

      int w_out_end = w_in + padding_width;

      T value = 0;
307
      int index;
308
      if (data_layout != DataLayout::kNHWC) {
309 310 311 312 313 314 315 316 317
        index =
            ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
            w_in;
      } else {
        index =
            ((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
            c_in;
      }

318 319 320 321 322 323
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
324 325 326 327 328 329 330 331 332 333 334 335 336 337

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
338
              int output_grad_offset;
339
              if (data_layout != DataLayout::kNHWC) {
340 341 342 343 344 345 346 347 348 349 350 351
                output_grad_offset =
                    ((batch * output_channels + c_out) * output_height +
                     s_h_out) *
                        output_width +
                    s_w_out;
              } else {
                output_grad_offset =
                    ((batch * output_height + s_h_out) * output_width +
                     s_w_out) *
                        output_channels +
                    c_out;
              }
352 353 354 355
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
356 357
        }
      }
358
      input_grad_data[index] = value;
Z
zlx 已提交
359 360 361 362
    }
  }
}

363 364
template <typename T, int c_filter, int c_filter_multiplier,
          bool fuse_relu_before_conv>
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
__device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;

      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

      T value = 0;
390
      int index;
391
      if (data_layout != DataLayout::kNHWC) {
392 393 394 395 396 397 398 399
        index =
            ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
            w_in;
      } else {
        index =
            ((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
            c_in;
      }
400 401 402 403 404 405
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
406 407 408 409 410 411 412 413 414 415 416 417

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
418
              int output_grad_offset;
419
              if (data_layout != DataLayout::kNHWC) {
420 421 422 423 424 425 426 427 428 429 430 431
                output_grad_offset =
                    ((batch * output_channels + c_out) * output_height +
                     s_h_out) *
                        output_width +
                    s_w_out;
              } else {
                output_grad_offset =
                    ((batch * output_height + s_h_out) * output_width +
                     s_w_out) *
                        output_channels +
                    c_out;
              }
432 433 434 435 436 437 438 439 440 441 442 443
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

444 445
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
446
__global__ void KernelDepthwiseConvInputGradSp(
447
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
448
  if (c_filter_multiplier == 0)
449 450
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
451 452 453
        output_height, output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
454
        dilate_width, input_grad_data, data_layout);
455
  else if (c_filter == -1)
456 457
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
458 459 460
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
461
        input_grad_data, data_layout);
462
  else
463 464 465
    KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier,
                                        fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
466 467 468
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
469
        input_grad_data, data_layout);
470 471
}

472
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
473
template <typename T, bool fuse_relu_before_conv>
474 475 476 477 478 479 480
__device__ __inline__ void KernelDepthwiseConvFilterGrad(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
481 482
    const int dilate_width, T* filter_grad_data,
    const DataLayout data_layout = DataLayout::kNCHW) {
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
  T s = 0;

  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
502 503 504
#define gaid_nhwc(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C))
        int input_id;
505
        if (data_layout != DataLayout::kNHWC) {
506 507 508 509 510 511 512 513 514 515 516 517 518
          input_id = ((bid * (gridDim.z / filter_multiplier) +
                       kernel_id / filter_multiplier) *
                          input_height +
                      image_hk) *
                         input_width +
                     image_wk;
          if (fuse_relu_before_conv) {
            s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
                 max(0.0f, input_data[input_id]);
          } else {
            s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
                 input_data[input_id];
          }
519
        } else {
520 521 522 523 524 525 526 527 528 529 530
          input_id =
              ((bid * input_height + image_hk) * input_width + image_wk) *
                  (gridDim.z / filter_multiplier) +
              kernel_id / filter_multiplier;
          if (fuse_relu_before_conv) {
            s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
                 max(0.0f, input_data[input_id]);
          } else {
            s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
                 input_data[input_id];
          }
531
        }
532 533

#undef gaid
534
      }
Z
zlx 已提交
535 536
    }
  }
537
  CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
538 539
}

540
template <typename T, int c_filter_multiplier, bool fuse_relu_before_conv>
541 542 543 544 545 546 547
__global__ void KernelDepthwiseConvFilterGradSp(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
548 549
    const int dilate_width, T* filter_grad_data,
    const DataLayout data_layout = DataLayout::kNCHW) {
550
  if (c_filter_multiplier == 0)
551
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
552 553 554 555
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
556
        dilate_width, filter_grad_data, data_layout);
557
  else
558
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
559 560 561 562
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
563
        dilate_width, filter_grad_data, data_layout);
Z
zlx 已提交
564 565 566 567 568 569 570
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
571 572 573
template <class T, bool fuse_relu_before_conv>
class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
                           fuse_relu_before_conv> {
Z
zlx 已提交
574 575 576
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
X
xzl 已提交
577 578
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
579
                  const std::vector<int>& paddings,
580 581
                  const std::vector<int>& dilations, framework::Tensor* output,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
582
    const int batch_size = input.dims()[0];
583
    const int input_channels =
584
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
585
    const int input_height =
586
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
587
    const int input_width =
588
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
589
    const int output_channels =
590
        (data_layout != DataLayout::kNHWC ? output->dims()[1]
591 592
                                          : output->dims()[3]);
    const int output_height =
593
        (data_layout != DataLayout::kNHWC ? output->dims()[2]
594 595
                                          : output->dims()[1]);
    const int output_width =
596
        (data_layout != DataLayout::kNHWC ? output->dims()[3]
597
                                          : output->dims()[2]);
598 599
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
600 601 602 603
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
604 605
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
606 607 608 609 610

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

611
    int thread = 512;
612 613 614 615
    if (output_width > 1024 && output_width <= 2048)
      thread = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      thread = output_width;
616 617 618
#ifdef __HIPCC__
    thread = std::min(thread, 256);
#endif
619 620 621 622
    int blocks = std::min(std::max(thread / output_width, 1), output_height);
    dim3 threads(std::min(output_width, thread), blocks, 1);
    dim3 grid(output_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;
623 624 625

    int nums_output =
        batch_size * output_channels * output_height * output_width;
626 627 628 629
#ifdef __HIPCC__
    int block_size = 256;
    int grid_size = std::min((nums_output + block_size - 1) / block_size, 256);
#else
630
    int block_size = 512;
631 632
    int grid_size = (nums_output + block_size - 1) / block_size;
#endif
633

634
#define check_case(c_filter_multiplier, c_stride, c_filter)                  \
635 636
  if (c_filter_multiplier == 0 ||                                            \
      filter_multiplier == c_filter_multiplier &&                            \
637 638 639
          stride_height == stride_width && stride_height == c_stride &&      \
          (ksize_height == ksize_width && ksize_height == c_filter ||        \
           c_filter == -1)) {                                                \
640 641
    if (c_filter == -1) {                                                    \
      threads.x = block_size;                                                \
642
      grid.x = grid_size;                                                    \
643 644
      threads.y = threads.z = grid.y = grid.z = 1;                           \
    }                                                                        \
645 646 647
    KernelDepthwiseConvSp<                                                   \
        T, c_filter_multiplier, c_stride, c_filter,                          \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
648 649 650 651
        input_data, filter_data, batch_size, output_channels, output_height, \
        output_width, input_channels, input_height, input_width,             \
        filter_multiplier, ksize_height, ksize_width, stride_height,         \
        stride_width, padding_height, padding_width, dilate_height,          \
652
        dilate_width, output_data, data_layout);                             \
653 654
    return;                                                                  \
  }
655 656 657 658 659 660
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
661 662 663 664 665 666
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
667 668 669
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
670
#undef check_case
Z
zlx 已提交
671 672 673
  }
};

674 675 676
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
                                    fuse_relu_before_conv> {
Z
zlx 已提交
677 678 679
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
680 681
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
682 683
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
684
                  const std::vector<int>& dilations,
685 686
                  framework::Tensor* input_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
687
    const int batch_size = input.dims()[0];
688
    const int input_channels =
689
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
690
    const int input_height =
691
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
692
    const int input_width =
693
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
694
    const int output_channels =
695
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
696 697
                                          : output_grad.dims()[3]);
    const int output_height =
698
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
699 700
                                          : output_grad.dims()[1]);
    const int output_width =
701
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
702
                                          : output_grad.dims()[2]);
703 704 705
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
706 707 708
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
709 710
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
711

712
    const T* input_data = input.data<T>();
713
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
714 715 716
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

717
    int thread = 512;
718 719 720 721
    if (input_width > 1024 && input_width <= 2048)
      thread = (input_width - 1) / 2 + 1;
    else if (input_width > 512 && input_width <= 1024)
      thread = input_width;
722 723 724 725 726
    int blocks = std::min(std::max(thread / input_width, 1), input_height);
    dim3 threads(std::min(input_width, thread), blocks, 1);
    dim3 grid(input_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;

727
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
728 729
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
730 731 732
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
733
    KernelDepthwiseConvInputGradSp<                                     \
734 735 736 737 738 739
        T, c_filter_multiplier, c_stride, c_filter,                     \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
        input_data, output_grad_data, filter_data, batch_size,          \
        output_channels, output_height, output_width, input_channels,   \
        input_height, input_width, filter_multiplier, ksize_height,     \
        ksize_width, stride_height, stride_width, padding_height,       \
740 741
        padding_width, dilate_height, dilate_width, input_grad_data,    \
        data_layout);                                                   \
742 743
    return;                                                             \
  }
744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
759
#undef check_case
Z
zlx 已提交
760 761 762
  }
};

763 764 765
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
                                     fuse_relu_before_conv> {
Z
zlx 已提交
766 767 768
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
769
                  const framework::Tensor& output_grad,
X
xzl 已提交
770 771
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
772
                  const std::vector<int>& dilations,
773 774
                  framework::Tensor* filter_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
775
    const int batch_size = input.dims()[0];
776
    const int input_channels =
777
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
778
    const int input_height =
779
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
780
    const int input_width =
781
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
782
    const int output_channels =
783
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
784 785
                                          : output_grad.dims()[3]);
    const int output_height =
786
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
787 788
                                          : output_grad.dims()[1]);
    const int output_width =
789
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
790
                                          : output_grad.dims()[2]);
791 792
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
793 794 795 796
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
797 798
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
799 800 801

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
802
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
803

804
    int block_size = 512;
805 806 807 808
    if (output_width > 1024 && output_width <= 2048)
      block_size = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      block_size = output_width;
809 810 811 812 813 814 815 816 817
    int crop_output_height =
        std::min(std::max(block_size / output_width, 1), output_height);
    dim3 grid(ksize_width, ksize_height, output_channels);
    dim3 threads(std::min(output_width, block_size), crop_output_height, 1);
    int filter_multiplier = output_channels / input_channels;

#define check_case(c_filter_multiplier)                                       \
  if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \
    KernelDepthwiseConvFilterGradSp<                                          \
818 819
        T, c_filter_multiplier,                                               \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(       \
820 821 822 823
        output_grad_data, input_data, batch_size, output_channels,            \
        output_height, output_width, input_channels, input_height,            \
        input_width, filter_multiplier, ksize_height, ksize_width,            \
        stride_height, stride_width, padding_height, padding_width,           \
824
        dilate_height, dilate_width, filter_grad_data, data_layout);          \
825 826 827 828 829
    return;                                                                   \
  }
    check_case(1);
    check_case(0);
#undef check_case
Z
zlx 已提交
830 831 832
  }
};

833 834
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, false>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, false>;
Z
zlx 已提交
835

836 837
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             false>;
Z
zlx 已提交
838
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
839 840 841 842 843 844 845 846 847 848 849 850
                                             double, false>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float, false>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              double, false>;

template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, true>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, true>;

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             true>;
Z
zlx 已提交
851
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
852
                                             double, true>;
853 854

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
855
                                              float, true>;
Z
zlx 已提交
856
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
857
                                              double, true>;
Z
zlx 已提交
858 859 860 861

}  // namespace math
}  // namespace operators
}  // namespace paddle