depthwise_conv.h 70.8 KB
Newer Older
H
hong 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hong 已提交
15
#pragma once
A
Abhinav Arora 已提交
16
#include <vector>
17

H
hong 已提交
18 19 20 21
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/hostdevice.h"

22 23 24 25 26 27 28
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
H
hong 已提交
29

30 31
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
32
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zlx 已提交
33 34 35 36 37

namespace paddle {
namespace operators {
namespace math {

H
hong 已提交
38 39 40 41 42 43 44 45 46 47 48 49
using DataLayout = framework::DataLayout;

/*
 * \brief Compute the depthwise convolution which include
 * forward process and backpropagation process
 */
template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvFunctor {
 public:
  void operator()(const DeviceContext& context,
50 51
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
H
hong 已提交
52 53 54
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
55
                  phi::DenseTensor* output,
H
hong 已提交
56 57 58 59 60 61 62 63 64
                  const DataLayout data_layout = DataLayout::kNCHW);
};

template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvInputGradFunctor {
 public:
  void operator()(const DeviceContext& context,
65 66 67
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
                  const phi::DenseTensor& output_grad,
H
hong 已提交
68 69 70
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
71
                  phi::DenseTensor* input_grad,
H
hong 已提交
72 73 74 75 76 77 78 79 80
                  const DataLayout data_layout = DataLayout::kNCHW);
};

template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvFilterGradFunctor {
 public:
  void operator()(const DeviceContext& context,
81 82
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& output_grad,
H
hong 已提交
83 84 85
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
86
                  phi::DenseTensor* filter_grad,
H
hong 已提交
87 88 89
                  const DataLayout data_layout = DataLayout::kNCHW);
};

90
template <typename T>
W
wangguanzhong 已提交
91
static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) {
92 93
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
W
wangguanzhong 已提交
94 95 96
  val = WarpReduce(temp_storage).Sum(val, warp_size);
  return val;
}
97

W
wangguanzhong 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
template <typename T>
__forceinline__ __device__ T BlockReduceSum(T val) {
  static __shared__ T shared[32];
  int thread_id = threadIdx.x + threadIdx.y * blockDim.x +
                  threadIdx.z * blockDim.x * blockDim.y;
  int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
  int lane = thread_id % warp_size;
  int wid = thread_id / warp_size;

  val = WarpReduceSum(val, warp_size);  // Each warp performs partial reduction

  if (lane == 0) shared[wid] = val;  // Write reduced value to shared memory
  __syncthreads();                   // Wait for all partial reductions

  // read from shared memory only if that warp existed
  int block_size = blockDim.x * blockDim.y * blockDim.z;
  if (thread_id < (block_size - 1) / warp_size + 1) {
    val = shared[lane];
  } else {
    val = static_cast<T>(0);
  }
119

W
wangguanzhong 已提交
120 121 122 123 124 125 126 127
  if (wid == 0) {
    val = WarpReduceSum(val, warp_size);  // Final reduce within first warp
  }
  __syncthreads();
  if (thread_id != 0) {
    val = static_cast<T>(0);
  }
  return val;
128 129
}

130 131 132 133 134 135 136 137
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
138
      const int dilate_height, const int dilate_width, T *const output_data
139

140 141
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
142
template <typename T, int c_filter, bool fuse_relu_before_conv>
143 144
__device__ __inline__ void KernelDepthwiseConvNCHW(
    ARG_DEFINE_KernelDepthwiseConv) {
145 146
  const int fw_size = c_filter != -1 ? c_filter : filter_width;
  const int fh_size = c_filter != -1 ? c_filter : filter_height;
147 148 149 150
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

151 152 153 154 155 156 157 158
  int tmp_1 = idx / output_width;
  const int w_out = idx - tmp_1 * output_width;
  int tmp_2 = tmp_1 / output_height;
  const int h_out = tmp_1 - tmp_2 * output_height;
  tmp_1 = tmp_2;
  tmp_2 = tmp_1 / output_channels;
  const int c_out = tmp_1 - tmp_2 * output_channels;
  const int batch = tmp_2;
159 160

  const int c_in = c_out / filter_multiplier;
161
  T value(0);
162 163 164

  int in_offset =
      ((batch * input_channels + c_in) * input_height) * input_width;
165 166 167
  int weight_offset = c_out * filter_height * filter_width;
  int h_in_start = -padding_height + h_out * stride_height;
  int w_in_start = -padding_width + w_out * stride_width;
168 169

#pragma unroll
170 171
  for (int fh = 0, h_in = h_in_start; fh < fh_size;
       fh++, h_in += dilate_height) {
172
#pragma unroll
173 174 175
    for (int fw = 0, w_in = w_in_start; fw < fw_size;
         fw++, w_in += dilate_width) {
      if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) {
176 177 178
        int offset = in_offset + h_in * input_width + w_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
179 180
          value += filter_data[weight_offset] *
                   static_cast<T>(max(0.0f, static_cast<double>(in_data)));
181
        } else {
182
          value += filter_data[weight_offset] * in_data;
183
        }
184
      }
185 186 187
      weight_offset++;
    }
  }
188
  output_data[idx] = value;
189
}
190

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int c_out = idx % output_channels;
  const int w_out = (idx / output_channels) % output_width;
  const int h_out = (idx / output_channels / output_width) % output_height;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
