vol2col.cu 13.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

A
Abhinav Arora 已提交
15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/vol2col.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
19 20 21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

template <class T>
__global__ void vol2col(int num_kernels, const T* data_vol, int depth,
C
chengduoZH 已提交
26 27 28 29 30
                        int height, int width, int dilation_d, int dilation_h,
                        int dilation_w, int filter_depth, int filter_height,
                        int filter_width, int stride_depth, int stride_height,
                        int stride_width, int padding_depth, int padding_height,
                        int padding_width, int output_detph, int output_height,
31 32 33 34 35 36
                        int output_width, T* data_col,
                        const DataLayout data_layout) {
  int input_channels =
      num_kernels / output_detph / output_height / output_width;
  int channels_col =
      input_channels * filter_depth * filter_height * filter_width;
C
chengduoZH 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
    int w_out = index % output_width;
    int h_out = (index / output_width) % output_height;
    int d_out = (index / output_width / output_height) % output_detph;
    int channel_in = index / output_width / output_height / output_detph;
    int channel_out = channel_in * filter_depth * filter_height * filter_width;
    int w_in = w_out * stride_width - padding_width;
    int h_in = h_out * stride_height - padding_height;
    int d_in = d_out * stride_depth - padding_depth;

    data_col += ((channel_out * output_detph + d_out) * output_height + h_out) *
                    output_width +
                w_out;
    for (int k = 0; k < filter_depth; ++k) {
      for (int i = 0; i < filter_height; ++i) {
        for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
54 55 56
          int d = d_in + k * dilation_d;
          int h = h_in + i * dilation_h;
          int w = w_in + j * dilation_w;
57
          int vol_idx;
58
          if (data_layout != DataLayout::kNHWC) {
59 60 61 62 63
            vol_idx = ((channel_in * depth + d) * height + h) * width + w;
          } else {
            vol_idx =
                ((d * height + h) * width + w) * input_channels + channel_in;
          }
C
chengduoZH 已提交
64 65
          *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
                       w < width)
66
                          ? data_vol[vol_idx]
C
chengduoZH 已提交
67 68 69 70 71 72 73 74 75
                          : 0;
          data_col += output_detph * output_height * output_width;
        }
      }
    }
  }
}

