vol2col.cu 10.9 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/vol2col.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <class T>
__global__ void vol2col(int num_kernels, const T* data_vol, int depth,
C
chengduoZH 已提交
24 25 26 27 28 29
                        int height, int width, int dilation_d, int dilation_h,
                        int dilation_w, int filter_depth, int filter_height,
                        int filter_width, int stride_depth, int stride_height,
                        int stride_width, int padding_depth, int padding_height,
                        int padding_width, int output_detph, int output_height,
                        int output_width, T* data_col) {
C
chengduoZH 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
    int w_out = index % output_width;
    int h_out = (index / output_width) % output_height;
    int d_out = (index / output_width / output_height) % output_detph;
    int channel_in = index / output_width / output_height / output_detph;
    int channel_out = channel_in * filter_depth * filter_height * filter_width;
    int w_in = w_out * stride_width - padding_width;
    int h_in = h_out * stride_height - padding_height;
    int d_in = d_out * stride_depth - padding_depth;

    data_col += ((channel_out * output_detph + d_out) * output_height + h_out) *
                    output_width +
                w_out;
    data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
    for (int k = 0; k < filter_depth; ++k) {
      for (int i = 0; i < filter_height; ++i) {
        for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
48 49 50 51 52
          int d = d_in + k * dilation_d;
          int h = h_in + i * dilation_h;
          int w = w_in + j * dilation_w;
          int col_idx = (k * dilation_d * height + i * dilation_h) * width +
                        j * dilation_w;
C
chengduoZH 已提交
53 54
          *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
                       w < width)
C
chengduoZH 已提交
55
                          ? data_vol[col_idx]
C
chengduoZH 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
                          : 0;
          data_col += output_detph * output_height * output_width;
        }
      }
    }
  }
}

