im2col.cc 13.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/im2col.h"
W
wanghuancoder 已提交
16

T
tensor-tang 已提交
17
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
H
hedaoyuan 已提交
18

W
wanghuancoder 已提交
19 20 21 22 23 24
namespace paddle {
namespace platform {
class CPUDeviceContext;
}  // namespace platform
}  // namespace paddle

H
hedaoyuan 已提交
25
namespace paddle {
26
namespace operators {
27
namespace math {
H
hedaoyuan 已提交
28 29

/*
H
hedaoyuan 已提交
30 31 32
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
33 34
 */
template <class T>
H
hedaoyuan 已提交
35
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
36
                    platform::CPUDeviceContext, T> {
H
hedaoyuan 已提交
37
 public:
Q
QI JUN 已提交
38
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
39 40
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
41 42
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
43 44 45 46 47
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
L
liym27 已提交
48
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
49 50 51 52
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
H
hedaoyuan 已提交
53

T
tensor-tang 已提交
54
    if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
55
        dilation[1] == 1) {
L
liym27 已提交
56 57
      if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 &&
          padding[3] == 0) {
58
        im2col_sh1sw1dh1dw1ph0pw0<T>(im, col, data_layout);
59
        return;
L
liym27 已提交
60 61
      } else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 &&
                 padding[3] == 1) {
62
        im2col_sh1sw1dh1dw1ph1pw1<T>(im, col, data_layout);
63
        return;
H
hedaoyuan 已提交
64
      }
65
      // TODO(TJ): complete padding >=2
H
hedaoyuan 已提交
66
    }
67
    im2col_common<T>(im, dilation, stride, padding, col, data_layout);
H
hedaoyuan 已提交
68 69 70 71
  }
};

/*
H
hedaoyuan 已提交
72 73 74
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
75 76
 */
template <class T>
H
hedaoyuan 已提交
77
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
78
                    platform::CPUDeviceContext, T> {
H
hedaoyuan 已提交
79
 public:
Q
QI JUN 已提交
80
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
81 82 83
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
84 85
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
86 87 88 89 90
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
L
liym27 已提交
91
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
92 93 94 95
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
96
    int im_channels =
97
        (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
98
    int im_height =
99
        (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
100
    int im_width =
101
        (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
H
hedaoyuan 已提交
102 103
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
104 105
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];
C
chengduoZH 已提交
106

C
chengduoZH 已提交
107 108 109
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       ((dilation[0] * (filter_height - 1) + 1))) /
                              stride[0] +
C
chengduoZH 已提交
110
                          1,
111 112 113
                      col_height, platform::errors::InvalidArgument(
                                      "Output_height and padding(padding_up, "
                                      "padding_down) are inconsistent."));
C
chengduoZH 已提交
114 115 116
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       ((dilation[1] * (filter_width - 1) + 1))) /
                              stride[1] +
C
chengduoZH 已提交
117
                          1,
118 119 120
                      col_width, platform::errors::InvalidArgument(
                                     "Output_height and padding(padding_up, "
                                     "padding_down) are inconsistent."));
C
chengduoZH 已提交
121

C
chengduoZH 已提交
122
    int channels_col = im_channels * filter_height * filter_width;
H
hedaoyuan 已提交
123