206
  T value(0);
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = ((batch * input_height + h_in) * input_width + w_in) *
224
                         input_channels +
225 226
                     c_in;
        T in_data = input_data[offset];
227
        const T* weight = filter_data + weight_offset * output_channels + c_out;
228
        if (fuse_relu_before_conv) {
229 230
          value += weight[0] *
                   static_cast<T>(max(0.0f, static_cast<double>(in_data)));
231
        } else {
232
          value += weight[0] * in_data;
233
        }
Z
zlx 已提交
234
      }
235
      weight_offset++;
Z
zlx 已提交
236 237
    }
  }
238 239 240 241
  int index = batch * output_channels * output_height * output_width +
              h_out * output_width * output_channels + w_out * output_channels +
              c_out;
  output_data[index] = value;
Z
zlx 已提交
242
}
243

244
template <typename T, int c_filter, bool fuse_relu_before_conv>
245
__device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
246
    ARG_DEFINE_KernelDepthwiseConv) {
247 248
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
249 250 251 252
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
253

254 255 256 257 258 259
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
260
      T value(0);
261 262 263 264 265
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

266 267
      int in_offset =
          ((batch * input_channels + c_in) * input_height) * input_width;
268 269 270 271 272 273 274 275 276 277 278 279

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
280 281 282
            int offset = in_offset + h_in * input_width + w_in;
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
283 284
                       static_cast<T>(
                           max(0.0f, static_cast<double>(input_data[offset])));
285
            } else {
286
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
287
            }
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
          }
        }
      }
      int index =
          ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
          w_out;
      output_data[index] = value;
    }
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int batch = blockIdx.z;
  int h_out = blockIdx.x * dilate_height + blockIdx.y;
  if (h_out >= output_height) {
    return;
  }
  int in_offset = batch * input_height * input_width * input_channels;
  int out_offset =
      (batch * output_height + h_out) * output_width * output_channels;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];

  for (int c_out = threadIdx.x; c_out < output_channels; c_out += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; i++) {
      const T* weight = filter_data + i * output_channels + c_out;
      r_weight[i] = weight[0];
    }
    const int c_in = c_out / filter_multiplier;
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_out = i_wi * dilate_width + i_dw;
      if (w_out >= output_width) {
        continue;
      }
328
      T value(0);
329 330 331 332 333 334 335 336 337
      const int w_in_start = -padding_width + w_out * stride_width;
      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
            int offset =
                in_offset + (h_in * input_width + w_in) * input_channels + c_in;
338 339
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
340 341
                       static_cast<T>(
                           max(0.0, static_cast<double>(input_data[offset])));
342 343 344
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
345 346 347
          }
        }
      }
348
      int index = out_offset + w_out * output_channels + c_out;
349 350 351 352 353
      output_data[index] = value;
    }
  }
}

