depthwise_conv.h 71.7 KB
Newer Older
H
hong 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hong 已提交
15
#pragma once
A
Abhinav Arora 已提交
16
#include <vector>
17

H
hong 已提交
18 19 20 21
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/hostdevice.h"

22 23 24 25 26 27 28
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
H
hong 已提交
29

30 31
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
32
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zlx 已提交
33 34 35 36 37

namespace paddle {
namespace operators {
namespace math {

H
hong 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
using DataLayout = framework::DataLayout;

/*
 * \brief Compute the depthwise convolution which include
 * forward process and backpropagation process
 */
template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvFunctor {
 public:
  void operator()(const DeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
                  framework::Tensor* output,
                  const DataLayout data_layout = DataLayout::kNCHW);
};

template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvInputGradFunctor {
 public:
  void operator()(const DeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
                  framework::Tensor* input_grad,
                  const DataLayout data_layout = DataLayout::kNCHW);
};

template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvFilterGradFunctor {
 public:
  void operator()(const DeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& output_grad,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
                  framework::Tensor* filter_grad,
                  const DataLayout data_layout = DataLayout::kNCHW);
};

90
template <typename T>
W
wangguanzhong 已提交
91
static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) {
92 93
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
W
wangguanzhong 已提交
94 95 96
  val = WarpReduce(temp_storage).Sum(val, warp_size);
  return val;
}
97

W
wangguanzhong 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
template <typename T>
__forceinline__ __device__ T BlockReduceSum(T val) {
  static __shared__ T shared[32];
  int thread_id = threadIdx.x + threadIdx.y * blockDim.x +
                  threadIdx.z * blockDim.x * blockDim.y;
  int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
  int lane = thread_id % warp_size;
  int wid = thread_id / warp_size;

  val = WarpReduceSum(val, warp_size);  // Each warp performs partial reduction

  if (lane == 0) shared[wid] = val;  // Write reduced value to shared memory
  __syncthreads();                   // Wait for all partial reductions

  // read from shared memory only if that warp existed
  int block_size = blockDim.x * blockDim.y * blockDim.z;
  if (thread_id < (block_size - 1) / warp_size + 1) {
    val = shared[lane];
  } else {
    val = static_cast<T>(0);
  }
119

W
wangguanzhong 已提交
120 121 122 123 124 125 126 127
  if (wid == 0) {
    val = WarpReduceSum(val, warp_size);  // Final reduce within first warp
  }
  __syncthreads();
  if (thread_id != 0) {
    val = static_cast<T>(0);
  }
  return val;
128 129
}

130 131 132 133 134 135 136 137
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
138
      const int dilate_height, const int dilate_width, T *const output_data
139

140 141
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
142
template <typename T, bool fuse_relu_before_conv>
143 144 145 146 147 148 149 150 151 152 153 154 155
__device__ __inline__ void KernelDepthwiseConvNCHW(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int w_out = idx % output_width;
  const int h_out = (idx / output_width) % output_height;
  const int c_out = (idx / output_width / output_height) % output_channels;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  const T* weight = filter_data + c_out * filter_height * filter_width;
156
  T value(0);
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  int in_offset =
      ((batch * input_channels + c_in) * input_height) * input_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = in_offset + h_in * input_width + w_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
179 180
          value += weight[weight_offset] *
                   T(max(0.0f, static_cast<double>(in_data)));
181 182 183
        } else {
          value += weight[weight_offset] * in_data;
        }
184
      }
185 186 187 188 189 190 191 192
      weight_offset++;
    }
  }
  int index = batch * output_channels * output_height * output_width +
              c_out * output_height * output_width + h_out * output_width +
              w_out;
  output_data[index] = value;
}
193

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int c_out = idx % output_channels;
  const int w_out = (idx / output_channels) % output_width;
  const int h_out = (idx / output_channels / output_width) % output_height;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
209
  T value(0);
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = ((batch * input_height + h_in) * input_width + w_in) *
227
                         input_channels +
228 229
                     c_in;
        T in_data = input_data[offset];
