im2col.cc 13.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/im2col.h"
16
#include <vector>
T
tensor-tang 已提交
17
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
H
hedaoyuan 已提交
18 19

namespace paddle {
20
namespace operators {
21
namespace math {
H
hedaoyuan 已提交
22 23

/*
H
hedaoyuan 已提交
24 25 26
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
27 28
 */
template <class T>
H
hedaoyuan 已提交
29
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
30
                    platform::CPUDeviceContext, T> {
H
hedaoyuan 已提交
31
 public:
Q
QI JUN 已提交
32
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
33 34
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
35 36
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
37 38 39 40 41
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
L
liym27 已提交
42
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
43 44 45 46
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
H
hedaoyuan 已提交
47

T
tensor-tang 已提交
48
    if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
49
        dilation[1] == 1) {
L
liym27 已提交
50 51
      if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 &&
          padding[3] == 0) {
52
        im2col_sh1sw1dh1dw1ph0pw0<T>(im, col, data_layout);
53
        return;
L
liym27 已提交
54 55
      } else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 &&
                 padding[3] == 1) {
56
        im2col_sh1sw1dh1dw1ph1pw1<T>(im, col, data_layout);
57
        return;
H
hedaoyuan 已提交
58
      }
59
      // TODO(TJ): complete padding >=2
H
hedaoyuan 已提交
60
    }
61
    im2col_common<T>(im, dilation, stride, padding, col, data_layout);
H
hedaoyuan 已提交
62 63 64 65
  }
};

/*
H
hedaoyuan 已提交
66 67 68
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
69 70
 */
template <class T>
H
hedaoyuan 已提交
71
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
72
                    platform::CPUDeviceContext, T> {
H
hedaoyuan 已提交
73
 public:
Q
QI JUN 已提交
74
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
75 76 77
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
78 79
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
80 81 82 83 84
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
L
liym27 已提交
85
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
86 87 88 89
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
90
    int im_channels =
91
        (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
92
    int im_height =
93
        (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
94
    int im_width =
95
        (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
H
hedaoyuan 已提交
96 97
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
98 99
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];
C
chengduoZH 已提交
100

C
chengduoZH 已提交
101 102 103
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       ((dilation[0] * (filter_height - 1) + 1))) /
                              stride[0] +
C
chengduoZH 已提交
104
                          1,
105 106 107
                      col_height, platform::errors::InvalidArgument(
                                      "Output_height and padding(padding_up, "
                                      "padding_down) are inconsistent."));
C
chengduoZH 已提交
108 109 110
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       ((dilation[1] * (filter_width - 1) + 1))) /
                              stride[1] +
C
chengduoZH 已提交
111
                          1,
112 113 114
                      col_width, platform::errors::InvalidArgument(
                                     "Output_height and padding(padding_up, "
                                     "padding_down) are inconsistent."));
C
chengduoZH 已提交
115

C
chengduoZH 已提交
116
    int channels_col = im_channels * filter_height * filter_width;
H
hedaoyuan 已提交
117

