depthwise_conv.cu 14.4 KB
Newer Older
Z
zlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

X
xzl 已提交
15
#include "paddle/operators/math/depthwise_conv.h"
Z
zlx 已提交
16 17 18 19 20 21
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

22 23
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
Z
zlx 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
template <typename T>
__global__ void KernelDepthwiseConv(
    const int nthreads, const T* const input_data, const T* const filter_data,
    const int batch_size, const int output_channels, const int output_height,
    const int output_width, const int input_channels, const int input_height,
    const int input_width, const int filter_multiplier, const int filter_height,
    const int filter_width, const int stride_height, const int stride_width,
    const int padding_height, const int padding_width, T* const output_data) {
  int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;

  if (index < nthreads) {
    const int batch = index / output_channels / output_height / output_width;
    const int c_out = (index / output_height / output_width) % output_channels;
    const int h_out = (index / output_width) % output_height;
    const int w_out = index % output_width;

    const int c_in = c_out / filter_multiplier;
    const T* weight = filter_data + c_out * filter_height * filter_width;
    T value = 0;
    const int h_in_start = -padding_height + h_out * stride_height;
    const int w_in_start = -padding_width + w_out * stride_width;
    const int h_in_end =
        -padding_height + h_out * stride_height + filter_height - 1;
    const int w_in_end =
        -padding_width + w_out * stride_width + filter_width - 1;
X
xzl 已提交
49 50 51 52

    const int in_offset =
        ((batch * input_channels + c_in) * input_height) * input_width;

Z
zlx 已提交
53 54 55 56
    if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) &&
        (w_in_end < input_width)) {
      for (int kh = 0; kh < filter_height; ++kh) {
        for (int kw = 0; kw < filter_width; ++kw) {
X
xzl 已提交
57 58 59 60
          const int h_in = h_in_start + kh;
          const int w_in = w_in_start + kw;
          const int offset = in_offset + h_in * input_width + w_in;

Z
zlx 已提交
61 62 63 64 65 66 67
          value += (*weight) * input_data[offset];
          ++weight;
        }
      }
    } else {
      for (int kh = 0; kh < filter_height; ++kh) {
        for (int kw = 0; kw < filter_width; ++kw) {
X
xzl 已提交
68 69
          const int h_in = h_in_start + kh;
          const int w_in = w_in_start + kw;
Z
zlx 已提交
70 71
          if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) &&
              (w_in < input_width)) {
X
xzl 已提交
72
            const int offset = in_offset + h_in * input_width + w_in;
Z
zlx 已提交
73 74 75 76 77 78 79 80 81
            value += (*weight) * input_data[offset];
          }
          ++weight;
        }
      }
    }
    output_data[index] = value;
  }
}
82

Z
zlx 已提交
83 84
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
template <typename T>
85 86 87 88 89 90 91 92
__global__ void KernelDepthwiseConvInputGrad(
    const int nthreads, const T* const output_grad_data,
    const T* const filter_data, const int batch_size, const int output_channels,
    const int output_height, const int output_width, const int input_channels,
    const int input_height, const int input_width, const int filter_multiplier,
    const int filter_height, const int filter_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* const input_grad_data) {
Z
zlx 已提交
93 94
  int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
95 96 97 98
    const int batch = index / input_channels / input_height / input_width;
    const int c_in = (index / input_height / input_width) % input_channels;
    const int h_in = (index / input_width) % input_height;
    const int w_in = index % input_width;
Z
zlx 已提交
99

100
    const int c_out_start = c_in * filter_multiplier;
Z
zlx 已提交
101

102 103
    int h_out_start =
        (h_in - filter_height + padding_height + stride_height) / stride_height;
Z
zlx 已提交
104
    h_out_start = 0 > h_out_start ? 0 : h_out_start;
105 106 107 108 109 110

    int h_out_end = (h_in + padding_height) / stride_height;
    h_out_end = output_height - 1 < h_out_end ? output_height - 1 : h_out_end;

    int w_out_start =
        (w_in - filter_width + padding_width + stride_width) / stride_width;
Z
zlx 已提交
111
    w_out_start = 0 > w_out_start ? 0 : w_out_start;
112 113 114

    int w_out_end = (w_in + padding_width) / stride_width;
    w_out_end = output_width - 1 < w_out_end ? output_width - 1 : w_out_end;
Z
zlx 已提交
115 116 117

    T value = 0;

118
    for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
Z
zlx 已提交
119 120
         c_out++) {
      for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) {
121
        const int filter_h = h_in + padding_height - h_out * stride_height;
Z
zlx 已提交
122
        for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) {
123 124 125 126 127 128
          const int filter_w = w_in + padding_width - w_out * stride_width;
          const int filter_offset = c_out * filter_height * filter_width +
                                    filter_h * filter_width + filter_w;
          const int output_grad_offset =
              ((batch * output_channels + c_out) * output_height + h_out) *
                  output_width +
Z
zlx 已提交
129
              w_out;
130 131
          value +=
              output_grad_data[output_grad_offset] * filter_data[filter_offset];
Z
zlx 已提交
132 133 134
        }
      }
    }
135
    input_grad_data[index] += value;
Z
zlx 已提交
136 137 138
  }
}

