vol2col.cu 18.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

A
Abhinav Arora 已提交
15 16
#include <algorithm>
#include <vector>
17

18
#include "paddle/phi/backends/gpu/gpu_context.h"
19
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
20
#include "paddle/phi/backends/gpu/gpu_primitives.h"
21
#include "paddle/phi/kernels/funcs/vol2col.h"
C
chengduoZH 已提交
22

23 24
namespace phi {
namespace funcs {
C
chengduoZH 已提交
25 26

template <class T>
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
__global__ void vol2col(int num_kernels,
                        const T* data_vol,
                        int depth,
                        int height,
                        int width,
                        int dilation_d,
                        int dilation_h,
                        int dilation_w,
                        int filter_depth,
                        int filter_height,
                        int filter_width,
                        int stride_depth,
                        int stride_height,
                        int stride_width,
                        int padding_depth,
                        int padding_height,
                        int padding_width,
                        int output_detph,
                        int output_height,
                        int output_width,
                        T* data_col,
48 49 50 51 52
                        const DataLayout data_layout) {
  int input_channels =
      num_kernels / output_detph / output_height / output_width;
  int channels_col =
      input_channels * filter_depth * filter_height * filter_width;
C
chengduoZH 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
    int w_out = index % output_width;
    int h_out = (index / output_width) % output_height;
    int d_out = (index / output_width / output_height) % output_detph;
    int channel_in = index / output_width / output_height / output_detph;
    int channel_out = channel_in * filter_depth * filter_height * filter_width;
    int w_in = w_out * stride_width - padding_width;
    int h_in = h_out * stride_height - padding_height;
    int d_in = d_out * stride_depth - padding_depth;

    data_col += ((channel_out * output_detph + d_out) * output_height + h_out) *
                    output_width +
                w_out;
    for (int k = 0; k < filter_depth; ++k) {
      for (int i = 0; i < filter_height; ++i) {
        for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
70 71 72
          int d = d_in + k * dilation_d;
          int h = h_in + i * dilation_h;
          int w = w_in + j * dilation_w;
73
          int vol_idx;
74
          if (data_layout != DataLayout::kNHWC) {
75 76 77 78 79
            vol_idx = ((channel_in * depth + d) * height + h) * width + w;
          } else {
            vol_idx =
                ((d * height + h) * width + w) * input_channels + channel_in;
          }
C
chengduoZH 已提交
80 81
          *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
                       w < width)
82
                          ? data_vol[vol_idx]
C
chengduoZH 已提交
83 84 85 86 87 88 89 90 91
                          : 0;
          data_col += output_detph * output_height * output_width;
        }
      }
    }
  }
}

