im2col.cc 14.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/im2col.h"
W
wanghuancoder 已提交
16

T
tensor-tang 已提交
17
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
H
hedaoyuan 已提交
18

W
wanghuancoder 已提交
19 20 21 22 23 24
namespace paddle {
namespace platform {
class CPUDeviceContext;
}  // namespace platform
}  // namespace paddle

25 26 27 28
namespace phi {
class CPUContext;
}  // namespace phi

H
hedaoyuan 已提交
29
namespace paddle {
30
namespace operators {
31
namespace math {
H
hedaoyuan 已提交
32 33

/*
H
hedaoyuan 已提交
34 35 36
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
37
 */
38 39 40
template <class T, typename DeviceContext>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
                    T> {
H
hedaoyuan 已提交
41
 public:
42 43
  void operator()(const DeviceContext& context, const framework::Tensor& im,
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
44
                  const std::vector<int>& stride,
45 46
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
47 48 49 50 51
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
L
liym27 已提交
52
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
53 54 55 56
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
H
hedaoyuan 已提交
57

T
tensor-tang 已提交
58
    if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
59
        dilation[1] == 1) {
L
liym27 已提交
60 61
      if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 &&
          padding[3] == 0) {
62
        im2col_sh1sw1dh1dw1ph0pw0<T>(im, col, data_layout);
63
        return;
L
liym27 已提交
64 65
      } else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 &&
                 padding[3] == 1) {
66
        im2col_sh1sw1dh1dw1ph1pw1<T>(im, col, data_layout);
67
        return;
H
hedaoyuan 已提交
68
      }
69
      // TODO(TJ): complete padding >=2
H
hedaoyuan 已提交
70
    }
71
    im2col_common<T>(im, dilation, stride, padding, col, data_layout);
H
hedaoyuan 已提交
72 73 74 75
  }
};

/*
H
hedaoyuan 已提交
76 77 78
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
79
 */
80 81 82
template <class T, typename DeviceContext>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
                    T> {
H
hedaoyuan 已提交
83
 public:
84
  void operator()(const DeviceContext& context, const framework::Tensor& col,
C
chengduoZH 已提交
85 86
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
87 88
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
89 90 91 92 93
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
L
liym27 已提交
94
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
95 96 97 98
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
99
    int im_channels =
100
        (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
101
    int im_height =
102
        (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
103
    int im_width =
104
        (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
H
hedaoyuan 已提交
105 106
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
107 108
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];
C
chengduoZH 已提交
109

C
chengduoZH 已提交
110 111 112
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       ((dilation[0] * (filter_height - 1) + 1))) /
                              stride[0] +
C
chengduoZH 已提交
113
                          1,
114 115 116
                      col_height, platform::errors::InvalidArgument(
                                      "Output_height and padding(padding_up, "
                                      "padding_down) are inconsistent."));
C
chengduoZH 已提交
117 118 119
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       ((dilation[1] * (filter_width - 1) + 1))) /
                              stride[1] +
C
chengduoZH 已提交
120
                          1,
121 122 123
                      col_width, platform::errors::InvalidArgument(
                                     "Output_height and padding(padding_up, "
                                     "padding_down) are inconsistent."));
C
chengduoZH 已提交
124

C
chengduoZH 已提交
125
    int channels_col = im_channels * filter_height * filter_width;
H
hedaoyuan 已提交
126