139
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
Z
zlx 已提交
140
template <typename T>
141 142 143 144 145 146 147 148
__global__ void KernelDepthwiseConvFilterGrad(
    const int nthreads, const T* const output_grad_data,
    const T* const input_data, const int num, const int output_channels,
    const int output_height, const int output_width, const int input_channels,
    const int input_height, const int input_width, const int filter_multiplier,
    const int filter_height, const int filter_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* const filter_grad_data) {
Z
zlx 已提交
149 150
  int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
151 152 153 154 155 156 157 158 159 160
    const int w_out = index % output_width;
    const int h_out = (index / output_width) % output_height;
    const int c_out = (index / output_width / output_height) % output_channels;
    const int batch = (index / output_width / output_height / output_channels);
    const int c_in = c_out / filter_multiplier;
    const int h_in_start = -padding_height + h_out * stride_height;
    const int w_in_start = -padding_width + w_out * stride_width;
    const int h_in_end =
        -padding_height + h_out * stride_height + filter_height;
    const int w_in_end = -padding_width + w_out * stride_width + filter_width;
X
xzl 已提交
161 162 163 164 165
    const int in_offset =
        (batch * input_channels + c_in) * input_height * input_width;

    T* addr_offset = filter_grad_data + c_out * filter_height * filter_width;

166 167 168 169
    if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) &&
        (w_in_end < input_width)) {
      for (int kw = 0; kw < filter_width; kw++) {
        for (int kh = 0; kh < filter_height; kh++) {
X
xzl 已提交
170 171 172
          const int h_in = h_in_start + kh;
          const int w_in = w_in_start + kw;
          const int offset = in_offset + h_in * input_width + w_in;
173
          const T diff_temp = output_grad_data[index] * input_data[offset];
X
xzl 已提交
174
          T* addr = addr_offset + kh * filter_width + kw;
175 176 177
          paddle::platform::CudaAtomicAdd(addr, diff_temp);
        }
      }
Z
zlx 已提交
178
    } else {
179 180
      for (int kw = 0; kw < filter_width; kw++) {
        for (int kh = 0; kh < filter_height; kh++) {
X
xzl 已提交
181 182
          const int h_in = h_in_start + kh;
          const int w_in = w_in_start + kw;
183 184
          if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) &&
              (w_in < input_width)) {
X
xzl 已提交
185
            const int offset = in_offset + h_in * input_width + w_in;
186
            const T diff_temp = output_grad_data[index] * input_data[offset];
X
xzl 已提交
187
            T* addr = addr_offset + kh * filter_width + kw;
188 189 190 191
            paddle::platform::CudaAtomicAdd(addr, diff_temp);
          }
        }
      }
Z
zlx 已提交
192 193 194 195 196 197 198 199 200
    }
  }
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
X
xzl 已提交
201
template <class T>
Z
zlx 已提交
202 203 204 205
class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
X
xzl 已提交
206 207 208
                  const framework::Tensor& filter,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* output) {
Z
zlx 已提交
209 210 211 212 213 214 215
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
216 217
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

X
xzl 已提交
232
    KernelDepthwiseConv<T><<<grid, threads, 0, context.stream()>>>(
Z
zlx 已提交
233 234 235 236 237 238 239 240 241
        nthreads, input_data, filter_data, batch_size, output_channels,
        output_height, output_width, input_channels, input_height, input_width,
        output_channels / input_channels, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        output_data);
  }
};

template <typename T>
242
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
Z
zlx 已提交
243 244 245
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
246 247
                  const framework::Tensor& filter,
                  const framework::Tensor& output_grad,
X
xzl 已提交
248 249
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
250
                  framework::Tensor* input_grad) {
Z
zlx 已提交
251 252 253 254
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
255 256 257 258 259 260
    const int output_channels = output_grad.dims()[1];
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
261 262 263 264
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

265
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
266 267 268 269 270 271 272 273
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

274 275 276 277 278 279
    KernelDepthwiseConvInputGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, output_grad_data, filter_data, batch_size, output_channels,
        output_height, output_width, input_channels, input_height, input_width,
        output_channels / input_channels, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
Z
zlx 已提交
280 281 282 283
  }
};

template <typename T>
284
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
Z
zlx 已提交
285 286 287
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
288
                  const framework::Tensor& output_grad,
X
xzl 已提交
289 290
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
291
                  framework::Tensor* filter_grad) {
Z
zlx 已提交
292 293 294 295
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
296 297 298 299 300
    const int output_channels = output_grad.dims()[1];
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
301 302 303 304 305 306 307
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
308
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
309 310

    int nthreads = batch_size * output_channels * output_height * output_width;
311

Z
zlx 已提交
312 313 314 315
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

316 317 318 319 320 321
    KernelDepthwiseConvFilterGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, output_grad_data, input_data, batch_size, output_channels,
        output_height, output_width, input_channels, input_height, input_width,
        output_channels / input_channels, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        filter_grad_data);
Z
zlx 已提交
322 323 324
  }
};

325 326
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double>;
Z
zlx 已提交
327 328

template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
329
                                             float>;
Z
zlx 已提交
330
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
331 332 333 334
                                             double>;

template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
                                              float>;
Z
zlx 已提交
335
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
336
                                              double>;
Z
zlx 已提交
337 338 339 340

}  // namespace math
}  // namespace operators
}  // namespace paddle