H
hong 已提交
354 355 356 357 358 359
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
360
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
361 362 363 364 365 366 367 368 369
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter == -1) {
370
    if (data_layout != DataLayout::kNHWC) {
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
      KernelDepthwiseConvNCHW<T, c_filter, fuse_relu_before_conv>(
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          output_data);
391
    } else {
H
hong 已提交
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
      KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(input_data,
                                                        filter_data,
                                                        batch_size,
                                                        output_channels,
                                                        output_height,
                                                        output_width,
                                                        input_channels,
                                                        input_height,
                                                        input_width,
                                                        final_filter_multiplier,
                                                        filter_height,
                                                        filter_width,
                                                        h_stride,
                                                        w_stride,
                                                        padding_height,
                                                        padding_width,
                                                        dilate_height,
                                                        dilate_width,
                                                        output_data);
411 412
    }
  } else {
413 414
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvCFilterNCHW<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
433 434 435
          output_data);
    } else {
      KernelDepthwiseConvCFilterNHWC<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
454 455
          output_data);
    }
456
  }
457 458
}

Z
zlx 已提交
459
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
460
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
461 462 463 464 465
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
466 467 468 469
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
470
      T *const input_grad_data
471

472
template <typename T, bool fuse_relu_before_conv>
473
__device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
474
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
475 476
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;
477 478 479 480 481 482 483 484 485 486
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int c_out_start = c_in * filter_multiplier;
      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;
      int h_out_end = h_in + padding_height;
      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;
      int w_out_end = w_in + padding_width;

487
      T value(0);
488 489 490
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
491

492
      if (fuse_relu_before_conv) {
493
        if (input_data[index] <= T(0)) {
494 495 496 497
          input_grad_data[index] = 0;
          continue;
        }
      }
498 499 500 501 502 503 504 505 506 507 508 509 510 511

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
              int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int batch = blockIdx.z;
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int w_in = threadIdx.y; w_in < input_width; w_in += blockDim.y) {
      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;
      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

544
      T value(0);
545 546 547 548
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
549
        if (input_data[index] <= T(0)) {
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        int weight_offset = filter_height * filter_width;
        for (int h_out = h_out_start, h_f = 0; h_f < filter_height;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < filter_width;
               w_out += dilate_width, w_f++) {
            weight_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              int filter_offset = weight_offset * output_channels + c_out;
573 574 575 576
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
577 578
        }
      }
579
      input_grad_data[index] = value;
Z
zlx 已提交
580 581 582 583
    }
  }
}

H
hong 已提交
584 585 586
template <typename T,
          int c_filter,
          int c_filter_multiplier,
587
          bool fuse_relu_before_conv>
588
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW(
589
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
590 591
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

608
      T value(0);
609 610 611
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
612
      if (fuse_relu_before_conv) {
613
        if (input_data[index] <= T(0)) {
614 615 616 617
          input_grad_data[index] = 0;
          continue;
        }
      }
618 619 620 621 622 623 624 625 626 627 628 629

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
630 631 632 633 634
              int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
635 636 637 638 639 640 641 642 643 644 645 646
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

H
hong 已提交
647 648 649
template <typename T,
          int c_filter,
          int c_filter_multiplier,
650
          bool fuse_relu_before_conv>
651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
  const int batch = blockIdx.z;
  const int wi_size = (input_width + dilate_width - 1) / dilate_width;
  const int h_out_start =
      h_in - (c_filter - 1) * dilate_height + padding_height;

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
      int c_out = c_in * c_filter_multiplier + c_i;
      for (int i = 0; i < c_filter * c_filter; i++)
        r_weight[i + c_i * c_filter * c_filter] =
            filter_data[(c_filter * c_filter - i - 1) * output_channels +
                        c_out];
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_in = i_wi * dilate_width + i_dw;
      if (w_in >= input_width) {
        continue;
      }
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

681
      T value(0);
682 683 684 685
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
686
        if (input_data[index] <= T(0)) {
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
        int c_out = c_in * c_filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

H
hong 已提交
719 720 721 722 723 724
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
725
__global__ void KernelDepthwiseConvInputGradSp(
726
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
727 728 729 730 731 732 733 734 735 736 737 738
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }

  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvInputGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
759 760
    } else {
      KernelDepthwiseConvInputGradNHWC<T, fuse_relu_before_conv>(
H
hong 已提交
761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
781 782 783
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
H
hong 已提交
784 785 786
      KernelDepthwiseConvInputGradCFilterNCHW<T,
                                              c_filter,
                                              c_filter_multiplier,
787
                                              fuse_relu_before_conv>(
H
hong 已提交
788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          c_filter_multiplier,
          filter_height,
          filter_width,
          c_stride,
          c_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
808
    } else {
H
hong 已提交
809 810 811
      KernelDepthwiseConvInputGradCFilterNHWC<T,
                                              c_filter,
                                              c_filter_multiplier,
812
                                              fuse_relu_before_conv>(
H
hong 已提交
813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          c_filter_multiplier,
          filter_height,
          filter_width,
          c_stride,
          c_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
833 834
    }
  }
835 836
}

837
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
838
template <typename T, bool fuse_relu_before_conv>
839
__device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
H
hong 已提交
840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
859
  T s(0);
860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876
  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
877 878 879 880 881 882 883 884
        int input_id = ((bid * (gridDim.z / filter_multiplier) +
                         kernel_id / filter_multiplier) *
                            input_height +
                        image_hk) *
                           input_width +
                       image_wk;
        if (fuse_relu_before_conv) {
          s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
885 886
               static_cast<T>(
                   max(0.0f, static_cast<double>(input_data[input_id])));
887
        } else {
888 889 890 891 892 893 894
          s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
               input_data[input_id];
        }
#undef gaid
      }
    }
  }
