depthwise_conv.cu 21.0 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
A
Abhinav Arora 已提交
16
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/depthwise_conv.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
Z
zlx 已提交
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#if CUDA_VERSION < 9000
  for (int offset = 16; offset > 0; offset /= 2)
    val += __shfl_down(val, offset);
  return val;
#else
#define FULL_MASK 0xffffffff
  for (int offset = 16; offset > 0; offset /= 2)
    val += __shfl_down_sync(FULL_MASK, val, offset);
  return val;
#endif
}
__forceinline__ __device__ unsigned lane_id() {
  unsigned ret;
  asm volatile("mov.u32 %0, %laneid;" : "=r"(ret));
  return ret;
}

__forceinline__ __device__ unsigned warp_id() {
  unsigned ret;
  asm volatile("mov.u32 %0, %warpid;" : "=r"(ret));
  return ret;
}

49 50
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
Z
zlx 已提交
51
template <typename T>
52 53 54 55 56
__device__ __inline__ void KernelDepthwiseConv(
    const T* const input_data, const T* const filter_data, const int batch_size,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
Z
zlx 已提交
57
    const int filter_width, const int stride_height, const int stride_width,
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* const output_data) {
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
      const T* weight = filter_data + c_out * filter_height * filter_width;
      T value = 0;
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + filter_height * dilate_height;
      const int w_in_end = w_in_start + filter_width * dilate_width;

      const int in_offset =
          ((batch * input_channels + c_in) * input_height) * input_width;

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;
      int weight_offset = 0;

      for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
        for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
          if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
              w_in < w_end) {
            const int offset = in_offset + h_in * input_width + w_in;
            value += weight[weight_offset] * input_data[offset];
          }
          weight_offset++;
        }
Z
zlx 已提交
91
      }
92 93 94 95
      int index =
          ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
          w_out;
      output_data[index] = value;
Z
zlx 已提交
96 97 98
    }
  }
}
99

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
template <typename T, int c_filter_multiplier, int c_stride>
__global__ void KernelDepthwiseConvSp(
    const T* const input_data, const T* const filter_data, const int batch_size,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* const output_data) {
  if (c_filter_multiplier == 0)
    KernelDepthwiseConv<T>(input_data, filter_data, batch_size, output_channels,
                           output_height, output_width, input_channels,
                           input_height, input_width, filter_multiplier,
                           filter_height, filter_width, stride_height,
                           stride_width, padding_height, padding_width,
                           dilate_height, dilate_width, output_data);

  else
    KernelDepthwiseConv<T>(input_data, filter_data, batch_size, output_channels,
                           output_height, output_width, input_channels,
                           input_height, input_width, c_filter_multiplier,
                           filter_height, filter_height, c_stride, c_stride,
                           padding_height, padding_width, dilate_height,
                           dilate_width, output_data);
}

Z
zlx 已提交
126 127
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
template <typename T>
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
__device__ __inline__ void KernelDepthwiseConvInputGrad(
    const T* const output_grad_data, const T* const filter_data,
    const int batch_size, const int output_channels, const int output_height,
    const int output_width, const int input_channels, const int input_height,
    const int input_width, const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* const input_grad_data) {
  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_in = blockIdx.x;

      const int c_out_start = c_in * filter_multiplier;

      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;

      int h_out_end = h_in + padding_height;

      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

      int w_out_end = w_in + padding_width;

      T value = 0;

      for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
           c_out++) {
        int filter_offset = (c_out + 1) * filter_height * filter_width;
        for (int h_out = h_out_start; h_out <= h_out_end;
             h_out += dilate_height) {
          for (int w_out = w_out_start; w_out <= w_out_end;
               w_out += dilate_width) {
            filter_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              const int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
177 178
        }
      }
179 180 181 182
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
      input_grad_data[index] = value;
Z
zlx 已提交
183 184 185 186
    }
  }
}

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
template <typename T, int c_filter_multiplier, int c_stride>
__global__ void KernelDepthwiseConvInputGradSp(
    const T* const output_grad_data, const T* const filter_data,
    const int batch_size, const int output_channels, const int output_height,
    const int output_width, const int input_channels, const int input_height,
    const int input_width, const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* const input_grad_data) {
  if (c_filter_multiplier == 0)
    KernelDepthwiseConvInputGrad<T>(
        output_grad_data, filter_data, batch_size, output_channels,
        output_height, output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
        dilate_width, input_grad_data);
  else
    KernelDepthwiseConvInputGrad<T>(
        output_grad_data, filter_data, batch_size, output_channels,
        output_height, output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
        padding_height, padding_width, dilate_height, dilate_width,
        input_grad_data);
}

