vol2col.cc 21.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/vol2col.h"
W
wanghuancoder 已提交
16

H
hong 已提交
17 18
#include "paddle/phi/backends/cpu/cpu_context.h"

W
wanghuancoder 已提交
19 20 21 22 23
namespace paddle {
namespace platform {
class CPUDeviceContext;
}  // namespace platform
}  // namespace paddle
C
chengduoZH 已提交
24 25 26 27 28 29 30 31 32 33 34 35

namespace paddle {
namespace operators {
namespace math {

/*
 * vol = [input_channels, input_depth, input_height, input_width]
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
Q
QI JUN 已提交
36
class Vol2ColFunctor<platform::CPUDeviceContext, T> {
C
chengduoZH 已提交
37
 public:
Q
QI JUN 已提交
38
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
39 40 41
                  const framework::Tensor& vol,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
42 43
                  const std::vector<int>& paddings, framework::Tensor* col,
                  const DataLayout data_layout) const {
44 45 46 47 48 49 50 51 52
    PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
                      platform::errors::InvalidArgument(
                          "The dimension of vol should be 4, but received %d.",
                          vol.dims().size()));

    PADDLE_ENFORCE_EQ(col->dims().size(), 7,
                      platform::errors::InvalidArgument(
                          "The dimension of col should be 7, but received %d.",
                          col->dims().size()));
53 54

    int input_channels =
55
        (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
56
    int input_depth =
57
        (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
58
    int input_height =
59
        (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
60
    int input_width =
61
        (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
C
chengduoZH 已提交
62 63 64 65 66 67
    int filter_depth = col->dims()[1];
    int filter_height = col->dims()[2];
    int filter_width = col->dims()[3];
    int output_depth = col->dims()[4];
    int output_height = col->dims()[5];
    int output_width = col->dims()[6];
C
chengduoZH 已提交
68 69 70
    int channels_col =
        input_channels * filter_depth * filter_height * filter_width;

L
liym27 已提交
71 72 73 74 75 76 77 78
    // changed
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
79

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
                            ((dilations[0] * (filter_depth - 1) + 1))) /
                               strides[0] +
                           1;
    PADDLE_ENFORCE_EQ(
        input_depth_tmp, output_depth,
        platform::errors::InvalidArgument(
            "input_depth(%d) and output_depth(%d) are mismatching.",
            input_depth_tmp, output_depth));
    auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
                             ((dilations[1] * (filter_height - 1) + 1))) /
                                strides[1] +
                            1;
    PADDLE_ENFORCE_EQ(
        input_height_tmp, output_height,
        platform::errors::InvalidArgument(
            "input_height(%d) and output_height(%d) are mismatching.",
            input_height_tmp, output_height));
    auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
                            ((dilations[2] * (filter_width - 1) + 1))) /
                               strides[2] +
                           1;
    PADDLE_ENFORCE_EQ(
        input_width_tmp, output_width,
        platform::errors::InvalidArgument(
            "input_width(%d) and output_width(%d) are mismatching.",
            input_width_tmp, output_width));
C
chengduoZH 已提交
107
    const T* vol_data = vol.data<T>();
C
chengduoZH 已提交
108
    T* col_data = col->data<T>();
C
chengduoZH 已提交
109 110 111 112 113 114 115

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int d_offset = (c / filter_width / filter_height) % filter_depth;
      int c_in = c / filter_width / filter_height / filter_depth;
      for (int d = 0; d < output_depth; ++d) {
L
liym27 已提交
116
        int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
C
chengduoZH 已提交
117
        for (int h = 0; h < output_height; ++h) {
L
liym27 已提交
118
          int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
C
chengduoZH 已提交
119
          for (int w = 0; w < output_width; ++w) {
L
liym27 已提交
120
            int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
C
chengduoZH 已提交
121 122 123

            int col_idx =
                ((c * output_depth + d) * output_height + h) * output_width + w;
124
            int vol_idx;
125
            if (data_layout != DataLayout::kNHWC) {
126 127 128 129 130 131 132 133
              vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) *
                            input_width +
                        w_pad;
            } else {
              vol_idx = ((d_pad * input_height + h_pad) * input_width + w_pad) *
                            input_channels +
                        c_in;
            }
C
chengduoZH 已提交
134 135 136 137 138
            col_data[col_idx] =
                (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
                 w_pad >= input_width || d_pad < 0 || d_pad >= input_depth)
                    ? static_cast<T>(0)
                    : vol_data[vol_idx];
C
chengduoZH 已提交
139 140 141 142 143 144 145
          }
        }
      }
    }
  }
};

H
hong 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
template <class T>
class Vol2ColFunctor<phi::CPUContext, T> {
 public:
  void operator()(const phi::CPUContext& context, const framework::Tensor& vol,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* col,
                  const DataLayout data_layout) const {
    PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
                      platform::errors::InvalidArgument(
                          "The dimension of vol should be 4, but received %d.",
                          vol.dims().size()));

    PADDLE_ENFORCE_EQ(col->dims().size(), 7,
                      platform::errors::InvalidArgument(
                          "The dimension of col should be 7, but received %d.",
                          col->dims().size()));

    int input_channels =
        (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
    int input_depth =
        (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
    int input_height =
        (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
    int input_width =
        (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
    int filter_depth = col->dims()[1];
    int filter_height = col->dims()[2];
    int filter_width = col->dims()[3];
    int output_depth = col->dims()[4];
    int output_height = col->dims()[5];
    int output_width = col->dims()[6];
    int channels_col =
        input_channels * filter_depth * filter_height * filter_width;

    // changed
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];

    auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
                            ((dilations[0] * (filter_depth - 1) + 1))) /
                               strides[0] +
                           1;
    PADDLE_ENFORCE_EQ(
        input_depth_tmp, output_depth,
        platform::errors::InvalidArgument(
            "input_depth(%d) and output_depth(%d) are mismatching.",
            input_depth_tmp, output_depth));
    auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
                             ((dilations[1] * (filter_height - 1) + 1))) /
                                strides[1] +
                            1;
    PADDLE_ENFORCE_EQ(
        input_height_tmp, output_height,
        platform::errors::InvalidArgument(
            "input_height(%d) and output_height(%d) are mismatching.",
            input_height_tmp, output_height));
    auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
                            ((dilations[2] * (filter_width - 1) + 1))) /
                               strides[2] +
                           1;
    PADDLE_ENFORCE_EQ(
        input_width_tmp, output_width,
        platform::errors::InvalidArgument(
            "input_width(%d) and output_width(%d) are mismatching.",
            input_width_tmp, output_width));
    const T* vol_data = vol.data<T>();
    T* col_data = col->data<T>();

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int d_offset = (c / filter_width / filter_height) % filter_depth;
      int c_in = c / filter_width / filter_height / filter_depth;
      for (int d = 0; d < output_depth; ++d) {
        int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
        for (int h = 0; h < output_height; ++h) {
          int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
          for (int w = 0; w < output_width; ++w) {
            int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];

            int col_idx =
                ((c * output_depth + d) * output_height + h) * output_width + w;
            int vol_idx;
            if (data_layout != DataLayout::kNHWC) {
              vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) *
                            input_width +
                        w_pad;
            } else {
              vol_idx = ((d_pad * input_height + h_pad) * input_width + w_pad) *
                            input_channels +
                        c_in;
            }
            col_data[col_idx] =
                (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
                 w_pad >= input_width || d_pad < 0 || d_pad >= input_depth)
                    ? static_cast<T>(0)
                    : vol_data[vol_idx];
          }
        }
      }
    }
  }
};

C
chengduoZH 已提交
256 257 258 259 260 261 262
/*
 * vol = [input_channels,input_depth, input_height, input_width]
 * col =
 *   [input_channels, filter_depth, filter_height, filter_width,
 *                    output_depth, output_height, output_width]
 */
template <class T>
Q
QI JUN 已提交
263
class Col2VolFunctor<platform::CPUDeviceContext, T> {
C
chengduoZH 已提交
264
 public:
Q
QI JUN 已提交
265
  void operator()(const platform::CPUDeviceContext& context,
C
chengduoZH 已提交
266 267 268
                  const framework::Tensor& col,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
269 270
                  const std::vector<int>& paddings, framework::Tensor* vol,
                  const DataLayout data_layout) const {
271 272 273 274 275 276 277 278 279
    PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
                      platform::errors::InvalidArgument(
                          "The dimension of vol should be 4, but received %d.",
                          vol->dims().size()));