/*
76 77 78 79
 * im = [input_channels,intpu_depth, input_height, input_width] for
 * channels_first
 * im = [input_depth, input_height, input_width, input_channels] for
 * channels_last
C
chengduoZH 已提交
80 81 82 83 84
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
Q
QI JUN 已提交
85
class Vol2ColFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
86
 public:
Q
QI JUN 已提交
87
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
88 89 90
                  const framework::Tensor& vol,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
91 92 93 94 95 96
                  const std::vector<int>& paddings, framework::Tensor* col,
                  const DataLayout data_layout) const {
    PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
                      "The dimension of vol should be 4.");
    PADDLE_ENFORCE_EQ(col->dims().size(), 7,
                      "The dimension of col should be 7.");
C
chengduoZH 已提交
97

98
    int input_channels =
99
        (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
100
    int input_depth =
101
        (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
102
    int input_height =
103
        (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
104
    int input_width =
105
        (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
C
chengduoZH 已提交
106 107 108 109 110 111
    int filter_depth = col->dims()[1];
    int filter_height = col->dims()[2];
    int filter_width = col->dims()[3];
    int output_depth = col->dims()[4];
    int output_height = col->dims()[5];
    int output_width = col->dims()[6];
C
chengduoZH 已提交
112

L
liym27 已提交
113 114 115 116 117 118 119 120
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
    PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
C
chengduoZH 已提交
121 122
                       ((dilations[0] * (filter_depth - 1) + 1))) /
                              strides[0] +
C
chengduoZH 已提交
123 124 125
                          1,
                      output_depth,
                      "input_depth and output_depth are "
L
liym27 已提交
126 127
                      "mismatching.");
    PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
C
chengduoZH 已提交
128 129
                       ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
C
chengduoZH 已提交
130 131 132
                          1,
                      output_height,
                      "input_height and output_height are "
L
liym27 已提交
133 134
                      "mismatching.");
    PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
C
chengduoZH 已提交
135 136
                       ((dilations[2] * (filter_width - 1) + 1))) /
                              strides[2] +
C
chengduoZH 已提交
137 138 139
                          1,
                      output_width,
                      "input_width and output_width are "
L
liym27 已提交
140
                      "mismatching.");
C
chengduoZH 已提交
141

C
chengduoZH 已提交
142 143 144 145 146
    int num_outputs =
        input_channels * output_depth * output_height * output_width;

    const int threads = 1024;
    const int blocks = (num_outputs + 1024 - 1) / 1024;
Q
QI JUN 已提交
147
    vol2col<T><<<blocks, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
148
        num_outputs, vol.data<T>(), input_depth, input_height, input_width,
C
chengduoZH 已提交
149
        dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
L
liym27 已提交
150
        filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
151 152
        pad_w_left, output_depth, output_height, output_width, col->data<T>(),
        data_layout);
C
chengduoZH 已提交
153 154 155 156 157
  }
};

template <class T>
__global__ void col2vol(int num_kernels, const T* data_col, int depth,
C
chengduoZH 已提交
158 159 160 161 162
                        int height, int width, int dilation_d, int dilation_h,
                        int dilation_w, int filter_depth, int filter_height,
                        int filter_width, int stride_depth, int stride_height,
                        int stride_width, int padding_depth, int padding_height,
                        int padding_width, int output_detph, int output_height,
163 164
                        int output_width, T* data_vol,
                        const DataLayout data_layout) {
C
chengduoZH 已提交
165 166 167 168
  const int d_filter_depth = dilation_d * (filter_depth - 1) + 1;
  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

169
  int input_channels = num_kernels / depth / height / width;
C
chengduoZH 已提交
170 171 172
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
    T src_val = 0;
173
    int w = (data_layout != DataLayout::kNHWC
174 175
                 ? index % width + padding_width
                 : (index / input_channels) % width + padding_width);
176
    int h = (data_layout != DataLayout::kNHWC
177 178
                 ? (index / width) % height + padding_height
                 : (index / input_channels / width) % height + padding_height);
179
    int d = (data_layout != DataLayout::kNHWC
180 181
                 ? (index / width / height) % depth + padding_depth
                 : index / input_channels / width / height + padding_depth);
182
    int c = (data_layout != DataLayout::kNHWC ? index / width / height / depth
183
                                              : index % input_channels);
C
chengduoZH 已提交
184

C
chengduoZH 已提交
185 186
    // compute the start and end of the output
    int w_col_start =
C
chengduoZH 已提交
187
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
C
chengduoZH 已提交
188 189
    int w_col_end = min(w / stride_width + 1, output_width);
    int h_col_start =
C
chengduoZH 已提交
190
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
C
chengduoZH 已提交
191 192
    int h_col_end = min(h / stride_height + 1, output_height);
    int d_col_start =
C
chengduoZH 已提交
193
        (d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1;
C
chengduoZH 已提交
194 195 196 197 198
    int d_col_end = min(d / stride_depth + 1, output_detph);

    for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
C
chengduoZH 已提交
199 200 201 202 203 204 205 206 207 208 209 210
          int d_off = (d - d_col * stride_depth);
          int h_off = (h - h_col * stride_height);
          int w_off = (w - w_col * stride_width);
          if (d_off % dilation_d == 0 && h_off % dilation_h == 0 &&
              w_off % dilation_w == 0) {
            d_off /= dilation_d;
            h_off /= dilation_h;
            w_off /= dilation_w;

            int data_col_index =
                (((((c * filter_depth + d_off) * filter_height + h_off) *
                       filter_width +
211 212 213
                   w_off)));
            data_col_index =
                ((data_col_index * output_detph + d_col) * output_height +
C
chengduoZH 已提交
214 215 216 217 218
                 h_col) *
                    output_width +
                w_col;
            src_val += data_col[data_col_index];
          }
C
chengduoZH 已提交
219 220 221 222 223 224 225 226
        }
      }
    }
    data_vol[index] = src_val;
  }
}

/*
227 228 229 230
 * im = [input_channels,intpu_depth, input_height, input_width] for
 * channels_first
 * im = [input_depth, input_height, input_width, input_channels] for
 * channels_last
C
chengduoZH 已提交
231 232 233 234 235
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
Q
QI JUN 已提交
236
class Col2VolFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
237
 public:
Q
QI JUN 已提交
238
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
239 240 241
                  const framework::Tensor& col,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
242 243 244 245 246 247
                  const std::vector<int>& paddings, framework::Tensor* vol,
                  const DataLayout data_layout) const {
    PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
                      "The dimension of vol should be 4.");
    PADDLE_ENFORCE_EQ(col.dims().size(), 7,
                      "The dimension of col should be 7.");
C
chengduoZH 已提交
248

249
    int input_channels =
250
        (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
251
    int input_depth =
252
        (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
253
    int input_height =
254
        (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
255
    int input_width =
256
        (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
C
chengduoZH 已提交
257 258 259 260 261 262 263
    int filter_depth = col.dims()[1];
    int filter_height = col.dims()[2];
    int filter_width = col.dims()[3];
    int output_depth = col.dims()[4];
    int output_height = col.dims()[5];
    int output_width = col.dims()[6];

L
liym27 已提交
264 265 266 267 268 269 270 271 272
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];

    PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
C
chengduoZH 已提交
273 274
                       ((dilations[0] * (filter_depth - 1) + 1))) /
                              strides[0] +
C
chengduoZH 已提交
275 276 277
                          1,
                      output_depth,
                      "input_depth and output_depth are "
L
liym27 已提交
278 279
                      "mismatching.");
    PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
C
chengduoZH 已提交
280 281
                       ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
C
chengduoZH 已提交
282 283 284
                          1,
                      output_height,
                      "input_height and output_height are "
L
liym27 已提交
285 286
                      "mismatching.");
    PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
C
chengduoZH 已提交
287 288
                       ((dilations[2] * (filter_width - 1) + 1))) /
                              strides[2] +
C
chengduoZH 已提交
289 290 291
                          1,
                      output_width,
                      "input_width and output_width are "
L
liym27 已提交
292
                      "mismatching.");
C
chengduoZH 已提交
293

C
chengduoZH 已提交
294 295 296 297 298
    int num_kernels = input_channels * input_depth * input_height * input_width;

    const int threads = 1024;
    const int blocks = (num_kernels + 1024 - 1) / 1024;

Q
QI JUN 已提交
299
    col2vol<T><<<blocks, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
300
        num_kernels, col.data<T>(), input_depth, input_height, input_width,
C
chengduoZH 已提交
301
        dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
L
liym27 已提交
302
        filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
303 304
        pad_w_left, output_depth, output_height, output_width, vol->data<T>(),
        data_layout);
C
chengduoZH 已提交
305 306 307
  }
};

Q
QI JUN 已提交
308 309 310 311
template class Vol2ColFunctor<platform::CUDADeviceContext, float>;
template class Vol2ColFunctor<platform::CUDADeviceContext, double>;
template class Col2VolFunctor<platform::CUDADeviceContext, float>;
template class Col2VolFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
312 313 314 315

}  // namespace math
}  // namespace operators
}  // namespace paddle