/*
92 93 94 95
 * im = [input_channels,intpu_depth, input_height, input_width] for
 * channels_first
 * im = [input_depth, input_height, input_width, input_channels] for
 * channels_last
C
chengduoZH 已提交
96 97 98 99
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
W
Wilber 已提交
100 101 102 103 104
// template <class DeviceContext, class T>
// class Vol2ColFunctor {
//  public:
template <class DeviceContext, class T>
void Vol2ColFunctor<DeviceContext, T>::operator()(
105
    const DeviceContext& context,
106
    const phi::DenseTensor& vol,
107 108 109
    const std::vector<int>& dilations,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
110
    phi::DenseTensor* col,
W
Wilber 已提交
111
    const DataLayout data_layout) const {
112 113
  PADDLE_ENFORCE_EQ(vol.dims().size(),
                    4,
114
                    phi::errors::InvalidArgument(
W
Wilber 已提交
115 116
                        "The dimension of  vol should be 4, but received %d.",
                        vol.dims().size()));
117 118
  PADDLE_ENFORCE_EQ(col->dims().size(),
                    7,
119
                    phi::errors::InvalidArgument(
W
Wilber 已提交
120 121
                        "The dimension of col should be 7, but received %d.",
                        col->dims().size()));
C
chengduoZH 已提交
122

W
Wilber 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136
  int input_channels =
      (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
  int input_depth =
      (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
  int input_height =
      (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
  int input_width =
      (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
  int filter_depth = col->dims()[1];
  int filter_height = col->dims()[2];
  int filter_width = col->dims()[3];
  int output_depth = col->dims()[4];
  int output_height = col->dims()[5];
  int output_width = col->dims()[6];
C
chengduoZH 已提交
137

W
Wilber 已提交
138 139 140 141 142 143 144 145 146 147 148
  bool paddings_size_is_6 = (paddings.size() == 6);
  int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
  int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
  int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
  int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
  int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
  int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
  auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
                          ((dilations[0] * (filter_depth - 1) + 1))) /
                             strides[0] +
                         1;
149 150
  PADDLE_ENFORCE_EQ(input_depth_tmp,
                    output_depth,
151
                    phi::errors::InvalidArgument(
W
Wilber 已提交
152
                        "input_depth(%d) and output_depth(%d) are mismatching.",
153 154
                        input_depth_tmp,
                        output_depth));
W
Wilber 已提交
155 156 157 158 159
  auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
                           ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
                          1;
  PADDLE_ENFORCE_EQ(
160 161
      input_height_tmp,
      output_height,
162
      phi::errors::InvalidArgument(
W
Wilber 已提交
163
          "input_height(%d) and output_height(%d) are mismatching.",
164 165
          input_height_tmp,
          output_height));
W
Wilber 已提交
166 167 168 169
  auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
                          ((dilations[2] * (filter_width - 1) + 1))) /
                             strides[2] +
                         1;
170 171
  PADDLE_ENFORCE_EQ(input_width_tmp,
                    output_width,
172
                    phi::errors::InvalidArgument(
W
Wilber 已提交
173
                        "input_width(%d) and output_width(%d) are mismatching.",
174 175
                        input_width_tmp,
                        output_width));
C
chengduoZH 已提交
176

W
Wilber 已提交
177 178
  int num_outputs =
      input_channels * output_depth * output_height * output_width;
C
chengduoZH 已提交
179

W
Wilber 已提交
180
  int max_threads = 1024;
F
feng_shuai 已提交
181
#ifdef WITH_NV_JETSON
182
  phi::backends::gpu::ChangeThreadNum(context, &max_threads);
F
feng_shuai 已提交
183 184
#endif

W
Wilber 已提交
185 186
  const int threads = max_threads;
  const int blocks = (num_outputs + max_threads - 1) / max_threads;
F
feng_shuai 已提交
187

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
  vol2col<T><<<blocks, threads, 0, context.stream()>>>(num_outputs,
                                                       vol.data<T>(),
                                                       input_depth,
                                                       input_height,
                                                       input_width,
                                                       dilations[0],
                                                       dilations[1],
                                                       dilations[2],
                                                       filter_depth,
                                                       filter_height,
                                                       filter_width,
                                                       strides[0],
                                                       strides[1],
                                                       strides[2],
                                                       pad_d_forth,
                                                       pad_h_up,
                                                       pad_w_left,
                                                       output_depth,
                                                       output_height,
                                                       output_width,
                                                       col->data<T>(),
                                                       data_layout);
W
Wilber 已提交
210 211
}
// };
C
chengduoZH 已提交
212 213

template <class T>
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
__global__ void col2vol(int num_kernels,
                        const T* data_col,
                        int depth,
                        int height,
                        int width,
                        int dilation_d,
                        int dilation_h,
                        int dilation_w,
                        int filter_depth,
                        int filter_height,
                        int filter_width,
                        int stride_depth,
                        int stride_height,
                        int stride_width,
                        int padding_depth,
                        int padding_height,
                        int padding_width,
                        int output_detph,
                        int output_height,
                        int output_width,
                        T* data_vol,
235
                        const DataLayout data_layout) {
C
chengduoZH 已提交
236 237 238 239
  const int d_filter_depth = dilation_d * (filter_depth - 1) + 1;
  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

240
  int input_channels = num_kernels / depth / height / width;
C
chengduoZH 已提交
241 242 243
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
    T src_val = 0;
244
    int w = (data_layout != DataLayout::kNHWC
245 246
                 ? index % width + padding_width
                 : (index / input_channels) % width + padding_width);
247
    int h = (data_layout != DataLayout::kNHWC
248 249
                 ? (index / width) % height + padding_height
                 : (index / input_channels / width) % height + padding_height);
250
    int d = (data_layout != DataLayout::kNHWC
251 252
                 ? (index / width / height) % depth + padding_depth
                 : index / input_channels / width / height + padding_depth);
253
    int c = (data_layout != DataLayout::kNHWC ? index / width / height / depth
254
                                              : index % input_channels);
C
chengduoZH 已提交
255

C
chengduoZH 已提交
256 257
    // compute the start and end of the output
    int w_col_start =
C
chengduoZH 已提交
258
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
C
chengduoZH 已提交
259 260
    int w_col_end = min(w / stride_width + 1, output_width);
    int h_col_start =
C
chengduoZH 已提交
261
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
C
chengduoZH 已提交
262 263
    int h_col_end = min(h / stride_height + 1, output_height);
    int d_col_start =
C
chengduoZH 已提交
264
        (d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1;
C
chengduoZH 已提交
265 266 267 268 269
    int d_col_end = min(d / stride_depth + 1, output_detph);

    for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
C
chengduoZH 已提交
270 271 272 273 274 275 276 277 278 279 280 281
          int d_off = (d - d_col * stride_depth);
          int h_off = (h - h_col * stride_height);
          int w_off = (w - w_col * stride_width);
          if (d_off % dilation_d == 0 && h_off % dilation_h == 0 &&
              w_off % dilation_w == 0) {
            d_off /= dilation_d;
            h_off /= dilation_h;
            w_off /= dilation_w;

            int data_col_index =
                (((((c * filter_depth + d_off) * filter_height + h_off) *
                       filter_width +
282 283 284
                   w_off)));
            data_col_index =
                ((data_col_index * output_detph + d_col) * output_height +
C
chengduoZH 已提交
285 286 287 288 289
                 h_col) *
                    output_width +
                w_col;
            src_val += data_col[data_col_index];
          }
C
chengduoZH 已提交
290 291 292 293 294 295 296 297
        }
      }
    }
    data_vol[index] = src_val;
  }
}

/*
298 299 300 301
 * im = [input_channels,intpu_depth, input_height, input_width] for
 * channels_first
 * im = [input_depth, input_height, input_width, input_channels] for
 * channels_last
C
chengduoZH 已提交
302 303 304 305
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
W
Wilber 已提交
306 307 308 309 310
// template <class DeviceContext, class T>
// class Col2VolFunctor<DeviceContext, T> {
//  public:
template <class DeviceContext, class T>
void Col2VolFunctor<DeviceContext, T>::operator()(
311
    const DeviceContext& context,
312
    const phi::DenseTensor& col,
313 314 315
    const std::vector<int>& dilations,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
316
    phi::DenseTensor* vol,
W
Wilber 已提交
317
    const DataLayout data_layout) const {
318 319
  PADDLE_ENFORCE_EQ(vol->dims().size(),
                    4,
320
                    phi::errors::InvalidArgument(
W
Wilber 已提交
321 322
                        "The dimension of vol  should be 4, but received %d.",
                        vol->dims().size()));
323 324
  PADDLE_ENFORCE_EQ(col.dims().size(),
                    7,
325
                    phi::errors::InvalidArgument(
W
Wilber 已提交
326 327
                        "The dimension of col  should be 7, but received %d.",
                        col.dims().size()));
C
chengduoZH 已提交
328

W
Wilber 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342
  int input_channels =
      (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
  int input_depth =
      (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
  int input_height =
      (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
  int input_width =
      (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
  int filter_depth = col.dims()[1];
  int filter_height = col.dims()[2];
  int filter_width = col.dims()[3];
  int output_depth = col.dims()[4];
  int output_height = col.dims()[5];
  int output_width = col.dims()[6];
C
chengduoZH 已提交
343

W
Wilber 已提交
344 345 346 347 348 349 350
  bool paddings_size_is_6 = (paddings.size() == 6);
  int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
  int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
  int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
  int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
  int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
  int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
L
liym27 已提交
351

W
Wilber 已提交
352 353 354 355
  auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
                          ((dilations[0] * (filter_depth - 1) + 1))) /
                             strides[0] +
                         1;
356 357
  PADDLE_ENFORCE_EQ(input_depth_tmp,
                    output_depth,
358
                    phi::errors::InvalidArgument(
W
Wilber 已提交
359
                        "input_depth(%d) and output_depth(%d) are mismatching.",
360 361
                        input_depth_tmp,
                        output_depth));
W
Wilber 已提交
362 363 364 365 366
  auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
                           ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
                          1;
  PADDLE_ENFORCE_EQ(
367 368
      input_height_tmp,
      output_height,
369
      phi::errors::InvalidArgument(
W
Wilber 已提交
370
          "input_height(%d) and output_height(%d) are mismatching.",
371 372
          input_height_tmp,
          output_height));
W
Wilber 已提交
373 374 375 376
  auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
                          ((dilations[2] * (filter_width - 1) + 1))) /
                             strides[2] +
                         1;
377 378
  PADDLE_ENFORCE_EQ(input_width_tmp,
                    output_width,
379
                    phi::errors::InvalidArgument(
W
Wilber 已提交
380
                        "input_width(%d) and output_width(%d) are mismatching.",
381 382
                        input_width_tmp,
                        output_width));
C
chengduoZH 已提交
383

W
Wilber 已提交
384
  int num_kernels = input_channels * input_depth * input_height * input_width;
C
chengduoZH 已提交
385

W
Wilber 已提交
386
  int max_threads = 1024;
F
feng_shuai 已提交
387
#ifdef WITH_NV_JETSON
388
  phi::backends::gpu::ChangeThreadNum(context, &max_threads);
F
feng_shuai 已提交
389 390
#endif

W
Wilber 已提交
391 392
  const int threads = max_threads;
  const int blocks = (num_kernels + max_threads - 1) / max_threads;
C
chengduoZH 已提交
393

394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
  col2vol<T><<<blocks, threads, 0, context.stream()>>>(num_kernels,
                                                       col.data<T>(),
                                                       input_depth,
                                                       input_height,
                                                       input_width,
                                                       dilations[0],
                                                       dilations[1],
                                                       dilations[2],
                                                       filter_depth,
                                                       filter_height,
                                                       filter_width,
                                                       strides[0],
                                                       strides[1],
                                                       strides[2],
                                                       pad_d_forth,
                                                       pad_h_up,
                                                       pad_w_left,
                                                       output_depth,
                                                       output_height,
                                                       output_width,
                                                       vol->data<T>(),
                                                       data_layout);
W
Wilber 已提交
416 417
}
// };
C
chengduoZH 已提交
418

419 420
template class Vol2ColFunctor<phi::GPUContext, float>;
template class Vol2ColFunctor<phi::GPUContext, double>;
W
Wilber 已提交
421

422 423
template class Col2VolFunctor<phi::GPUContext, float>;
template class Col2VolFunctor<phi::GPUContext, double>;
C
chengduoZH 已提交
424

425 426
}  // namespace funcs
}  // namespace phi