From 3b65bc7a2631fe2326149cd228d2ecde3a86081b Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Fri, 26 May 2017 14:17:25 +0800 Subject: [PATCH] Add a naive convolution implement --- paddle/function/ConvOp.cpp | 128 +++++++++++++++++++++++++++++++++++++ paddle/function/ConvOp.h | 67 +++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 paddle/function/ConvOp.cpp create mode 100644 paddle/function/ConvOp.h diff --git a/paddle/function/ConvOp.cpp b/paddle/function/ConvOp.cpp new file mode 100644 index 000000000..50f030585 --- /dev/null +++ b/paddle/function/ConvOp.cpp @@ -0,0 +1,128 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ConvFunc.h" + +namespace paddle { + +/* + * The three arguments are stored in memory in row major order. + * inputData = [batchSize, inputChannels, inputHeight, inputWidth] + * filterData = [outputChannels, inputChannels, filterHeight, filterWidth] + * outputData = [batchSize, outputChannels, outputHeight, outputWidth] + */ +template +class NaiveConvFunctor { +public: + void operator()(const T* inputData, + size_t batchSize, + size_t inputChannels, + size_t inputHeight, + size_t inputWidth, + const T* filterData, + size_t filterHeight, + size_t filterWidth, + T* outputData, + size_t outputChannels, + size_t outputHeight, + size_t outputWidth, + size_t padding, + size_t stride) { + for (size_t batch = 0; batch < batchSize; batch++) { + for (size_t outC = 0; outC < outputChannels; outC++) { + for (size_t outH = 0; outH < outputHeight; outH++) { + for (size_t outW = 0; outW < outputWidth; outW++) { + const int inStartH = (outH * stride) - padding; + const int inStartW = (outW * stride) - padding; + T outValue = (T)0; + for (size_t inC = 0; inC < inputChannels; inC++) { + for (size_t fH = 0; fH < filterHeight; fH++) { + for (size_t fW = 0; fW < filterWidth; fW++) { + T inValue; + const int inH = inStartH + fH; + const int inW = inStartW + fW; + if ((inH >= 0 && inH < inputHeight) && + (inW >= 0 && inW < inputWidth)) { + size_t offsetInput = + batch * inputChannels * inputHeight * inputWidth + + inC * inputHeight * inputWidth + inH * inputWidth + inW; + inValue = inputData[offsetInput]; + } else { + inValue = (T)0; + } + size_t offsetFilter = + outC * inputChannels * filterHeight * filterWidth + + inC * filterHeight * filterWidth + fH * filterWidth + fW; + T filterValue = filterData[offsetFilter]; + outValue += (inValue * filterValue); + } + } + } + + size_t offset = + batch * outputChannels * outputHeight * outputWidth + + outC * outputHeight * outputWidth + outH * outputWidth + outW; + outputData[offset] = outValue; + } + } + } + } + } +}; + +template +class NaiveConvFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + check(inputs, outputs); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + size_t batchSize = inputs[0].shape()[0]; + size_t inputChannels = inputs[0].shape()[1]; + size_t inputHeight = inputs[0].shape()[2]; + size_t inputWidth = inputs[0].shape()[3]; + size_t filterHeight = inputs[1].shape()[2]; + size_t filterWidth = inputs[1].shape()[2]; + size_t outputChannels = outputs[0].shape()[1]; + size_t outputHeight = outputs[0].shape()[2]; + size_t outputWidth = outputs[0].shape()[3]; + + float* inputData = inputs[0].data(); + float* filterData = inputs[1].data(); + float* outputData = outputs[0].data(); + NaiveConvFunctor conv; + conv(inputData, + batchSize, + inputChannels, + inputHeight, + inputWidth, + filterData, + filterHeight, + filterWidth, + outputData, + outputChannels, + outputHeight, + outputWidth, + padding_, + stride_); + } +}; + +REGISTER_TYPED_FUNC(NaiveConv, CPU, NaiveConvFunction); + +} // namespace paddle diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h new file mode 100644 index 000000000..4d678cfe2 --- /dev/null +++ b/paddle/function/ConvOp.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" + +namespace paddle { + +/* + * Function Arguments: + * + * \param inputs[0] Input image data, is NCHW format, where N is batch size, + * C is the number of channels, H and W is the height and + * width of input image. + * \param inputs[1] Filter data, is MCHW, where M is the number of output + * channels, C is the number of input channels, H and W + * is height and width of filter. + * \param outputs[0] Output image data, is NCHW format, where N is batch size, + * C is the number of channels, H and W is the height and + * width of output image. + * + * \note Implemented based on the ConvFunctionBase class only supports + * input data in the NCHW format. + */ +class ConvFunctionBase : public FunctionBase { +public: + void init(const FuncConfig& config) override { + // function arguments + stride_ = config.get("stride"); + padding_ = config.get("padding"); + + // number of inputs and outputs + numInputs_ = 2; + numOutputs_ = 1; + } + + virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK_EQ(inputs[1].shape().ndims(), (size_t)4); + CHECK_EQ(outputs[0].shape().ndims(), (size_t)4); + + CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]); + CHECK(inputs[0].shape()[1] == inputs[1].shape()[1]); + CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]); + } + +protected: + size_t padding_; + size_t stride_; +}; + +} // namespace paddle -- GitLab