230
        const T* weight = filter_data + weight_offset * output_channels + c_out;
231
        if (fuse_relu_before_conv) {
232
          value += weight[0] * T(max(0.0f, static_cast<double>(in_data)));
233
        } else {
234
          value += weight[0] * in_data;
235
        }
Z
zlx 已提交
236
      }
237
      weight_offset++;
Z
zlx 已提交
238 239
    }
  }
240 241 242 243
  int index = batch * output_channels * output_height * output_width +
              h_out * output_width * output_channels + w_out * output_channels +
              c_out;
  output_data[index] = value;
Z
zlx 已提交
244
}
245

246
template <typename T, int c_filter, bool fuse_relu_before_conv>
247
__device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
248
    ARG_DEFINE_KernelDepthwiseConv) {
249 250
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
251 252 253 254
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
255

256 257 258 259 260 261
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
262
      T value(0);
263 264 265 266 267
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

268 269
      int in_offset =
          ((batch * input_channels + c_in) * input_height) * input_width;
270 271 272 273 274 275 276 277 278 279 280 281

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
282 283 284
            int offset = in_offset + h_in * input_width + w_in;
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
285
                       T(max(0.0f, static_cast<double>(input_data[offset])));
286
            } else {
287
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
288
            }
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
          }
        }
      }
      int index =
          ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
          w_out;
      output_data[index] = value;
    }
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int batch = blockIdx.z;
  int h_out = blockIdx.x * dilate_height + blockIdx.y;
  if (h_out >= output_height) {
    return;
  }
  int in_offset = batch * input_height * input_width * input_channels;
  int out_offset =
      (batch * output_height + h_out) * output_width * output_channels;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];

  for (int c_out = threadIdx.x; c_out < output_channels; c_out += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; i++) {
      const T* weight = filter_data + i * output_channels + c_out;
      r_weight[i] = weight[0];
    }
    const int c_in = c_out / filter_multiplier;
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_out = i_wi * dilate_width + i_dw;
      if (w_out >= output_width) {
        continue;
      }
329
      T value(0);
330 331 332 333 334 335 336 337 338
      const int w_in_start = -padding_width + w_out * stride_width;
      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
            int offset =
                in_offset + (h_in * input_width + w_in) * input_channels + c_in;
339 340
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
341
                       T(max(0.0, static_cast<double>(input_data[offset])));
342 343 344
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
345 346 347
          }
        }
      }
348
      int index = out_offset + w_out * output_channels + c_out;
349 350 351 352 353
      output_data[index] = value;
    }
  }
}

H
hong 已提交
354 355 356 357 358 359
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
360
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
361 362 363 364 365 366 367 368 369
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter == -1) {
370
    if (data_layout != DataLayout::kNHWC) {
H
hong 已提交
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
      KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(input_data,
                                                        filter_data,
                                                        batch_size,
                                                        output_channels,
                                                        output_height,
                                                        output_width,
                                                        input_channels,
                                                        input_height,
                                                        input_width,
                                                        final_filter_multiplier,
                                                        filter_height,
                                                        filter_width,
                                                        h_stride,
                                                        w_stride,
                                                        padding_height,
                                                        padding_width,
                                                        dilate_height,
                                                        dilate_width,
                                                        output_data);
390
    } else {
H
hong 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
      KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(input_data,
                                                        filter_data,
                                                        batch_size,
                                                        output_channels,
                                                        output_height,
                                                        output_width,
                                                        input_channels,
                                                        input_height,
                                                        input_width,
                                                        final_filter_multiplier,
                                                        filter_height,
                                                        filter_width,
                                                        h_stride,
                                                        w_stride,
                                                        padding_height,
                                                        padding_width,
                                                        dilate_height,
                                                        dilate_width,
                                                        output_data);
410 411
    }
  } else {
412 413
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvCFilterNCHW<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
432 433 434
          output_data);
    } else {
      KernelDepthwiseConvCFilterNHWC<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
453 454
          output_data);
    }
455
  }
456 457
}