C
chengduoZH 已提交
124
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
125 126 127
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
C
chengduoZH 已提交
128 129 130
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int c_im = c / (filter_width * filter_height);
C
chengduoZH 已提交
131
      for (int h = 0; h < col_height; ++h) {
C
chengduoZH 已提交
132
        int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
C
chengduoZH 已提交
133
        for (int w = 0; w < col_width; ++w) {
C
chengduoZH 已提交
134
          int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
C
chengduoZH 已提交
135 136
          if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
              (im_col_idx) >= 0 && (im_col_idx) < im_width) {
137
            int im_offset;
138
            if (data_layout != DataLayout::kNHWC) {
139 140 141 142 143 144 145
              im_offset =
                  (c_im * im_height + im_row_idx) * im_width + im_col_idx;
            } else {
              im_offset =
                  (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
            }
            im_data[im_offset] +=
C
chengduoZH 已提交
146
                col_data[(c * col_height + h) * col_width + w];
H
hedaoyuan 已提交
147 148 149 150 151 152 153
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
154
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
155
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
156
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
157
                             platform::CPUDeviceContext, double>;
H
hedaoyuan 已提交
158
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
159
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
160
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
161
                             platform::CPUDeviceContext, double>;
H
hedaoyuan 已提交
162 163

/*
H
hedaoyuan 已提交
164 165 166
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
167 168
 */
template <class T>
H
hedaoyuan 已提交
169
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
170
                    platform::CPUDeviceContext, T> {
H
hedaoyuan 已提交
171
 public:
Q
QI JUN 已提交
172
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
173 174
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
175 176
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
177 178 179 180 181
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
L
liym27 已提交
182
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
183 184 185 186
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
C
chengduoZH 已提交
187 188 189
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
190 191 192 193
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];
H
hedaoyuan 已提交
194 195

    const T* im_data = im.data<T>();
C
chengduoZH 已提交
196
    T* col_data = col->data<T>();
H
hedaoyuan 已提交
197

C
chengduoZH 已提交
198 199 200
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
201 202
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
C
refine  
chengduoZH 已提交
203 204
            int im_row_offset =
                col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
205 206 207
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_col_offset =
C
chengduoZH 已提交
208
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
refine  
chengduoZH 已提交
209

C
chengduoZH 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
              int col_offset =
                  ((((col_row_idx)*col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;

              int im_offset = (channel * im_height + im_row_offset) * im_width +
                              im_col_offset;
              col_data[col_offset] =
                  (im_row_offset < 0 || im_row_offset >= im_height ||
                   im_col_offset < 0 || im_col_offset >= im_width)
                      ? static_cast<T>(0)
                      : im_data[im_offset];
H
hedaoyuan 已提交
225 226 227 228 229 230 231 232 233
            }
          }
        }
      }
    }
  }
};

/*
H
hedaoyuan 已提交
234 235 236
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
237 238
 */
template <class T>
H
hedaoyuan 已提交
239
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
240
                    platform::CPUDeviceContext, T> {
H
hedaoyuan 已提交
241
 public:
Q
QI JUN 已提交
242
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
243 244 245
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
246 247
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
248 249 250 251 252
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
L
liym27 已提交
253
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
254 255 256 257
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
C
chengduoZH 已提交
258 259 260
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
261 262
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
263 264
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];
H
hedaoyuan 已提交
265

C
chengduoZH 已提交
266 267
    PADDLE_ENFORCE_EQ(
        (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
268 269 270
        col_height, platform::errors::InvalidArgument(
                        "Output_height and padding(padding_up, padding_down) "
                        "are inconsistent."));
C
chengduoZH 已提交
271 272 273
    PADDLE_ENFORCE_EQ(
        (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
        col_width,
274 275
        platform::errors::InvalidArgument("col_width and padding(padding_left, "
                                          "padding_right) are inconsistent."));
276

C
chengduoZH 已提交
277
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
278 279
    const T* col_data = col.data<T>();

C
chengduoZH 已提交
280 281 282
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
283 284
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
C
refine  
chengduoZH 已提交
285 286
            int im_row_offset =
                col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
287 288 289
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_col_offset =
C
chengduoZH 已提交
290
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
refine  
chengduoZH 已提交
291

C
chengduoZH 已提交
292 293 294 295 296 297 298
              int col_offset =
                  (((col_row_idx * col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;
C
refine  
chengduoZH 已提交
299

C
chengduoZH 已提交
300 301
              if (im_row_offset >= 0 && im_row_offset < im_height &&
                  im_col_offset >= 0 && im_col_offset < im_width) {
H
hedaoyuan 已提交
302
                int im_offset =
C
chengduoZH 已提交
303
                    (channel * im_height + im_row_offset) * im_width +
H
hedaoyuan 已提交
304 305
                    im_col_offset;
                im_data[im_offset] += col_data[col_offset];
H
hedaoyuan 已提交
306 307 308 309 310 311 312 313 314
              }
            }
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
315
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
316
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
317
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
318
                             platform::CPUDeviceContext, double>;
H
hedaoyuan 已提交
319
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
320
                             platform::CPUDeviceContext, float>;
H
hedaoyuan 已提交
321
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
322
                             platform::CPUDeviceContext, double>;
H
hedaoyuan 已提交
323

324
}  // namespace math
325
}  // namespace operators
H
hedaoyuan 已提交
326
}  // namespace paddle