diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h new file mode 100644 index 0000000000000000000000000000000000000000..d461ec7510b482cd2ce7b7748eb9a9e057f1a8f4 --- /dev/null +++ b/paddle/function/Im2Col.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { + +/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */ +enum ColFormat { kCFO = 0, kOCF = 1 }; + +/* + * \brief Converts the image data of three dimensions(CHW) into a colData of + * five dimensions in the Im2ColFunctor calculation, + * And in the Col2ImFunctor calculation, it is reversed. + * + * \param imData Image data of NCHW format. + * The shape of imData is: + * [inputChannels, inputHeight, inputWidth]. + * \param colData colData data. + * + * If the template argument Format is kCFO, the shape of colData is: + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + * So, it is easy to reshape into a convolution matrix for convolution + * calculation based on matrix multiplication. + * The shape of convolution matrix is [height, width], where the height is equal + * inputChannels * filterHeight * filterWidth, and the width is equal + * outputHeight * outputWidth. + * + * Reshape: + * shape of colData shape of sequence + * [inputChannels, + * filterHeight, + * filterWidth, ======> [seqLength, stepSize] + * outputHeight, + * outputWidth] + * + * If the template argument Format is kOCF, the shape of colData is: + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + * So, it is easy to reshape into a sequence matrix for rnn calculation. + * The shape of sequence matrix is [seqLength, stepSize], where the seqLength + * is equal outputHeight * outputWidth, and the stepSize is equal + * inputChannels * filterHeight * filterWidth. + * + * Reshape: + * shape of colData shape of sequence + * [outputHeight, + * outputWidth, + * inputChannels, ======> [seqLength, stepSize] + * filterHeight, + * filterWidth] + * + * \note The caller needs to ensure that imShape.inputChannels is equal to + * colShape.inputChannels. + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth); +}; + +template +class Col2ImFunctor { +public: + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth); +}; + +} // namespace paddle diff --git a/paddle/function/ImageExpandOp.cpp b/paddle/function/ImageExpandOp.cpp index 4d8c25ffcdafa3dac0d239fa39b28d9714ebf611..ad34967bd65808361a38d1e5b0cc0042ea2df8c8 100644 --- a/paddle/function/ImageExpandOp.cpp +++ b/paddle/function/ImageExpandOp.cpp @@ -13,31 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "Function.h" -#include "GemmConvOp.h" +#include "Im2Col.h" namespace paddle { /* - * imData = [input_channels, input_height, input_width] - * colData = [output_height, output_width, - * input_channels, filter_height, filter_width] + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] */ template class Im2ColFunctor { public: void operator()(const T* imData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* colData) { + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; for (int outputH = 0; outputH < outputHeight; ++outputH) { for (int outputW = 0; outputW < outputWidth; ++outputW) { for (int channel = 0; channel < inputChannels; ++channel) { @@ -55,7 +57,7 @@ public: filterW; if (imRowOffset < 0 || imRowOffset >= inputHeight || imColOffset < 0 || imColOffset >= inputWidth) { - colData[colDataOffset] = T(0); + colData[colDataOffset] = float(0); } else { int imDataOffset = (channel * inputHeight + imRowOffset) * inputWidth + @@ -70,22 +72,29 @@ public: } }; +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + */ template class Col2ImFunctor { public: - void operator()(const T* colData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* imData) { + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; for (int outputH = 0; outputH < outputHeight; ++outputH) { for (int outputW = 0; outputW < outputWidth; ++outputW) { for (int channel = 0; channel < inputChannels; ++channel) { @@ -146,7 +155,7 @@ public: virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} - void check(const TensorShape& image, const TensorShape& sequence) { + void check(const TensorShape& image, const TensorShape& sequence) const { // image shape should be 4-dimensional. CHECK_EQ(image.ndims(), (size_t)4); // sequence shape should be 3-dimensional. @@ -159,7 +168,7 @@ public: // Calculate the shape of colData based on the shape of the image // and the shape of the sequence. TensorShape getColShape(const TensorShape& image, - const TensorShape& sequence) { + const TensorShape& sequence) const { size_t inputChannels = image[1]; size_t inputHeight = image[2]; size_t inputWidth = image[3]; @@ -174,8 +183,7 @@ public: CHECK_EQ(seqLength, outputHeight * outputWidth); CHECK_EQ(stepSize, inputChannels * blockH() * blockW()); - // [output_height, output_width, - // input_channels, filter_height, filter_width] + // [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] return TensorShape({outputHeight, outputWidth, inputChannels, @@ -215,40 +223,29 @@ public: const TensorShape& sequence = outputs[0].shape(); check(image, sequence); + TensorShape imShape = TensorShape({image[1], image[2], image[3]}); TensorShape colShape = getColShape(image, sequence); size_t batchSize = image[0]; - size_t inputChannels = image[1]; - size_t inputHeight = image[2]; - size_t inputWidth = image[3]; - size_t seqLength = sequence[1]; - size_t stepSize = sequence[2]; - size_t outputHeight = colShape[0]; - size_t outputWidth = colShape[1]; real* imageData = inputs[0].data(); real* seqData = outputs[0].data(); Im2ColFunctor im2col; for (size_t i = 0; i < batchSize; i++) { - // The result of im2col is [output_height, output_width, - // input_channels, filter_height, filter_width], and it is easy to + // The result of im2col is [outputHeight, outputWidth, + // inputChannels, filterHeight, filterWidth], and it is easy to // reshape into [seqLength, stepSize], where seqLength is equal // output_height * output_width, stepSize is equal // input_channels * filter_height * filter_width im2col(imageData, - inputChannels, - inputHeight, - inputWidth, - blockH(), - blockW(), + imShape, + seqData, + colShape, strideH(), strideW(), paddingH(), - paddingW(), - outputHeight, - outputWidth, - seqData); - imageData += inputChannels * inputHeight * inputWidth; - seqData += seqLength * stepSize; + paddingW()); + imageData += imShape.getElements(); + seqData += colShape.getElements(); } } }; @@ -270,35 +267,24 @@ public: const TensorShape& sequence = inputs[0].shape(); check(image, sequence); + TensorShape imShape = TensorShape({image[1], image[2], image[3]}); TensorShape colShape = getColShape(image, sequence); size_t batchSize = image[0]; - size_t inputChannels = image[1]; - size_t inputHeight = image[2]; - size_t inputWidth = image[3]; - size_t seqLength = sequence[1]; - size_t stepSize = sequence[2]; - size_t outputHeight = colShape[0]; - size_t outputWidth = colShape[1]; real* imageData = outputs[0].data(); real* seqData = inputs[0].data(); Col2ImFunctor col2im; for (size_t i = 0; i < batchSize; i++) { - col2im(seqData, - inputChannels, - inputHeight, - inputWidth, - blockH(), - blockW(), + col2im(imageData, + imShape, + seqData, + colShape, strideH(), strideW(), paddingH(), - paddingW(), - outputHeight, - outputWidth, - imageData); - imageData += inputChannels * inputHeight * inputWidth; - seqData += seqLength * stepSize; + paddingW()); + imageData += imShape.getElements(); + seqData += colShape.getElements(); } } };