im2col.cc 12.0 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hedaoyuan 已提交
15
#include "paddle/operators/math/im2col.h"
H
hedaoyuan 已提交
16 17

namespace paddle {
18
namespace operators {
19
namespace math {
H
hedaoyuan 已提交
20 21

/*
H
hedaoyuan 已提交
22 23 24
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
25 26
 */
template <class T>
H
hedaoyuan 已提交
27 28
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                    platform::CPUPlace, T> {
H
hedaoyuan 已提交
29
 public:
H
hedaoyuan 已提交
30 31
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& im, framework::Tensor& col,
C
chengduoZH 已提交
32 33
                  int stride_height, int stride_width, int padding_up,
                  int padding_down, int padding_left, int padding_right) {
H
hedaoyuan 已提交
34 35 36 37 38 39 40 41 42 43
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);

    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
    int output_height = col.dims()[3];
    int output_width = col.dims()[4];
C
chengduoZH 已提交
44 45 46 47 48 49 50 51 52 53

    PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
                           stride_height +
                       1 ==
                   output_height);
    PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
                           stride_width +
                       1 ==
                   output_width);

H
hedaoyuan 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66
    int channels_col = input_channels * filter_height * filter_width;

    const T* im_data = im.data<T>();
    T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int c_im = c / filter_width / filter_height;
      for (int h = 0; h < output_height; ++h) {
        for (int w = 0; w < output_width; ++w) {
          int im_row_idx = h * stride_height + h_offset;
          int im_col_idx = w * stride_width + w_offset;
C
chengduoZH 已提交
67 68 69 70
          if ((im_row_idx - padding_up) < 0 ||
              (im_row_idx - padding_up) >= input_height ||
              (im_col_idx - padding_left) < 0 ||
              (im_col_idx - padding_left) >= input_width) {
H
hedaoyuan 已提交
71
            col_data[(c * output_height + h) * output_width + w] = T(0);
H
hedaoyuan 已提交
72
          } else {
C
chengduoZH 已提交
73 74
            im_row_idx += c_im * input_height - padding_up;
            im_col_idx -= padding_left;
H
hedaoyuan 已提交
75 76
            col_data[(c * output_height + h) * output_width + w] =
                im_data[im_row_idx * input_width + im_col_idx];
H
hedaoyuan 已提交
77 78 79 80 81 82 83 84
          }
        }
      }
    }
  }
};

/*
H
hedaoyuan 已提交
85 86 87
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
88 89
 */
template <class T>
H
hedaoyuan 已提交
90 91
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                    platform::CPUPlace, T> {
H
hedaoyuan 已提交
92
 public:
H
hedaoyuan 已提交
93 94
  void operator()(const platform::DeviceContext& context, framework::Tensor& im,
                  const framework::Tensor& col, int stride_height,
C
chengduoZH 已提交
95 96
                  int stride_width, int padding_up, int padding_down,
                  int padding_left, int padding_right) {
H
hedaoyuan 已提交
97 98 99 100 101 102 103 104 105
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
    int output_height = col.dims()[3];
    int output_width = col.dims()[4];
C
chengduoZH 已提交
106 107 108 109 110 111 112 113 114 115

    PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
                           stride_height +
                       1 ==
                   output_height);
    PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
                           stride_width +
                       1 ==
                   output_width);