W
wangguanzhong 已提交
895 896

  T val = BlockReduceSum(s);
897
  if (threadIdx.y == 0 && threadIdx.x == 0) filter_grad_data[gbid] = val;
898 899 900 901
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
H
hong 已提交
902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
921 922 923 924 925 926
  int bid = blockIdx.z;
  int image_h = blockIdx.y;
  int kernel_iw = blockIdx.x % filter_width;
  int kernel_ih = blockIdx.x / filter_width;
  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
927
    T s(0);
928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946
    int gbid =
        ((kernel_id * filter_height) + kernel_ih) * filter_width + kernel_iw;
    for (int image_w = threadIdx.y; image_w < output_width;
         image_w += blockDim.y) {
      int kernel_h = kernel_ih * dilate_height - padding_height;
      int kernel_w = kernel_iw * dilate_width - padding_width;

      int image_hk = image_h * stride_height + kernel_h;
      int image_wk = image_w * stride_width + kernel_w;
      if (image_hk < 0 || image_hk >= input_height) continue;
      if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * output_channels + (C))
      int input_id =
          ((bid * input_height + image_hk) * input_width + image_wk) *
              input_channels +
          kernel_id / filter_multiplier;
      if (fuse_relu_before_conv) {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
947 948
             static_cast<T>(
                 max(0.0f, static_cast<double>(input_data[input_id])));
949 950 951 952 953 954 955 956 957 958 959 960
      } else {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
             input_data[input_id];
      }
#undef gaid
    }
    platform::CudaAtomicAdd(&filter_grad_data[gbid], s);
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
H
hong 已提交
961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
  const int bid = blockIdx.z;
  int image_h = blockIdx.x * dilate_height + blockIdx.y;
  if (image_h >= output_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;

  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; ++i) {
      r_weight[i] = 0;
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int image_w = i_wi * dilate_width + i_dw;
      if (image_w >= output_width) {
        continue;
      }
      for (int kernel_ih = 0; kernel_ih < c_filter; ++kernel_ih) {
        for (int kernel_iw = 0; kernel_iw < c_filter; ++kernel_iw) {
          int kernel_h = kernel_ih * dilate_height - padding_height;
          int kernel_w = kernel_iw * dilate_width - padding_width;
          int image_hk = image_h * stride_height + kernel_h;
          int image_wk = image_w * stride_width + kernel_w;
          if (image_hk < 0 || image_hk >= input_height) continue;
          if (image_wk < 0 || image_wk >= input_width) continue;
          int input_id =
1010
              ((bid * input_height + image_hk) * input_width + image_wk) *
1011
                  input_channels +
1012
              kernel_id / filter_multiplier;
1013 1014 1015 1016
          int output_id =
              ((bid * output_height + image_h) * output_width + image_w) *
                  output_channels +
              kernel_id;
1017
          T s(0);
1018
          if (fuse_relu_before_conv) {
1019
            s = output_grad_data[output_id] *
1020 1021
                static_cast<T>(
                    max(0.0f, static_cast<double>(input_data[input_id])));
1022
          } else {
1023
            s = output_grad_data[output_id] * input_data[input_id];
1024
          }
1025
          r_weight[kernel_ih * c_filter + kernel_iw] += s;
1026
        }
1027
      }
Z
zlx 已提交
1028
    }
