vol2col.cu 11.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

A
Abhinav Arora 已提交
15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/vol2col.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
19 20 21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

template <class T>
__global__ void vol2col(int num_kernels, const T* data_vol, int depth,
C
chengduoZH 已提交
26 27 28 29 30 31
                        int height, int width, int dilation_d, int dilation_h,
                        int dilation_w, int filter_depth, int filter_height,
                        int filter_width, int stride_depth, int stride_height,
                        int stride_width, int padding_depth, int padding_height,
                        int padding_width, int output_detph, int output_height,
                        int output_width, T* data_col) {
C
chengduoZH 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
    int w_out = index % output_width;
    int h_out = (index / output_width) % output_height;
    int d_out = (index / output_width / output_height) % output_detph;
    int channel_in = index / output_width / output_height / output_detph;
    int channel_out = channel_in * filter_depth * filter_height * filter_width;
    int w_in = w_out * stride_width - padding_width;
    int h_in = h_out * stride_height - padding_height;
    int d_in = d_out * stride_depth - padding_depth;

    data_col += ((channel_out * output_detph + d_out) * output_height + h_out) *
                    output_width +
                w_out;
    data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
    for (int k = 0; k < filter_depth; ++k) {
      for (int i = 0; i < filter_height; ++i) {
        for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
50 51 52 53 54
          int d = d_in + k * dilation_d;
          int h = h_in + i * dilation_h;
          int w = w_in + j * dilation_w;
          int col_idx = (k * dilation_d * height + i * dilation_h) * width +
                        j * dilation_w;
C
chengduoZH 已提交
55 56
          *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
                       w < width)
C
chengduoZH 已提交
57
                          ? data_vol[col_idx]
C
chengduoZH 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
                          : 0;
          data_col += output_detph * output_height * output_width;
        }
      }
    }
  }
}