H
hedaoyuan 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128
    int channels_col = input_channels * filter_height * filter_width;

    T* im_data = im.data<T>();
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int c_im = c / filter_width / filter_height;
      for (int h = 0; h < output_height; ++h) {
        for (int w = 0; w < output_width; ++w) {
          int im_row_idx = h * stride_height + h_offset;
          int im_col_idx = w * stride_width + w_offset;
C
chengduoZH 已提交
129 130 131 132 133 134
          if ((im_row_idx - padding_up) >= 0 &&
              (im_row_idx - padding_up) < input_height &&
              (im_col_idx - padding_left) >= 0 &&
              (im_col_idx - padding_left) < input_width) {
            im_row_idx += c_im * input_height - padding_up;
            im_col_idx -= padding_left;
H
hedaoyuan 已提交
135 136
            im_data[im_row_idx * input_width + im_col_idx] +=
                col_data[(c * output_height + h) * output_width + w];
H
hedaoyuan 已提交
137 138 139 140 141 142 143
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
144 145 146 147 148 149 150 151
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::CPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::CPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::CPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::CPUPlace, double>;
H
hedaoyuan 已提交
152 153

/*
H
hedaoyuan 已提交
154 155 156
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
157 158
 */
template <class T>
H
hedaoyuan 已提交
159 160
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                    platform::CPUPlace, T> {
H
hedaoyuan 已提交
161
 public:
H
hedaoyuan 已提交
162 163
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& im, framework::Tensor& col,
C
chengduoZH 已提交
164 165
                  int stride_height, int stride_width, int padding_up,
                  int padding_down, int padding_left, int padding_right) {
H
hedaoyuan 已提交
166 167 168 169 170 171 172
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
173
    int output_height = col.dims()[0];
H
hedaoyuan 已提交
174 175
    int output_width = col.dims()[1];

C
chengduoZH 已提交
176 177 178 179 180 181 182 183
    PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
                           stride_height +
                       1 ==
                   output_height);
    PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
                           stride_width +
                       1 ==
                   output_width);
184

H
hedaoyuan 已提交
185 186 187
    const T* im_data = im.data<T>();
    T* col_data = col.data<T>();

C
chengduoZH 已提交
188
    for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
H
hedaoyuan 已提交
189 190 191 192 193 194 195
      for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
        for (int channel = 0; channel < input_channels; ++channel) {
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
              int im_row_offset =
C
chengduoZH 已提交
196
                  col_row_idx * stride_height + filter_row_idx - padding_up;
H
hedaoyuan 已提交
197
              int im_col_offset =
C
chengduoZH 已提交
198 199 200 201 202 203 204 205
                  col_col_idx * stride_width + filter_col_idx - padding_left;
              int col_offset = ((((col_row_idx)*output_width + col_col_idx) *
                                     input_channels +
                                 channel) *
                                    filter_height +
                                filter_row_idx) *
                                   filter_width +
                               filter_col_idx;
H
hedaoyuan 已提交
206 207 208
              if (im_row_offset < 0 || im_row_offset >= input_height ||
                  im_col_offset < 0 || im_col_offset >= input_width) {
                col_data[col_offset] = T(0);
H
hedaoyuan 已提交
209
              } else {
H
hedaoyuan 已提交
210 211 212 213
                int im_offset =
                    (channel * input_height + im_row_offset) * input_width +
                    im_col_offset;
                col_data[col_offset] = im_data[im_offset];
H
hedaoyuan 已提交
214 215 216 217 218 219 220 221 222 223
              }
            }
          }
        }
      }
    }
  }
};

/*
H
hedaoyuan 已提交
224 225 226
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
227 228
 */
template <class T>
H
hedaoyuan 已提交
229 230
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                    platform::CPUPlace, T> {
H
hedaoyuan 已提交
231
 public:
H
hedaoyuan 已提交
232 233
  void operator()(const platform::DeviceContext& context, framework::Tensor& im,
                  const framework::Tensor& col, int stride_height,
C
chengduoZH 已提交
234 235
                  int stride_width, int padding_up, int padding_down,
                  int padding_left, int padding_right) {
H
hedaoyuan 已提交
236 237 238 239 240 241 242
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
243
    int output_height = col.dims()[0];
H
hedaoyuan 已提交
244 245
    int output_width = col.dims()[1];

C
chengduoZH 已提交
246 247 248 249 250 251 252 253
    PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
                           stride_height +
                       1 ==
                   output_height);
    PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
                           stride_width +
                       1 ==
                   output_width);
254

H
hedaoyuan 已提交
255 256 257
    T* im_data = im.data<T>();
    const T* col_data = col.data<T>();

C
chengduoZH 已提交
258
    for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
H
hedaoyuan 已提交
259 260 261 262 263 264
      for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
        for (int channel = 0; channel < input_channels; ++channel) {
          for (int filter_row_idx = 0; filter_row_idx < filter_height;
               ++filter_row_idx) {
            for (int filter_col_idx = 0; filter_col_idx < filter_width;
                 ++filter_col_idx) {
265
              int im_row_offset =  // change or not ???
C
chengduoZH 已提交
266
                  col_row_idx * stride_height + filter_row_idx - padding_up;
H
hedaoyuan 已提交
267
              int im_col_offset =
C
chengduoZH 已提交
268 269 270 271 272 273 274 275
                  col_col_idx * stride_width + filter_col_idx - padding_left;
              int col_offset = (((col_row_idx * output_width + col_col_idx) *
                                     input_channels +
                                 channel) *
                                    filter_height +
                                filter_row_idx) *
                                   filter_width +
                               filter_col_idx;
H
hedaoyuan 已提交
276 277 278 279 280 281
              if (im_row_offset >= 0 && im_row_offset < input_height &&
                  im_col_offset >= 0 && im_col_offset < input_width) {
                int im_offset =
                    (channel * input_height + im_row_offset) * input_width +
                    im_col_offset;
                im_data[im_offset] += col_data[col_offset];
H
hedaoyuan 已提交
282 283 284 285 286 287 288 289 290
              }
            }
          }
        }
      }
    }
  }
};

H
hedaoyuan 已提交
291 292 293 294 295 296 297 298
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::CPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::CPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::CPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::CPUPlace, double>;
H
hedaoyuan 已提交
299

300
}  // namespace math
301
}  // namespace operators
H
hedaoyuan 已提交
302
}  // namespace paddle