im2col.cc 13.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/im2col.h"
W
wanghuancoder 已提交
16

T
tensor-tang 已提交
17
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
H
hedaoyuan 已提交
18

19 20 21 22
namespace phi {
class CPUContext;
}  // namespace phi

H
hedaoyuan 已提交
23
namespace paddle {
24
namespace operators {
25
namespace math {
H
hedaoyuan 已提交
26 27

/*
H
hedaoyuan 已提交
28 29 30
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
31
 */
32
template <class T, typename DeviceContext>
33 34
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                    DeviceContext,
35
                    T> {
H
hedaoyuan 已提交
36
 public:
37 38
  void operator()(const DeviceContext& context,
                  const framework::Tensor& im,
39
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
40
                  const std::vector<int>& stride,
41 42
                  const std::vector<int>& padding,
                  framework::Tensor* col,
43
                  const DataLayout data_layout) {
44 45
    PADDLE_ENFORCE_EQ(im.dims().size(),
                      3,
46 47 48 49
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
50 51
    PADDLE_ENFORCE_EQ(col->dims().size(),
                      5,
52 53 54 55
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
H
hedaoyuan 已提交
56

T
tensor-tang 已提交
57
    if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
58
        dilation[1] == 1) {
L
liym27 已提交
59 60
      if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 &&
          padding[3] == 0) {
61
        im2col_sh1sw1dh1dw1ph0pw0<T>(im, col, data_layout);
62
        return;
L
liym27 已提交
63 64
      } else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 &&
                 padding[3] == 1) {
65
        im2col_sh1sw1dh1dw1ph1pw1<T>(im, col, data_layout);
66
        return;
H
hedaoyuan 已提交
67
      }
68
      // TODO(TJ): complete padding >=2
H
hedaoyuan 已提交
69
    }
70
    im2col_common<T>(im, dilation, stride, padding, col, data_layout);
H
hedaoyuan 已提交
71 72 73 74
  }
};

/*
H
hedaoyuan 已提交
75 76 77
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
78
 */
79
template <class T, typename DeviceContext>
80 81
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                    DeviceContext,
82
                    T> {
H
hedaoyuan 已提交
83
 public:
84 85
  void operator()(const DeviceContext& context,
                  const framework::Tensor& col,
C
chengduoZH 已提交
86 87
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
88 89
                  const std::vector<int>& padding,
                  framework::Tensor* im,
90
                  const DataLayout data_layout) {
91 92
    PADDLE_ENFORCE_EQ(im->dims().size(),
                      3,
93 94 95 96
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
97 98
    PADDLE_ENFORCE_EQ(col.dims().size(),
                      5,
99 100 101 102
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
103
    int im_channels =
104
        (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
105
    int im_height =
106
        (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
107
    int im_width =
108
        (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
H
hedaoyuan 已提交
109 110
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
111 112
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];
C
chengduoZH 已提交
113

C
chengduoZH 已提交
114 115 116
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       ((dilation[0] * (filter_height - 1) + 1))) /
                              stride[0] +
C
chengduoZH 已提交
117
                          1,
118 119 120 121
                      col_height,
                      platform::errors::InvalidArgument(
                          "Output_height and padding(padding_up, "
                          "padding_down) are inconsistent."));
C
chengduoZH 已提交
122 123 124
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       ((dilation[1] * (filter_width - 1) + 1))) /
                              stride[1] +
C
chengduoZH 已提交
125
                          1,
126 127 128 129
                      col_width,
                      platform::errors::InvalidArgument(
                          "Output_height and padding(padding_up, "
                          "padding_down) are inconsistent."));
C
chengduoZH 已提交
130

C
chengduoZH 已提交
131
    int channels_col = im_channels * filter_height * filter_width;
H
hedaoyuan 已提交
132

C
chengduoZH 已提交
133
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
134 135 136
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
C
chengduoZH 已提交
137 138 139
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int c_im = c / (filter_width * filter_height);
C
chengduoZH 已提交
140
      for (int h = 0; h < col_height; ++h) {
C
chengduoZH 已提交
141
        int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
C
chengduoZH 已提交
142
        for (int w = 0; w < col_width; ++w) {
C
chengduoZH 已提交
143
          int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
C
chengduoZH 已提交
144 145
          if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
              (im_col_idx) >= 0 && (im_col_idx) < im_width) {
146
            int im_offset;
147
            if (data_layout != DataLayout::kNHWC) {
148 149 150 151 152 153 154
              im_offset =
                  (c_im * im_height + im_row_idx) * im_width + im_col_idx;
            } else {
              im_offset =
                  (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
            }
            im_data[im_offset] +=
C
chengduoZH 已提交
155
                col_data[(c * col_height + h) * col_width + w];
H
hedaoyuan 已提交
156 157 158 159 160 161 162
          }
        }
      }
    }
  }
};

163
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
164 165
                             phi::CPUContext,
                             float>;
166
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
167 168
                             phi::CPUContext,
                             double>;
169
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
170 171
                             phi::CPUContext,
                             float>;
172
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
173 174
                             phi::CPUContext,
                             double>;
H
hedaoyuan 已提交
175 176

/*
H
hedaoyuan 已提交
177 178 179
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
180
 */
181
template <class T, typename DeviceContext>
182 183
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                    DeviceContext,
184
                    T> {
H
hedaoyuan 已提交
185
 public:
186 187
  void operator()(const DeviceContext& context,
                  const framework::Tensor& im,
188
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
189
                  const std::vector<int>& stride,
190 191
                  const std::vector<int>& padding,
                  framework::Tensor* col,
192
                  const DataLayout data_layout) {
193 194
    PADDLE_ENFORCE_EQ(im.dims().size(),
                      3,
195 196 197 198
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
199 200
    PADDLE_ENFORCE_EQ(col->dims().size(),
                      5,
201 202 203 204
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
C
chengduoZH 已提交
205 206 207
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
208 209 210 211
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];
H
hedaoyuan 已提交
212 213

    const T* im_data = im.data<T>();
C
chengduoZH 已提交
214
    T* col_data = col->data<T>();
H
hedaoyuan 已提交
215

C
chengduoZH 已提交
216 217 218
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
219 220
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
C
refine  
chengduoZH 已提交
221 222
            int im_row_offset =
                col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
223 224 225
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_col_offset =
C
chengduoZH 已提交
226
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
refine  
chengduoZH 已提交
227

C
chengduoZH 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
              int col_offset =
                  ((((col_row_idx)*col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;

              int im_offset = (channel * im_height + im_row_offset) * im_width +
                              im_col_offset;
              col_data[col_offset] =
                  (im_row_offset < 0 || im_row_offset >= im_height ||
                   im_col_offset < 0 || im_col_offset >= im_width)
                      ? static_cast<T>(0)
                      : im_data[im_offset];
H
hedaoyuan 已提交
243 244 245 246 247 248 249 250 251
            }
          }
        }
      }
    }
  }
};

