diff --git a/paddle/function/Im2ColOp.cpp b/paddle/function/Im2ColOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b7d1eb1eded7a7471fd5833a649916d3ee3e598e --- /dev/null +++ b/paddle/function/Im2ColOp.cpp @@ -0,0 +1,235 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Im2Col.h" + +namespace paddle { + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputHeight = colShape[3]; + int outputWidth = colShape[4]; + int channelsCol = inputChannels * filterHeight * filterWidth; + + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterWidth; + int hOffset = (c / filterWidth) % filterHeight; + int c_im = c / filterWidth / filterHeight; + for (int h = 0; h < outputHeight; ++h) { + for (int w = 0; w < outputWidth; ++w) { + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { + colData[(c * outputHeight + h) * outputWidth + w] = T(0); + } else { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + colData[(c * outputHeight + h) * outputWidth + w] = + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } + } +}; + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + */ +template +class Col2ImFunctor { +public: + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputHeight = colShape[3]; + int outputWidth = colShape[4]; + int channelsCol = inputChannels * filterHeight * filterWidth; + + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterWidth; + int hOffset = (c / filterWidth) % filterHeight; + int c_im = c / filterWidth / filterHeight; + for (int h = 0; h < outputHeight; ++h) { + for (int w = 0; w < outputWidth; ++w) { + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) >= 0 && + (imRowIdx - paddingHeight) < inputHeight && + (imColIdx - paddingWidth) >= 0 && + (imColIdx - paddingWidth) < inputWidth) { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + imData[imRowIdx * inputWidth + imColIdx] += + colData[(c * outputHeight + h) * outputWidth + w]; + } + } + } + } + } +}; + +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; + for (int outputH = 0; outputH < outputHeight; ++outputH) { + for (int outputW = 0; outputW < outputWidth; ++outputW) { + for (int channel = 0; channel < inputChannels; ++channel) { + for (int filterH = 0; filterH < filterHeight; ++filterH) { + for (int filterW = 0; filterW < filterWidth; ++filterW) { + int imRowOffset = + outputH * strideHeight + filterH - paddingHeight; + int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int colDataOffset = + (((outputH * outputWidth + outputW) * inputChannels + + channel) * + filterHeight + + filterH) * + filterWidth + + filterW; + if (imRowOffset < 0 || imRowOffset >= inputHeight || + imColOffset < 0 || imColOffset >= inputWidth) { + colData[colDataOffset] = float(0); + } else { + int imDataOffset = + (channel * inputHeight + imRowOffset) * inputWidth + + imColOffset; + colData[colDataOffset] = imData[imDataOffset]; + } + } + } + } + } + } + } +}; + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + */ +template +class Col2ImFunctor { +public: + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; + for (int outputH = 0; outputH < outputHeight; ++outputH) { + for (int outputW = 0; outputW < outputWidth; ++outputW) { + for (int channel = 0; channel < inputChannels; ++channel) { + for (int filterH = 0; filterH < filterHeight; ++filterH) { + for (int filterW = 0; filterW < filterWidth; ++filterW) { + int imRowOffset = + outputH * strideHeight + filterH - paddingHeight; + int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int colDataOffset = + (((outputH * outputWidth + outputW) * inputChannels + + channel) * + filterHeight + + filterH) * + filterWidth + + filterW; + if (imRowOffset >= 0 && imRowOffset < inputHeight && + imColOffset >= 0 && imColOffset < inputWidth) { + int imDataOffset = + (channel * inputHeight + imRowOffset) * inputWidth + + imColOffset; + imData[imDataOffset] += colData[colDataOffset]; + } + } + } + } + } + } + } +}; + +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; + +} // namespace paddle diff --git a/paddle/function/Im2ColOpGpu.cu b/paddle/function/Im2ColOpGpu.cu index 361ecc4401a162d2352a9f478aac8dc88c9dcf94..15ba854009636d027447d104071163100d5e3f4b 100644 --- a/paddle/function/Im2ColOpGpu.cu +++ b/paddle/function/Im2ColOpGpu.cu @@ -57,6 +57,11 @@ void im2col(const T* data_im, int numOuts, int height, int width, } } +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + */ template class Im2ColFunctor { public: @@ -71,10 +76,10 @@ public: int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; - int filterHeight = colShape[3]; - int filterWidth = colShape[4]; - int outputHeight = colShape[0]; - int outputWidth = colShape[1]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputHeight = colShape[3]; + int outputWidth = colShape[4]; int numKernels = inputChannels * outputHeight * outputWidth; int blocks = (numKernels + 1024 -1) / 1024; @@ -135,6 +140,11 @@ void col2im(size_t n, const T* data_col, size_t height, } } +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + */ template class Col2ImFunctor { public: @@ -149,10 +159,10 @@ public: int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; - int filterHeight = colShape[3]; - int filterWidth = colShape[4]; - int outputHeight = colShape[0]; - int outputWidth = colShape[1]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputHeight = colShape[3]; + int outputWidth = colShape[4]; size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight) * (inputWidth + 2*paddingWidth); diff --git a/paddle/function/ImageExpandOp.cpp b/paddle/function/ImageExpandOp.cpp index f227f6d0e10facac0ac5a667feb509e156c2b027..625bf5b6edf44148b85bbb09da43da6f210e34b7 100644 --- a/paddle/function/ImageExpandOp.cpp +++ b/paddle/function/ImageExpandOp.cpp @@ -17,114 +17,6 @@ limitations under the License. */ namespace paddle { -/* - * imShape = [inputChannels, inputHeight, inputWidth] - * colShape = - * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] - */ -template -class Im2ColFunctor { -public: - void operator()(const T* imData, - const TensorShape& imShape, - T* colData, - const TensorShape& colShape, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth) { - int inputChannels = imShape[0]; - int inputHeight = imShape[1]; - int inputWidth = imShape[2]; - int filterHeight = colShape[3]; - int filterWidth = colShape[4]; - int outputHeight = colShape[0]; - int outputWidth = colShape[1]; - for (int outputH = 0; outputH < outputHeight; ++outputH) { - for (int outputW = 0; outputW < outputWidth; ++outputW) { - for (int channel = 0; channel < inputChannels; ++channel) { - for (int filterH = 0; filterH < filterHeight; ++filterH) { - for (int filterW = 0; filterW < filterWidth; ++filterW) { - int imRowOffset = - outputH * strideHeight + filterH - paddingHeight; - int imColOffset = outputW * strideWidth + filterW - paddingWidth; - int colDataOffset = - (((outputH * outputWidth + outputW) * inputChannels + - channel) * - filterHeight + - filterH) * - filterWidth + - filterW; - if (imRowOffset < 0 || imRowOffset >= inputHeight || - imColOffset < 0 || imColOffset >= inputWidth) { - colData[colDataOffset] = float(0); - } else { - int imDataOffset = - (channel * inputHeight + imRowOffset) * inputWidth + - imColOffset; - colData[colDataOffset] = imData[imDataOffset]; - } - } - } - } - } - } - } -}; - -/* - * imShape = [inputChannels, inputHeight, inputWidth] - * colShape = - * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] - */ -template -class Col2ImFunctor { -public: - void operator()(T* imData, - const TensorShape& imShape, - const T* colData, - const TensorShape& colShape, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth) { - int inputChannels = imShape[0]; - int inputHeight = imShape[1]; - int inputWidth = imShape[2]; - int filterHeight = colShape[3]; - int filterWidth = colShape[4]; - int outputHeight = colShape[0]; - int outputWidth = colShape[1]; - for (int outputH = 0; outputH < outputHeight; ++outputH) { - for (int outputW = 0; outputW < outputWidth; ++outputW) { - for (int channel = 0; channel < inputChannels; ++channel) { - for (int filterH = 0; filterH < filterHeight; ++filterH) { - for (int filterW = 0; filterW < filterWidth; ++filterW) { - int imRowOffset = - outputH * strideHeight + filterH - paddingHeight; - int imColOffset = outputW * strideWidth + filterW - paddingWidth; - int colDataOffset = - (((outputH * outputWidth + outputW) * inputChannels + - channel) * - filterHeight + - filterH) * - filterWidth + - filterW; - if (imRowOffset >= 0 && imRowOffset < inputHeight && - imColOffset >= 0 && imColOffset < inputWidth) { - int imDataOffset = - (channel * inputHeight + imRowOffset) * inputWidth + - imColOffset; - imData[imDataOffset] += colData[colDataOffset]; - } - } - } - } - } - } - } -}; - /* * \brief Converts the image data of four dimensions(NCHW) into * a sequence data of three dimensions(NST) in the forward calculation,