im2col.cc 12.2 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hedaoyuan 已提交
15
#include "paddle/operators/math/im2col.h"
H
hedaoyuan 已提交
16 17

namespace paddle {
18
namespace operators {
19
namespace math {
H
hedaoyuan 已提交
20 21

/*
H
hedaoyuan 已提交
22 23 24
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
25 26
 */
template <class T>
H
hedaoyuan 已提交
27 28
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                    platform::CPUPlace, T> {
H
hedaoyuan 已提交
29
 public:
H
hedaoyuan 已提交
30
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
31 32 33
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* col) {
H
hedaoyuan 已提交
34
    PADDLE_ENFORCE(im.dims().size() == 3);
C
chengduoZH 已提交
35
    PADDLE_ENFORCE(col->dims().size() == 5);
H
hedaoyuan 已提交
36

C
chengduoZH 已提交
37 38 39
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
40 41 42 43
    int filter_height = col->dims()[1];
    int filter_width = col->dims()[2];
    int col_height = col->dims()[3];
    int col_width = col->dims()[4];
C
chengduoZH 已提交
44

C
chengduoZH 已提交
45 46 47
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       ((dilation[0] * (filter_height - 1) + 1))) /
                              stride[0] +
C
chengduoZH 已提交
48 49 50 51
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
52 53 54
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       ((dilation[1] * (filter_width - 1) + 1))) /
                              stride[1] +
C
chengduoZH 已提交
55 56
                          1,
                      col_width,
C
chengduoZH 已提交
57
                      "Output_height and padding(padding_up, padding_down) are "
C
chengduoZH 已提交
58
                      "inconsistent.");
C
chengduoZH 已提交
59

C
chengduoZH 已提交
60
    int channels_col = im_channels * filter_height * filter_width;
H
hedaoyuan 已提交
61 62

    const T* im_data = im.data<T>();
C
chengduoZH 已提交
63
    T* col_data = col->data<T>();
H
hedaoyuan 已提交
64 65 66 67 68

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int c_im = c / filter_width / filter_height;
C
chengduoZH 已提交
69 70
      for (int h = 0; h < col_height; ++h) {
        for (int w = 0; w < col_width; ++w) {
C
chengduoZH 已提交
71 72
          int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
          int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
C
chengduoZH 已提交
73 74
          int col_idx = (c * col_height + h) * col_width + w;
          int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
C
chengduoZH 已提交
75

C
chengduoZH 已提交
76 77 78 79
          col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
                               im_col_idx < 0 || im_col_idx >= im_width)
                                  ? static_cast<T>(0)
                                  : im_data[im_idx];
H
hedaoyuan 已提交
80 81 82 83 84 85 86
        }
      }
    }
  }
};

/*
H
hedaoyuan 已提交
87 88 89
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
90 91
 */
template <class T>
H
hedaoyuan 已提交
92 93
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                    platform::CPUPlace, T> {
H
hedaoyuan 已提交
94
 public:
C
chengduoZH 已提交
95 96 97 98 99 100
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* im) {
    PADDLE_ENFORCE(im->dims().size() == 3);
H
hedaoyuan 已提交
101
    PADDLE_ENFORCE(col.dims().size() == 5);
C
chengduoZH 已提交
102 103 104
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
105 106
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
107 108
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];
C
chengduoZH 已提交
109

C
chengduoZH 已提交
110 111 112
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       ((dilation[0] * (filter_height - 1) + 1))) /
                              stride[0] +
C
chengduoZH 已提交
113 114 115 116
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
117 118 119
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       ((dilation[1] * (filter_width - 1) + 1))) /
                              stride[1] +
C
chengduoZH 已提交
120 121
                          1,
                      col_width,
C
chengduoZH 已提交
122
                      "Output_height and padding(padding_up, padding_down) are "
C
chengduoZH 已提交
123
                      "inconsistent.");
C
chengduoZH 已提交
124

C
chengduoZH 已提交
125
    int channels_col = im_channels * filter_height * filter_width;
H
hedaoyuan 已提交
126