Z
zlx 已提交
458
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
459
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
460 461 462 463 464
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
465 466 467 468
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
469
      T *const input_grad_data
470

471
template <typename T, bool fuse_relu_before_conv>
472
__device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
473
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
474 475
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;
476 477 478 479 480 481 482 483 484 485
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int c_out_start = c_in * filter_multiplier;
      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;
      int h_out_end = h_in + padding_height;
      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;
      int w_out_end = w_in + padding_width;

486
      T value(0);
487 488 489
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
490

491
      if (fuse_relu_before_conv) {
492
        if (input_data[index] <= T(0)) {
493 494 495 496
          input_grad_data[index] = 0;
          continue;
        }
      }
497 498 499 500 501 502 503 504 505 506 507 508 509 510

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
              int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int batch = blockIdx.z;
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int w_in = threadIdx.y; w_in < input_width; w_in += blockDim.y) {
      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;
      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

543
      T value(0);
544 545 546 547
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
548
        if (input_data[index] <= T(0)) {
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        int weight_offset = filter_height * filter_width;
        for (int h_out = h_out_start, h_f = 0; h_f < filter_height;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < filter_width;
               w_out += dilate_width, w_f++) {
            weight_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              int filter_offset = weight_offset * output_channels + c_out;
572 573 574 575
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
576 577
        }
      }
578
      input_grad_data[index] = value;
Z
zlx 已提交
579 580 581 582
    }
  }
}

H
hong 已提交
583 584 585
template <typename T,
          int c_filter,
          int c_filter_multiplier,
586
          bool fuse_relu_before_conv>
587
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW(
588
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
589 590
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

607
      T value(0);
608 609 610
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
611
      if (fuse_relu_before_conv) {
612
        if (input_data[index] <= T(0)) {
613 614 615 616
          input_grad_data[index] = 0;
          continue;
        }
      }
617 618 619 620 621 622 623 624 625 626 627 628

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
629 630 631 632 633
              int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
634 635 636 637 638 639 640 641 642 643 644 645
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

H
hong 已提交
646 647 648
template <typename T,
          int c_filter,
          int c_filter_multiplier,
649
          bool fuse_relu_before_conv>
650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
  const int batch = blockIdx.z;
  const int wi_size = (input_width + dilate_width - 1) / dilate_width;
  const int h_out_start =
      h_in - (c_filter - 1) * dilate_height + padding_height;

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
      int c_out = c_in * c_filter_multiplier + c_i;
      for (int i = 0; i < c_filter * c_filter; i++)
        r_weight[i + c_i * c_filter * c_filter] =
            filter_data[(c_filter * c_filter - i - 1) * output_channels +
                        c_out];
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_in = i_wi * dilate_width + i_dw;
      if (w_in >= input_width) {
        continue;
      }
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

680
      T value(0);
681 682 683 684
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
685
        if (input_data[index] <= T(0)) {
686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
        int c_out = c_in * c_filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

H
hong 已提交
718 719 720 721 722 723
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
724
__global__ void KernelDepthwiseConvInputGradSp(
725
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
726 727 728 729 730 731 732 733 734 735 736 737
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }

  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvInputGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
758 759
    } else {
      KernelDepthwiseConvInputGradNHWC<T, fuse_relu_before_conv>(
H
hong 已提交
760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
780 781 782
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
H
hong 已提交
783 784 785
      KernelDepthwiseConvInputGradCFilterNCHW<T,
                                              c_filter,
                                              c_filter_multiplier,
786
                                              fuse_relu_before_conv>(
H
hong 已提交
787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          c_filter_multiplier,
          filter_height,
          filter_width,
          c_stride,
          c_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
807
    } else {
H
hong 已提交
808 809 810
      KernelDepthwiseConvInputGradCFilterNHWC<T,
                                              c_filter,
                                              c_filter_multiplier,
811
                                              fuse_relu_before_conv>(
H
hong 已提交
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          c_filter_multiplier,
          filter_height,
          filter_width,
          c_stride,
          c_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
832 833
    }
  }
834 835
}

836
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
837
template <typename T, bool fuse_relu_before_conv>
838
__device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
H
hong 已提交
839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
858
  T s(0);
859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875
  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
876 877 878 879 880 881 882 883
        int input_id = ((bid * (gridDim.z / filter_multiplier) +
                         kernel_id / filter_multiplier) *
                            input_height +
                        image_hk) *
                           input_width +
                       image_wk;
        if (fuse_relu_before_conv) {
          s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
884
               T(max(0.0f, static_cast<double>(input_data[input_id])));
885
        } else {
886 887 888 889 890 891 892
          s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
               input_data[input_id];
        }
#undef gaid
      }
    }
  }