C
chengduoZH 已提交
127
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
128 129 130
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
C
chengduoZH 已提交
131 132 133
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int c_im = c / (filter_width * filter_height);
C
chengduoZH 已提交
134
      for (int h = 0; h < col_height; ++h) {
C
chengduoZH 已提交
135
        int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
C
chengduoZH 已提交
136
        for (int w = 0; w < col_width; ++w) {
C
chengduoZH 已提交
137
          int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
C
chengduoZH 已提交
138 139
          if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
              (im_col_idx) >= 0 && (im_col_idx) < im_width) {
140
            int im_offset;
141
            if (data_layout != DataLayout::kNHWC) {
142 143 144 145 146 147 148
              im_offset =
                  (c_im * im_height + im_row_idx) * im_width + im_col_idx;
            } else {
              im_offset =
                  (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
            }
            im_data[im_offset] +=
C
chengduoZH 已提交
149
                col_data[(c * col_height + h) * col_width + w];
H
hedaoyuan 已提交
150 151 152 153 154 155 156
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
157
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
158
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
159
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
160
                             platform::CPUDeviceContext, double>;
161 162 163 164
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             phi::CPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             phi::CPUContext, double>;
H
hedaoyuan 已提交
165
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
166
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
167
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
168
                             platform::CPUDeviceContext, double>;
169 170 171 172
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             phi::CPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             phi::CPUContext, double>;
H
hedaoyuan 已提交
173 174

/*
H
hedaoyuan 已提交
175 176 177
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
178
 */
179 180 181
template <class T, typename DeviceContext>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
                    T> {
H
hedaoyuan 已提交
182
 public:
183 184
  void operator()(const DeviceContext& context, const framework::Tensor& im,
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
185
                  const std::vector<int>& stride,
186 187
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
188 189 190 191 192
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
L
liym27 已提交
193
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
194 195 196 197
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
C
chengduoZH 已提交
198 199 200
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
201 202 203 204
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];
H
hedaoyuan 已提交
205 206

    const T* im_data = im.data<T>();
C
chengduoZH 已提交
207
    T* col_data = col->data<T>();
H
hedaoyuan 已提交
208

C
chengduoZH 已提交
209 210 211
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
212 213
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
C
refine  
chengduoZH 已提交
214 215
            int im_row_offset =
                col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
216 217 218
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_col_offset =
C
chengduoZH 已提交
219
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
refine  
chengduoZH 已提交
220

C
chengduoZH 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
              int col_offset =
                  ((((col_row_idx)*col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;

              int im_offset = (channel * im_height + im_row_offset) * im_width +
                              im_col_offset;
              col_data[col_offset] =
                  (im_row_offset < 0 || im_row_offset >= im_height ||
                   im_col_offset < 0 || im_col_offset >= im_width)
                      ? static_cast<T>(0)
                      : im_data[im_offset];
H
hedaoyuan 已提交
236 237 238 239 240 241 242 243 244
            }
          }
        }
      }
    }
  }
};

/*
H
hedaoyuan 已提交
245 246 247
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
248
 */
249 250 251
template <class T, typename DeviceContext>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
                    T> {
H
hedaoyuan 已提交
252
 public:
253
  void operator()(const DeviceContext& context, const framework::Tensor& col,
C
chengduoZH 已提交
254 255
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
256 257
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
258 259 260 261 262
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
L
liym27 已提交
263
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
264 265 266 267
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
C
chengduoZH 已提交
268 269 270
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
271 272
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
273 274
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];
H
hedaoyuan 已提交
275

C
chengduoZH 已提交
276 277
    PADDLE_ENFORCE_EQ(
        (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
278 279 280
        col_height, platform::errors::InvalidArgument(
                        "Output_height and padding(padding_up, padding_down) "
                        "are inconsistent."));
C
chengduoZH 已提交
281 282 283
    PADDLE_ENFORCE_EQ(
        (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
        col_width,
284 285
        platform::errors::InvalidArgument("col_width and padding(padding_left, "
                                          "padding_right) are inconsistent."));
286

C
chengduoZH 已提交
287
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
288 289
    const T* col_data = col.data<T>();

C
chengduoZH 已提交
290 291 292
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
293 294
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
C
refine  
chengduoZH 已提交
295 296
            int im_row_offset =
                col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
297 298 299
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_col_offset =
C
chengduoZH 已提交
300
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
refine  
chengduoZH 已提交
301

C
chengduoZH 已提交
302 303 304 305 306 307 308
              int col_offset =
                  (((col_row_idx * col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;
C
refine  
chengduoZH 已提交
309

C
chengduoZH 已提交
310 311
              if (im_row_offset >= 0 && im_row_offset < im_height &&
                  im_col_offset >= 0 && im_col_offset < im_width) {
H
hedaoyuan 已提交
312
                int im_offset =
C
chengduoZH 已提交
313
                    (channel * im_height + im_row_offset) * im_width +
H
hedaoyuan 已提交
314 315
                    im_col_offset;
                im_data[im_offset] += col_data[col_offset];
H
hedaoyuan 已提交
316 317 318 319 320 321 322 323 324
              }
            }
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
325
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
326
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
327
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
328
                             platform::CPUDeviceContext, double>;
329 330 331 332
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             phi::CPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             phi::CPUContext, double>;
H
hedaoyuan 已提交
333
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
334
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
335
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
336
                             platform::CPUDeviceContext, double>;
337 338 339 340
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             phi::CPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             phi::CPUContext, double>;
341
}  // namespace math
342
}  // namespace operators
H
hedaoyuan 已提交
343
}  // namespace paddle