1029 1030 1031 1032
    for (int i = 0; i < c_filter * c_filter; ++i) {
      T* weight = filter_grad_data + i * output_channels + kernel_id;
      platform::CudaAtomicAdd(&weight[0], r_weight[i]);
    }
Z
zlx 已提交
1033
  }
1034 1035
}

H
hong 已提交
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvFilterGradSp(const T* output_grad_data,
                                                const T* input_data,
                                                const int num,
                                                const int output_channels,
                                                const int output_height,
                                                const int output_width,
                                                const int input_channels,
                                                const int input_height,
                                                const int input_width,
                                                const int filter_multiplier,
                                                const int filter_height,
                                                const int filter_width,
                                                const int stride_height,
                                                const int stride_width,
                                                const int padding_height,
                                                const int padding_width,
                                                const int dilate_height,
                                                const int dilate_width,
                                                T* filter_grad_data) {
1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1090 1091 1092
          filter_grad_data);
    } else {
      KernelDepthwiseConvFilterGradNHWC<T, fuse_relu_before_conv>(
H
hong 已提交
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1111 1112 1113 1114 1115
          filter_grad_data);
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1134 1135
          filter_grad_data);
    } else {
H
hong 已提交
1136 1137
      KernelDepthwiseConvFilterGradCFilterNHWC<T,
                                               c_filter,
1138
                                               fuse_relu_before_conv>(
H
hong 已提交
1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1157 1158 1159
          filter_grad_data);
    }
  }
Z
zlx 已提交
1160 1161 1162 1163 1164 1165 1166
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
1167
template <class T, bool fuse_relu_before_conv>
H
hong 已提交
1168
class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
Z
zlx 已提交
1169
 public:
H
hong 已提交
1170
  void operator()(const phi::GPUContext& context,
1171 1172
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
X
xzl 已提交
1173
                  const std::vector<int>& strides,
1174
                  const std::vector<int>& paddings,
H
hong 已提交
1175
                  const std::vector<int>& dilations,
1176
                  phi::DenseTensor* output,
1177
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1178
    const int batch_size = input.dims()[0];
1179
    const int input_channels =
1180
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1181
    const int input_height =
1182
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1183
    const int input_width =
1184
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1185
    const int output_channels =
1186
        (data_layout != DataLayout::kNHWC ? output->dims()[1]
1187 1188
                                          : output->dims()[3]);
    const int output_height =
1189
        (data_layout != DataLayout::kNHWC ? output->dims()[2]
1190 1191
                                          : output->dims()[1]);
    const int output_width =
1192
        (data_layout != DataLayout::kNHWC ? output->dims()[3]
1193
                                          : output->dims()[2]);
1194 1195
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
1196 1197 1198 1199
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1200 1201
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1202 1203 1204 1205 1206

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

1207
    phi::DenseTensor filter_hwc;
1208
    if (data_layout == DataLayout::kNHWC) {
H
hong 已提交
1209 1210 1211 1212
      framework::DDim filter_hwc_dims({filter.dims()[2],
                                       filter.dims()[3],
                                       filter.dims()[0],
                                       filter.dims()[1]});
1213 1214 1215
      filter_hwc.Resize(filter_hwc_dims);
      filter_hwc.mutable_data<T>(context.GetPlace());
      std::vector<int> perm_axis({2, 3, 0, 1});
H
hong 已提交
1216
      phi::funcs::TransposeNormal<phi::GPUContext, T> trans;
1217 1218 1219 1220
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1221
    int thread = 512;
1222 1223 1224
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1225

1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048)
        thread = (output_width - 1) / 2 + 1;
      else if (output_width > 512 && output_width <= 1024)
        thread = output_width;
#ifdef __HIPCC__
      thread = std::min(thread, 256);
#endif
      blocks = std::min(std::max(thread / output_width, 1), output_height);
      threads = dim3(std::min(output_width, thread), blocks, 1);
      grid = dim3(output_channels, batch_size, 1);
    } else {
1238
#ifdef __HIPCC__
1239
      thread = std::min(thread, 256);
1240
#endif
1241 1242 1243 1244 1245
      blocks = std::min(
          std::max(thread / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(output_channels, thread), blocks, 1);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1246 1247
                  dilate_height,
                  batch_size);
1248
    }