W
wangguanzhong 已提交
893 894

  T val = BlockReduceSum(s);
895
  if (threadIdx.y == 0 && threadIdx.x == 0) filter_grad_data[gbid] = val;
896 897 898 899
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
H
hong 已提交
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
919 920 921 922 923 924
  int bid = blockIdx.z;
  int image_h = blockIdx.y;
  int kernel_iw = blockIdx.x % filter_width;
  int kernel_ih = blockIdx.x / filter_width;
  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
925
    T s(0);
926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944
    int gbid =
        ((kernel_id * filter_height) + kernel_ih) * filter_width + kernel_iw;
    for (int image_w = threadIdx.y; image_w < output_width;
         image_w += blockDim.y) {
      int kernel_h = kernel_ih * dilate_height - padding_height;
      int kernel_w = kernel_iw * dilate_width - padding_width;

      int image_hk = image_h * stride_height + kernel_h;
      int image_wk = image_w * stride_width + kernel_w;
      if (image_hk < 0 || image_hk >= input_height) continue;
      if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * output_channels + (C))
      int input_id =
          ((bid * input_height + image_hk) * input_width + image_wk) *
              input_channels +
          kernel_id / filter_multiplier;
      if (fuse_relu_before_conv) {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
945
             T(max(0.0f, static_cast<double>(input_data[input_id])));
946 947 948 949 950 951 952 953 954 955 956 957
      } else {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
             input_data[input_id];
      }
#undef gaid
    }
    platform::CudaAtomicAdd(&filter_grad_data[gbid], s);
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
H
hong 已提交
958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006
  const int bid = blockIdx.z;
  int image_h = blockIdx.x * dilate_height + blockIdx.y;
  if (image_h >= output_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;

  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; ++i) {
      r_weight[i] = 0;
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int image_w = i_wi * dilate_width + i_dw;
      if (image_w >= output_width) {
        continue;
      }
      for (int kernel_ih = 0; kernel_ih < c_filter; ++kernel_ih) {
        for (int kernel_iw = 0; kernel_iw < c_filter; ++kernel_iw) {
          int kernel_h = kernel_ih * dilate_height - padding_height;
          int kernel_w = kernel_iw * dilate_width - padding_width;
          int image_hk = image_h * stride_height + kernel_h;
          int image_wk = image_w * stride_width + kernel_w;
          if (image_hk < 0 || image_hk >= input_height) continue;
          if (image_wk < 0 || image_wk >= input_width) continue;
          int input_id =
1007
              ((bid * input_height + image_hk) * input_width + image_wk) *
1008
                  input_channels +
1009
              kernel_id / filter_multiplier;
1010 1011 1012 1013
          int output_id =
              ((bid * output_height + image_h) * output_width + image_w) *
                  output_channels +
              kernel_id;
1014
          T s(0);
1015
          if (fuse_relu_before_conv) {
1016
            s = output_grad_data[output_id] *
1017
                T(max(0.0f, static_cast<double>(input_data[input_id])));
1018
          } else {
1019
            s = output_grad_data[output_id] * input_data[input_id];
1020
          }
1021
          r_weight[kernel_ih * c_filter + kernel_iw] += s;
1022
        }
1023
      }
Z
zlx 已提交
1024
    }
1025 1026 1027 1028
    for (int i = 0; i < c_filter * c_filter; ++i) {
      T* weight = filter_grad_data + i * output_channels + kernel_id;
      platform::CudaAtomicAdd(&weight[0], r_weight[i]);
    }
