im2col.h 4.0 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <vector>
18

19 20 21 22
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
H
hedaoyuan 已提交
23

24 25
namespace phi {
namespace funcs {
H
hedaoyuan 已提交
26

27
using DataLayout = phi::DataLayout;
28

H
hedaoyuan 已提交
29
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
H
hedaoyuan 已提交
30
enum class ColFormat { kCFO = 0, kOCF = 1 };
H
hedaoyuan 已提交
31 32 33 34 35 36 37 38

/*
 * \brief Converts the image data of three dimensions(CHW) into a colData of
 *        five dimensions in the Im2ColFunctor calculation,
 *        And in the Col2ImFunctor calculation, it is reversed.
 *
 * \param imData   Image data.
 * \param imShape  The shape of imData,
H
hedaoyuan 已提交
39
 *                 [input_channels, input_height, input_width].
H
hedaoyuan 已提交
40 41 42
 * \param colData  Column data.
 * \param colShape The shape of colData.
 *
C
chengduoZH 已提交
43 44 45 46 47 48 49 50 51
 * \param dilations    dilation data.
 * \param 2-dimension  [dilation_height, dilation_width].
 *
 * \param strides      stride data.
 * \param 2-dimension  [stride_height, stride_width].
 *
 * \param paddings     padding data.
 * \param 4-dimension  [up_pad, left_pad, down_pad, right_pad].
 *
H
hedaoyuan 已提交
52
 * If the template argument Format is kCFO, the shape of colData is:
H
hedaoyuan 已提交
53
 * [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
54 55 56
 * So, it is easy to reshape into a convolution matrix for convolution
 * calculation based on matrix multiplication.
 * The shape of convolution matrix is [height, width], where the height is equal
H
hedaoyuan 已提交
57 58
 * input_channels * filter_height * filter_width, and the width is equal
 * output_height * output_width.
H
hedaoyuan 已提交
59 60 61
 *
 * Reshape:
 *     shape of colData           shape of convolution matrix
H
hedaoyuan 已提交
62 63 64 65 66
 *     [input_channels,
 *      filter_height,
 *      filter_width,      ======>      [height, width]
 *      output_height,
 *      output_width]
H
hedaoyuan 已提交
67 68
 *
 * If the template argument Format is kOCF, the shape of colData is:
H
hedaoyuan 已提交
69
 * [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
70
 * So, it is easy to reshape into a sequence matrix for rnn calculation.
H
hedaoyuan 已提交
71 72 73
 * The shape of sequence matrix is [seq_length, step_size], where the seq_length
 * is equal output_height * output_width, and the step_size is equal
 * input_channels * filter_height * filter_width.
H
hedaoyuan 已提交
74 75 76
 *
 * Reshape:
 *     shape of colData             shape of sequence matrix
H
hedaoyuan 已提交
77 78 79 80 81
 *     [output_height,
 *      output_width,
 *      input_channels,    ======>    [seqLength, stepSize]
 *      filter_height,
 *      filter_width]
H
hedaoyuan 已提交
82 83 84 85
 *
 * \note The caller needs to ensure that imShape.inputChannels is equal to
 *       colShape.inputChannels.
 */
Q
QI JUN 已提交
86
template <ColFormat Format, typename DeviceContext, typename T>
H
hedaoyuan 已提交
87 88
class Im2ColFunctor {
 public:
89
  void operator()(const DeviceContext& context,
90
                  const phi::DenseTensor& im,
Q
QI JUN 已提交
91
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
92
                  const std::vector<int>& stride,
93
                  const std::vector<int>& padding,
94
                  phi::DenseTensor* col,
95
                  const DataLayout data_layout = DataLayout::kNCHW);
H
hedaoyuan 已提交
96 97
};

Q
QI JUN 已提交
98
template <ColFormat Format, typename DeviceContext, typename T>
H
hedaoyuan 已提交
99 100
class Col2ImFunctor {
 public:
101
  void operator()(const DeviceContext& context,
102
                  const phi::DenseTensor& col,
C
chengduoZH 已提交
103 104
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
105
                  const std::vector<int>& padding,
106
                  phi::DenseTensor* im,
107
                  const DataLayout data_layout = DataLayout::kNCHW);
H
hedaoyuan 已提交
108 109
};

110 111
}  // namespace funcs
}  // namespace phi