depthwise_conv.cu 36.4 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
A
Abhinav Arora 已提交
16
#include <vector>
17
#include "cub/cub.cuh"
Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/depthwise_conv.h"
19
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
20
#include "paddle/fluid/platform/cuda_primitives.h"
Z
zlx 已提交
21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

26
template <typename T>
27 28 29 30 31
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
  value = WarpReduce(temp_storage).Sum(value);
  if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
32 33
}

34 35 36 37 38 39 40 41
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
42 43
      const int dilate_height, const int dilate_width, T *const output_data,   \
      const DataLayout data_layout = DataLayout::kNCHW
44

45 46
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
47
template <typename T, bool fuse_relu_before_conv>
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
__device__ __inline__ void KernelDepthwiseConvNCHW(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int w_out = idx % output_width;
  const int h_out = (idx / output_width) % output_height;
  const int c_out = (idx / output_width / output_height) % output_channels;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  const T* weight = filter_data + c_out * filter_height * filter_width;
  T value = 0;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  int in_offset =
      ((batch * input_channels + c_in) * input_height) * input_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = in_offset + h_in * input_width + w_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
          value += weight[weight_offset] * max(0.0f, in_data);
        } else {
          value += weight[weight_offset] * in_data;
        }
88
      }
89 90 91 92 93 94 95 96
      weight_offset++;
    }
  }
  int index = batch * output_channels * output_height * output_width +
              c_out * output_height * output_width + h_out * output_width +
              w_out;
  output_data[index] = value;
}
97

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int c_out = idx % output_channels;
  const int w_out = (idx / output_channels) % output_width;
  const int h_out = (idx / output_channels / output_width) % output_height;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
  const T* weight = filter_data + c_out * filter_height * filter_width;
  T value = 0;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = ((batch * input_height + h_in) * input_width + w_in) *
                         output_channels +
                     c_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
          value += weight[weight_offset] * max(0.0f, in_data);
        } else {
          value += weight[weight_offset] * in_data;
139
        }
Z
zlx 已提交
140
      }
141
      weight_offset++;
Z
zlx 已提交
142 143
    }
  }
144 145 146 147
  int index = batch * output_channels * output_height * output_width +
              h_out * output_width * output_channels + w_out * output_channels +
              c_out;
  output_data[index] = value;
Z
zlx 已提交
148
}
149

150
template <typename T, int c_filter, bool fuse_relu_before_conv>
151 152 153 154 155 156 157 158
__device__ __inline__ void KernelDepthwiseConvCFilter(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int kWeghtSize = c_filter * c_filter;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
159

160 161 162 163 164 165 166 167 168 169 170 171
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

172
      int in_offset;
173
      if (data_layout != DataLayout::kNHWC) {
174 175 176 177 178
        in_offset =
            ((batch * input_channels + c_in) * input_height) * input_width;
      } else {
        in_offset = batch * input_height * input_width * input_channels;
      }
179 180 181 182 183 184 185 186 187 188 189 190

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
191
            int offset;
192
            if (data_layout != DataLayout::kNHWC) {
193 194 195 196 197
              offset = in_offset + h_in * input_width + w_in;
            } else {
              offset = in_offset +
                       (h_in * input_width + w_in) * input_channels + c_in;
            }
198 199 200 201 202 203
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
                       max(0.0f, input_data[offset]);
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
204 205 206
          }
        }
      }
207
      int index;
208
      if (data_layout != DataLayout::kNHWC) {
209 210 211 212 213 214 215 216
        index = ((batch * gridDim.x + c_out) * output_height + h_out) *
                    output_width +
                w_out;
      } else {
        index = ((batch * output_height + h_out) * output_width + w_out) *
                    gridDim.x +
                c_out;
      }
217 218 219 220 221
      output_data[index] = value;
    }
  }
}

222 223
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
224
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
225 226 227 228 229 230 231 232 233 234 235
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter == -1) {
    if (data_layout == DataLayout::kNCHW) {
      KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(
236 237
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
238 239
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
240
          output_data, data_layout);
241 242
    } else {
      KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(
243 244
          input_data, filter_data, batch_size, output_channels, output_height,
          output_width, input_channels, input_height, input_width,
245 246
          final_filter_multiplier, filter_height, filter_width, h_stride,
          w_stride, padding_height, padding_width, dilate_height, dilate_width,
247
          output_data, data_layout);
248 249 250 251 252 253 254 255
    }
  } else {
    KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
        input_data, filter_data, batch_size, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        final_filter_multiplier, filter_height, filter_width, h_stride,
        w_stride, padding_height, padding_width, dilate_height, dilate_width,
        output_data, data_layout);
