depthwise_conv.cu 14.8 KB
Newer Older
Z
zlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

X
xzl 已提交
15
#include "paddle/operators/math/depthwise_conv.h"
Z
zlx 已提交
16 17 18 19 20 21
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

22 23
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
Z
zlx 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
template <typename T>
__global__ void KernelDepthwiseConv(
    const int nthreads, const T* const input_data, const T* const filter_data,
    const int batch_size, const int output_channels, const int output_height,
    const int output_width, const int input_channels, const int input_height,
    const int input_width, const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, T* const output_data) {
  int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;

  if (index < nthreads) {
    const int batch = index / output_channels / output_height / output_width;
    const int c_out = (index / output_height / output_width) % output_channels;
    const int h_out = (index / output_width) % output_height;
    const int w_out = index % output_width;

    const int c_in = c_out / filter_multiplier;
    const T* weight = filter_data + c_out * filter_height * filter_width;
    T value = 0;
    const int h_in_start = -padding_height + h_out * stride_height;
    const int w_in_start = -padding_width + w_out * stride_width;
    const int h_in_end =
        -padding_height + h_out * stride_height + filter_height - 1;
    const int w_in_end =
        -padding_width + w_out * stride_width + filter_width - 1;
    if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) &&
        (w_in_end < input_width)) {
      for (int kh = 0; kh < filter_height; ++kh) {
        for (int kw = 0; kw < filter_width; ++kw) {
          const int h_in = -padding_height + h_out * stride_height + kh;
          const int w_in = -padding_width + w_out * stride_width + kw;
          const int offset =
              ((batch * input_channels + c_in) * input_height + h_in) *
                  input_width +
              w_in;
          value += (*weight) * input_data[offset];
          ++weight;
        }
      }
    } else {
      for (int kh = 0; kh < filter_height; ++kh) {
        for (int kw = 0; kw < filter_width; ++kw) {
          const int h_in = -padding_height + h_out * stride_height + kh;
          const int w_in = -padding_width + w_out * stride_width + kw;
          if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) &&
              (w_in < input_width)) {
            const int offset =
                ((batch * input_channels + c_in) * input_height + h_in) *
                    input_width +
                w_in;
            value += (*weight) * input_data[offset];
          }
          ++weight;
        }
      }
    }
    output_data[index] = value;
  }
}
83

Z
zlx 已提交
84 85
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
template <typename T>
86 87 88 89 90 91 92 93
__global__ void KernelDepthwiseConvInputGrad(
    const int nthreads, const T* const output_grad_data,
    const T* const filter_data, const int batch_size, const int output_channels,
    const int output_height, const int output_width, const int input_channels,
    const int input_height, const int input_width, const int filter_multiplier,
    const int filter_height, const int filter_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* const input_grad_data) {
Z
zlx 已提交
94 95
  int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
96 97 98 99
    const int batch = index / input_channels / input_height / input_width;
    const int c_in = (index / input_height / input_width) % input_channels;
    const int h_in = (index / input_width) % input_height;
    const int w_in = index % input_width;
Z
zlx 已提交
100

101
    const int c_out_start = c_in * filter_multiplier;
Z
zlx 已提交
102

103 104
    int h_out_start =
        (h_in - filter_height + padding_height + stride_height) / stride_height;
Z
zlx 已提交
105
    h_out_start = 0 > h_out_start ? 0 : h_out_start;
106 107 108 109 110 111

    int h_out_end = (h_in + padding_height) / stride_height;
    h_out_end = output_height - 1 < h_out_end ? output_height - 1 : h_out_end;

    int w_out_start =
        (w_in - filter_width + padding_width + stride_width) / stride_width;
Z
zlx 已提交
112
    w_out_start = 0 > w_out_start ? 0 : w_out_start;
113 114 115

    int w_out_end = (w_in + padding_width) / stride_width;
    w_out_end = output_width - 1 < w_out_end ? output_width - 1 : w_out_end;
Z
zlx 已提交
116 117 118

    T value = 0;

119
    for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
Z
zlx 已提交
120 121
         c_out++) {
      for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) {
122
        const int filter_h = h_in + padding_height - h_out * stride_height;
Z
zlx 已提交
123
        for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) {
124 125 126 127 128 129
          const int filter_w = w_in + padding_width - w_out * stride_width;
          const int filter_offset = c_out * filter_height * filter_width +
                                    filter_h * filter_width + filter_w;
          const int output_grad_offset =
              ((batch * output_channels + c_out) * output_height + h_out) *
                  output_width +
Z
zlx 已提交
130
              w_out;
131 132
          value +=
              output_grad_data[output_grad_offset] * filter_data[filter_offset];
Z
zlx 已提交
133 134 135
        }
      }
    }
136
    input_grad_data[index] += value;
Z
zlx 已提交
137 138 139
  }
}

