im2col.h 3.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <vector>
Y
Yi Wang 已提交
18 19 20
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
H
hedaoyuan 已提交
21 22

namespace paddle {
23
namespace operators {
24
namespace math {
H
hedaoyuan 已提交
25 26

/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
H
hedaoyuan 已提交
27
enum class ColFormat { kCFO = 0, kOCF = 1 };
H
hedaoyuan 已提交
28 29 30 31 32 33 34 35

/*
 * \brief Converts the image data of three dimensions(CHW) into a colData of
 *        five dimensions in the Im2ColFunctor calculation,
 *        And in the Col2ImFunctor calculation, it is reversed.
 *
 * \param imData   Image data.
 * \param imShape  The shape of imData,
H
hedaoyuan 已提交
36
 *                 [input_channels, input_height, input_width].
H
hedaoyuan 已提交
37 38 39
 * \param colData  Column data.
 * \param colShape The shape of colData.
 *
C
chengduoZH 已提交
40 41 42 43 44 45 46 47 48
 * \param dilations    dilation data.
 * \param 2-dimension  [dilation_height, dilation_width].
 *
 * \param strides      stride data.
 * \param 2-dimension  [stride_height, stride_width].
 *
 * \param paddings     padding data.
 * \param 4-dimension  [up_pad, left_pad, down_pad, right_pad].
 *
H
hedaoyuan 已提交
49
 * If the template argument Format is kCFO, the shape of colData is:
H
hedaoyuan 已提交
50
 * [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
51 52 53
 * So, it is easy to reshape into a convolution matrix for convolution
 * calculation based on matrix multiplication.
 * The shape of convolution matrix is [height, width], where the height is equal
H
hedaoyuan 已提交
54 55
 * input_channels * filter_height * filter_width, and the width is equal
 * output_height * output_width.
H
hedaoyuan 已提交
56 57 58
 *
 * Reshape:
 *     shape of colData           shape of convolution matrix
H
hedaoyuan 已提交
59 60 61 62 63
 *     [input_channels,
 *      filter_height,
 *      filter_width,      ======>      [height, width]
 *      output_height,
 *      output_width]
H
hedaoyuan 已提交
64 65
 *
 * If the template argument Format is kOCF, the shape of colData is:
H
hedaoyuan 已提交
66
 * [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
67
 * So, it is easy to reshape into a sequence matrix for rnn calculation.
H
hedaoyuan 已提交
68 69 70
 * The shape of sequence matrix is [seq_length, step_size], where the seq_length
 * is equal output_height * output_width, and the step_size is equal
 * input_channels * filter_height * filter_width.
H
hedaoyuan 已提交
71 72 73
 *
 * Reshape:
 *     shape of colData             shape of sequence matrix
H
hedaoyuan 已提交
74 75 76 77 78
 *     [output_height,
 *      output_width,
 *      input_channels,    ======>    [seqLength, stepSize]
 *      filter_height,
 *      filter_width]
H
hedaoyuan 已提交
79 80 81 82
 *
 * \note The caller needs to ensure that imShape.inputChannels is equal to
 *       colShape.inputChannels.
 */
Q
QI JUN 已提交
83
template <ColFormat Format, typename DeviceContext, typename T>
H
hedaoyuan 已提交
84 85
class Im2ColFunctor {
 public:
Q
QI JUN 已提交
86 87
  void operator()(const DeviceContext& context, const framework::Tensor& im,
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
88 89
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* col);
H
hedaoyuan 已提交
90 91
};

Q
QI JUN 已提交
92
template <ColFormat Format, typename DeviceContext, typename T>
H
hedaoyuan 已提交
93 94
class Col2ImFunctor {
 public:
Q
QI JUN 已提交
95
  void operator()(const DeviceContext& context, const framework::Tensor& col,
C
chengduoZH 已提交
96 97 98
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* im);
H
hedaoyuan 已提交
99 100
};

101
}  // namespace math
102
}  // namespace operators
H
hedaoyuan 已提交
103
}  // namespace paddle