256
  }
257 258
}

Z
zlx 已提交
259
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
260
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
261 262 263 264 265
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
266 267 268 269
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
270 271
      T *const input_grad_data,                                                \
      const DataLayout data_layout = DataLayout::kNCHW
272

273
template <typename T, bool fuse_relu_before_conv>
274
__device__ __inline__ void KernelDepthwiseConvInputGrad(
275
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      const int c_out_start = c_in * filter_multiplier;

      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;

      int h_out_end = h_in + padding_height;

      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

      int w_out_end = w_in + padding_width;

      T value = 0;
294
      int index;
295
      if (data_layout != DataLayout::kNHWC) {
296 297 298 299 300 301 302 303 304
        index =
            ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
            w_in;
      } else {
        index =
            ((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
            c_in;
      }

305 306 307 308 309 310
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
311 312 313 314 315 316 317 318 319 320 321 322 323 324

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
325
              int output_grad_offset;
326
              if (data_layout != DataLayout::kNHWC) {
327 328 329 330 331 332 333 334 335 336 337 338
                output_grad_offset =
                    ((batch * output_channels + c_out) * output_height +
                     s_h_out) *
                        output_width +
                    s_w_out;
              } else {
                output_grad_offset =
                    ((batch * output_height + s_h_out) * output_width +
                     s_w_out) *
                        output_channels +
                    c_out;
              }
339 340 341 342
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
343 344
        }
      }
345
      input_grad_data[index] = value;
Z
zlx 已提交
346 347 348 349
    }
  }
}

350 351
template <typename T, int c_filter, int c_filter_multiplier,
          bool fuse_relu_before_conv>
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
__device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeghtSize];
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;

      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

      T value = 0;
377
      int index;
378
      if (data_layout != DataLayout::kNHWC) {
379 380 381 382 383 384 385 386
        index =
            ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
            w_in;
      } else {
        index =
            ((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
            c_in;
      }
387 388 389 390 391 392
      if (fuse_relu_before_conv) {
        if (input_data[index] <= 0) {
          input_grad_data[index] = 0;
          continue;
        }
      }
393 394 395 396 397 398 399 400 401 402 403 404

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
405
              int output_grad_offset;
406
              if (data_layout != DataLayout::kNHWC) {
407 408 409 410 411 412 413 414 415 416 417 418
                output_grad_offset =
                    ((batch * output_channels + c_out) * output_height +
                     s_h_out) *
                        output_width +
                    s_w_out;
              } else {
                output_grad_offset =
                    ((batch * output_height + s_h_out) * output_width +
                     s_w_out) *
                        output_channels +
                    c_out;
              }
419 420 421 422 423 424 425 426 427 428 429 430
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

431 432
template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
          bool fuse_relu_before_conv>
433
__global__ void KernelDepthwiseConvInputGradSp(
434
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
435
  if (c_filter_multiplier == 0)
436 437
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
438 439 440
        output_height, output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
441
        dilate_width, input_grad_data, data_layout);
442
  else if (c_filter == -1)
443 444
    KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
445 446 447
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
448
        input_grad_data, data_layout);
449
  else
450 451 452
    KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier,
                                        fuse_relu_before_conv>(
        input_data, output_grad_data, filter_data, batch_size, output_channels,
453 454 455
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
456
        input_grad_data, data_layout);
457 458
}

459
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
460
template <typename T, bool fuse_relu_before_conv>
461 462 463 464 465 466 467
__device__ __inline__ void KernelDepthwiseConvFilterGrad(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
468 469
    const int dilate_width, T* filter_grad_data,
    const DataLayout data_layout = DataLayout::kNCHW) {
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
  T s = 0;

  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
489 490 491
#define gaid_nhwc(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C))
        int input_id;
492
        if (data_layout != DataLayout::kNHWC) {
493 494 495 496 497 498 499 500 501 502 503 504 505
          input_id = ((bid * (gridDim.z / filter_multiplier) +
                       kernel_id / filter_multiplier) *
                          input_height +
                      image_hk) *
                         input_width +
                     image_wk;
          if (fuse_relu_before_conv) {
            s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
                 max(0.0f, input_data[input_id]);
          } else {
            s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
                 input_data[input_id];
          }
506
        } else {
507 508 509 510 511 512 513 514 515 516 517
          input_id =
              ((bid * input_height + image_hk) * input_width + image_wk) *
                  (gridDim.z / filter_multiplier) +
              kernel_id / filter_multiplier;
          if (fuse_relu_before_conv) {
            s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
                 max(0.0f, input_data[input_id]);
          } else {
            s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
                 input_data[input_id];
          }
518
        }
519 520

#undef gaid
521
      }
