From 48e0f432537a97e915306601a8b5c8b72a77b6d1 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 12 Jun 2017 21:22:15 +0800 Subject: [PATCH] Add ImageExpandFunction. --- paddle/function/GemmConvOp.h | 84 +++++++++++++++ paddle/function/ImageExpandOp.cpp | 164 ++++++++++++++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 paddle/function/GemmConvOp.h create mode 100644 paddle/function/ImageExpandOp.cpp diff --git a/paddle/function/GemmConvOp.h b/paddle/function/GemmConvOp.h new file mode 100644 index 000000000..25d2e220b --- /dev/null +++ b/paddle/function/GemmConvOp.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// #include "ConvOp.h" + +namespace paddle { + +/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */ +enum ColFormat { kCFO = 0, kOCF = 1 }; + +/* + * \brief Converts the image data of four dimensions(NCHW) into a colData. + * Then you can reshape colData to a convolution matrix for + * convolution calculation based on matrix multiplication. + * + * \param imData Image data of NCHW format. + * The format of imData is: + * [input_channels, input_height, input_width]. + * \param colData colData data. + * If the template argument Format is kCFO, + * the format of colData is: + * [input_channels, + * filter_height, + * filter_width, + * output_height, + * output_width] + * If the template argument Format is kOCF, + * the format of colData is: + * [output_height, + * output_width, + * input_channels, + * filter_height, + * filter_width] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* colData); +}; + +template +class Col2ImFunctor { +public: + void operator()(const T* colData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* imData); +}; + +} // namespace paddle diff --git a/paddle/function/ImageExpandOp.cpp b/paddle/function/ImageExpandOp.cpp new file mode 100644 index 000000000..426b6c8e3 --- /dev/null +++ b/paddle/function/ImageExpandOp.cpp @@ -0,0 +1,164 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" +#include "GemmConvOp.h" + +namespace paddle { + +/* + * imData = [input_channels, input_height, input_width] + * colData = [output_height, output_width, + * input_channels, filter_height, filter_width] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* colData) { + for (int outputH = 0; outputH < outputHeight; ++outputH) { + for (int outputW = 0; outputW < outputWidth; ++outputW) { + for (int channel = 0; channel < inputChannels; ++channel) { + for (int filterH = 0; filterH < filterHeight; ++filterH) { + for (int filterW = 0; filterW < filterWidth; ++filterW) { + int imRowOffset = + outputH * strideHeight + filterH - paddingHeight; + int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int colDataOffset = + (((outputH * outputWidth + outputW) * inputChannels + + channel) * + filterHeight + + filterH) * + filterWidth + + filterW; + if (imRowOffset < 0 || imRowOffset >= inputHeight || + imColOffset < 0 || imColOffset >= inputWidth) { + colData[colDataOffset] = T(0); + } else { + int imDataOffset = + (channel * inputHeight + imRowOffset) * inputWidth + + imColOffset; + colData[colDataOffset] = imData[imDataOffset]; + } + } + } + } + } + } + } +}; + +/* + * \brief Converts the image data of four dimensions(NCHW) into + * a sequence data of three dimensions(NST). Where N is batch size, + * S is the length of the sequence after each image is expanded, + * T is the size of each time step in the sequence. + * + * \param inputs[0] Image data of NCHW format. + * \param outputs[0] Sequence data of NST format. + */ +template +class ImageExpandFunction : public FunctionBase { +public: + void init(const FuncConfig& config) override { + // function arguments + strides_ = config.get>("strides"); + paddings_ = config.get>("paddings"); + blocks_ = config.get>("blocks"); + + // number of inputs and outputs + numInputs_ = 1; + numOutputs_ = 1; + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + const TensorShape& input = inputs[0].shape(); + const TensorShape& output = outputs[0].shape(); + // input argument should be 4-dimensional. + CHECK_EQ(input.ndims(), (size_t)4); + // output argument should be 3-dimensional. + CHECK_EQ(output.ndims(), (size_t)3); + // The batchSize of the input needs to be equal to + // the batchSize of the output. + CHECK_EQ(input[0], output[0]); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t seqLength = output[1]; + size_t stepSize = output[2]; + size_t outputHeight = + 1 + + (inputHeight + 2 * paddingH() - blockH() + strideH() - 1) / strideH(); + size_t outputWidth = + 1 + + (inputWidth + 2 * paddingW() - blockW() + strideW() - 1) / strideW(); + CHECK_EQ(seqLength, outputHeight * outputWidth); + CHECK_EQ(stepSize, inputChannels * blockH() * blockH()); + + real* inputData = inputs[0].data(); + real* outputData = outputs[0].data(); + Im2ColFunctor im2col; + for (size_t i = 0; i < batchSize; i++) { + im2col(inputData, + inputChannels, + inputHeight, + inputWidth, + blockH(), + blockW(), + strideH(), + strideW(), + paddingH(), + paddingW(), + outputHeight, + outputWidth, + outputData); + inputData += inputChannels * inputHeight * inputWidth; + outputData += seqLength * stepSize; + } + } + +protected: + std::vector strides_; + std::vector paddings_; + std::vector blocks_; + + inline int strideH() const { return strides_[0]; } + + inline int strideW() const { return strides_[1]; } + + inline int paddingH() const { return paddings_[0]; } + + inline int paddingW() const { return paddings_[1]; } + + inline int blockH() const { return blocks_[0]; } + + inline int blockW() const { return blocks_[1]; } +}; + +} // namespace paddle -- GitLab