212
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
Z
zlx 已提交
213
template <typename T>
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
__device__ __inline__ void KernelDepthwiseConvFilterGrad(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* filter_grad_data) {
  T s = 0;

  int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
  int lid = lane_id();

  for (int image_w = threadIdx.x; image_w < output_width;
       image_w += blockDim.x) {
    for (int bid = 0; bid < num; bid++) {
      for (int image_h = threadIdx.y; image_h < output_height;
           image_h += blockDim.y) {
        int kernel_id = blockIdx.z;
        int kernel_h = blockIdx.y * dilate_height - padding_height;
        int kernel_w = blockIdx.x * dilate_width - padding_width;

        int image_hk = image_h * stride_height + kernel_h;
        int image_wk = image_w * stride_width + kernel_w;
        if (image_hk < 0 || image_hk >= input_height) continue;
        if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
  ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))

        s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
             input_data[((bid * (gridDim.z / filter_multiplier) +
                          kernel_id / filter_multiplier) *
                             input_height +
                         image_hk) *
                            input_width +
                        image_wk];

#undef gaid
252
      }
Z
zlx 已提交
253 254
    }
  }
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
#if __CUDA_ARCH__ >= 530
  s = warpReduceSum<T>(s);
  if (lid == 0) paddle::platform::CudaAtomicAdd(&filter_grad_data[gbid], s);
#else
  paddle::platform::CudaAtomicAdd(&filter_grad_data[gbid], s);
#endif
}

template <typename T, int c_filter_multiplier>
__global__ void KernelDepthwiseConvFilterGradSp(
    const T* output_grad_data, const T* input_data, const int num,
    const int output_channels, const int output_height, const int output_width,
    const int input_channels, const int input_height, const int input_width,
    const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, const int dilate_height,
    const int dilate_width, T* filter_grad_data) {
  if (c_filter_multiplier == 0)
    KernelDepthwiseConvFilterGrad<T>(
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
        dilate_width, filter_grad_data);
  else
    KernelDepthwiseConvFilterGrad<T>(
        output_grad_data, input_data, num, output_channels, output_height,
        output_width, input_channels, input_height, input_width,
        c_filter_multiplier, filter_height, filter_width, stride_height,
        stride_width, padding_height, padding_width, dilate_height,
        dilate_width, filter_grad_data);
Z
zlx 已提交
286 287 288 289 290 291 292
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
X
xzl 已提交
293
template <class T>
Z
zlx 已提交
294 295 296 297
class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
X
xzl 已提交
298 299
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
300 301 302
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
                  framework::Tensor* output) {
Z
zlx 已提交
303 304 305 306 307 308 309
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
310 311
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
312 313 314 315
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
316 317
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
318 319 320 321 322

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
    int thread = 512;
    int blocks = std::min(std::max(thread / output_width, 1), output_height);
    dim3 threads(std::min(output_width, thread), blocks, 1);
    dim3 grid(output_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;
#define check_case(c_filter_multiplier, c_stride)                            \
  if (c_filter_multiplier == 0 ||                                            \
      filter_multiplier == c_filter_multiplier &&                            \
          stride_height == stride_width && stride_height == c_stride) {      \
    KernelDepthwiseConvSp<T, c_filter_multiplier,                            \
                          c_stride><<<grid, threads, 0, context.stream()>>>( \
        input_data, filter_data, batch_size, output_channels, output_height, \
        output_width, input_channels, input_height, input_width,             \
        filter_multiplier, ksize_height, ksize_width, stride_height,         \
        stride_width, padding_height, padding_width, dilate_height,          \
        dilate_width, output_data);                                          \
    return;                                                                  \
  }
    check_case(1, 1);
    check_case(1, 2);
    // NOTE(liangdun): 0,0 for other case
    // add other case if needed, e.g. check_case(2^n,1)
    check_case(0, 0);
#undef check_case
Z
zlx 已提交
347 348 349 350
  }
};