    PADDLE_ENFORCE_EQ(col.dims().size(), 7,
                      platform::errors::InvalidArgument(
                          "The dimension of col  should be 7, but received %d.",
                          col.dims().size()));
280 281

    int input_channels =
282
        (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
283
    int input_depth =
284
        (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
285
    int input_height =
286
        (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
287
    int input_width =
288
        (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
C
chengduoZH 已提交
289 290 291 292 293 294 295 296 297
    int filter_depth = col.dims()[1];
    int filter_height = col.dims()[2];
    int filter_width = col.dims()[3];
    int output_depth = col.dims()[4];
    int output_height = col.dims()[5];
    int output_width = col.dims()[6];
    int channels_col =
        input_channels * filter_depth * filter_height * filter_width;

L
liym27 已提交
298 299 300 301 302 303 304 305
    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];

306 307 308 309
    auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
                            ((dilations[0] * (filter_depth - 1) + 1))) /
                               strides[0] +
                           1;
310 311 312 313 314
    PADDLE_ENFORCE_EQ(
        input_depth_tmp, output_depth,
        platform::errors::InvalidArgument(
            "input_depth(%d) and output_depth(%d) are mismatching.",
            input_depth_tmp, output_depth));
315 316 317 318
    auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
                             ((dilations[1] * (filter_height - 1) + 1))) /
                                strides[1] +
                            1;
319 320 321 322 323
    PADDLE_ENFORCE_EQ(
        input_height_tmp, output_height,
        platform::errors::InvalidArgument(
            "input_height(%d) and output_height(%d) are mismatching.",
            input_height_tmp, output_height));
324 325 326 327
    auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
                            ((dilations[2] * (filter_width - 1) + 1))) /
                               strides[2] +
                           1;
328 329 330 331 332
    PADDLE_ENFORCE_EQ(
        input_width_tmp, output_width,
        platform::errors::InvalidArgument(
            "input_width(%d)  and output_width(%d) are mismatching.",
            input_width_tmp, output_width));
C
chengduoZH 已提交
333
    T* vol_data = vol->data<T>();
C
chengduoZH 已提交
334 335 336 337 338 339 340 341
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int d_offset = (c / filter_width / filter_height) % filter_depth;
      int cIm = c / filter_width / filter_height / filter_depth;
      for (int d = 0; d < output_depth; ++d) {
L
liym27 已提交
342
        int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
C
chengduoZH 已提交
343
        for (int h = 0; h < output_height; ++h) {
L
liym27 已提交
344
          int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
C
chengduoZH 已提交
345
          for (int w = 0; w < output_width; ++w) {
L
liym27 已提交
346
            int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
C
chengduoZH 已提交
347 348 349

            if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
                w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
350
              int vol_idx;
351
              if (data_layout != DataLayout::kNHWC) {
352 353 354 355 356 357 358 359 360
                vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) *
                              input_width +
                          w_pad;
              } else {
                vol_idx =
                    ((d_pad * input_height + h_pad) * input_width + w_pad) *
                        input_channels +
                    cIm;
              }
C
chengduoZH 已提交
361 362 363 364 365 366 367 368 369 370 371 372
              int col_idx =
                  ((c * output_depth + d) * output_height + h) * output_width +
                  w;
              vol_data[vol_idx] += col_data[col_idx];
            }
          }
        }
      }
    }
  }
};