C
chengduoZH 已提交
127
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
128 129 130 131 132 133
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int c_im = c / filter_width / filter_height;
C
chengduoZH 已提交
134 135
      for (int h = 0; h < col_height; ++h) {
        for (int w = 0; w < col_width; ++w) {
C
chengduoZH 已提交
136 137
          int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
          int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
C
chengduoZH 已提交
138

C
chengduoZH 已提交
139 140 141 142 143
          if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
              (im_col_idx) >= 0 && (im_col_idx) < im_width) {
            im_row_idx += c_im * im_height;
            im_data[im_row_idx * im_width + im_col_idx] +=
                col_data[(c * col_height + h) * col_width + w];
H
hedaoyuan 已提交
144 145 146 147 148 149 150
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
151 152 153 154 155 156 157 158
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::CPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::CPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::CPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::CPUPlace, double>;
H
hedaoyuan 已提交
159 160

/*
H
hedaoyuan 已提交
161 162 163
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
164 165
 */
template <class T>
H
hedaoyuan 已提交
166 167
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                    platform::CPUPlace, T> {
H
hedaoyuan 已提交
168
 public:
H
hedaoyuan 已提交
169
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
170 171 172
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* col) {
H
hedaoyuan 已提交
173
    PADDLE_ENFORCE(im.dims().size() == 3);
C
chengduoZH 已提交
174
    PADDLE_ENFORCE(col->dims().size() == 5);
C
chengduoZH 已提交
175 176 177
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
178 179 180 181
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];
H
hedaoyuan 已提交
182

C
chengduoZH 已提交
183 184 185 186 187 188 189 190 191 192
    PADDLE_ENFORCE_EQ(
        (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
        col_height,
        "Output_height and padding(padding_up, padding_down) are "
        "inconsistent.");
    PADDLE_ENFORCE_EQ(
        (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
        col_width,
        "col_width and padding(padding_left, padding_right) are "
        "inconsistent.");
193

H
hedaoyuan 已提交
194
    const T* im_data = im.data<T>();
C
chengduoZH 已提交
195
    T* col_data = col->data<T>();
H
hedaoyuan 已提交
196

C
chengduoZH 已提交
197 198 199
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
200 201 202 203 204
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_row_offset =
C
chengduoZH 已提交
205
                  col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
206
              int im_col_offset =
C
chengduoZH 已提交
207
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
chengduoZH 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
              int col_offset =
                  ((((col_row_idx)*col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;

              int im_offset = (channel * im_height + im_row_offset) * im_width +
                              im_col_offset;
              col_data[col_offset] =
                  (im_row_offset < 0 || im_row_offset >= im_height ||
                   im_col_offset < 0 || im_col_offset >= im_width)
                      ? static_cast<T>(0)
                      : im_data[im_offset];
H
hedaoyuan 已提交
223 224 225 226 227 228 229 230 231
            }
          }
        }
      }
    }
  }
};

/*
H
hedaoyuan 已提交
232 233 234
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
235 236
 */
template <class T>
H
hedaoyuan 已提交
237 238
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                    platform::CPUPlace, T> {
H
hedaoyuan 已提交
239
 public:
C
chengduoZH 已提交
240 241 242 243 244 245
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* im) {
    PADDLE_ENFORCE(im->dims().size() == 3);
H
hedaoyuan 已提交
246
    PADDLE_ENFORCE(col.dims().size() == 5);
C
chengduoZH 已提交
247 248 249
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
250 251
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
252 253
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];
H
hedaoyuan 已提交
254

C
chengduoZH 已提交
255 256 257 258 259 260 261 262 263 264
    PADDLE_ENFORCE_EQ(
        (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
        col_height,
        "Output_height and padding(padding_up, padding_down) are "
        "inconsistent.");
    PADDLE_ENFORCE_EQ(
        (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
        col_width,
        "col_width and padding(padding_left, padding_right) are "
        "inconsistent.");
265

C
chengduoZH 已提交
266
    T* im_data = im->data<T>();
H
hedaoyuan 已提交
267 268
    const T* col_data = col.data<T>();

C
chengduoZH 已提交
269 270 271
    for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
      for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
        for (int channel = 0; channel < im_channels; ++channel) {
H
hedaoyuan 已提交
272 273 274 275
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
C
chengduoZH 已提交
276
              int im_row_offset =
C
chengduoZH 已提交
277
                  col_row_idx * stride[0] + filter_row_idx - padding[0];
H
hedaoyuan 已提交
278
              int im_col_offset =
C
chengduoZH 已提交
279
                  col_col_idx * stride[1] + filter_col_idx - padding[1];
C
chengduoZH 已提交
280 281 282 283 284 285 286 287 288
              int col_offset =
                  (((col_row_idx * col_width + col_col_idx) * im_channels +
                    channel) *
                       filter_height +
                   filter_row_idx) *
                      filter_width +
                  filter_col_idx;
              if (im_row_offset >= 0 && im_row_offset < im_height &&
                  im_col_offset >= 0 && im_col_offset < im_width) {
H
hedaoyuan 已提交
289
                int im_offset =
C
chengduoZH 已提交
290
                    (channel * im_height + im_row_offset) * im_width +
H
hedaoyuan 已提交
291 292
                    im_col_offset;
                im_data[im_offset] += col_data[col_offset];
H
hedaoyuan 已提交
293 294 295 296 297 298 299 300 301
              }
            }
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
302 303 304 305 306 307 308 309
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::CPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::CPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::CPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::CPUPlace, double>;
H
hedaoyuan 已提交
310

311
}  // namespace math
312
}  // namespace operators
H
hedaoyuan 已提交
313
}  // namespace paddle