/*
 * im = [input_channels,intpu_depth, input_height, input_width]
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
class Vol2ColFunctor<platform::GPUPlace, T> {
 public:
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
74 75 76 77 78
                  const framework::Tensor& vol,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  framework::Tensor* col) const {
C
chengduoZH 已提交
79
    PADDLE_ENFORCE(vol.dims().size() == 4);
C
chengduoZH 已提交
80
    PADDLE_ENFORCE(col->dims().size() == 7);
C
chengduoZH 已提交
81 82 83 84 85

    int input_channels = vol.dims()[0];
    int input_depth = vol.dims()[1];
    int input_height = vol.dims()[2];
    int input_width = vol.dims()[3];
C
chengduoZH 已提交
86 87 88 89 90 91
    int filter_depth = col->dims()[1];
    int filter_height = col->dims()[2];
    int filter_width = col->dims()[3];
    int output_depth = col->dims()[4];
    int output_height = col->dims()[5];
    int output_width = col->dims()[6];
C
chengduoZH 已提交
92

C
chengduoZH 已提交
93 94 95
    PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
                       ((dilations[0] * (filter_depth - 1) + 1))) /
                              strides[0] +
C
chengduoZH 已提交
96 97 98 99
                          1,
                      output_depth,
                      "input_depth and output_depth are "
                      "Mismatching.");
C
chengduoZH 已提交
100 101 102
    PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
                       ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
C
chengduoZH 已提交
103 104 105 106
                          1,
                      output_height,
                      "input_height and output_height are "
                      "Mismatching.");
C
chengduoZH 已提交
107 108 109
    PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
                       ((dilations[2] * (filter_width - 1) + 1))) /
                              strides[2] +
C
chengduoZH 已提交
110 111 112 113 114
                          1,
                      output_width,
                      "input_width and output_width are "
                      "Mismatching.");

C
chengduoZH 已提交
115 116 117 118 119 120 121 122 123
    int num_outputs =
        input_channels * output_depth * output_height * output_width;

    const int threads = 1024;
    const int blocks = (num_outputs + 1024 - 1) / 1024;
    vol2col<T><<<blocks, threads, 0,
                 reinterpret_cast<const platform::CUDADeviceContext&>(context)
                     .stream()>>>(
        num_outputs, vol.data<T>(), input_depth, input_height, input_width,
C
chengduoZH 已提交
124 125 126 127
        dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
        filter_width, strides[0], strides[1], strides[2], paddings[0],
        paddings[1], paddings[2], output_depth, output_height, output_width,
        col->data<T>());
C
chengduoZH 已提交
128 129 130 131 132
  }
};

template <class T>
__global__ void col2vol(int num_kernels, const T* data_col, int depth,
C
chengduoZH 已提交
133 134 135 136 137 138 139 140 141 142
                        int height, int width, int dilation_d, int dilation_h,
                        int dilation_w, int filter_depth, int filter_height,
                        int filter_width, int stride_depth, int stride_height,
                        int stride_width, int padding_depth, int padding_height,
                        int padding_width, int output_detph, int output_height,
                        int output_width, T* data_vol) {
  const int d_filter_depth = dilation_d * (filter_depth - 1) + 1;
  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

C
chengduoZH 已提交
143 144 145 146 147 148 149
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
    T src_val = 0;
    int w = index % width + padding_width;
    int h = (index / width) % height + padding_height;
    int d = (index / width / height) % depth + padding_depth;
    int c = index / width / height / depth;
C
chengduoZH 已提交
150

C
chengduoZH 已提交
151 152
    // compute the start and end of the output
    int w_col_start =
C
chengduoZH 已提交
153
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
C
chengduoZH 已提交
154 155
    int w_col_end = min(w / stride_width + 1, output_width);
    int h_col_start =
C
chengduoZH 已提交
156
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
C
chengduoZH 已提交
157 158
    int h_col_end = min(h / stride_height + 1, output_height);
    int d_col_start =
C
chengduoZH 已提交
159
        (d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1;
C
chengduoZH 已提交
160 161 162 163 164
    int d_col_end = min(d / stride_depth + 1, output_detph);

    for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
C
chengduoZH 已提交
165 166 167 168 169 170 171 172 173 174 175 176
          int d_off = (d - d_col * stride_depth);
          int h_off = (h - h_col * stride_height);
          int w_off = (w - w_col * stride_width);
          if (d_off % dilation_d == 0 && h_off % dilation_h == 0 &&
              w_off % dilation_w == 0) {
            d_off /= dilation_d;
            h_off /= dilation_h;
            w_off /= dilation_w;

            int data_col_index =
                (((((c * filter_depth + d_off) * filter_height + h_off) *
                       filter_width +
177 178 179
                   w_off)));
            data_col_index =
                ((data_col_index * output_detph + d_col) * output_height +
C
chengduoZH 已提交
180 181 182 183 184
                 h_col) *
                    output_width +
                w_col;
            src_val += data_col[data_col_index];
          }
C
chengduoZH 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        }
      }
    }
    data_vol[index] = src_val;
  }
}

/*
 * im = [input_channels, input_depth, input_height, input_width]
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
class Col2VolFunctor<platform::GPUPlace, T> {
 public:
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
202 203 204 205 206 207
                  const framework::Tensor& col,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  framework::Tensor* vol) const {
    PADDLE_ENFORCE(vol->dims().size() == 4);
C
chengduoZH 已提交
208 209
    PADDLE_ENFORCE(col.dims().size() == 7);

C
chengduoZH 已提交
210 211 212 213
    int input_channels = vol->dims()[0];
    int input_depth = vol->dims()[1];
    int input_height = vol->dims()[2];
    int input_width = vol->dims()[3];
C
chengduoZH 已提交
214 215 216 217 218 219 220
    int filter_depth = col.dims()[1];
    int filter_height = col.dims()[2];
    int filter_width = col.dims()[3];
    int output_depth = col.dims()[4];
    int output_height = col.dims()[5];
    int output_width = col.dims()[6];

C
chengduoZH 已提交
221 222 223
    PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
                       ((dilations[0] * (filter_depth - 1) + 1))) /
                              strides[0] +
C
chengduoZH 已提交
224 225 226 227
                          1,
                      output_depth,
                      "input_depth and output_depth are "
                      "Mismatching.");
C
chengduoZH 已提交
228 229 230
    PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
                       ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
C
chengduoZH 已提交
231 232 233 234
                          1,
                      output_height,
                      "input_height and output_height are "
                      "Mismatching.");
C
chengduoZH 已提交
235 236 237
    PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
                       ((dilations[2] * (filter_width - 1) + 1))) /
                              strides[2] +
C
chengduoZH 已提交
238 239 240 241 242
                          1,
                      output_width,
                      "input_width and output_width are "
                      "Mismatching.");

C
chengduoZH 已提交
243 244 245 246 247 248 249 250 251
    int num_kernels = input_channels * input_depth * input_height * input_width;

    const int threads = 1024;
    const int blocks = (num_kernels + 1024 - 1) / 1024;

    col2vol<T><<<blocks, threads, 0,
                 reinterpret_cast<const platform::CUDADeviceContext&>(context)
                     .stream()>>>(
        num_kernels, col.data<T>(), input_depth, input_height, input_width,
C
chengduoZH 已提交
252 253 254 255
        dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
        filter_width, strides[0], strides[1], strides[2], paddings[0],
        paddings[1], paddings[2], output_depth, output_height, output_width,
        vol->data<T>());
C
chengduoZH 已提交
256 257 258 259 260 261 262 263 264 265 266
  }
};

template class Vol2ColFunctor<platform::GPUPlace, float>;
template class Vol2ColFunctor<platform::GPUPlace, double>;
template class Col2VolFunctor<platform::GPUPlace, float>;
template class Col2VolFunctor<platform::GPUPlace, double>;

}  // namespace math
}  // namespace operators
}  // namespace paddle