Z
zlx 已提交
522 523
    }
  }
524
  CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
525 526
}

527
template <typename T, int c_filter_multiplier, bool fuse_relu_before_conv>
528 529 530 531 532 533 534
__global__ void KernelDepthwiseConvFilterGradSp(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
535 536
    const int dilate_width, T* filter_grad_data,
    const DataLayout data_layout = DataLayout::kNCHW) {
537
  if (c_filter_multiplier == 0)
538
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
539 540 541 542
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
543
        dilate_width, filter_grad_data, data_layout);
544
  else
545
    KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
546 547 548 549
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
550
        dilate_width, filter_grad_data, data_layout);
Z
zlx 已提交
551 552 553 554 555 556 557
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
558 559 560
template <class T, bool fuse_relu_before_conv>
class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
                           fuse_relu_before_conv> {
Z
zlx 已提交
561 562 563
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
X
xzl 已提交
564 565
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
566
                  const std::vector<int>& paddings,
567 568
                  const std::vector<int>& dilations, framework::Tensor* output,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
569
    const int batch_size = input.dims()[0];
570
    const int input_channels =
571
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
572
    const int input_height =
573
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
574
    const int input_width =
575
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
576
    const int output_channels =
577
        (data_layout != DataLayout::kNHWC ? output->dims()[1]
578 579
                                          : output->dims()[3]);
    const int output_height =
580
        (data_layout != DataLayout::kNHWC ? output->dims()[2]
581 582
                                          : output->dims()[1]);
    const int output_width =
583
        (data_layout != DataLayout::kNHWC ? output->dims()[3]
584
                                          : output->dims()[2]);
585 586
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
587 588 589 590
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
591 592
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
593 594 595 596 597

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

598
    int thread = 512;
599 600 601 602
    if (output_width > 1024 && output_width <= 2048)
      thread = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      thread = output_width;
603 604 605 606
    int blocks = std::min(std::max(thread / output_width, 1), output_height);
    dim3 threads(std::min(output_width, thread), blocks, 1);
    dim3 grid(output_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;
607 608 609 610 611

    int nums_output =
        batch_size * output_channels * output_height * output_width;
    int block_size = 512;

612
#define check_case(c_filter_multiplier, c_stride, c_filter)                  \
613 614
  if (c_filter_multiplier == 0 ||                                            \
      filter_multiplier == c_filter_multiplier &&                            \
615 616 617
          stride_height == stride_width && stride_height == c_stride &&      \
          (ksize_height == ksize_width && ksize_height == c_filter ||        \
           c_filter == -1)) {                                                \
618 619 620 621 622
    if (c_filter == -1) {                                                    \
      threads.x = block_size;                                                \
      grid.x = (nums_output + block_size - 1) / block_size;                  \
      threads.y = threads.z = grid.y = grid.z = 1;                           \
    }                                                                        \
623 624 625
    KernelDepthwiseConvSp<                                                   \
        T, c_filter_multiplier, c_stride, c_filter,                          \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(      \
626 627 628 629
        input_data, filter_data, batch_size, output_channels, output_height, \
        output_width, input_channels, input_height, input_width,             \
        filter_multiplier, ksize_height, ksize_width, stride_height,         \
        stride_width, padding_height, padding_width, dilate_height,          \
630
        dilate_width, output_data, data_layout);                             \
631 632
    return;                                                                  \
  }
633 634 635 636 637 638
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
639 640 641 642 643 644
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
645 646 647
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
648
#undef check_case
Z
zlx 已提交
649 650 651
  }
};

652 653 654
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
                                    fuse_relu_before_conv> {
Z
zlx 已提交
655 656 657
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
658 659
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
660 661
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
662
                  const std::vector<int>& dilations,
663 664
                  framework::Tensor* input_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
665
    const int batch_size = input.dims()[0];
666
    const int input_channels =
667
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
668
    const int input_height =
669
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
670
    const int input_width =
671
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
672
    const int output_channels =
673
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
674 675
                                          : output_grad.dims()[3]);
    const int output_height =