/*
 * im = [input_channels,intpu_depth, input_height, input_width]
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
Q
QI JUN 已提交
73
class Vol2ColFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
74
 public:
Q
QI JUN 已提交
75
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
76 77 78 79 80
                  const framework::Tensor& vol,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  framework::Tensor* col) const {
81 82
    PADDLE_ENFORCE_EQ(vol.dims().size(), 4);
    PADDLE_ENFORCE_EQ(col->dims().size(), 7);
C
chengduoZH 已提交
83 84 85 86 87

    int input_channels = vol.dims()[0];
    int input_depth = vol.dims()[1];
    int input_height = vol.dims()[2];
    int input_width = vol.dims()[3];
C
chengduoZH 已提交
88 89 90 91 92 93
    int filter_depth = col->dims()[1];
    int filter_height = col->dims()[2];
    int filter_width = col->dims()[3];
    int output_depth = col->dims()[4];
    int output_height = col->dims()[5];
    int output_width = col->dims()[6];
C
chengduoZH 已提交
94

L
liym27 已提交
95 96 97 98 99 100 101 102
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
    PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
C
chengduoZH 已提交
103 104
                       ((dilations[0] * (filter_depth - 1) + 1))) /
                              strides[0] +
C
chengduoZH 已提交
105 106 107
                          1,
                      output_depth,
                      "input_depth and output_depth are "
L
liym27 已提交
108 109
                      "mismatching.");
    PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
C
chengduoZH 已提交
110 111
                       ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
C
chengduoZH 已提交
112 113 114
                          1,
                      output_height,
                      "input_height and output_height are "
L
liym27 已提交
115 116
                      "mismatching.");
    PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
C
chengduoZH 已提交
117 118
                       ((dilations[2] * (filter_width - 1) + 1))) /
                              strides[2] +
C
chengduoZH 已提交
119 120 121
                          1,
                      output_width,
                      "input_width and output_width are "
L
liym27 已提交
122
                      "mismatching.");
C
chengduoZH 已提交
123

C
chengduoZH 已提交
124 125 126 127 128
    int num_outputs =
        input_channels * output_depth * output_height * output_width;

    const int threads = 1024;
    const int blocks = (num_outputs + 1024 - 1) / 1024;
Q
QI JUN 已提交
129
    vol2col<T><<<blocks, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
130
        num_outputs, vol.data<T>(), input_depth, input_height, input_width,
C
chengduoZH 已提交
131
        dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
L
liym27 已提交
132 133
        filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
        pad_w_left, output_depth, output_height, output_width, col->data<T>());
C
chengduoZH 已提交
134 135 136 137 138
  }
};

template <class T>
__global__ void col2vol(int num_kernels, const T* data_col, int depth,
C
chengduoZH 已提交
139 140 141 142 143 144 145 146 147 148
                        int height, int width, int dilation_d, int dilation_h,
                        int dilation_w, int filter_depth, int filter_height,
                        int filter_width, int stride_depth, int stride_height,
                        int stride_width, int padding_depth, int padding_height,
                        int padding_width, int output_detph, int output_height,
                        int output_width, T* data_vol) {
  const int d_filter_depth = dilation_d * (filter_depth - 1) + 1;
  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

C
chengduoZH 已提交
149 150 151 152 153 154 155
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
    T src_val = 0;
    int w = index % width + padding_width;
    int h = (index / width) % height + padding_height;
    int d = (index / width / height) % depth + padding_depth;
    int c = index / width / height / depth;
C
chengduoZH 已提交
156

C
chengduoZH 已提交
157 158
    // compute the start and end of the output
    int w_col_start =
C
chengduoZH 已提交
159
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
C
chengduoZH 已提交
160 161
    int w_col_end = min(w / stride_width + 1, output_width);
    int h_col_start =
C
chengduoZH 已提交
162
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
C
chengduoZH 已提交
163 164
    int h_col_end = min(h / stride_height + 1, output_height);
    int d_col_start =
C
chengduoZH 已提交
165
        (d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1;
C
chengduoZH 已提交
166 167 168 169 170
    int d_col_end = min(d / stride_depth + 1, output_detph);

    for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
C
chengduoZH 已提交
171 172 173 174 175 176 177 178 179 180 181 182
          int d_off = (d - d_col * stride_depth);
          int h_off = (h - h_col * stride_height);
          int w_off = (w - w_col * stride_width);
          if (d_off % dilation_d == 0 && h_off % dilation_h == 0 &&
              w_off % dilation_w == 0) {
            d_off /= dilation_d;
            h_off /= dilation_h;
            w_off /= dilation_w;

            int data_col_index =
                (((((c * filter_depth + d_off) * filter_height + h_off) *
                       filter_width +
183 184 185
                   w_off)));
            data_col_index =
                ((data_col_index * output_detph + d_col) * output_height +
C
chengduoZH 已提交
186 187 188 189 190
                 h_col) *
                    output_width +
                w_col;
            src_val += data_col[data_col_index];
          }
C
chengduoZH 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204
        }
      }
    }
    data_vol[index] = src_val;
  }
}

/*
 * im = [input_channels, input_depth, input_height, input_width]
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
Q
QI JUN 已提交
205
class Col2VolFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
206
 public:
Q
QI JUN 已提交
207
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
208 209 210 211 212
                  const framework::Tensor& col,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  framework::Tensor* vol) const {
213 214
    PADDLE_ENFORCE_EQ(vol->dims().size(), 4);
    PADDLE_ENFORCE_EQ(col.dims().size(), 7);
C
chengduoZH 已提交
215

C
chengduoZH 已提交
216 217 218 219
    int input_channels = vol->dims()[0];
    int input_depth = vol->dims()[1];
    int input_height = vol->dims()[2];
    int input_width = vol->dims()[3];
C
chengduoZH 已提交
220 221 222 223 224 225 226
    int filter_depth = col.dims()[1];
    int filter_height = col.dims()[2];
    int filter_width = col.dims()[3];
    int output_depth = col.dims()[4];
    int output_height = col.dims()[5];
    int output_width = col.dims()[6];

L
liym27 已提交
227 228 229 230 231 232 233 234 235
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];

    PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
C
chengduoZH 已提交
236 237
                       ((dilations[0] * (filter_depth - 1) + 1))) /
                              strides[0] +
C
chengduoZH 已提交
238 239 240
                          1,
                      output_depth,
                      "input_depth and output_depth are "
L
liym27 已提交
241 242
                      "mismatching.");
    PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
C
chengduoZH 已提交
243 244
                       ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
C
chengduoZH 已提交
245 246 247
                          1,
                      output_height,
                      "input_height and output_height are "
L
liym27 已提交
248 249
                      "mismatching.");
    PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
C
chengduoZH 已提交
250 251
                       ((dilations[2] * (filter_width - 1) + 1))) /
                              strides[2] +
C
chengduoZH 已提交
252 253 254
                          1,
                      output_width,
                      "input_width and output_width are "
L
liym27 已提交
255
                      "mismatching.");
C
chengduoZH 已提交
256

C
chengduoZH 已提交
257 258 259 260 261
    int num_kernels = input_channels * input_depth * input_height * input_width;

    const int threads = 1024;
    const int blocks = (num_kernels + 1024 - 1) / 1024;

Q
QI JUN 已提交
262
    col2vol<T><<<blocks, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
263
        num_kernels, col.data<T>(), input_depth, input_height, input_width,
C
chengduoZH 已提交
264
        dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
L
liym27 已提交
265 266
        filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
        pad_w_left, output_depth, output_height, output_width, vol->data<T>());
C
chengduoZH 已提交
267 268 269
  }
};

Q
QI JUN 已提交
270 271 272 273
template class Vol2ColFunctor<platform::CUDADeviceContext, float>;
template class Vol2ColFunctor<platform::CUDADeviceContext, double>;
template class Col2VolFunctor<platform::CUDADeviceContext, float>;
template class Col2VolFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
274 275 276 277

}  // namespace math
}  // namespace operators
}  // namespace paddle