1249
    int filter_multiplier = output_channels / input_channels;
1250 1251
    int nums_output =
        batch_size * output_channels * output_height * output_width;
1252 1253 1254
#ifdef __HIPCC__
    int block_size = 256;
#else
1255
    int block_size = 512;
1256
#endif
1257
    int grid_size = (nums_output + block_size - 1) / block_size;
1258

1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
    if (c_filter == -1) {                                               \
      threads.x = block_size;                                           \
      grid.x = grid_size;                                               \
      threads.y = threads.z = grid.y = grid.z = 1;                      \
    }                                                                   \
    if (data_layout != DataLayout::kNHWC) {                             \
      KernelDepthwiseConvSp<T,                                          \
                            c_filter_multiplier,                        \
                            c_stride,                                   \
                            c_filter,                                   \
                            DataLayout::kNCHW,                          \
                            fuse_relu_before_conv>                      \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   output_data);        \
    } else {                                                            \
      KernelDepthwiseConvSp<T,                                          \
                            c_filter_multiplier,                        \
                            c_stride,                                   \
                            c_filter,                                   \
                            DataLayout::kNHWC,                          \
                            fuse_relu_before_conv>                      \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   output_data);        \
    }                                                                   \
    return;                                                             \
1324
  }
1325 1326 1327 1328 1329 1330
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
1331 1332 1333 1334 1335 1336
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
1337 1338 1339
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1340
#undef check_case
Z
zlx 已提交
1341 1342 1343
  }
};

1344
template <typename T, bool fuse_relu_before_conv>
H
hong 已提交
1345
class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
Z
zlx 已提交
1346
 public:
H
hong 已提交
1347
  void operator()(const phi::GPUContext& context,
1348 1349 1350
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
                  const phi::DenseTensor& output_grad,
X
xzl 已提交
1351 1352
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1353
                  const std::vector<int>& dilations,
1354
                  phi::DenseTensor* input_grad,
1355
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1356
    const int batch_size = input.dims()[0];
1357
    const int input_channels =
1358
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1359
    const int input_height =
1360
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1361
    const int input_width =
1362
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1363
    const int output_channels =
1364
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1365 1366
                                          : output_grad.dims()[3]);
    const int output_height =
1367
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1368 1369
                                          : output_grad.dims()[1]);
    const int output_width =
1370
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1371
                                          : output_grad.dims()[2]);
1372 1373 1374
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
1375 1376 1377
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1378 1379
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1380

1381
    const T* input_data = input.data<T>();
1382
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
1383 1384 1385
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

1386
    phi::DenseTensor filter_hwc;