template <typename T>
351
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
Z
zlx 已提交
352 353 354
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
355 356
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
357 358
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
359
                  const std::vector<int>& dilations,
360
                  framework::Tensor* input_grad) {
Z
zlx 已提交
361 362 363 364
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
365 366 367 368 369 370
    const int output_channels = output_grad.dims()[1];
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
371 372 373
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
374 375
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
376

377
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
378 379 380
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
    int thread = 512;
    int blocks = std::min(std::max(thread / input_width, 1), input_height);
    dim3 threads(std::min(input_width, thread), blocks, 1);
    dim3 grid(input_channels, batch_size, 1);
    int filter_multiplier = output_channels / input_channels;

#define check_case(c_filter_multiplier, c_stride)                       \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride) { \
    KernelDepthwiseConvInputGradSp<                                     \
        T, c_filter_multiplier,                                         \
        c_stride><<<grid, threads, 0, context.stream()>>>(              \
        output_grad_data, filter_data, batch_size, output_channels,     \
        output_height, output_width, input_channels, input_height,      \
        input_width, filter_multiplier, ksize_height, ksize_width,      \
        stride_height, stride_width, padding_height, padding_width,     \
        dilate_height, dilate_width, input_grad_data);                  \
    return;                                                             \
  }
    check_case(1, 1);
    check_case(1, 2);
    // NOTE(liangdun): 0,0 for other case
    // add other case if needed, e.g. check_case(2^n,1)
    check_case(0, 0);
#undef check_case
Z
zlx 已提交
407 408 409 410
  }
};

template <typename T>
411
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
Z
zlx 已提交
412 413 414
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
415
                  const framework::Tensor& output_grad,
X
xzl 已提交
416 417
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
418
                  const std::vector<int>& dilations,
419
                  framework::Tensor* filter_grad) {
Z
zlx 已提交
420 421 422 423
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
424 425 426 427 428
    const int output_channels = output_grad.dims()[1];
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
429 430 431 432
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
433 434
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
435 436 437

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
438
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
439

440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
    int block_size = 512;
    int crop_output_height =
        std::min(std::max(block_size / output_width, 1), output_height);
    dim3 grid(ksize_width, ksize_height, output_channels);
    dim3 threads(std::min(output_width, block_size), crop_output_height, 1);
    int filter_multiplier = output_channels / input_channels;

#define check_case(c_filter_multiplier)                                       \
  if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \
    KernelDepthwiseConvFilterGradSp<                                          \
        T, c_filter_multiplier><<<grid, threads, 0, context.stream()>>>(      \
        output_grad_data, input_data, batch_size, output_channels,            \
        output_height, output_width, input_channels, input_height,            \
        input_width, filter_multiplier, ksize_height, ksize_width,            \
        stride_height, stride_width, padding_height, padding_width,           \
        dilate_height, dilate_width, filter_grad_data);                       \
    return;                                                                   \
  }
    check_case(1);
    check_case(0);
#undef check_case
Z
zlx 已提交
461 462 463
  }
};

464 465
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double>;
Z
zlx 已提交
466 467

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
468
                                             float>;
Z
zlx 已提交
469
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
470 471 472 473
                                             double>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float>;
Z
zlx 已提交
474
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
475
                                              double>;
Z
zlx 已提交
476 477 478 479

}  // namespace math
}  // namespace operators
}  // namespace paddle