Z
zlx 已提交
1029
  }
1030 1031
}

H
hong 已提交
1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvFilterGradSp(const T* output_grad_data,
                                                const T* input_data,
                                                const int num,
                                                const int output_channels,
                                                const int output_height,
                                                const int output_width,
                                                const int input_channels,
                                                const int input_height,
                                                const int input_width,
                                                const int filter_multiplier,
                                                const int filter_height,
                                                const int filter_width,
                                                const int stride_height,
                                                const int stride_width,
                                                const int padding_height,
                                                const int padding_width,
                                                const int dilate_height,
                                                const int dilate_width,
                                                T* filter_grad_data) {
1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1086 1087 1088
          filter_grad_data);
    } else {
      KernelDepthwiseConvFilterGradNHWC<T, fuse_relu_before_conv>(
H
hong 已提交
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1107 1108 1109 1110 1111
          filter_grad_data);
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1130 1131
          filter_grad_data);
    } else {
H
hong 已提交
1132 1133
      KernelDepthwiseConvFilterGradCFilterNHWC<T,
                                               c_filter,
1134
                                               fuse_relu_before_conv>(
H
hong 已提交
1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1153 1154 1155
          filter_grad_data);
    }
  }
Z
zlx 已提交
1156 1157 1158 1159 1160 1161 1162
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
1163
template <class T, bool fuse_relu_before_conv>
H
hong 已提交
1164
class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
Z
zlx 已提交
1165
 public:
H
hong 已提交
1166
  void operator()(const phi::GPUContext& context,
Z
zlx 已提交
1167
                  const framework::Tensor& input,
X
xzl 已提交
1168 1169
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
1170
                  const std::vector<int>& paddings,
H
hong 已提交
1171 1172
                  const std::vector<int>& dilations,
                  framework::Tensor* output,
1173
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1174
    const int batch_size = input.dims()[0];
1175
    const int input_channels =
1176
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1177
    const int input_height =
1178
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1179
    const int input_width =
1180
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1181
    const int output_channels =
1182
        (data_layout != DataLayout::kNHWC ? output->dims()[1]
1183 1184
                                          : output->dims()[3]);
    const int output_height =
1185
        (data_layout != DataLayout::kNHWC ? output->dims()[2]
1186 1187
                                          : output->dims()[1]);
    const int output_width =
1188
        (data_layout != DataLayout::kNHWC ? output->dims()[3]
1189
                                          : output->dims()[2]);
1190 1191
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
1192 1193 1194 1195
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1196 1197
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1198 1199 1200 1201 1202

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

1203 1204
    framework::Tensor filter_hwc;
    if (data_layout == DataLayout::kNHWC) {
H
hong 已提交
1205 1206 1207 1208
      framework::DDim filter_hwc_dims({filter.dims()[2],
                                       filter.dims()[3],
                                       filter.dims()[0],
                                       filter.dims()[1]});
1209 1210 1211
      filter_hwc.Resize(filter_hwc_dims);
      filter_hwc.mutable_data<T>(context.GetPlace());
      std::vector<int> perm_axis({2, 3, 0, 1});
H
hong 已提交
1212
      phi::funcs::TransposeNormal<phi::GPUContext, T> trans;
1213 1214 1215 1216
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1217
    int thread = 512;
1218 1219 1220
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1221

1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048)
        thread = (output_width - 1) / 2 + 1;
      else if (output_width > 512 && output_width <= 1024)
        thread = output_width;
#ifdef __HIPCC__
      thread = std::min(thread, 256);
#endif
      blocks = std::min(std::max(thread / output_width, 1), output_height);
      threads = dim3(std::min(output_width, thread), blocks, 1);
      grid = dim3(output_channels, batch_size, 1);
    } else {
1234
#ifdef __HIPCC__
1235
      thread = std::min(thread, 256);
1236
#endif
1237 1238 1239 1240 1241
      blocks = std::min(
          std::max(thread / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(output_channels, thread), blocks, 1);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1242 1243
                  dilate_height,
                  batch_size);
1244
    }