C
chengduoZH 已提交
118
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
119 120 121
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
C
chengduoZH 已提交
122 123 124
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int c_im = c / (filter_width * filter_height);
C
chengduoZH 已提交
125
      for (int h = 0; h < col_height; ++h) {
C
chengduoZH 已提交
126
        int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
C
chengduoZH 已提交
127
        for (int w = 0; w < col_width; ++w) {
C
chengduoZH 已提交
128
          int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
C
chengduoZH 已提交
129 130
          if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
              (im_col_idx) >= 0 && (im_col_idx) < im_width) {
131
            int im_offset;
132
            if (data_layout != DataLayout::kNHWC) {
133 134 135 136 137 138 139
              im_offset =
                  (c_im * im_height + im_row_idx) * im_width + im_col_idx;
            } else {
              im_offset =
                  (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
            }
            im_data[im_offset] +=
C
chengduoZH 已提交
140
                col_data[(c * col_height + h) * col_width + w];
H
hedaoyuan 已提交
141 142 143 144 145 146 147
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
148
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
149
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
150
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
151
                             platform::CPUDeviceContext, double>;
H
hedaoyuan 已提交
152
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
153
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
154
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
155
                             platform::CPUDeviceContext, double>;
H
hedaoyuan 已提交
156 157

/*
H
hedaoyuan 已提交
158 159 160
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
161 162
 */
template <class T>
H
hedaoyuan 已提交
163
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
164
                    platform::CPUDeviceContext, T> {
H
hedaoyuan 已提交
165
 public:
Q
QI JUN 已提交
166
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
167 168
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
169 170
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
171 172 173 174 175
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
L
liym27 已提交
176
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
177 178 179 180
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
C
chengduoZH 已提交
181 182 183
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
184 185 186 187
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];
H
hedaoyuan 已提交
188 189

    const T* im_data = im.data<T>();
C
chengduoZH 已提交
190
    T* col_data = col->data<T>();
H
hedaoyuan 已提交
191

C
chengduoZH 已提交
192 193 194
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
195 196
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
C
refine  
chengduoZH 已提交
197 198
            int im_row_offset =
                col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
199 200 201
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_col_offset =
C
chengduoZH 已提交
202
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
refine  
chengduoZH 已提交
203

C
chengduoZH 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
              int col_offset =
                  ((((col_row_idx)*col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;

              int im_offset = (channel * im_height + im_row_offset) * im_width +
                              im_col_offset;
              col_data[col_offset] =
                  (im_row_offset < 0 || im_row_offset >= im_height ||
                   im_col_offset < 0 || im_col_offset >= im_width)
                      ? static_cast<T>(0)
                      : im_data[im_offset];
H
hedaoyuan 已提交
219 220 221 222 223 224 225 226 227
            }
          }
        }
      }
    }
  }
};

/*
H
hedaoyuan 已提交
228 229 230
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
231 232
 */
template <class T>
H
hedaoyuan 已提交
233
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
234
                    platform::CPUDeviceContext, T> {
H
hedaoyuan 已提交
235
 public:
Q
QI JUN 已提交
236
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
237 238 239
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
240 241
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
242 243 244 245 246
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
L
liym27 已提交
247
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
248 249 250 251
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
C
chengduoZH 已提交
252 253 254
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
255 256
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
257 258
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];
H
hedaoyuan 已提交
259

C
chengduoZH 已提交
260 261
    PADDLE_ENFORCE_EQ(
        (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
262 263 264
        col_height, platform::errors::InvalidArgument(
                        "Output_height and padding(padding_up, padding_down) "
                        "are inconsistent."));
C
chengduoZH 已提交
265 266 267
    PADDLE_ENFORCE_EQ(
        (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
        col_width,
268 269
        platform::errors::InvalidArgument("col_width and padding(padding_left, "
                                          "padding_right) are inconsistent."));
270

C
chengduoZH 已提交
271
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
272 273
    const T* col_data = col.data<T>();

C
chengduoZH 已提交
274 275 276
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
277 278
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
C
refine  
chengduoZH 已提交
279 280
            int im_row_offset =
                col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
281 282 283
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_col_offset =
C
chengduoZH 已提交
284
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
refine  
chengduoZH 已提交
285

C
chengduoZH 已提交
286 287 288 289 290 291 292
              int col_offset =
                  (((col_row_idx * col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;
C
refine  
chengduoZH 已提交
293

C
chengduoZH 已提交
294 295
              if (im_row_offset >= 0 && im_row_offset < im_height &&
                  im_col_offset >= 0 && im_col_offset < im_width) {
H
hedaoyuan 已提交
296
                int im_offset =
C
chengduoZH 已提交
297
                    (channel * im_height + im_row_offset) * im_width +
H
hedaoyuan 已提交
298 299
                    im_col_offset;
                im_data[im_offset] += col_data[col_offset];
H
hedaoyuan 已提交
300 301 302 303 304 305 306 307 308
              }
            }
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
309
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
310
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
311
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
312
                             platform::CPUDeviceContext, double>;
H
hedaoyuan 已提交
313
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
314
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
315
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
316
                             platform::CPUDeviceContext, double>;
H
hedaoyuan 已提交
317

318
}  // namespace math
319
}  // namespace operators
H
hedaoyuan 已提交
320
}  // namespace paddle