提交 8295eb91 编写于 作者: G gangliao 提交者: GitHub

Merge pull request #287 from gangliao/bilinear

Add bilinear interpolation layer
......@@ -275,6 +275,12 @@ interpolation_layer
:members: interpolation_layer
:noindex:
bilinear_interp_layer
----------------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: bilinear_interp_layer
:noindex:
power_layer
-----------
.. automodule:: paddle.trainer_config_helpers.layers
......
......@@ -240,6 +240,70 @@ extern void hl_CMRNorm_backward(
size_t channels, size_t height, size_t width, size_t sizeX,
real alpha, real beta);
/**
* @brief Bilinear interpolation forward.
*
* @param[in] inData input value.
* @param[in] inImgH input image height.
* @param[in] inImgW input image width.
* @param[in] inputH input batchSize.
* @param[in] inputW input image data dim.
* @param[out] outData output value.
* @param[in] outImgH output image height.
* @param[in] outImgW output image width.
* @param[in] outputH output batchSize.
* @param[in] outputW output image data dim.
* @param[in] numChannels number of channels.
* @param[in] ratioH inImgH / outImgH.
* @param[in] ratioW inImgW / outImgW.
*
*/
extern void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW);
/**
* @brief Bilinear interpolation backward.
*
* @param[out] inGrad input gradient.
* @param[in] inImgH input image height.
* @param[in] inImgW input image width.
* @param[in] inputH input batchSize.
* @param[in] inputW input image data dim.
* @param[in] outGrad output gradient.
* @param[in] outImgH output image height.
* @param[in] outImgW output image width.
* @param[in] outputH output batchSize.
* @param[in] outputW output image data dim.
* @param[in] numChannels number of channels.
* @param[in] ratioH inImgH / outImgH.
* @param[in] ratioW inImgW / outImgW.
*
*/
extern void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW);
/**
* @brief MaxOut forward.
*
......
......@@ -89,6 +89,34 @@ inline void hl_CMRNorm_backward(
size_t channels, size_t height, size_t width, size_t sizeX,
real alpha, real beta) {}
inline void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {}
inline void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {}
inline void hl_maxout_forward(
const real* inData, real* outData, int* idData,
size_t batchSize, size_t size, size_t featLen, size_t group) {}
......
......@@ -522,7 +522,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 -1) / 1024;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
......@@ -532,6 +532,138 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
CHECK_SYNC("hl_CMRNorm_backward");
}
__global__ void KeBilinearInterpFw(const real* in,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* out,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int nthreads = outputH * outputW;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int outIdH = tid / outputW;
int outIdW = tid % outputW;
int inImgSize = inputW / numChannels;
int outImgSize = outputW / numChannels;
int channelId = outIdW / outImgSize;
int outImgIdy = (outIdW % outImgSize) / outImgW;
int inImgIdy = ratioH * outImgIdy;
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
real h1lambda = ratioH * outImgIdy - inImgIdy;
real h2lambda = 1.f - h1lambda;
int outImgIdx = tid % outImgW;
int inImgIdx = ratioW * outImgIdx;
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
real w1lambda = ratioW * outImgIdx - inImgIdx;
real w2lambda = 1.f - w1lambda;
const real* inPos =
&in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
// bilinear interpolation
out[outIdH * outputW + outIdW] =
h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
h1lambda * (w2lambda * inPos[hId * inImgW] + w1lambda * inPos[hId * inImgW + wId]);
}
}
void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int threadNum = outputH * outputW;
int blocks = (threadNum + 1024 - 1) / 1024;
KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
inData, inImgH, inImgW, inputH, inputW, outData, outImgH,
outImgW, outputH, outputW, numChannels, ratioH, ratioW);
CHECK_SYNC("hl_bilinear_forward failed");
}
__global__ void KeBilinearInterpBw(real* in,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* out,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int nthreads = outputH * outputW;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int outIdH = tid / outputW;
int outIdW = tid % outputW;
int inImgSize = inputW / numChannels;
int outImgSize = outputW / numChannels;
int channelId = outIdW / outImgSize;
int outImgIdy = (outIdW % outImgSize) / outImgW;
int inImgIdy = ratioH * outImgIdy;
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
real h1lambda = ratioH * outImgIdy - inImgIdy;
real h2lambda = 1.f - h1lambda;
int outImgIdx = tid % outImgW;
int inImgIdx = ratioW * outImgIdx;
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
real w1lambda = ratioW * outImgIdx - inImgIdx;
real w2lambda = 1.f - w1lambda;
real* inPos =
&in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
const real* outPos = &out[outIdH * outputW + outIdW];
atomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
atomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
atomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]);
atomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]);
}
}
void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int threadNum = outputH * outputW;
int blocks = (threadNum + 1024 - 1) / 1024;
KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
inGrad, inImgH, inImgW, inputH, inputW, outGrad, outImgH,
outImgW, outputH, outputW, numChannels, ratioH, ratioW);
CHECK_SYNC("hl_bilinear_backward failed");
}
__global__ void maxoutFpCompute(size_t nthreads, const real * inData,
real * outData, int* idData,
size_t size, size_t featLen, size_t groups) {
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "BilinearInterpLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(bilinear_interp, BilinearInterpLayer);
size_t BilinearInterpLayer::getSize() {
inImgH_ = inputLayers_[0]->getOutput().getFrameHeight();
inImgW_ = inputLayers_[0]->getOutput().getFrameWidth();
const BilinearInterpConfig& conf = config_.inputs(0).bilinear_interp_conf();
if (inImgH_ == 0) {
inImgH_ = conf.img_size_y();
}
if (inImgW_ == 0) {
inImgW_ = conf.img_size_x();
}
outImgH_ = conf.out_size_y();
outImgW_ = conf.out_size_x();
numChannels_ = conf.num_channels();
CHECK(outImgH_ > 0 && outImgW_ > 0);
CHECK(inImgH_ > 0 && inImgW_ > 0);
CHECK(numChannels_);
ratioH_ = (outImgH_ > 1) ?
static_cast<real>(inImgH_ - 1) / (outImgH_ - 1) : 0.f;
ratioW_ = (outImgW_ > 1) ?
static_cast<real>(inImgW_ - 1) / (outImgW_ - 1) : 0.f;
getOutput().setFrameHeight(outImgH_);
getOutput().setFrameWidth(outImgW_);
return outImgH_ * outImgW_ * numChannels_;
}
bool BilinearInterpLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
CHECK_EQ(1, config_.inputs_size());
return true;
}
void BilinearInterpLayer::forward(PassType passType) {
Layer::forward(passType);
size_t batchSize = getInput(0).getBatchSize();
size_t size = getSize();
{
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
resetOutput(batchSize, size);
}
MatrixPtr inV = getInputValue(0);
MatrixPtr outV = getOutputValue();
{
REGISTER_TIMER_INFO("FwBilinearInterpTimer", getName().c_str());
outV->bilinearForward(*inV, inImgH_, inImgW_, outImgH_, outImgW_,
numChannels_, ratioH_, ratioW_);
}
}
void BilinearInterpLayer::backward(const UpdateCallback& callback) {
(void) callback;
MatrixPtr inputG = getInputGrad(0);
MatrixPtr outG = getOutputGrad();
{
REGISTER_TIMER_INFO("BwBilinearInterpTimer", getName().c_str());
if (inputG) {
inputG->bilinearBackward(*outG, outImgH_, outImgW_, inImgH_, inImgW_,
numChannels_, ratioH_, ratioW_);
}
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief A layer for bilinear interpolation which is
* used on conv layer output.
*
* @note The config file api is bilinear_interp_layer.
*/
class BilinearInterpLayer : public Layer {
protected:
size_t outImgH_, outImgW_;
size_t inImgH_, inImgW_;
real ratioH_, ratioW_;
size_t numChannels_;
public:
explicit BilinearInterpLayer(const LayerConfig& config) : Layer(config) {}
virtual ~BilinearInterpLayer() {}
size_t getSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
};
} // namespace paddle
......@@ -175,6 +175,27 @@ TEST(Projection, conv) {
}
#endif
TEST(Layer, BilinearInterpLayer) {
TestConfig config;
config.layerConfig.set_type("bilinear_interp");
config.biasSize = 0;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 4096, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
BilinearInterpConfig* bilinear = input->mutable_bilinear_interp_conf();
bilinear->set_img_size_x(32);
bilinear->set_img_size_y(32);
bilinear->set_num_channels(4);
for (auto useGpu : {false, true}) {
for (auto outSize : {32, 64}) {
bilinear->set_out_size_x(outSize);
bilinear->set_out_size_y(outSize);
testLayerGrad(config, "bilinear_interp", 10, false, useGpu);
}
}
}
TEST(Layer, concat) {
TestConfig config;
config.biasSize = 0;
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include <string.h>
#include "hl_cnn.h"
#include "hl_gpu.h"
#include "hl_table_apply.h"
#include "hl_top_k.h"
......@@ -1231,6 +1232,62 @@ void GpuMatrix::addColumnVector(const Matrix& b) {
BaseMatrix::addColVector(const_cast<Matrix&>(b));
}
void GpuMatrix::bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
CHECK(dynamic_cast<const GpuMatrix*>(&in));
const size_t outputW = getWidth();
const size_t outputH = getHeight();
const size_t inputW = in.getWidth();
const size_t inputH = in.getHeight();
real* outData = getData();
const real* inData = in.getData();
if (inImgH == outImgW && inImgW == outImgW) {
this->copyFrom(in);
} else {
hl_bilinear_forward(
inData, inImgH, inImgW, inputH, inputW, outData,
outImgH, outImgW, outputH, outputW, numChannels,
ratioH, ratioW);
}
}
void GpuMatrix::bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
CHECK(dynamic_cast<const GpuMatrix*>(&out));
const size_t inputW = getWidth();
const size_t inputH = getHeight();
const size_t outputW = out.getWidth();
const size_t outputH = out.getHeight();
real* inGrad = getData();
const real* outGrad = out.getData();
if (outImgH == inImgH && outImgW == inImgW) {
this->add(const_cast<Matrix&>(out));
} else {
hl_bilinear_backward(
inGrad, inImgH, inImgW, inputH, inputW, outGrad,
outImgH, outImgW, outputH, outputW, numChannels,
ratioH, ratioW);
}
}
/**
* CpuMatrix
*/
......@@ -3841,6 +3898,112 @@ void CpuMatrix::classificationErrorMulti(Matrix& output, Matrix& label,
}
}
void CpuMatrix::bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
CHECK(dynamic_cast<const CpuMatrix*>(&in));
size_t outputW = getWidth();
size_t batchSize = getHeight();
size_t inputW = in.getWidth();
size_t inputH = in.getHeight();
size_t inPosOffset = inImgH * inImgW;
size_t outPosOffset = outImgH * outImgW;
(void)(inputH);
real* outData = getData();
const real* inData = in.getData();
if (inImgH == outImgH && inImgW == outImgW) {
this->copyFrom(in);
} else {
for (size_t k = 0; k < batchSize; ++k) { // loop for batches
for (size_t i = 0; i < outImgH; ++i) { // loop for images
size_t h = ratioH * i;
size_t hid = (h < inImgH - 1) ? 1 : 0;
real h1lambda = ratioH * i - h;
real h2lambda = 1 - h1lambda;
for (size_t j = 0; j < outImgW; ++j) {
size_t w = ratioW * j;
size_t wid = (w < inImgW - 1) ? 1 : 0;
real w1lambda = ratioW * j - w;
real w2lambda = 1 - w1lambda;
// calculate four position for bilinear interpolation
const real* inPos = &inData[k * inputW + h * inImgW + w];
real* outPos = &outData[k * outputW + i * outImgW + j];
for (size_t c = 0; c < numChannels; ++c) { // loop for channels
// bilinear interpolation
outPos[0] =
h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wid]) +
h1lambda * (w2lambda * inPos[hid * inImgW] +
w1lambda * inPos[hid * inImgW + wid]);
inPos += inPosOffset;
outPos += outPosOffset;
}
}
}
}
}
}
void CpuMatrix::bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
CHECK(dynamic_cast<const CpuMatrix*>(&out));
size_t inputW = getWidth();
size_t inputH = getHeight();
size_t outputW = out.getWidth();
size_t batchSize = out.getHeight();
size_t inPosOffset = inImgH * inImgW;
size_t outPosOffset = outImgH * outImgW;
(void)(inputH);
real* inGrad = getData();
const real* outGrad = out.getData();
if (inImgH == outImgH && inImgW == outImgW) {
this->add(const_cast<Matrix&>(out));
} else {
for (size_t k = 0; k < batchSize; ++k) { // loop for batches
for (size_t i = 0; i < outImgH; ++i) { // loop for images
size_t h = ratioH * i;
size_t hid = (h < inImgH - 1) ? 1 : 0;
real h1lambda = ratioH * i - h;
real h2lambda = 1 - h1lambda;
for (size_t j = 0; j < outImgW; ++j) {
size_t w = ratioW * j;
size_t wid = (w < inImgW - 1) ? 1 : 0;
real w1lambda = ratioW * j - w;
real w2lambda = 1 - w1lambda;
real* inPos = &inGrad[k * inputW + h * inImgW + w];
const real* outPos = &outGrad[k * outputW + i * outImgW + j];
for (size_t c = 0; c < numChannels; ++c) { // loop for channels
inPos[0] += h2lambda * w2lambda * outPos[0];
inPos[wid] += h2lambda * w1lambda * outPos[0];
inPos[hid * inImgW] += h1lambda * w2lambda * outPos[0];
inPos[hid * inImgW + wid] += h1lambda * w1lambda * outPos[0];
inPos += inPosOffset;
outPos += outPosOffset;
}
}
}
}
}
}
////////////////////////////////////////////////////////////////
// functions executed via cpu //
////////////////////////////////////////////////////////////////
......
......@@ -995,6 +995,26 @@ public:
virtual void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) {
LOG(FATAL) << "Not implemented";
}
virtual void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
LOG(FATAL) << "Not implemented";
}
virtual void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
LOG(FATAL) << "Not implemented";
}
};
inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) {
......@@ -1265,6 +1285,24 @@ public:
int contextLength,
int contextStart, int totalPad,
size_t beginPad);
void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
};
class CpuMatrix : public Matrix {
......@@ -1553,6 +1591,24 @@ public:
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
void classificationErrorMulti(Matrix& output, Matrix& label, real threshold);
void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
};
class SharedCpuMatrix : public CpuMatrix {
......
......@@ -90,6 +90,73 @@ void MatrixCheckErr(const Matrix& matrix1, const Matrix& matrix2) {
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
}
void testBilinearFwdBwd(int numSamples, int imgSizeH, int imgSizeW,
int channels) {
int inWidth = imgSizeH * imgSizeW * channels;
int outWidth = 2 * imgSizeH * 2 * imgSizeW * channels;
real ratioH = 0.5;
real ratioW = 0.5;
// forward
MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true);
MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true);
MatrixPtr targetCheck = CpuMatrix::create(numSamples, outWidth, false, false);
input->randomizeUniform();
inputGpu->copyFrom(*input);
target->bilinearForward(*input, imgSizeH, imgSizeW,
2 * imgSizeH, 2 * imgSizeW, channels, ratioH, ratioW);
targetGpu->bilinearForward(*inputGpu, imgSizeH, imgSizeW,
2 * imgSizeH, 2 * imgSizeW, channels, ratioH, ratioW);
// check
targetCheck->copyFrom(*targetGpu);
MatrixCheckErr(*target, *targetCheck);
// backward
MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true);
MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpuGrad = GpuMatrix::create(numSamples, outWidth, false,
true);
MatrixPtr targetCheckGrad =
CpuMatrix::create(numSamples, inWidth, false, false);
inputGrad->randomizeUniform();
targetGrad->randomizeUniform();
inputGpuGrad->copyFrom(*inputGrad);
targetGpuGrad->copyFrom(*targetGrad);
inputGrad->bilinearBackward(*targetGrad, 2 * imgSizeH, 2 * imgSizeW,
imgSizeH, imgSizeW, channels, ratioH, ratioW);
inputGpuGrad->bilinearBackward(*targetGpuGrad, 2 * imgSizeH, 2 * imgSizeW,
imgSizeH, imgSizeW, channels, ratioH, ratioW);
// check
targetCheckGrad->copyFrom(*inputGpuGrad);
MatrixCheckErr(*inputGrad, *targetCheckGrad);
}
TEST(Matrix, BilinearFwdBwd) {
for (auto numSamples : {5, 10}) {
for (auto channels : {8, 16}) {
for (auto imgSizeH : {14, 28}) {
for (auto imgSizeW : {16, 30}) {
VLOG(3) << " numSamples=" << numSamples
<< " channels=" << channels
<< " imgSizeH=" << imgSizeH
<< " imgSizeW=" << imgSizeW;
testBilinearFwdBwd(numSamples, imgSizeH, imgSizeW, channels);
}
}
}
}
}
void testMatrixProjectionForward(int contextStart, int contextLength,
bool padding, int batchSize, int inputDim) {
MatrixPtr cpuInput = std::make_shared<CpuMatrix>(batchSize, inputDim);
......
......@@ -212,6 +212,15 @@ message OperatorConfig {
optional int32 num_filters = 7;
}
message BilinearInterpConfig {
// The size of input feature map.
optional uint32 img_size_x = 1;
optional uint32 img_size_y = 2;
// The size of output feature map.
required uint32 out_size_x = 3;
required uint32 out_size_y = 4;
required uint32 num_channels = 5;
}
message ImageConfig {
// The image data dimensionality.
......@@ -234,7 +243,8 @@ message LayerInputConfig {
// If the input layer has multi-output.
// Set the argument name.
optional string input_layer_argument = 9;
optional MaxOutConfig maxout_conf = 10;
optional BilinearInterpConfig bilinear_interp_conf = 10;
optional MaxOutConfig maxout_conf = 11;
}
message LayerConfig {
......
......@@ -465,6 +465,7 @@ class Input(Cfg):
sparse_update=None,
gradient_clipping_threshold=None,
conv=None,
bilinear_interp=None,
norm=None,
pool=None,
image=None,
......@@ -768,6 +769,16 @@ class Conv(Cfg):
if output_x is not None:
config_assert(output_x <= 0)
# please refer to the comments in proto/ModelConfig.proto
@config_class
class BilinearInterp(Cfg):
def __init__(
self,
out_size_x = None,
out_size_y = None,
num_channels = None):
self.add_keys(locals())
# please refer to the comments in proto/ModelConfig.proto
@config_class
class Pool(Cfg):
......@@ -1008,6 +1019,11 @@ def TestData(data_config, async_load_data=None):
" Data definition")
g_config.test_data_config.async_load_data = async_load_data
def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
bilinear_conf.out_size_x = bilinear.out_size_x;
bilinear_conf.out_size_y = bilinear.out_size_y;
bilinear_conf.num_channels = bilinear.num_channels;
'''
caffe_mode: compute the output size using floor instead of ceil,
which is consistent of caffe and CuDNN's convention.
......@@ -2470,6 +2486,22 @@ class InterpolationLayer(LayerBase):
config_assert(input_layer1.size == input_layer2.size,
'the two vector inputs should be of the same size')
@config_layer('bilinear_interp')
class BilinearInterpLayer(LayerBase):
def __init__(
self,
name,
inputs,
**xargs):
super(BilinearInterpLayer, self).__init__(
name, 'bilinear_interp', 0, inputs=inputs, **xargs)
input_layer = self.get_input_layer(0)
parse_bilinear(self.inputs[0].bilinear_interp,
input_layer.name,
self.config.inputs[0].bilinear_interp_conf);
conf = self.inputs[0].bilinear_interp
self.set_layer_size(conf.out_size_x * conf.out_size_y * conf.num_channels)
@config_layer('sum_to_one_norm')
class SumToOneNormLayer(LayerBase):
def __init__(
......
......@@ -40,8 +40,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'img_cmrnorm_layer', 'addto_layer',
'concat_layer', 'lstm_step_layer', 'recurrent_group',
'memory', 'StaticInput', 'expand_layer', 'scaling_layer',
'power_layer', 'interpolation_layer', 'trans_layer',
'sum_to_one_norm_layer',
'power_layer', 'interpolation_layer', 'bilinear_interp_layer',
'trans_layer', 'sum_to_one_norm_layer',
'get_output_layer', 'LayerType', 'context_projection',
'beam_search', 'maxid_layer', 'GeneratedInput', 'SubsequenceInput',
'gru_step_layer', 'recurrent_layer',
......@@ -94,6 +94,7 @@ class LayerType(object):
EXPAND_LAYER = 'expand'
INTERPOLATION_LAYER = 'interpolation'
BILINEAR_INTERP_LAYER = 'bilinear_interp'
POWER_LAYER = 'power'
SCALING_LAYER = 'scaling'
TRANS_LAYER = 'trans'
......@@ -1259,6 +1260,52 @@ def interpolation_layer(input, weight, name=None, layer_attr=None):
size=input[0].size)
@wrap_name_default()
@layer_support()
def bilinear_interp_layer(input,
out_size_x=None,
out_size_y=None,
name=None,
layer_attr=None):
"""
This layer is to implement bilinear interpolation on conv layer output.
Please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation
The simple usage is:
.. code-block:: python
bilinear = bilinear_interp_layer(input=layer1, out_size_x=64, out_size_y=64)
:param input: A input layer.
:type input: LayerOutput.
:param out_size_x: bilinear interpolation output width.
:type out_size_x: int|None
:param out_size_y: bilinear interpolation output height.
:type out_size_y: int|None
:param name: The layer's name, which cna not be specified.
:type name: None|basestring
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert input.layer_type == LayerType.CONV_LAYER
assert isinstance(input.activation, LinearActivation)
assert out_size_x > 0 and out_size_y > 0
assert input.num_filters is not None
num_channels = input.num_filters
Layer(name=name,
inputs=Input(input.name,
bilinear_interp=BilinearInterp(out_size_x=out_size_x,
out_size_y=out_size_y,
num_channels=num_channels)),
type=LayerType.BILINEAR_INTERP_LAYER,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name, LayerType.BILINEAR_INTERP_LAYER, parents=[input],
num_filters=num_channels)
@wrap_name_default()
@layer_support()
def power_layer(input, weight, name=None, layer_attr=None):
......
......@@ -11,7 +11,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight
test_maxout test_bi_grumemory math_ops)
test_bilinear_interp test_maxout test_bi_grumemory math_ops)
for conf in ${configs[*]}
......
type: "nn"
layers {
name: "data"
type: "data"
size: 2304
active_type: ""
}
layers {
name: "__conv_0__"
type: "exconv"
size: 36864
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "___conv_0__.w0"
conv_conf {
filter_size: 3
channels: 1
stride: 1
padding: 1
groups: 1
filter_channels: 1
output_x: 48
img_size: 48
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
num_filters: 16
shared_biases: true
}
layers {
name: "__bilinear_interp_layer_0__"
type: "bilinear_interp"
size: 65536
active_type: ""
inputs {
input_layer_name: "__conv_0__"
bilinear_interp_conf {
out_size_x: 64
out_size_y: 64
num_channels: 16
}
}
}
layers {
name: "__pool_0__"
type: "pool"
size: 16384
active_type: ""
inputs {
input_layer_name: "__bilinear_interp_layer_0__"
pool_conf {
pool_type: "max-projection"
channels: 4
size_x: 2
stride: 2
output_x: 64
img_size: 128
padding: 0
size_y: 2
stride_y: 2
output_y: 64
img_size_y: 128
padding_y: 0
}
}
}
layers {
name: "__fc_layer_0__"
type: "fc"
size: 384
active_type: "tanh"
inputs {
input_layer_name: "__pool_0__"
input_parameter_name: "___fc_layer_0__.w0"
}
}
parameters {
name: "___conv_0__.w0"
size: 144
initial_mean: 0.0
initial_std: 0.471404520791
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___conv_0__.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___fc_layer_0__.w0"
size: 6291456
initial_mean: 0.0
initial_std: 0.0078125
dims: 16384
dims: 384
initial_strategy: 0
initial_smart: true
}
input_layer_names: "data"
output_layer_names: "__fc_layer_0__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "__conv_0__"
layer_names: "__bilinear_interp_layer_0__"
layer_names: "__pool_0__"
layer_names: "__fc_layer_0__"
input_layer_names: "data"
output_layer_names: "__fc_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
settings(
batch_size=1000,
learning_rate=1e-5
)
data = data_layer(name='data', size=2304)
conv = img_conv_layer(input=data,
filter_size = 3,
num_channels=1,
num_filters=16,
padding=1,
act=LinearActivation(),
bias_attr=True)
bilinear = bilinear_interp_layer(input=conv,
out_size_x=64,
out_size_y=64)
pool = img_pool_layer(input=bilinear,
num_channels=4,
pool_size=2,
stride=2,
pool_type=MaxPooling())
fc = fc_layer(input=pool, size=384, bias_attr=False)
outputs(fc)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册