1245
    int filter_multiplier = output_channels / input_channels;
1246 1247
    int nums_output =
        batch_size * output_channels * output_height * output_width;
1248 1249 1250
#ifdef __HIPCC__
    int block_size = 256;
#else
1251
    int block_size = 512;
1252
#endif
1253
    int grid_size = (nums_output + block_size - 1) / block_size;
1254

1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
    if (c_filter == -1) {                                               \
      threads.x = block_size;                                           \
      grid.x = grid_size;                                               \
      threads.y = threads.z = grid.y = grid.z = 1;                      \
    }                                                                   \
    if (data_layout != DataLayout::kNHWC) {                             \
      KernelDepthwiseConvSp<T,                                          \
                            c_filter_multiplier,                        \
                            c_stride,                                   \
                            c_filter,                                   \
                            DataLayout::kNCHW,                          \
                            fuse_relu_before_conv>                      \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   output_data);        \
    } else {                                                            \
      KernelDepthwiseConvSp<T,                                          \
                            c_filter_multiplier,                        \
                            c_stride,                                   \
                            c_filter,                                   \
                            DataLayout::kNHWC,                          \
                            fuse_relu_before_conv>                      \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   output_data);        \
    }                                                                   \
    return;                                                             \
1320
  }
1321 1322 1323 1324 1325 1326
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
1327 1328 1329 1330 1331 1332
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
1333 1334 1335
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1336
#undef check_case
Z
zlx 已提交
1337 1338 1339
  }
};

1340
template <typename T, bool fuse_relu_before_conv>
H
hong 已提交
1341
class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
Z
zlx 已提交
1342
 public:
H
hong 已提交
1343
  void operator()(const phi::GPUContext& context,
Z
zlx 已提交
1344
                  const framework::Tensor& input,
1345 1346
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
1347 1348
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1349
                  const std::vector<int>& dilations,
1350 1351
                  framework::Tensor* input_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1352
    const int batch_size = input.dims()[0];
1353
    const int input_channels =
1354
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1355
    const int input_height =
1356
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1357
    const int input_width =
1358
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1359
    const int output_channels =
1360
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1361 1362
                                          : output_grad.dims()[3]);
    const int output_height =
1363
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1364 1365
                                          : output_grad.dims()[1]);
    const int output_width =
1366
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1367
                                          : output_grad.dims()[2]);
1368 1369 1370
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
1371 1372 1373
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1374 1375
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1376

1377
    const T* input_data = input.data<T>();
1378
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
1379 1380 1381
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

1382 1383
    framework::Tensor filter_hwc;
    if (data_layout == DataLayout::kNHWC) {
H
hong 已提交
1384 1385 1386 1387
      framework::DDim filter_hwc_dims({filter.dims()[2],
                                       filter.dims()[3],
                                       filter.dims()[0],
                                       filter.dims()[1]});
1388 1389 1390
      filter_hwc.Resize(filter_hwc_dims);
      filter_hwc.mutable_data<T>(context.GetPlace());
      std::vector<int> perm_axis({2, 3, 0, 1});
H
hong 已提交
1391
      phi::funcs::TransposeNormal<phi::GPUContext, T> trans;
1392 1393 1394 1395
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1396
    int thread = 512;
1397 1398 1399
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1400

1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415
    if (data_layout != DataLayout::kNHWC) {
      if (input_width > 1024 && input_width <= 2048) {
        thread = (input_width - 1) / 2 + 1;
      } else if (input_width > 512 && input_width <= 1024) {
        thread = input_width;
      }
      blocks = std::min(std::max(thread / input_width, 1), input_height);
      threads = dim3(std::min(input_width, thread), blocks, 1);
      grid = dim3(input_channels, batch_size, 1);
    } else {
      blocks = std::min(
          std::max(thread / input_channels, 1),
          ((input_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(input_channels, thread), blocks, 1);
      grid = dim3((input_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1416 1417
                  dilate_height,
                  batch_size);
1418
    }
1419 1420
    int filter_multiplier = output_channels / input_channels;

1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
    if (data_layout != DataLayout::kNHWC) {                             \
      KernelDepthwiseConvInputGradSp<T,                                 \
                                     c_filter_multiplier,               \
                                     c_stride,                          \
                                     c_filter,                          \
                                     DataLayout::kNCHW,                 \
                                     fuse_relu_before_conv>             \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   output_grad_data,    \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   input_grad_data);    \
    } else {                                                            \
      KernelDepthwiseConvInputGradSp<T,                                 \
                                     c_filter_multiplier,               \
                                     c_stride,                          \
                                     c_filter,                          \
                                     DataLayout::kNHWC,                 \
                                     fuse_relu_before_conv>             \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   output_grad_data,    \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   input_grad_data);    \
    }                                                                   \
    return;                                                             \
1483
  }
1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1499
#undef check_case
Z
zlx 已提交
1500 1501 1502
  }
};

