From ddfff3a7fd0e746f8ff232945ed0a05a89b9da1c Mon Sep 17 00:00:00 2001 From: liaogang Date: Sun, 30 Oct 2016 01:04:06 +0800 Subject: [PATCH] Add bilinear interpolation layer --- doc/ui/api/trainer_config_helpers/layers.rst | 6 + paddle/cuda/include/hl_cnn.h | 56 +++++++ paddle/cuda/include/stub/hl_cnn_stub.h | 24 +++ paddle/cuda/src/hl_cuda_cnn.cu | 134 ++++++++++++++- paddle/gserver/layers/BilinearInterpLayer.cpp | 87 ++++++++++ paddle/gserver/layers/BilinearInterpLayer.h | 45 +++++ paddle/gserver/tests/test_LayerGrad.cpp | 20 +++ paddle/math/Matrix.cpp | 154 ++++++++++++++++++ paddle/math/Matrix.h | 44 +++++ paddle/math/tests/test_matrixCompare.cpp | 66 ++++++++ proto/ModelConfig.proto.m4 | 10 ++ python/paddle/trainer/config_parser.py | 35 ++++ .../paddle/trainer_config_helpers/layers.py | 69 +++++++- .../tests/configs/generate_protostr.sh | 2 +- .../tests/configs/test_bilinear_interp.py | 33 ++++ 15 files changed, 781 insertions(+), 4 deletions(-) create mode 100644 paddle/gserver/layers/BilinearInterpLayer.cpp create mode 100644 paddle/gserver/layers/BilinearInterpLayer.h create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py diff --git a/doc/ui/api/trainer_config_helpers/layers.rst b/doc/ui/api/trainer_config_helpers/layers.rst index c1d7a7ce81..0144346610 100644 --- a/doc/ui/api/trainer_config_helpers/layers.rst +++ b/doc/ui/api/trainer_config_helpers/layers.rst @@ -263,6 +263,12 @@ interpolation_layer :members: interpolation_layer :noindex: +bilinear_interp_layer +------------------- +.. automodule:: paddle.trainer_config_helpers.layers + :members: bilinear_interp_layer + :noindex: + power_layer ----------- .. automodule:: paddle.trainer_config_helpers.layers diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 5d750333e1..aa4720f6ca 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -240,4 +240,60 @@ extern void hl_CMRNorm_backward( size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta); +/** + * @brief Bilinear interpolation forward. + * + * @param[in] inData input value. + * @param[in] inImgH input image height. + * @param[in] inImgW input image width. + * @param[in] inputH input batchSize. + * @param[in] inputW input image data dim. + * @param[out] outData output value. + * @param[in] outImgH output image height. + * @param[in] outImgW output image width. + * @param[in] outputH output batchSize. + * @param[in] outputW output image data dim. + * @param[in] numChannels number of channels. + * + */ +extern void hl_bilinear_forward(const real* inData, + const size_t inImgH, + const size_t inImgW, + const size_t inputH, + const size_t inputW, + real* outData, + const size_t outImgH, + const size_t outImgW, + const size_t outputH, + const size_t outputW, + const size_t numChannels); + + /** + * @brief Bilinear interpolation backward. + * + * @param[out] inGrad input gradient. + * @param[in] inImgH input image height. + * @param[in] inImgW input image width. + * @param[in] inputH input batchSize. + * @param[in] inputW input image data dim. + * @param[in] outGrad output gradient. + * @param[in] outImgH output image height. + * @param[in] outImgW output image width. + * @param[in] outputH output batchSize. + * @param[in] outputW output image data dim. + * @param[in] numChannels number of channels. + * + */ +extern void hl_bilinear_backward(real* inGrad, + const size_t inImgH, + const size_t inImgW, + const size_t inputH, + const size_t inputW, + const real* outGrad, + const size_t outImgH, + const size_t outImgW, + const size_t outputH, + const size_t outputW, + const size_t numChannels); + #endif /* HL_CNN_H_ */ diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 38e359c3eb..aa9442fb80 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -89,4 +89,28 @@ inline void hl_CMRNorm_backward( size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) {} +inline void hl_bilinear_forward(const real* inData, + const size_t inImgH, + const size_t inImgW, + const size_t inputH, + const size_t inputW, + real* outData, + const size_t outImgH, + const size_t outImgW, + const size_t outputH, + const size_t outputW, + const size_t numChannels) {} + +inline void hl_bilinear_backward(real* inGrad, + const size_t inImgH, + const size_t inImgW, + const size_t inputH, + const size_t inputW, + const real* outGrad, + const size_t outImgH, + const size_t outImgW, + const size_t outputH, + const size_t outputW, + const size_t numChannels) {} + #endif // HL_CNN_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index abac83a3e0..f965adc135 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -522,7 +522,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, size_t height, size_t width, size_t sizeX, real alpha, real beta) { size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 -1) / 1024; + size_t blocksX = (threadsNum + 1024 - 1) / 1024; size_t blocksY = 1; dim3 threads(1024, 1); dim3 grid(blocksX, blocksY); @@ -531,3 +531,135 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, height, width, sizeX, alpha, beta, inDiff); CHECK_SYNC("hl_CMRNorm_backward"); } + +__global__ void KeBilinearInterpFw(const size_t nthreads, + const real* in, + const size_t inImgH, + const size_t inImgW, + const size_t inputH, + const size_t inputW, + real* out, + const size_t outImgH, + const size_t outImgW, + const size_t outputH, + const size_t outputW, + const size_t numChannels, + const real ratioH, + const real ratioW) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if(tid < nthreads) { + int outIdH = tid / (outputW / numChannels); + int outIdW = tid % (outputW / numChannels); + + int inIdH = ratioH * (outIdW / outImgW); + int hId = (inIdH < inImgH - 1) ? 1 : 0; + real hlambda = ratioH * (outIdW / outImgW) - inIdH; + + int inIdW = ratioW * (tid % outImgW); + int wId = (inIdW < inImgW - 1) ? 1 : 0; + real wlambda = ratioW * (tid % outImgW) - inIdW; + + const real* inPos = &in[outIdH * inputW + inIdH * inImgW + inIdW]; + real* outPos = &out[outIdH * outputW + outIdW]; + for (int c = 0; c < numChannels; ++c) { + // bilinear interpolation + outPos[0] = (1.f - hlambda) * + ((1.f - wlambda) * inPos[0] + wlambda * inPos[wId]) + + hlambda * ((1.f - wlambda) * inPos[hId * inImgW] + + wlambda * inPos[hId * inImgW + wId]); + inPos += inImgH * inImgW; + outPos += outImgH * outImgW; + } + } +} + +void hl_bilinear_forward(const real* inData, + const size_t inImgH, + const size_t inImgW, + const size_t inputH, + const size_t inputW, + real* outData, + const size_t outImgH, + const size_t outImgW, + const size_t outputH, + const size_t outputW, + const size_t numChannels) { + int threadNum = outputH * outImgH * outImgW; + int blocks = (threadNum + 1024 - 1) / 1024; + + real ratioH = (outImgH > 1) ? + static_cast(inImgH - 1) / (outImgH - 1) : 0.f; + real ratioW = (outImgW > 1) ? + static_cast(inImgW - 1) / (outImgW - 1) : 0.f; + + KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>( + threadNum, inData, inImgH, inImgW, inputH, inputW, outData, + outImgH, outImgW, outputH, outputW, numChannels, ratioH, ratioW); + CHECK_SYNC("hl_bilinear_forward failed"); +} + +__global__ void KeBilinearInterpBw(const size_t nthreads, + real* in, + const size_t inImgH, + const size_t inImgW, + const size_t inputH, + const size_t inputW, + const real* out, + const size_t outImgH, + const size_t outImgW, + const size_t outputH, + const size_t outputW, + const size_t numChannels, + const real ratioH, + const real ratioW) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if(tid < nthreads) { + int outIdH = tid / (outputW / numChannels); + int outIdW = tid % (outputW / numChannels); + + int inIdH = ratioH * (outIdW / outImgW); + int hId = (inIdH < inImgH - 1) ? 1 : 0; + real hlambda = ratioH * (outIdW / outImgW) - inIdH; + + int inIdW = ratioW * (tid % outImgW); + int wId = (inIdW < inImgW - 1) ? 1 : 0; + real wlambda = ratioW * (tid % outImgW) - inIdW; + + const real* outPos = &out[outIdH * outputW + outIdW]; + real* inPos = &in[outIdH * inputW + inIdH * inImgW + inIdW]; + for (int c = 0; c < numChannels; ++c) { + atomicAdd(&inPos[0], (1.f - hlambda) * (1.f - wlambda) * outPos[0]); + atomicAdd(&inPos[wId], (1.f - hlambda) * wlambda * outPos[0]); + atomicAdd(&inPos[hId * inImgW], hlambda * (1.f - wlambda) * outPos[0]); + atomicAdd(&inPos[hId * inImgW + wId], hlambda * wlambda * outPos[0]); + inPos += inImgH * inImgW; + outPos += outImgH * outImgW; + } + } +} + +void hl_bilinear_backward(real* inGrad, + const size_t inImgH, + const size_t inImgW, + const size_t inputH, + const size_t inputW, + const real* outGrad, + const size_t outImgH, + const size_t outImgW, + const size_t outputH, + const size_t outputW, + const size_t numChannels) { + int threadNum = outputH * outImgH * outImgW; + int blocks = (threadNum + 1024 - 1) / 1024; + + real ratioH = (outImgH > 1) ? + static_cast(inImgH - 1) / (outImgH - 1) : 0.f; + real ratioW = (outImgW > 1) ? + static_cast(inImgW - 1) / (outImgW - 1) : 0.f; + + KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>( + threadNum, inGrad, inImgH, inImgW, inputH, inputW, outGrad, + outImgH, outImgW, outputH, outputW, numChannels, ratioH, ratioW); + CHECK_SYNC("hl_bilinear_backward failed"); +} \ No newline at end of file diff --git a/paddle/gserver/layers/BilinearInterpLayer.cpp b/paddle/gserver/layers/BilinearInterpLayer.cpp new file mode 100644 index 0000000000..f43086e585 --- /dev/null +++ b/paddle/gserver/layers/BilinearInterpLayer.cpp @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "BilinearInterpLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(bilinear_interp, BilinearInterpLayer); + +size_t BilinearInterpLayer::getDataDimSize() { + getOutput().setFrameHeight(outImgH_); + getOutput().setFrameWidth(outImgW_); + return outImgH_ * outImgW_ * numChannels_; +} + +bool BilinearInterpLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + CHECK_EQ(1, config_.inputs_size()); + + const BilinearInterpConfig& conf = config_.inputs(0).bilinear_interp_conf(); + inImgH_ = inputLayers_[0]->getOutput().getFrameHeight(); + inImgW_ = inputLayers_[0]->getOutput().getFrameWidth(); + if (inImgH_ == 0) { + inImgH_ = conf.img_size_y(); + } + if (inImgW_ == 0) { + inImgW_ = conf.img_size_x(); + } + outImgH_ = conf.out_size_y(); + outImgW_ = conf.out_size_x(); + numChannels_ = conf.num_channels(); + + CHECK(outImgH_ > 0 && outImgW_ > 0); + CHECK(inImgH_ > 0 && inImgW_ > 0); + CHECK(numChannels_); + + return true; +} + +void BilinearInterpLayer::forward(PassType passType) { + Layer::forward(passType); + size_t batchSize = getInput(0).getBatchSize(); + size_t size = getDataDimSize(); + { + REGISTER_TIMER_INFO("FwResetTimer", getName().c_str()); + resetOutput(batchSize, size); + } + + MatrixPtr inV = getInputValue(0); + MatrixPtr outV = getOutputValue(); + { + REGISTER_TIMER_INFO("FwBilinearInterpTimer", getName().c_str()); + outV->bilinearForward(*inV, inImgH_, inImgW_, outImgH_, outImgW_, + numChannels_); + } +} + +void BilinearInterpLayer::backward(const UpdateCallback& callback) { + (void) callback; + + MatrixPtr inputG = getInputGrad(0); + MatrixPtr outG = getOutputGrad(); + { + REGISTER_TIMER_INFO("BwBilinearInterpTimer", getName().c_str()); + if (inputG) { + inputG->bilinearBackward(*outG, outImgH_, outImgW_, inImgH_, inImgW_, + numChannels_); + } + } +} +} // namespace paddle diff --git a/paddle/gserver/layers/BilinearInterpLayer.h b/paddle/gserver/layers/BilinearInterpLayer.h new file mode 100644 index 0000000000..24f5b99910 --- /dev/null +++ b/paddle/gserver/layers/BilinearInterpLayer.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +/** + * @brief A layer for bilinear interpolation which is + * used on conv layer output. + * + * @note The config file api is bilinear_interp_layer. + */ +class BilinearInterpLayer : public Layer { +protected: + size_t outImgH_, outImgW_; + size_t inImgH_, inImgW_; + size_t numChannels_; + +public: + explicit BilinearInterpLayer(const LayerConfig& config) : Layer(config) {} + + virtual ~BilinearInterpLayer() {} + + size_t getDataDimSize(); + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + void forward(PassType passType); + void backward(const UpdateCallback& callback = nullptr); +}; + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index c5723f8574..425d669206 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -31,6 +31,26 @@ P_DECLARE_double(checkgrad_eps); P_DECLARE_bool(thread_local_rand_use_global_seed); P_DECLARE_bool(prev_batch_state); +TEST(Layer, BilinearInterpLayer) { + TestConfig config; + config.layerConfig.set_type("bilinear_interp"); + config.biasSize = 0; + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 4096, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + BilinearInterpConfig* bilinear = input->mutable_bilinear_interp_conf(); + + bilinear->set_img_size_x(32); + bilinear->set_img_size_y(32); + bilinear->set_out_size_x(64); + bilinear->set_out_size_y(64); + bilinear->set_num_channels(4); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "bilinear_interp", 10, false, useGpu); + } +} + TEST(Operator, dot_mul) { TestConfig config; config.layerConfig.set_size(10); diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index a6ff2f3b35..4692557197 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/utils/Logging.h" #include +#include "hl_cnn.h" #include "hl_gpu.h" #include "hl_table_apply.h" #include "hl_top_k.h" @@ -1144,6 +1145,56 @@ void GpuMatrix::addColumnVector(const Matrix& b) { BaseMatrix::addColVector(const_cast(b)); } +void GpuMatrix::bilinearForward(const Matrix& in, + const size_t inImgH, + const size_t inImgW, + const size_t outImgH, + const size_t outImgW, + const size_t numChannels) { + CHECK(dynamic_cast(&in)); + + const size_t outputW = getWidth(); + const size_t outputH = getHeight(); + const size_t inputW = in.getWidth(); + const size_t inputH = in.getHeight(); + + real* outData = getData(); + const real* inData = in.getData(); + + if (inImgH == outImgW && inImgW == outImgW) { + this->copyFrom(in); + } else { + hl_bilinear_forward(inData, inImgH, inImgW, + inputH, inputW, outData, outImgH, outImgW, + outputH, outputW, numChannels); + } +} + +void GpuMatrix::bilinearBackward(const Matrix& out, + const size_t outImgH, + const size_t outImgW, + const size_t inImgH, + const size_t inImgW, + const size_t numChannels) { + CHECK(dynamic_cast(&out)); + + const size_t inputW = getWidth(); + const size_t inputH = getHeight(); + const size_t outputW = out.getWidth(); + const size_t outputH = out.getHeight(); + + real* inGrad = getData(); + const real* outGrad = out.getData(); + + if (outImgH == inImgH && outImgW == inImgW) { + this->copyFrom(out); + } else { + hl_bilinear_backward(inGrad, inImgH, inImgW, + inputH, inputW, outGrad, outImgH, outImgW, + outputH, outputW, numChannels); + } +} + /** * CpuMatrix */ @@ -3598,6 +3649,109 @@ void CpuMatrix::classificationErrorMulti(Matrix& output, Matrix& label, } } +void CpuMatrix::bilinearForward(const Matrix& in, + const size_t inImgH, + const size_t inImgW, + const size_t outImgH, + const size_t outImgW, + const size_t numChannels) { + CHECK(dynamic_cast(&in)); + + size_t outputW = getWidth(); + size_t outputH = getHeight(); + size_t inputW = in.getWidth(); + size_t inputH = in.getHeight(); + + real* outData = getData(); + const real* inData = in.getData(); + + const real ratioH = (outImgH > 1) ? + static_cast(inImgH - 1) / (outImgH - 1) : 0.f; + const real ratioW = (outImgW > 1) ? + static_cast(inImgW - 1) / (outImgW - 1) : 0.f; + + if (inImgH == outImgH && inImgW == outImgW) { + this->copyFrom(in); + } else { + for (int k = 0; k < outputH; ++k) { // loop for batches + for (int i = 0; i < outImgH; ++i) { // loop for images + int h = ratioH * i; + int hid = (h < inImgH - 1) ? 1 : 0; + real hlambda = ratioH * i - h; + + for (int j = 0; j < outImgW; ++j) { + int w = ratioW * j; + int wid = (w < inImgW - 1) ? 1 : 0; + real wlambda = ratioW * j - w; + // calculate four position for bilinear interpolation + const real* inPos = &inData[k * inputW + h * inImgW + w]; + real* outPos = &outData[k * outputW + i * outImgW + j]; + for (int c = 0; c < numChannels; ++c) { // loop for channels + // bilinear interpolation + outPos[0] = (1.f - hlambda) * + ((1.f - wlambda) * inPos[0] + wlambda * inPos[wid]) + + hlambda * ((1.f - wlambda) * inPos[hid * inImgW] + + wlambda * inPos[hid * inImgW + wid]); + inPos += inImgH * inImgW; + outPos += outImgH * outImgW; + } + } + } + } + } +} + +void CpuMatrix::bilinearBackward(const Matrix& out, + const size_t outImgH, + const size_t outImgW, + const size_t inImgH, + const size_t inImgW, + const size_t numChannels) { + CHECK(dynamic_cast(&out)); + + size_t inputW = getWidth(); + size_t inputH = getHeight(); + size_t outputW = out.getWidth(); + size_t outputH = out.getHeight(); + + real* inGrad = getData(); + const real* outGrad = out.getData(); + + const real ratioH = (outImgH > 1) ? + static_cast(inImgH - 1) / (outImgH - 1) : 0.f; + const real ratioW = (outImgW > 1) ? + static_cast(inImgW - 1) / (outImgW - 1) : 0.f; + + if (inImgH == outImgH && inImgW == outImgW) { + this->copyFrom(out); + } else { + for (int k = 0; k < outputH; ++k) { // loop for batches + for (int i = 0; i < outImgH; ++i) { // loop for images + int h = ratioH * i; + int hid = (h < inImgH - 1) ? 1 : 0; + real hlambda = ratioH * i - h; + + for (int j = 0; j < outImgW; ++j) { + int w = ratioW * j; + int wid = (w < inImgW - 1) ? 1 : 0; + real wlambda = ratioW * j - w; + + real* inPos = &inGrad[k * inputW + h * inImgW + w]; + const real* outPos = &outGrad[k * outputW + i * outImgW + j]; + for (int c = 0; c < numChannels; ++c) { // loop for channels + inPos[0] += (1.f - hlambda) * (1.f - wlambda) * outPos[0]; + inPos[wid] += (1.f - hlambda) * wlambda * outPos[0]; + inPos[hid * inImgW] += hlambda * (1.f - wlambda) * outPos[0]; + inPos[hid * inImgW + wid] += hlambda * wlambda * outPos[0]; + inPos += inImgH * inImgW; + outPos += outImgH * outImgW; + } + } + } + } + } +} + //////////////////////////////////////////////////////////////// // functions executed via cpu // //////////////////////////////////////////////////////////////// diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 5c15c94012..b4922d7e6f 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -930,6 +930,22 @@ public: virtual void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) { LOG(FATAL) << "Not implemented"; } + virtual void bilinearForward(const Matrix& in, + const size_t inImgH, + const size_t inImgW, + const size_t outImgH, + const size_t outImgW, + const size_t numChannels) { + LOG(FATAL) << "Not implemented"; + } + virtual void bilinearBackward(const Matrix& out, + const size_t outImgH, + const size_t outImgW, + const size_t inImgH, + const size_t inImgW, + const size_t numChannels) { + LOG(FATAL) << "Not implemented"; + } }; inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) { @@ -1191,6 +1207,20 @@ public: int contextLength, int contextStart, int totalPad, size_t beginPad); + + void bilinearForward(const Matrix& in, + const size_t inImgH, + const size_t inImgW, + const size_t outImgH, + const size_t outImgW, + const size_t numChannels); + + void bilinearBackward(const Matrix& out, + const size_t outImgH, + const size_t outImgW, + const size_t inImgH, + const size_t inImgW, + const size_t numChannels); }; class CpuMatrix : public Matrix { @@ -1469,6 +1499,20 @@ public: void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label); void classificationErrorMulti(Matrix& output, Matrix& label, real threshold); + + void bilinearForward(const Matrix& in, + const size_t inImgH, + const size_t inImgW, + const size_t outImgH, + const size_t outImgW, + const size_t numChannels); + + void bilinearBackward(const Matrix& out, + const size_t outImgH, + const size_t outImgW, + const size_t inImgH, + const size_t inImgW, + const size_t numChannels); }; class SharedCpuMatrix : public CpuMatrix { diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index e1bda79a8a..2ff19e7b3f 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -88,6 +88,72 @@ void MatrixCheckErr(const Matrix& matrix1, const Matrix& matrix2) { EXPECT_EQ(count, 0) << "There are " << count << " different element."; } +void testBilinearFwdBwd(int numSamples, int imgSizeH, int imgSizeW, + int channels) { + int inWidth = imgSizeH * imgSizeW * channels; + int outWidth = 2 * imgSizeH * 2 * imgSizeW * channels; + + // forward + MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false); + MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true); + + MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true); + MatrixPtr targetCheck = CpuMatrix::create(numSamples, outWidth, false, false); + + input->randomizeUniform(); + inputGpu->copyFrom(*input); + + target->bilinearForward(*input, imgSizeH, imgSizeW, + 2 * imgSizeH, 2 * imgSizeW, channels); + targetGpu->bilinearForward(*inputGpu, imgSizeH, imgSizeW, + 2 * imgSizeH, 2 * imgSizeW, channels); + + // check + targetCheck->copyFrom(*targetGpu); + MatrixCheckErr(*target, *targetCheck); + + // backward + MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false); + MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true); + + MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr targetGpuGrad = GpuMatrix::create(numSamples, outWidth, false, + true); + MatrixPtr targetCheckGrad = + CpuMatrix::create(numSamples, inWidth, false, false); + + inputGrad->randomizeUniform(); + targetGrad->randomizeUniform(); + inputGpuGrad->copyFrom(*inputGrad); + targetGpuGrad->copyFrom(*targetGrad); + + inputGrad->bilinearBackward(*targetGrad, 2 * imgSizeH, 2 * imgSizeW, + imgSizeH, imgSizeW, channels); + inputGpuGrad->bilinearBackward(*targetGpuGrad, 2 * imgSizeH, 2 * imgSizeW, + imgSizeH, imgSizeW, channels); + + // check + targetCheckGrad->copyFrom(*inputGpuGrad); + MatrixCheckErr(*inputGrad, *targetCheckGrad); +} + +TEST(Matrix, BilinearFwdBwd) { + for (auto numSamples : {5, 10}) { + for (auto channels : {8, 16}) { + for (auto imgSizeH : {14, 28}) { + for (auto imgSizeW : {16, 30}) { + VLOG(3) << " numSamples=" << numSamples + << " channels=" << channels + << " imgSizeH=" << imgSizeH + << " imgSizeW=" << imgSizeW; + testBilinearFwdBwd(numSamples, imgSizeH, imgSizeW, channels); + } + } + } + } +} + void testMatrixProjectionForward(int contextStart, int contextLength, bool padding, int batchSize, int inputDim) { MatrixPtr cpuInput = std::make_shared(batchSize, inputDim); diff --git a/proto/ModelConfig.proto.m4 b/proto/ModelConfig.proto.m4 index 25e36f9c4c..8bdcd70a41 100644 --- a/proto/ModelConfig.proto.m4 +++ b/proto/ModelConfig.proto.m4 @@ -203,6 +203,15 @@ message OperatorConfig { optional int32 num_filters = 7; } +message BilinearInterpConfig { + // The size if input feature map. + required uint32 img_size_x = 1; + required uint32 img_size_y = 2; + // The size if output feature map. + required uint32 out_size_x = 3; + required uint32 out_size_y = 4; + required uint32 num_channels = 5; +} message ImageConfig { // The image data dimensionality. @@ -225,6 +234,7 @@ message LayerInputConfig { // If the input layer has multi-output. // Set the argument name. optional string input_layer_argument = 9; + optional BilinearInterpConfig bilinear_interp_conf = 10; } message LayerConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index fb47fd0c6f..82446e980d 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -461,6 +461,7 @@ class Input(Cfg): sparse_update=None, gradient_clipping_threshold=None, conv=None, + bilinear_interp=None, norm=None, pool=None, image=None, @@ -723,6 +724,18 @@ class Conv(Cfg): if output_x is not None: config_assert(output_x <= 0) +# please refer to the comments in proto/ModelConfig.proto +@config_class +class BilinearInterp(Cfg): + def __init__( + self, + img_size_x = None, + img_size_y=None, + out_size_x = None, + out_size_y = None, + num_channels = None): + self.add_keys(locals()) + # please refer to the comments in proto/ModelConfig.proto @config_class class Pool(Cfg): @@ -953,6 +966,13 @@ def TestData(data_config, async_load_data=None): " Data definition") g_config.test_data_config.async_load_data = async_load_data +def parse_bilinear(bilinear, input_layer_name, bilinear_conf): + bilinear_conf.img_size_x = bilinear.img_size_x; + bilinear_conf.img_size_y = bilinear.img_size_y; + bilinear_conf.out_size_x = bilinear.out_size_x; + bilinear_conf.out_size_y = bilinear.out_size_y; + bilinear_conf.num_channels = bilinear.num_channels; + def parse_pool(pool, input_layer_name, pool_conf): pool_conf.pool_type = pool.pool_type config_assert(pool.pool_type in ['max-projection', 'avg-projection', @@ -2306,6 +2326,21 @@ class InterpolationLayer(LayerBase): config_assert(input_layer1.size == input_layer2.size, 'the two vector inputs should be of the same size') +@config_layer('bilinear_interp') +class BilinearInterpLayer(LayerBase): + def __init__( + self, + name, + inputs, + device=None): + super(BilinearInterpLayer, self).__init__( + name, 'bilinear_interp', 0, inputs=inputs, device=device) + input_layer = self.get_input_layer(0) + self.set_layer_size(input_layer.size) + parse_bilinear(self.inputs[0].bilinear_interp, + input_layer.name, + self.config.inputs[0].bilinear_interp_conf); + @config_layer('sum_to_one_norm') class SumToOneNormLayer(LayerBase): def __init__( diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 5e7e66a908..59df4646fa 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -40,8 +40,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", 'img_cmrnorm_layer', 'addto_layer', 'concat_layer', 'lstm_step_layer', 'recurrent_group', 'memory', 'StaticInput', 'expand_layer', 'scaling_layer', - 'power_layer', 'interpolation_layer', 'trans_layer', - 'sum_to_one_norm_layer', + 'power_layer', 'interpolation_layer', 'bilinear_interp_layer', + 'trans_layer', 'sum_to_one_norm_layer', 'get_output_layer', 'LayerType', 'context_projection', 'beam_search', 'maxid_layer', 'GeneratedInput', 'SubsequenceInput', 'gru_step_layer', 'recurrent_layer', @@ -92,6 +92,7 @@ class LayerType(object): EXPAND_LAYER = 'expand' INTERPOLATION_LAYER = 'interpolation' + BILINEAR_INTERP_LAYER = 'bilinear_interp' POWER_LAYER = 'power' SCALING_LAYER = 'scaling' TRANS_LAYER = 'trans' @@ -1252,6 +1253,70 @@ def interpolation_layer(input, weight, name=None, layer_attr=None): size=input[0].size) +@wrap_name_default() +@layer_support() +def bilinear_interp_layer(input, + img_size_x=None, + img_size_y=None, + out_size_x=None, + out_size_y=None, + num_channels=None, + name=None, + layer_attr=None): + """ + This layer is to implement bilinear interpolation on conv layer output. + + Please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation + + The simple usage is: + + .. code-block:: python + + bilinear = bilinear_interp_layer(input, + img_size_x, + img_size_y, + out_size_x, + out_size_y, + num_channels) + + :para input: A input layer. + :type input: LayerOutput. + :para img_size_x: previous layer output width. + :type img_size_x: int|None + :para img_size_y: previous layer output height. + :type img_size_y: int|None + :para out_size_x: bilinear interpolation output width. + :type out_size_x: int|None + :para out_size_y: bilinear interpolation output height. + :type out_size_y: int|None + :para num_channels: number of channels of input layer. If None, + it will be set automatically from previous output. + :type num_channels: int|None + :para name: The layer's name, which cna not be specified. + :type name: None|basestring + :para layer_attr: Extra Layer attribute. + :type layer_attr: ExtraLayerAttribute + :return: LayerOutput object. + :rtype: LayerOutput + """ + assert input.layer_type == LayerType.CONV_LAYER + assert isinstance(input.activation, LinearActivation) + assert img_size_x > 0 and img_size_y > 0 + assert out_size_x > 0 and out_size_y > 0 + if num_channels is None: + assert input.numfilters is not None + num_channels = input.num_filters + Layer(name=name, + inputs=Input(input.name, + bilinear_interp=BilinearInterp(img_size_x=img_size_x, + img_size_y=img_size_y, + out_size_x=out_size_x, + out_size_y=out_size_y, + num_channels=num_channels)), + type=LayerType.BILINEAR_INTERP_LAYER, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput(name, LayerType.BILINEAR_INTERP_LAYER, parents=[input]) + @wrap_name_default() @layer_support() def power_layer(input, weight, name=None, layer_attr=None): diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index fc2acbd41e..e8be0023e7 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -8,7 +8,7 @@ configs=(test_fc layer_activations projections test_print_layer test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers util_layers simple_rnn_layers unused_layers test_cost_layers -test_rnn_group) +test_rnn_group test_bilinear_interp) for conf in ${configs[*]} diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py b/python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py new file mode 100644 index 0000000000..7815b34abc --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py @@ -0,0 +1,33 @@ +from paddle.trainer_config_helpers import * + +settings( + batch_size=1000, + learning_rate=1e-5 +) + +data = data_layer(name='data', size=2304) + +conv = img_conv_layer(input=data, + filter_size = 3, + num_channels=1, + num_filters=16, + padding=1, + act=LinearActivation(), + bias_attr=True) + +bilinear = bilinear_interp_layer(input=conv, + img_size_x=32, + img_size_y=32, + out_size_x=64, + out_size_y=64, + num_channels=16) + +pool = img_pool_layer(input=bilinear, + num_channels=4, + pool_size=2, + stride=2, + pool_type=MaxPooling()) + +fc = fc_layer(input=pool, size=384, bias_attr=False) + +outputs(fc) \ No newline at end of file -- GitLab