H
hong 已提交
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
template <class T>
class Col2VolFunctor<phi::CPUContext, T> {
 public:
  void operator()(const phi::CPUContext& context, const framework::Tensor& col,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* vol,
                  const DataLayout data_layout) const {
    PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
                      platform::errors::InvalidArgument(
                          "The dimension of vol should be 4, but received %d.",
                          vol->dims().size()));

    PADDLE_ENFORCE_EQ(col.dims().size(), 7,
                      platform::errors::InvalidArgument(
                          "The dimension of col  should be 7, but received %d.",
                          col.dims().size()));

    int input_channels =
        (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
    int input_depth =
        (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
    int input_height =
        (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
    int input_width =
        (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
    int filter_depth = col.dims()[1];
    int filter_height = col.dims()[2];
    int filter_width = col.dims()[3];
    int output_depth = col.dims()[4];
    int output_height = col.dims()[5];
    int output_width = col.dims()[6];
    int channels_col =
        input_channels * filter_depth * filter_height * filter_width;

    bool paddings_size_is_6 = (paddings.size() == 6);
    int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
    int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
    int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
    int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
    int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
    int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];

    auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
                            ((dilations[0] * (filter_depth - 1) + 1))) /
                               strides[0] +
                           1;
    PADDLE_ENFORCE_EQ(
        input_depth_tmp, output_depth,
        platform::errors::InvalidArgument(
            "input_depth(%d) and output_depth(%d) are mismatching.",
            input_depth_tmp, output_depth));
    auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
                             ((dilations[1] * (filter_height - 1) + 1))) /
                                strides[1] +
                            1;
    PADDLE_ENFORCE_EQ(
        input_height_tmp, output_height,
        platform::errors::InvalidArgument(
            "input_height(%d) and output_height(%d) are mismatching.",
            input_height_tmp, output_height));
    auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
                            ((dilations[2] * (filter_width - 1) + 1))) /
                               strides[2] +
                           1;
    PADDLE_ENFORCE_EQ(
        input_width_tmp, output_width,
        platform::errors::InvalidArgument(
            "input_width(%d)  and output_width(%d) are mismatching.",
            input_width_tmp, output_width));
    T* vol_data = vol->data<T>();
    const T* col_data = col.data<T>();

