vol2col.cc 9.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/vol2col.h"
A
Abhinav Arora 已提交
16
#include <vector>
C
chengduoZH 已提交
17 18 19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace operators {
namespace math {

/*
 * vol = [input_channels, input_depth, input_height, input_width]
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
Q
QI JUN 已提交
29
class Vol2ColFunctor<platform::CPUDeviceContext, T> {
C
chengduoZH 已提交
30
 public:
Q
QI JUN 已提交
31
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
32 33 34 35 36
                  const framework::Tensor& vol,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  framework::Tensor* col) const {
L
liym27 已提交
37 38 39 40
    PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
                      "The dimension of vol should be 4.");
    PADDLE_ENFORCE_EQ(col->dims().size(), 7,
                      "The dimension of col should be 7.");
C
chengduoZH 已提交
41 42 43 44
    int input_channels = vol.dims()[0];
    int input_depth = vol.dims()[1];
    int input_height = vol.dims()[2];
    int input_width = vol.dims()[3];
C
chengduoZH 已提交
45 46 47 48 49 50
    int filter_depth = col->dims()[1];
    int filter_height = col->dims()[2];
    int filter_width = col->dims()[3];
    int output_depth = col->dims()[4];
    int output_height = col->dims()[5];
    int output_width = col->dims()[6];
C
chengduoZH 已提交
51 52 53
    int channels_col =
        input_channels * filter_depth * filter_height * filter_width;

L
liym27 已提交
54 55 56 57 58 59 60 61 62
    // changed
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
    PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
C
chengduoZH 已提交
63 64
                       ((dilations[0] * (filter_depth - 1) + 1))) /
                              strides[0] +
C
chengduoZH 已提交
65 66 67
                          1,
                      output_depth,
                      "input_depth and output_depth are "
C
chengduoZH 已提交
68
                      "mismatching.");
L
liym27 已提交
69
    PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
C
chengduoZH 已提交
70 71
                       ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
C
chengduoZH 已提交
72 73 74
                          1,
                      output_height,
                      "input_height and output_height are "
C
chengduoZH 已提交
75
                      "mismatching.");
L
liym27 已提交
76
    PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
C
chengduoZH 已提交
77 78
                       ((dilations[2] * (filter_width - 1) + 1))) /
                              strides[2] +
C
chengduoZH 已提交
79 80 81
                          1,
                      output_width,
                      "input_width and output_width are "
C
chengduoZH 已提交
82
                      "mismatching.");
C
chengduoZH 已提交
83
    const T* vol_data = vol.data<T>();
C
chengduoZH 已提交
84
    T* col_data = col->data<T>();
C
chengduoZH 已提交
85 86 87 88 89 90 91

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int d_offset = (c / filter_width / filter_height) % filter_depth;
      int c_in = c / filter_width / filter_height / filter_depth;
      for (int d = 0; d < output_depth; ++d) {
L
liym27 已提交
92
        int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
C
chengduoZH 已提交
93
        for (int h = 0; h < output_height; ++h) {
L
liym27 已提交
94
          int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
C
chengduoZH 已提交
95
          for (int w = 0; w < output_width; ++w) {
L
liym27 已提交
96
            int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
C
chengduoZH 已提交
97 98 99

            int col_idx =
                ((c * output_depth + d) * output_height + h) * output_width + w;
C
chengduoZH 已提交
100 101 102 103 104 105 106 107 108
            int vol_idx =
                ((c_in * input_depth + d_pad) * input_height + h_pad) *
                    input_width +
                w_pad;
            col_data[col_idx] =
                (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
                 w_pad >= input_width || d_pad < 0 || d_pad >= input_depth)
                    ? static_cast<T>(0)
                    : vol_data[vol_idx];
C
chengduoZH 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122
          }
        }
      }
    }
  }
};

/*
 * vol = [input_channels,input_depth, input_height, input_width]
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
Q
QI JUN 已提交
123
class Col2VolFunctor<platform::CPUDeviceContext, T> {
C
chengduoZH 已提交
124
 public:
Q
QI JUN 已提交
125
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
126 127 128 129 130
                  const framework::Tensor& col,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  framework::Tensor* vol) const {
L
liym27 已提交
131 132 133 134
    PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
                      "The dimension of vol should be 4.");
    PADDLE_ENFORCE_EQ(col.dims().size(), 7,
                      "The dimension of col should be 7.");
C
chengduoZH 已提交
135 136 137 138
    int input_channels = vol->dims()[0];
    int input_depth = vol->dims()[1];
    int input_height = vol->dims()[2];
    int input_width = vol->dims()[3];
C
chengduoZH 已提交
139 140 141 142 143 144 145 146 147
    int filter_depth = col.dims()[1];
    int filter_height = col.dims()[2];
    int filter_width = col.dims()[3];
    int output_depth = col.dims()[4];
    int output_height = col.dims()[5];
    int output_width = col.dims()[6];
    int channels_col =
        input_channels * filter_depth * filter_height * filter_width;

L
liym27 已提交
148 149 150 151 152 153 154 155 156
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];

    PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
C
chengduoZH 已提交
157 158
                       ((dilations[0] * (filter_depth - 1) + 1))) /
                              strides[0] +
C
chengduoZH 已提交
159 160 161
                          1,
                      output_depth,
                      "input_depth and output_depth are "
C
chengduoZH 已提交
162
                      "mismatching.");
L
liym27 已提交
163
    PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
C
chengduoZH 已提交
164 165
                       ((dilations[1] * (filter_height - 1) + 1))) /
                              strides[1] +
C
chengduoZH 已提交
166 167 168
                          1,
                      output_height,
                      "input_height and output_height are "
C
chengduoZH 已提交
169
                      "mismatching.");
L
liym27 已提交
170
    PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
C
chengduoZH 已提交
171 172
                       ((dilations[2] * (filter_width - 1) + 1))) /
                              strides[2] +
C
chengduoZH 已提交
173 174 175
                          1,
                      output_width,
                      "input_width and output_width are "
C
chengduoZH 已提交
176 177
                      "mismatching.");
    T* vol_data = vol->data<T>();
C
chengduoZH 已提交
178 179 180 181 182 183 184 185
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int d_offset = (c / filter_width / filter_height) % filter_depth;
      int cIm = c / filter_width / filter_height / filter_depth;
      for (int d = 0; d < output_depth; ++d) {
L
liym27 已提交
186
        int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
C
chengduoZH 已提交
187
        for (int h = 0; h < output_height; ++h) {
L
liym27 已提交
188
          int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
C
chengduoZH 已提交
189
          for (int w = 0; w < output_width; ++w) {
L
liym27 已提交
190
            int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
C
chengduoZH 已提交
191 192 193 194 195 196 197

            if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
                w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
              int vol_idx =
                  ((cIm * input_depth + d_pad) * input_height + h_pad) *
                      input_width +
                  w_pad;
C
chengduoZH 已提交
198

C
chengduoZH 已提交
199 200 201 202 203 204 205 206 207 208 209 210
              int col_idx =
                  ((c * output_depth + d) * output_height + h) * output_width +
                  w;
              vol_data[vol_idx] += col_data[col_idx];
            }
          }
        }
      }
    }
  }
};

Q
QI JUN 已提交
211 212 213 214
template class Vol2ColFunctor<platform::CPUDeviceContext, float>;
template class Vol2ColFunctor<platform::CPUDeviceContext, double>;
template class Col2VolFunctor<platform::CPUDeviceContext, float>;
template class Col2VolFunctor<platform::CPUDeviceContext, double>;
C
chengduoZH 已提交
215 216 217 218

}  // namespace math
}  // namespace operators
}  // namespace paddle