1387
    if (data_layout == DataLayout::kNHWC) {
H
hong 已提交
1388 1389 1390 1391
      framework::DDim filter_hwc_dims({filter.dims()[2],
                                       filter.dims()[3],
                                       filter.dims()[0],
                                       filter.dims()[1]});
1392 1393 1394
      filter_hwc.Resize(filter_hwc_dims);
      filter_hwc.mutable_data<T>(context.GetPlace());
      std::vector<int> perm_axis({2, 3, 0, 1});
H
hong 已提交
1395
      phi::funcs::TransposeNormal<phi::GPUContext, T> trans;
1396 1397 1398 1399
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1400
    int thread = 512;
1401 1402 1403
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1404

1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419
    if (data_layout != DataLayout::kNHWC) {
      if (input_width > 1024 && input_width <= 2048) {
        thread = (input_width - 1) / 2 + 1;
      } else if (input_width > 512 && input_width <= 1024) {
        thread = input_width;
      }
      blocks = std::min(std::max(thread / input_width, 1), input_height);
      threads = dim3(std::min(input_width, thread), blocks, 1);
      grid = dim3(input_channels, batch_size, 1);
    } else {
      blocks = std::min(
          std::max(thread / input_channels, 1),
          ((input_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(input_channels, thread), blocks, 1);
      grid = dim3((input_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1420 1421
                  dilate_height,
                  batch_size);
1422
    }
1423 1424
    int filter_multiplier = output_channels / input_channels;

1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
    if (data_layout != DataLayout::kNHWC) {                             \
      KernelDepthwiseConvInputGradSp<T,                                 \
                                     c_filter_multiplier,               \
                                     c_stride,                          \
                                     c_filter,                          \
                                     DataLayout::kNCHW,                 \
                                     fuse_relu_before_conv>             \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   output_grad_data,    \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   input_grad_data);    \
    } else {                                                            \
      KernelDepthwiseConvInputGradSp<T,                                 \
                                     c_filter_multiplier,               \
                                     c_stride,                          \
                                     c_filter,                          \
                                     DataLayout::kNHWC,                 \
                                     fuse_relu_before_conv>             \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   output_grad_data,    \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   input_grad_data);    \
    }                                                                   \
    return;                                                             \
1487
  }
1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1503
#undef check_case
Z
zlx 已提交
1504 1505 1506
  }
};

1507
template <typename T, bool fuse_relu_before_conv>
H
hong 已提交
1508 1509
class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                     T,
1510
                                     fuse_relu_before_conv> {
Z
zlx 已提交
1511
 public:
H
hong 已提交
1512
  void operator()(const phi::GPUContext& context,
1513 1514
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& output_grad,
X
xzl 已提交
1515 1516
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1517
                  const std::vector<int>& dilations,
1518
                  phi::DenseTensor* filter_grad,
1519
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1520
    const int batch_size = input.dims()[0];
1521
    const int input_channels =
1522
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1523
    const int input_height =
1524
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1525
    const int input_width =
1526
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1527
    const int output_channels =
1528
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1529 1530
                                          : output_grad.dims()[3]);
    const int output_height =
1531
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1532 1533
                                          : output_grad.dims()[1]);
    const int output_width =
1534
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1535
                                          : output_grad.dims()[2]);
1536 1537
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
1538 1539 1540 1541
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1542 1543
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1544 1545 1546

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
1547
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
1548

1549
    int block_size = 512;