    for (int c = 0; c < channels_col; ++c) {
      int w_offset = c % filter_width;
      int h_offset = (c / filter_width) % filter_height;
      int d_offset = (c / filter_width / filter_height) % filter_depth;
      int cIm = c / filter_width / filter_height / filter_depth;
      for (int d = 0; d < output_depth; ++d) {
        int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
        for (int h = 0; h < output_height; ++h) {
          int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
          for (int w = 0; w < output_width; ++w) {
            int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];

            if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
                w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
              int vol_idx;
              if (data_layout != DataLayout::kNHWC) {
                vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) *
                              input_width +
                          w_pad;
              } else {
                vol_idx =
                    ((d_pad * input_height + h_pad) * input_width + w_pad) *
                        input_channels +
                    cIm;
              }
              int col_idx =
                  ((c * output_depth + d) * output_height + h) * output_width +
                  w;
              vol_data[vol_idx] += col_data[col_idx];
            }
          }
        }
      }
    }
  }
};

Q
QI JUN 已提交
483 484
template class Vol2ColFunctor<platform::CPUDeviceContext, float>;
template class Vol2ColFunctor<platform::CPUDeviceContext, double>;
H
hong 已提交
485 486 487
template class Vol2ColFunctor<phi::CPUContext, float>;
template class Vol2ColFunctor<phi::CPUContext, double>;

Q
QI JUN 已提交
488 489
template class Col2VolFunctor<platform::CPUDeviceContext, float>;
template class Col2VolFunctor<platform::CPUDeviceContext, double>;
H
hong 已提交
490 491
template class Col2VolFunctor<phi::CPUContext, float>;
template class Col2VolFunctor<phi::CPUContext, double>;
C
chengduoZH 已提交
492 493 494 495

}  // namespace math
}  // namespace operators
}  // namespace paddle