1503
template <typename T, bool fuse_relu_before_conv>
H
hong 已提交
1504 1505
class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                     T,
1506
                                     fuse_relu_before_conv> {
Z
zlx 已提交
1507
 public:
H
hong 已提交
1508
  void operator()(const phi::GPUContext& context,
Z
zlx 已提交
1509
                  const framework::Tensor& input,
1510
                  const framework::Tensor& output_grad,
X
xzl 已提交
1511 1512
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1513
                  const std::vector<int>& dilations,
1514 1515
                  framework::Tensor* filter_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1516
    const int batch_size = input.dims()[0];
1517
    const int input_channels =
1518
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1519
    const int input_height =
1520
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1521
    const int input_width =
1522
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1523
    const int output_channels =
1524
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1525 1526
                                          : output_grad.dims()[3]);
    const int output_height =
1527
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1528 1529
                                          : output_grad.dims()[1]);
    const int output_width =
1530
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1531
                                          : output_grad.dims()[2]);
1532 1533
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
1534 1535 1536 1537
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1538 1539
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1540 1541 1542

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
1543
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
1544

1545
    int block_size = 512;
1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562
    int blocks;
    dim3 threads;
    dim3 grid;
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048) {
        block_size = (output_width - 1) / 2 + 1;
      } else if (output_width > 512 && output_width <= 1024) {
        block_size = output_width;
      }
      blocks = std::min(std::max(block_size / output_width, 1), output_height);
      grid = dim3(ksize_width, ksize_height, output_channels);
      threads = dim3(std::min(output_width, block_size), blocks, 1);
    } else {
      blocks = std::min(
          std::max(block_size / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1563 1564
                  dilate_height,
                  batch_size);
1565 1566
      threads = dim3(std::min(output_channels, block_size), blocks, 1);
    }
1567 1568
    int filter_multiplier = output_channels / input_channels;

1569 1570 1571 1572 1573 1574 1575
#define check_case(c_filter_multiplier, c_stride, c_filter)                    \
  if (c_filter_multiplier == 0 ||                                              \
      filter_multiplier == c_filter_multiplier &&                              \
          stride_height == stride_width && stride_height == c_stride &&        \
          (ksize_height == ksize_width && ksize_height == c_filter ||          \
           c_filter == -1)) {                                                  \
    if (data_layout != DataLayout::kNHWC) {                                    \
1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600
      KernelDepthwiseConvFilterGradSp<T,                                       \
                                      c_filter_multiplier,                     \
                                      c_stride,                                \
                                      c_filter,                                \
                                      DataLayout::kNCHW,                       \
                                      fuse_relu_before_conv>                   \
          <<<grid, threads, 0, context.stream()>>>(output_grad_data,           \
                                                   input_data,                 \
                                                   batch_size,                 \
                                                   output_channels,            \
                                                   output_height,              \
                                                   output_width,               \
                                                   input_channels,             \
                                                   input_height,               \
                                                   input_width,                \
                                                   filter_multiplier,          \
                                                   ksize_height,               \
                                                   ksize_width,                \
                                                   stride_height,              \
                                                   stride_width,               \
                                                   padding_height,             \
                                                   padding_width,              \
                                                   dilate_height,              \
                                                   dilate_width,               \
                                                   filter_grad_data);          \
1601 1602 1603
    } else {                                                                   \
      framework::Tensor filter_grad_hwc;                                       \
      if (c_filter != -1) {                                                    \
H
hong 已提交
1604 1605 1606 1607
        framework::DDim filter_grad_hwc_dims({filter_grad->dims()[2],          \
                                              filter_grad->dims()[3],          \
                                              filter_grad->dims()[0],          \
                                              filter_grad->dims()[1]});        \
1608 1609
        filter_grad_hwc.Resize(filter_grad_hwc_dims);                          \
        filter_grad_hwc.mutable_data<T>(context.GetPlace());                   \
H
hong 已提交
1610
        phi::funcs::SetConstant<phi::GPUContext, T> set_zero;                  \
1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624
        set_zero(context, &filter_grad_hwc, static_cast<T>(0));                \
        filter_grad_data = filter_grad_hwc.data<T>();                          \
      } else {                                                                 \
        block_size = 512;                                                      \
        if (output_channels > 1024 && output_channels <= 2048) {               \
          block_size = (output_channels - 1) / 2 + 1;                          \
        } else if (output_channels > 512 && output_channels <= 1024) {         \
          block_size = output_channels;                                        \
        }                                                                      \
        blocks =                                                               \
            std::min(std::max(block_size / output_channels, 1), output_width); \
        grid = dim3(ksize_width * ksize_height, output_height, batch_size);    \
        threads = dim3(std::min(output_channels, block_size), blocks, 1);      \
      }                                                                        \
1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649
      KernelDepthwiseConvFilterGradSp<T,                                       \
                                      c_filter_multiplier,                     \
                                      c_stride,                                \
                                      c_filter,                                \
                                      DataLayout::kNHWC,                       \
                                      fuse_relu_before_conv>                   \
          <<<grid, threads, 0, context.stream()>>>(output_grad_data,           \
                                                   input_data,                 \
                                                   batch_size,                 \
                                                   output_channels,            \
                                                   output_height,              \
                                                   output_width,               \
                                                   input_channels,             \
                                                   input_height,               \
                                                   input_width,                \
                                                   filter_multiplier,          \
                                                   ksize_height,               \
                                                   ksize_width,                \
                                                   stride_height,              \
                                                   stride_width,               \
                                                   padding_height,             \
                                                   padding_width,              \
                                                   dilate_height,              \
                                                   dilate_width,               \
                                                   filter_grad_data);          \
1650 1651
      if (c_filter != -1) {                                                    \
        std::vector<int> perm_axis({2, 3, 0, 1});                              \
H
hong 已提交
1652
        phi::funcs::TransposeNormal<phi::GPUContext, T> trans;                 \
1653 1654 1655 1656
        trans(context, filter_grad_hwc, filter_grad, perm_axis);               \
      }                                                                        \
    }                                                                          \
    return;                                                                    \
1657
  }
1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
1671
#undef check_case
Z
zlx 已提交
1672 1673 1674
  }
};

H
hong 已提交
1675 1676
template class DepthwiseConvFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFunctor<phi::GPUContext, double, false>;
1677
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, false>;
Z
zlx 已提交
1678

H
hong 已提交
1679 1680
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, false>;
1681 1682 1683
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
                                             platform::float16,
                                             false>;
1684

H
hong 已提交
1685 1686
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, false>;
1687 1688 1689
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                              platform::float16,
                                              false>;
1690

H
hong 已提交
1691 1692
template class DepthwiseConvFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFunctor<phi::GPUContext, double, true>;
1693
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, true>;
1694

H
hong 已提交
1695 1696
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, true>;
1697 1698 1699
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
                                             platform::float16,
                                             true>;
1700

H
hong 已提交
1701 1702
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, true>;
1703 1704 1705
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                              platform::float16,
                                              true>;
Z
zlx 已提交
1706 1707 1708 1709

}  // namespace math
}  // namespace operators
}  // namespace paddle