140
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
Z
zlx 已提交
141
template <typename T>
142 143 144 145 146 147 148 149
__global__ void KernelDepthwiseConvFilterGrad(
    const int nthreads, const T* const output_grad_data,
    const T* const input_data, const int num, const int output_channels,
    const int output_height, const int output_width, const int input_channels,
    const int input_height, const int input_width, const int filter_multiplier,
    const int filter_height, const int filter_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* const filter_grad_data) {
Z
zlx 已提交
150 151
  int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    const int w_out = index % output_width;
    const int h_out = (index / output_width) % output_height;
    const int c_out = (index / output_width / output_height) % output_channels;
    const int batch = (index / output_width / output_height / output_channels);
    const int c_in = c_out / filter_multiplier;
    const int h_in_start = -padding_height + h_out * stride_height;
    const int w_in_start = -padding_width + w_out * stride_width;
    const int h_in_end =
        -padding_height + h_out * stride_height + filter_height;
    const int w_in_end = -padding_width + w_out * stride_width + filter_width;
    if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) &&
        (w_in_end < input_width)) {
      for (int kw = 0; kw < filter_width; kw++) {
        for (int kh = 0; kh < filter_height; kh++) {
          const int h_in = -padding_height + h_out * stride_height + kh;
          const int w_in = -padding_width + w_out * stride_width + kw;
          const int offset =
              ((batch * input_channels + c_in) * input_height + h_in) *
                  input_width +
              w_in;
          const T diff_temp = output_grad_data[index] * input_data[offset];
          T* addr = filter_grad_data + c_out * filter_height * filter_width +
                    kh * filter_width + kw;
          paddle::platform::CudaAtomicAdd(addr, diff_temp);
        }
      }
Z
zlx 已提交
178
    } else {
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
      for (int kw = 0; kw < filter_width; kw++) {
        for (int kh = 0; kh < filter_height; kh++) {
          const int h_in = -padding_height + h_out * stride_height + kh;
          const int w_in = -padding_width + w_out * stride_width + kw;
          if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) &&
              (w_in < input_width)) {
            const int offset =
                ((batch * input_channels + c_in) * input_height + h_in) *
                    input_width +
                w_in;
            const T diff_temp = output_grad_data[index] * input_data[offset];
            T* addr = filter_grad_data + c_out * filter_height * filter_width +
                      kh * filter_width + kw;
            paddle::platform::CudaAtomicAdd(addr, diff_temp);
          }
        }
      }
Z
zlx 已提交
196 197 198 199 200 201 202 203 204
    }
  }
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
X
xzl 已提交
205
template <class T>
Z
zlx 已提交
206 207 208 209
class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
210 211
                  const framework::Tensor& filter, std::vector<int>& strides,
                  std::vector<int>& paddings, framework::Tensor* output) {
Z
zlx 已提交
212 213 214 215 216 217 218
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
219 220
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

X
xzl 已提交
235
    KernelDepthwiseConv<T><<<grid, threads, 0, context.stream()>>>(
Z
zlx 已提交
236 237 238 239 240 241 242 243 244
        nthreads, input_data, filter_data, batch_size, output_channels,
        output_height, output_width, input_channels, input_height, input_width,
        output_channels / input_channels, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        output_data);
  }
};

template <typename T>
245
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
Z
zlx 已提交
246 247 248
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
249 250
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
Z
zlx 已提交
251
                  std::vector<int>& strides, std::vector<int>& paddings,
252
                  framework::Tensor* input_grad) {
Z
zlx 已提交
253 254 255 256
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
257 258 259 260 261 262
    const int output_channels = output_grad.dims()[1];
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
263 264 265 266
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

267
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
268 269 270 271 272 273 274 275
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

276 277 278 279 280 281
    KernelDepthwiseConvInputGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, output_grad_data, filter_data, batch_size, output_channels,
        output_height, output_width, input_channels, input_height, input_width,
        output_channels / input_channels, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
Z
zlx 已提交
282 283 284 285
  }
};

template <typename T>
286
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
Z
zlx 已提交
287 288 289
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
290
                  const framework::Tensor& output_grad,
Z
zlx 已提交
291
                  std::vector<int>& strides, std::vector<int>& paddings,
292
                  framework::Tensor* filter_grad) {
Z
zlx 已提交
293 294 295 296
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
297 298 299 300 301
    const int output_channels = output_grad.dims()[1];
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
302 303 304 305 306 307 308
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
309
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
310 311

    int nthreads = batch_size * output_channels * output_height * output_width;
312

Z
zlx 已提交
313 314 315 316
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

317 318 319 320 321 322
    KernelDepthwiseConvFilterGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, output_grad_data, input_data, batch_size, output_channels,
        output_height, output_width, input_channels, input_height, input_width,
        output_channels / input_channels, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        filter_grad_data);
Z
zlx 已提交
323 324 325
  }
};

326 327
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double>;
Z
zlx 已提交
328 329

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
330
                                             float>;
Z
zlx 已提交
331
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
332 333 334 335
                                             double>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float>;
Z
zlx 已提交
336
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
337
                                              double>;
Z
zlx 已提交
338 339 340 341

}  // namespace math
}  // namespace operators
}  // namespace paddle