1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566
    int blocks;
    dim3 threads;
    dim3 grid;
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048) {
        block_size = (output_width - 1) / 2 + 1;
      } else if (output_width > 512 && output_width <= 1024) {
        block_size = output_width;
      }
      blocks = std::min(std::max(block_size / output_width, 1), output_height);
      grid = dim3(ksize_width, ksize_height, output_channels);
      threads = dim3(std::min(output_width, block_size), blocks, 1);
    } else {
      blocks = std::min(
          std::max(block_size / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1567 1568
                  dilate_height,
                  batch_size);
1569 1570
      threads = dim3(std::min(output_channels, block_size), blocks, 1);
    }
1571 1572
    int filter_multiplier = output_channels / input_channels;

1573 1574 1575 1576 1577 1578 1579
#define check_case(c_filter_multiplier, c_stride, c_filter)                    \
  if (c_filter_multiplier == 0 ||                                              \
      filter_multiplier == c_filter_multiplier &&                              \
          stride_height == stride_width && stride_height == c_stride &&        \
          (ksize_height == ksize_width && ksize_height == c_filter ||          \
           c_filter == -1)) {                                                  \
    if (data_layout != DataLayout::kNHWC) {                                    \
1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604
      KernelDepthwiseConvFilterGradSp<T,                                       \
                                      c_filter_multiplier,                     \
                                      c_stride,                                \
                                      c_filter,                                \
                                      DataLayout::kNCHW,                       \
                                      fuse_relu_before_conv>                   \
          <<<grid, threads, 0, context.stream()>>>(output_grad_data,           \
                                                   input_data,                 \
                                                   batch_size,                 \
                                                   output_channels,            \
                                                   output_height,              \
                                                   output_width,               \
                                                   input_channels,             \
                                                   input_height,               \
                                                   input_width,                \
                                                   filter_multiplier,          \
                                                   ksize_height,               \
                                                   ksize_width,                \
                                                   stride_height,              \
                                                   stride_width,               \
                                                   padding_height,             \
                                                   padding_width,              \
                                                   dilate_height,              \
                                                   dilate_width,               \
                                                   filter_grad_data);          \
1605
    } else {                                                                   \
1606
      phi::DenseTensor filter_grad_hwc;                                        \
1607
      if (c_filter != -1) {                                                    \
H
hong 已提交
1608 1609 1610 1611
        framework::DDim filter_grad_hwc_dims({filter_grad->dims()[2],          \
                                              filter_grad->dims()[3],          \
                                              filter_grad->dims()[0],          \
                                              filter_grad->dims()[1]});        \
1612 1613
        filter_grad_hwc.Resize(filter_grad_hwc_dims);                          \
        filter_grad_hwc.mutable_data<T>(context.GetPlace());                   \
H
hong 已提交
1614
        phi::funcs::SetConstant<phi::GPUContext, T> set_zero;                  \
1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628
        set_zero(context, &filter_grad_hwc, static_cast<T>(0));                \
        filter_grad_data = filter_grad_hwc.data<T>();                          \
      } else {                                                                 \
        block_size = 512;                                                      \
        if (output_channels > 1024 && output_channels <= 2048) {               \
          block_size = (output_channels - 1) / 2 + 1;                          \
        } else if (output_channels > 512 && output_channels <= 1024) {         \
          block_size = output_channels;                                        \
        }                                                                      \
        blocks =                                                               \
            std::min(std::max(block_size / output_channels, 1), output_width); \
        grid = dim3(ksize_width * ksize_height, output_height, batch_size);    \
        threads = dim3(std::min(output_channels, block_size), blocks, 1);      \
      }                                                                        \
1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653
      KernelDepthwiseConvFilterGradSp<T,                                       \
                                      c_filter_multiplier,                     \
                                      c_stride,                                \
                                      c_filter,                                \
                                      DataLayout::kNHWC,                       \
                                      fuse_relu_before_conv>                   \
          <<<grid, threads, 0, context.stream()>>>(output_grad_data,           \
                                                   input_data,                 \
                                                   batch_size,                 \
                                                   output_channels,            \
                                                   output_height,              \
                                                   output_width,               \
                                                   input_channels,             \
                                                   input_height,               \
                                                   input_width,                \
                                                   filter_multiplier,          \
                                                   ksize_height,               \
                                                   ksize_width,                \
                                                   stride_height,              \
                                                   stride_width,               \
                                                   padding_height,             \
                                                   padding_width,              \
                                                   dilate_height,              \
                                                   dilate_width,               \
                                                   filter_grad_data);          \
1654 1655
      if (c_filter != -1) {                                                    \
        std::vector<int> perm_axis({2, 3, 0, 1});                              \
H
hong 已提交
1656
        phi::funcs::TransposeNormal<phi::GPUContext, T> trans;                 \
1657 1658 1659 1660
        trans(context, filter_grad_hwc, filter_grad, perm_axis);               \
      }                                                                        \
    }                                                                          \
    return;                                                                    \
1661
  }
1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
1675
#undef check_case
Z
zlx 已提交
1676 1677 1678
  }
};

H
hong 已提交
1679 1680
template class DepthwiseConvFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFunctor<phi::GPUContext, double, false>;
1681
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, false>;
Z
zlx 已提交
1682

H
hong 已提交
1683 1684
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, false>;
1685 1686 1687
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
                                             platform::float16,
                                             false>;
1688

H
hong 已提交
1689 1690
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, false>;
1691 1692 1693
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                              platform::float16,
                                              false>;
1694

H
hong 已提交
1695 1696
template class DepthwiseConvFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFunctor<phi::GPUContext, double, true>;
1697
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, true>;
1698

H
hong 已提交
1699 1700
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, true>;
1701 1702 1703
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
                                             platform::float16,
                                             true>;
1704

H
hong 已提交
1705 1706
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, true>;
1707 1708 1709
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                              platform::float16,
                                              true>;
Z
zlx 已提交
1710 1711 1712 1713

}  // namespace math
}  // namespace operators
}  // namespace paddle