/*
H
hedaoyuan 已提交
252 253 254
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
255
 */
256
template <class T, typename DeviceContext>
257 258
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                    DeviceContext,
259
                    T> {
H
hedaoyuan 已提交
260
 public:
261 262
  void operator()(const DeviceContext& context,
                  const framework::Tensor& col,
C
chengduoZH 已提交
263 264
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
265 266
                  const std::vector<int>& padding,
                  framework::Tensor* im,
267
                  const DataLayout data_layout) {
268 269
    PADDLE_ENFORCE_EQ(im->dims().size(),
                      3,
270 271 272 273
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
274 275
    PADDLE_ENFORCE_EQ(col.dims().size(),
                      5,
276 277 278 279
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
C
chengduoZH 已提交
280 281 282
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
283 284
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
285 286
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];
H
hedaoyuan 已提交
287

C
chengduoZH 已提交
288 289
    PADDLE_ENFORCE_EQ(
        (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
290 291 292 293
        col_height,
        platform::errors::InvalidArgument(
            "Output_height and padding(padding_up, padding_down) "
            "are inconsistent."));
C
chengduoZH 已提交
294 295 296
    PADDLE_ENFORCE_EQ(
        (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
        col_width,
297 298
        platform::errors::InvalidArgument("col_width and padding(padding_left, "
                                          "padding_right) are inconsistent."));
299

C
chengduoZH 已提交
300
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
301 302
    const T* col_data = col.data<T>();

C
chengduoZH 已提交
303 304 305
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
306 307
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
C
refine  
chengduoZH 已提交
308 309
            int im_row_offset =
                col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
310 311 312
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_col_offset =
C
chengduoZH 已提交
313
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
refine  
chengduoZH 已提交
314

C
chengduoZH 已提交
315 316 317 318 319 320 321
              int col_offset =
                  (((col_row_idx * col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;
C
refine  
chengduoZH 已提交
322

C
chengduoZH 已提交
323 324
              if (im_row_offset >= 0 && im_row_offset < im_height &&
                  im_col_offset >= 0 && im_col_offset < im_width) {
H
hedaoyuan 已提交
325
                int im_offset =
C
chengduoZH 已提交
326
                    (channel * im_height + im_row_offset) * im_width +
H
hedaoyuan 已提交
327 328
                    im_col_offset;
                im_data[im_offset] += col_data[col_offset];
H
hedaoyuan 已提交
329 330 331 332 333 334 335 336 337
              }
            }
          }
        }
      }
    }
  }
};

338
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
339 340
                             phi::CPUContext,
                             float>;
341
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
342 343
                             phi::CPUContext,
                             double>;
344
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
345 346
                             phi::CPUContext,
                             float>;
347
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
348 349
                             phi::CPUContext,
                             double>;
350
}  // namespace math
351
}  // namespace operators
H
hedaoyuan 已提交
352
}  // namespace paddle