676
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
677 678
                                          : output_grad.dims()[1]);
    const int output_width =
679
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
680
                                          : output_grad.dims()[2]);
681 682 683
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
684 685 686
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
687 688
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
689

690
    const T* input_data = input.data<T>();
691
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
692 693 694
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

695
    int thread = 512;
696 697 698 699
    if (input_width > 1024 && input_width <= 2048)
      thread = (input_width - 1) / 2 + 1;
    else if (input_width > 512 && input_width <= 1024)
      thread = input_width;
700 701 702 703 704
    int blocks = std::min(std::max(thread / input_width, 1), input_height);
    dim3 threads(std::min(input_width, thread), blocks, 1);
    dim3 grid(input_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;

705
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
706 707
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
708 709 710
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
711
    KernelDepthwiseConvInputGradSp<                                     \
712 713 714 715 716 717
        T, c_filter_multiplier, c_stride, c_filter,                     \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
        input_data, output_grad_data, filter_data, batch_size,          \
        output_channels, output_height, output_width, input_channels,   \
        input_height, input_width, filter_multiplier, ksize_height,     \
        ksize_width, stride_height, stride_width, padding_height,       \
718 719
        padding_width, dilate_height, dilate_width, input_grad_data,    \
        data_layout);                                                   \
720 721
    return;                                                             \
  }
722 723 724 725 726 727 728 729 730 731 732 733 734 735 736
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
737
#undef check_case
Z
zlx 已提交
738 739 740
  }
};

741 742 743
template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
                                     fuse_relu_before_conv> {
Z
zlx 已提交
744 745 746
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
747
                  const framework::Tensor& output_grad,
X
xzl 已提交
748 749
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
750
                  const std::vector<int>& dilations,
751 752
                  framework::Tensor* filter_grad,
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
753
    const int batch_size = input.dims()[0];
754
    const int input_channels =
755
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
756
    const int input_height =
757
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
758
    const int input_width =
759
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
760
    const int output_channels =
761
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
762 763
                                          : output_grad.dims()[3]);
    const int output_height =
764
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
765 766
                                          : output_grad.dims()[1]);
    const int output_width =
767
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
768
                                          : output_grad.dims()[2]);
769 770
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
771 772 773 774
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
775 776
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
777 778 779

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
780
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
781

782
    int block_size = 512;
783 784 785 786
    if (output_width > 1024 && output_width <= 2048)
      block_size = (output_width - 1) / 2 + 1;
    else if (output_width > 512 && output_width <= 1024)
      block_size = output_width;
787 788 789 790 791 792 793 794 795
    int crop_output_height =
        std::min(std::max(block_size / output_width, 1), output_height);
    dim3 grid(ksize_width, ksize_height, output_channels);
    dim3 threads(std::min(output_width, block_size), crop_output_height, 1);
    int filter_multiplier = output_channels / input_channels;

#define check_case(c_filter_multiplier)                                       \
  if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \
    KernelDepthwiseConvFilterGradSp<                                          \
796 797
        T, c_filter_multiplier,                                               \
        fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>(       \
798 799 800 801
        output_grad_data, input_data, batch_size, output_channels,            \
        output_height, output_width, input_channels, input_height,            \
        input_width, filter_multiplier, ksize_height, ksize_width,            \
        stride_height, stride_width, padding_height, padding_width,           \
802
        dilate_height, dilate_width, filter_grad_data, data_layout);          \
803 804 805 806 807
    return;                                                                   \
  }
    check_case(1);
    check_case(0);
#undef check_case
Z
zlx 已提交
808 809 810
  }
};

811 812
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, false>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, false>;
Z
zlx 已提交
813

814 815
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             false>;
Z
zlx 已提交
816
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
817 818 819 820 821 822 823 824 825 826 827 828
                                             double, false>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float, false>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              double, false>;

template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, true>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, true>;

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
                                             true>;
Z
zlx 已提交
829
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
830
                                             double, true>;
831 832

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
833
                                              float, true>;
Z
zlx 已提交
834
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
835
                                              double, true>;
Z
zlx 已提交
836 837 838 839

}  // namespace math
}  // namespace operators
}  // namespace paddle