提交 ddfff3a7 编写于 作者: L liaogang

Add bilinear interpolation layer

上级 86bb5ef1
...@@ -263,6 +263,12 @@ interpolation_layer ...@@ -263,6 +263,12 @@ interpolation_layer
:members: interpolation_layer :members: interpolation_layer
:noindex: :noindex:
bilinear_interp_layer
-------------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: bilinear_interp_layer
:noindex:
power_layer power_layer
----------- -----------
.. automodule:: paddle.trainer_config_helpers.layers .. automodule:: paddle.trainer_config_helpers.layers
......
...@@ -240,4 +240,60 @@ extern void hl_CMRNorm_backward( ...@@ -240,4 +240,60 @@ extern void hl_CMRNorm_backward(
size_t channels, size_t height, size_t width, size_t sizeX, size_t channels, size_t height, size_t width, size_t sizeX,
real alpha, real beta); real alpha, real beta);
/**
* @brief Bilinear interpolation forward.
*
* @param[in] inData input value.
* @param[in] inImgH input image height.
* @param[in] inImgW input image width.
* @param[in] inputH input batchSize.
* @param[in] inputW input image data dim.
* @param[out] outData output value.
* @param[in] outImgH output image height.
* @param[in] outImgW output image width.
* @param[in] outputH output batchSize.
* @param[in] outputW output image data dim.
* @param[in] numChannels number of channels.
*
*/
extern void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels);
/**
* @brief Bilinear interpolation backward.
*
* @param[out] inGrad input gradient.
* @param[in] inImgH input image height.
* @param[in] inImgW input image width.
* @param[in] inputH input batchSize.
* @param[in] inputW input image data dim.
* @param[in] outGrad output gradient.
* @param[in] outImgH output image height.
* @param[in] outImgW output image width.
* @param[in] outputH output batchSize.
* @param[in] outputW output image data dim.
* @param[in] numChannels number of channels.
*
*/
extern void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels);
#endif /* HL_CNN_H_ */ #endif /* HL_CNN_H_ */
...@@ -89,4 +89,28 @@ inline void hl_CMRNorm_backward( ...@@ -89,4 +89,28 @@ inline void hl_CMRNorm_backward(
size_t channels, size_t height, size_t width, size_t sizeX, size_t channels, size_t height, size_t width, size_t sizeX,
real alpha, real beta) {} real alpha, real beta) {}
inline void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels) {}
inline void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels) {}
#endif // HL_CNN_STUB_H_ #endif // HL_CNN_STUB_H_
...@@ -522,7 +522,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, ...@@ -522,7 +522,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
size_t height, size_t width, size_t sizeX, size_t height, size_t width, size_t sizeX,
real alpha, real beta) { real alpha, real beta) {
size_t threadsNum = frameCnt * height * width; size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 -1) / 1024; size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1; size_t blocksY = 1;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY); dim3 grid(blocksX, blocksY);
...@@ -531,3 +531,135 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, ...@@ -531,3 +531,135 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
height, width, sizeX, alpha, beta, inDiff); height, width, sizeX, alpha, beta, inDiff);
CHECK_SYNC("hl_CMRNorm_backward"); CHECK_SYNC("hl_CMRNorm_backward");
} }
__global__ void KeBilinearInterpFw(const size_t nthreads,
const real* in,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* out,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < nthreads) {
int outIdH = tid / (outputW / numChannels);
int outIdW = tid % (outputW / numChannels);
int inIdH = ratioH * (outIdW / outImgW);
int hId = (inIdH < inImgH - 1) ? 1 : 0;
real hlambda = ratioH * (outIdW / outImgW) - inIdH;
int inIdW = ratioW * (tid % outImgW);
int wId = (inIdW < inImgW - 1) ? 1 : 0;
real wlambda = ratioW * (tid % outImgW) - inIdW;
const real* inPos = &in[outIdH * inputW + inIdH * inImgW + inIdW];
real* outPos = &out[outIdH * outputW + outIdW];
for (int c = 0; c < numChannels; ++c) {
// bilinear interpolation
outPos[0] = (1.f - hlambda) *
((1.f - wlambda) * inPos[0] + wlambda * inPos[wId]) +
hlambda * ((1.f - wlambda) * inPos[hId * inImgW] +
wlambda * inPos[hId * inImgW + wId]);
inPos += inImgH * inImgW;
outPos += outImgH * outImgW;
}
}
}
void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels) {
int threadNum = outputH * outImgH * outImgW;
int blocks = (threadNum + 1024 - 1) / 1024;
real ratioH = (outImgH > 1) ?
static_cast<float>(inImgH - 1) / (outImgH - 1) : 0.f;
real ratioW = (outImgW > 1) ?
static_cast<float>(inImgW - 1) / (outImgW - 1) : 0.f;
KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
threadNum, inData, inImgH, inImgW, inputH, inputW, outData,
outImgH, outImgW, outputH, outputW, numChannels, ratioH, ratioW);
CHECK_SYNC("hl_bilinear_forward failed");
}
__global__ void KeBilinearInterpBw(const size_t nthreads,
real* in,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* out,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < nthreads) {
int outIdH = tid / (outputW / numChannels);
int outIdW = tid % (outputW / numChannels);
int inIdH = ratioH * (outIdW / outImgW);
int hId = (inIdH < inImgH - 1) ? 1 : 0;
real hlambda = ratioH * (outIdW / outImgW) - inIdH;
int inIdW = ratioW * (tid % outImgW);
int wId = (inIdW < inImgW - 1) ? 1 : 0;
real wlambda = ratioW * (tid % outImgW) - inIdW;
const real* outPos = &out[outIdH * outputW + outIdW];
real* inPos = &in[outIdH * inputW + inIdH * inImgW + inIdW];
for (int c = 0; c < numChannels; ++c) {
atomicAdd(&inPos[0], (1.f - hlambda) * (1.f - wlambda) * outPos[0]);
atomicAdd(&inPos[wId], (1.f - hlambda) * wlambda * outPos[0]);
atomicAdd(&inPos[hId * inImgW], hlambda * (1.f - wlambda) * outPos[0]);
atomicAdd(&inPos[hId * inImgW + wId], hlambda * wlambda * outPos[0]);
inPos += inImgH * inImgW;
outPos += outImgH * outImgW;
}
}
}
void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels) {
int threadNum = outputH * outImgH * outImgW;
int blocks = (threadNum + 1024 - 1) / 1024;
real ratioH = (outImgH > 1) ?
static_cast<float>(inImgH - 1) / (outImgH - 1) : 0.f;
real ratioW = (outImgW > 1) ?
static_cast<float>(inImgW - 1) / (outImgW - 1) : 0.f;
KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
threadNum, inGrad, inImgH, inImgW, inputH, inputW, outGrad,
outImgH, outImgW, outputH, outputW, numChannels, ratioH, ratioW);
CHECK_SYNC("hl_bilinear_backward failed");
}
\ No newline at end of file
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "BilinearInterpLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(bilinear_interp, BilinearInterpLayer);
size_t BilinearInterpLayer::getDataDimSize() {
getOutput().setFrameHeight(outImgH_);
getOutput().setFrameWidth(outImgW_);
return outImgH_ * outImgW_ * numChannels_;
}
bool BilinearInterpLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
CHECK_EQ(1, config_.inputs_size());
const BilinearInterpConfig& conf = config_.inputs(0).bilinear_interp_conf();
inImgH_ = inputLayers_[0]->getOutput().getFrameHeight();
inImgW_ = inputLayers_[0]->getOutput().getFrameWidth();
if (inImgH_ == 0) {
inImgH_ = conf.img_size_y();
}
if (inImgW_ == 0) {
inImgW_ = conf.img_size_x();
}
outImgH_ = conf.out_size_y();
outImgW_ = conf.out_size_x();
numChannels_ = conf.num_channels();
CHECK(outImgH_ > 0 && outImgW_ > 0);
CHECK(inImgH_ > 0 && inImgW_ > 0);
CHECK(numChannels_);
return true;
}
void BilinearInterpLayer::forward(PassType passType) {
Layer::forward(passType);
size_t batchSize = getInput(0).getBatchSize();
size_t size = getDataDimSize();
{
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
resetOutput(batchSize, size);
}
MatrixPtr inV = getInputValue(0);
MatrixPtr outV = getOutputValue();
{
REGISTER_TIMER_INFO("FwBilinearInterpTimer", getName().c_str());
outV->bilinearForward(*inV, inImgH_, inImgW_, outImgH_, outImgW_,
numChannels_);
}
}
void BilinearInterpLayer::backward(const UpdateCallback& callback) {
(void) callback;
MatrixPtr inputG = getInputGrad(0);
MatrixPtr outG = getOutputGrad();
{
REGISTER_TIMER_INFO("BwBilinearInterpTimer", getName().c_str());
if (inputG) {
inputG->bilinearBackward(*outG, outImgH_, outImgW_, inImgH_, inImgW_,
numChannels_);
}
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief A layer for bilinear interpolation which is
* used on conv layer output.
*
* @note The config file api is bilinear_interp_layer.
*/
class BilinearInterpLayer : public Layer {
protected:
size_t outImgH_, outImgW_;
size_t inImgH_, inImgW_;
size_t numChannels_;
public:
explicit BilinearInterpLayer(const LayerConfig& config) : Layer(config) {}
virtual ~BilinearInterpLayer() {}
size_t getDataDimSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
};
} // namespace paddle
...@@ -31,6 +31,26 @@ P_DECLARE_double(checkgrad_eps); ...@@ -31,6 +31,26 @@ P_DECLARE_double(checkgrad_eps);
P_DECLARE_bool(thread_local_rand_use_global_seed); P_DECLARE_bool(thread_local_rand_use_global_seed);
P_DECLARE_bool(prev_batch_state); P_DECLARE_bool(prev_batch_state);
TEST(Layer, BilinearInterpLayer) {
TestConfig config;
config.layerConfig.set_type("bilinear_interp");
config.biasSize = 0;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 4096, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
BilinearInterpConfig* bilinear = input->mutable_bilinear_interp_conf();
bilinear->set_img_size_x(32);
bilinear->set_img_size_y(32);
bilinear->set_out_size_x(64);
bilinear->set_out_size_y(64);
bilinear->set_num_channels(4);
for (auto useGpu : {false, true}) {
testLayerGrad(config, "bilinear_interp", 10, false, useGpu);
}
}
TEST(Operator, dot_mul) { TEST(Operator, dot_mul) {
TestConfig config; TestConfig config;
config.layerConfig.set_size(10); config.layerConfig.set_size(10);
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include <string.h> #include <string.h>
#include "hl_cnn.h"
#include "hl_gpu.h" #include "hl_gpu.h"
#include "hl_table_apply.h" #include "hl_table_apply.h"
#include "hl_top_k.h" #include "hl_top_k.h"
...@@ -1144,6 +1145,56 @@ void GpuMatrix::addColumnVector(const Matrix& b) { ...@@ -1144,6 +1145,56 @@ void GpuMatrix::addColumnVector(const Matrix& b) {
BaseMatrix::addColVector(const_cast<Matrix&>(b)); BaseMatrix::addColVector(const_cast<Matrix&>(b));
} }
void GpuMatrix::bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels) {
CHECK(dynamic_cast<const GpuMatrix*>(&in));
const size_t outputW = getWidth();
const size_t outputH = getHeight();
const size_t inputW = in.getWidth();
const size_t inputH = in.getHeight();
real* outData = getData();
const real* inData = in.getData();
if (inImgH == outImgW && inImgW == outImgW) {
this->copyFrom(in);
} else {
hl_bilinear_forward(inData, inImgH, inImgW,
inputH, inputW, outData, outImgH, outImgW,
outputH, outputW, numChannels);
}
}
void GpuMatrix::bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels) {
CHECK(dynamic_cast<const GpuMatrix*>(&out));
const size_t inputW = getWidth();
const size_t inputH = getHeight();
const size_t outputW = out.getWidth();
const size_t outputH = out.getHeight();
real* inGrad = getData();
const real* outGrad = out.getData();
if (outImgH == inImgH && outImgW == inImgW) {
this->copyFrom(out);
} else {
hl_bilinear_backward(inGrad, inImgH, inImgW,
inputH, inputW, outGrad, outImgH, outImgW,
outputH, outputW, numChannels);
}
}
/** /**
* CpuMatrix * CpuMatrix
*/ */
...@@ -3598,6 +3649,109 @@ void CpuMatrix::classificationErrorMulti(Matrix& output, Matrix& label, ...@@ -3598,6 +3649,109 @@ void CpuMatrix::classificationErrorMulti(Matrix& output, Matrix& label,
} }
} }
void CpuMatrix::bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels) {
CHECK(dynamic_cast<const CpuMatrix*>(&in));
size_t outputW = getWidth();
size_t outputH = getHeight();
size_t inputW = in.getWidth();
size_t inputH = in.getHeight();
real* outData = getData();
const real* inData = in.getData();
const real ratioH = (outImgH > 1) ?
static_cast<real>(inImgH - 1) / (outImgH - 1) : 0.f;
const real ratioW = (outImgW > 1) ?
static_cast<real>(inImgW - 1) / (outImgW - 1) : 0.f;
if (inImgH == outImgH && inImgW == outImgW) {
this->copyFrom(in);
} else {
for (int k = 0; k < outputH; ++k) { // loop for batches
for (int i = 0; i < outImgH; ++i) { // loop for images
int h = ratioH * i;
int hid = (h < inImgH - 1) ? 1 : 0;
real hlambda = ratioH * i - h;
for (int j = 0; j < outImgW; ++j) {
int w = ratioW * j;
int wid = (w < inImgW - 1) ? 1 : 0;
real wlambda = ratioW * j - w;
// calculate four position for bilinear interpolation
const real* inPos = &inData[k * inputW + h * inImgW + w];
real* outPos = &outData[k * outputW + i * outImgW + j];
for (int c = 0; c < numChannels; ++c) { // loop for channels
// bilinear interpolation
outPos[0] = (1.f - hlambda) *
((1.f - wlambda) * inPos[0] + wlambda * inPos[wid]) +
hlambda * ((1.f - wlambda) * inPos[hid * inImgW] +
wlambda * inPos[hid * inImgW + wid]);
inPos += inImgH * inImgW;
outPos += outImgH * outImgW;
}
}
}
}
}
}
void CpuMatrix::bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels) {
CHECK(dynamic_cast<const CpuMatrix*>(&out));
size_t inputW = getWidth();
size_t inputH = getHeight();
size_t outputW = out.getWidth();
size_t outputH = out.getHeight();
real* inGrad = getData();
const real* outGrad = out.getData();
const real ratioH = (outImgH > 1) ?
static_cast<real>(inImgH - 1) / (outImgH - 1) : 0.f;
const real ratioW = (outImgW > 1) ?
static_cast<real>(inImgW - 1) / (outImgW - 1) : 0.f;
if (inImgH == outImgH && inImgW == outImgW) {
this->copyFrom(out);
} else {
for (int k = 0; k < outputH; ++k) { // loop for batches
for (int i = 0; i < outImgH; ++i) { // loop for images
int h = ratioH * i;
int hid = (h < inImgH - 1) ? 1 : 0;
real hlambda = ratioH * i - h;
for (int j = 0; j < outImgW; ++j) {
int w = ratioW * j;
int wid = (w < inImgW - 1) ? 1 : 0;
real wlambda = ratioW * j - w;
real* inPos = &inGrad[k * inputW + h * inImgW + w];
const real* outPos = &outGrad[k * outputW + i * outImgW + j];
for (int c = 0; c < numChannels; ++c) { // loop for channels
inPos[0] += (1.f - hlambda) * (1.f - wlambda) * outPos[0];
inPos[wid] += (1.f - hlambda) * wlambda * outPos[0];
inPos[hid * inImgW] += hlambda * (1.f - wlambda) * outPos[0];
inPos[hid * inImgW + wid] += hlambda * wlambda * outPos[0];
inPos += inImgH * inImgW;
outPos += outImgH * outImgW;
}
}
}
}
}
}
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
// functions executed via cpu // // functions executed via cpu //
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
......
...@@ -930,6 +930,22 @@ public: ...@@ -930,6 +930,22 @@ public:
virtual void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) { virtual void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
virtual void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels) {
LOG(FATAL) << "Not implemented";
}
virtual void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels) {
LOG(FATAL) << "Not implemented";
}
}; };
inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) { inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) {
...@@ -1191,6 +1207,20 @@ public: ...@@ -1191,6 +1207,20 @@ public:
int contextLength, int contextLength,
int contextStart, int totalPad, int contextStart, int totalPad,
size_t beginPad); size_t beginPad);
void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels);
void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels);
}; };
class CpuMatrix : public Matrix { class CpuMatrix : public Matrix {
...@@ -1469,6 +1499,20 @@ public: ...@@ -1469,6 +1499,20 @@ public:
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label); void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
void classificationErrorMulti(Matrix& output, Matrix& label, real threshold); void classificationErrorMulti(Matrix& output, Matrix& label, real threshold);
void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels);
void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels);
}; };
class SharedCpuMatrix : public CpuMatrix { class SharedCpuMatrix : public CpuMatrix {
......
...@@ -88,6 +88,72 @@ void MatrixCheckErr(const Matrix& matrix1, const Matrix& matrix2) { ...@@ -88,6 +88,72 @@ void MatrixCheckErr(const Matrix& matrix1, const Matrix& matrix2) {
EXPECT_EQ(count, 0) << "There are " << count << " different element."; EXPECT_EQ(count, 0) << "There are " << count << " different element.";
} }
void testBilinearFwdBwd(int numSamples, int imgSizeH, int imgSizeW,
int channels) {
int inWidth = imgSizeH * imgSizeW * channels;
int outWidth = 2 * imgSizeH * 2 * imgSizeW * channels;
// forward
MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true);
MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true);
MatrixPtr targetCheck = CpuMatrix::create(numSamples, outWidth, false, false);
input->randomizeUniform();
inputGpu->copyFrom(*input);
target->bilinearForward(*input, imgSizeH, imgSizeW,
2 * imgSizeH, 2 * imgSizeW, channels);
targetGpu->bilinearForward(*inputGpu, imgSizeH, imgSizeW,
2 * imgSizeH, 2 * imgSizeW, channels);
// check
targetCheck->copyFrom(*targetGpu);
MatrixCheckErr(*target, *targetCheck);
// backward
MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true);
MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpuGrad = GpuMatrix::create(numSamples, outWidth, false,
true);
MatrixPtr targetCheckGrad =
CpuMatrix::create(numSamples, inWidth, false, false);
inputGrad->randomizeUniform();
targetGrad->randomizeUniform();
inputGpuGrad->copyFrom(*inputGrad);
targetGpuGrad->copyFrom(*targetGrad);
inputGrad->bilinearBackward(*targetGrad, 2 * imgSizeH, 2 * imgSizeW,
imgSizeH, imgSizeW, channels);
inputGpuGrad->bilinearBackward(*targetGpuGrad, 2 * imgSizeH, 2 * imgSizeW,
imgSizeH, imgSizeW, channels);
// check
targetCheckGrad->copyFrom(*inputGpuGrad);
MatrixCheckErr(*inputGrad, *targetCheckGrad);
}
TEST(Matrix, BilinearFwdBwd) {
for (auto numSamples : {5, 10}) {
for (auto channels : {8, 16}) {
for (auto imgSizeH : {14, 28}) {
for (auto imgSizeW : {16, 30}) {
VLOG(3) << " numSamples=" << numSamples
<< " channels=" << channels
<< " imgSizeH=" << imgSizeH
<< " imgSizeW=" << imgSizeW;
testBilinearFwdBwd(numSamples, imgSizeH, imgSizeW, channels);
}
}
}
}
}
void testMatrixProjectionForward(int contextStart, int contextLength, void testMatrixProjectionForward(int contextStart, int contextLength,
bool padding, int batchSize, int inputDim) { bool padding, int batchSize, int inputDim) {
MatrixPtr cpuInput = std::make_shared<CpuMatrix>(batchSize, inputDim); MatrixPtr cpuInput = std::make_shared<CpuMatrix>(batchSize, inputDim);
......
...@@ -203,6 +203,15 @@ message OperatorConfig { ...@@ -203,6 +203,15 @@ message OperatorConfig {
optional int32 num_filters = 7; optional int32 num_filters = 7;
} }
message BilinearInterpConfig {
// The size if input feature map.
required uint32 img_size_x = 1;
required uint32 img_size_y = 2;
// The size if output feature map.
required uint32 out_size_x = 3;
required uint32 out_size_y = 4;
required uint32 num_channels = 5;
}
message ImageConfig { message ImageConfig {
// The image data dimensionality. // The image data dimensionality.
...@@ -225,6 +234,7 @@ message LayerInputConfig { ...@@ -225,6 +234,7 @@ message LayerInputConfig {
// If the input layer has multi-output. // If the input layer has multi-output.
// Set the argument name. // Set the argument name.
optional string input_layer_argument = 9; optional string input_layer_argument = 9;
optional BilinearInterpConfig bilinear_interp_conf = 10;
} }
message LayerConfig { message LayerConfig {
......
...@@ -461,6 +461,7 @@ class Input(Cfg): ...@@ -461,6 +461,7 @@ class Input(Cfg):
sparse_update=None, sparse_update=None,
gradient_clipping_threshold=None, gradient_clipping_threshold=None,
conv=None, conv=None,
bilinear_interp=None,
norm=None, norm=None,
pool=None, pool=None,
image=None, image=None,
...@@ -723,6 +724,18 @@ class Conv(Cfg): ...@@ -723,6 +724,18 @@ class Conv(Cfg):
if output_x is not None: if output_x is not None:
config_assert(output_x <= 0) config_assert(output_x <= 0)
# please refer to the comments in proto/ModelConfig.proto
@config_class
class BilinearInterp(Cfg):
def __init__(
self,
img_size_x = None,
img_size_y=None,
out_size_x = None,
out_size_y = None,
num_channels = None):
self.add_keys(locals())
# please refer to the comments in proto/ModelConfig.proto # please refer to the comments in proto/ModelConfig.proto
@config_class @config_class
class Pool(Cfg): class Pool(Cfg):
...@@ -953,6 +966,13 @@ def TestData(data_config, async_load_data=None): ...@@ -953,6 +966,13 @@ def TestData(data_config, async_load_data=None):
" Data definition") " Data definition")
g_config.test_data_config.async_load_data = async_load_data g_config.test_data_config.async_load_data = async_load_data
def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
bilinear_conf.img_size_x = bilinear.img_size_x;
bilinear_conf.img_size_y = bilinear.img_size_y;
bilinear_conf.out_size_x = bilinear.out_size_x;
bilinear_conf.out_size_y = bilinear.out_size_y;
bilinear_conf.num_channels = bilinear.num_channels;
def parse_pool(pool, input_layer_name, pool_conf): def parse_pool(pool, input_layer_name, pool_conf):
pool_conf.pool_type = pool.pool_type pool_conf.pool_type = pool.pool_type
config_assert(pool.pool_type in ['max-projection', 'avg-projection', config_assert(pool.pool_type in ['max-projection', 'avg-projection',
...@@ -2306,6 +2326,21 @@ class InterpolationLayer(LayerBase): ...@@ -2306,6 +2326,21 @@ class InterpolationLayer(LayerBase):
config_assert(input_layer1.size == input_layer2.size, config_assert(input_layer1.size == input_layer2.size,
'the two vector inputs should be of the same size') 'the two vector inputs should be of the same size')
@config_layer('bilinear_interp')
class BilinearInterpLayer(LayerBase):
def __init__(
self,
name,
inputs,
device=None):
super(BilinearInterpLayer, self).__init__(
name, 'bilinear_interp', 0, inputs=inputs, device=device)
input_layer = self.get_input_layer(0)
self.set_layer_size(input_layer.size)
parse_bilinear(self.inputs[0].bilinear_interp,
input_layer.name,
self.config.inputs[0].bilinear_interp_conf);
@config_layer('sum_to_one_norm') @config_layer('sum_to_one_norm')
class SumToOneNormLayer(LayerBase): class SumToOneNormLayer(LayerBase):
def __init__( def __init__(
......
...@@ -40,8 +40,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", ...@@ -40,8 +40,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'img_cmrnorm_layer', 'addto_layer', 'img_cmrnorm_layer', 'addto_layer',
'concat_layer', 'lstm_step_layer', 'recurrent_group', 'concat_layer', 'lstm_step_layer', 'recurrent_group',
'memory', 'StaticInput', 'expand_layer', 'scaling_layer', 'memory', 'StaticInput', 'expand_layer', 'scaling_layer',
'power_layer', 'interpolation_layer', 'trans_layer', 'power_layer', 'interpolation_layer', 'bilinear_interp_layer',
'sum_to_one_norm_layer', 'trans_layer', 'sum_to_one_norm_layer',
'get_output_layer', 'LayerType', 'context_projection', 'get_output_layer', 'LayerType', 'context_projection',
'beam_search', 'maxid_layer', 'GeneratedInput', 'SubsequenceInput', 'beam_search', 'maxid_layer', 'GeneratedInput', 'SubsequenceInput',
'gru_step_layer', 'recurrent_layer', 'gru_step_layer', 'recurrent_layer',
...@@ -92,6 +92,7 @@ class LayerType(object): ...@@ -92,6 +92,7 @@ class LayerType(object):
EXPAND_LAYER = 'expand' EXPAND_LAYER = 'expand'
INTERPOLATION_LAYER = 'interpolation' INTERPOLATION_LAYER = 'interpolation'
BILINEAR_INTERP_LAYER = 'bilinear_interp'
POWER_LAYER = 'power' POWER_LAYER = 'power'
SCALING_LAYER = 'scaling' SCALING_LAYER = 'scaling'
TRANS_LAYER = 'trans' TRANS_LAYER = 'trans'
...@@ -1252,6 +1253,70 @@ def interpolation_layer(input, weight, name=None, layer_attr=None): ...@@ -1252,6 +1253,70 @@ def interpolation_layer(input, weight, name=None, layer_attr=None):
size=input[0].size) size=input[0].size)
@wrap_name_default()
@layer_support()
def bilinear_interp_layer(input,
img_size_x=None,
img_size_y=None,
out_size_x=None,
out_size_y=None,
num_channels=None,
name=None,
layer_attr=None):
"""
This layer is to implement bilinear interpolation on conv layer output.
Please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation
The simple usage is:
.. code-block:: python
bilinear = bilinear_interp_layer(input,
img_size_x,
img_size_y,
out_size_x,
out_size_y,
num_channels)
:para input: A input layer.
:type input: LayerOutput.
:para img_size_x: previous layer output width.
:type img_size_x: int|None
:para img_size_y: previous layer output height.
:type img_size_y: int|None
:para out_size_x: bilinear interpolation output width.
:type out_size_x: int|None
:para out_size_y: bilinear interpolation output height.
:type out_size_y: int|None
:para num_channels: number of channels of input layer. If None,
it will be set automatically from previous output.
:type num_channels: int|None
:para name: The layer's name, which cna not be specified.
:type name: None|basestring
:para layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert input.layer_type == LayerType.CONV_LAYER
assert isinstance(input.activation, LinearActivation)
assert img_size_x > 0 and img_size_y > 0
assert out_size_x > 0 and out_size_y > 0
if num_channels is None:
assert input.numfilters is not None
num_channels = input.num_filters
Layer(name=name,
inputs=Input(input.name,
bilinear_interp=BilinearInterp(img_size_x=img_size_x,
img_size_y=img_size_y,
out_size_x=out_size_x,
out_size_y=out_size_y,
num_channels=num_channels)),
type=LayerType.BILINEAR_INTERP_LAYER,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name, LayerType.BILINEAR_INTERP_LAYER, parents=[input])
@wrap_name_default() @wrap_name_default()
@layer_support() @layer_support()
def power_layer(input, weight, name=None, layer_attr=None): def power_layer(input, weight, name=None, layer_attr=None):
......
...@@ -8,7 +8,7 @@ configs=(test_fc layer_activations projections test_print_layer ...@@ -8,7 +8,7 @@ configs=(test_fc layer_activations projections test_print_layer
test_sequence_pooling test_lstmemory_layer test_grumemory_layer test_sequence_pooling test_lstmemory_layer test_grumemory_layer
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers util_layers simple_rnn_layers unused_layers test_cost_layers img_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group) test_rnn_group test_bilinear_interp)
for conf in ${configs[*]} for conf in ${configs[*]}
......
from paddle.trainer_config_helpers import *
settings(
batch_size=1000,
learning_rate=1e-5
)
data = data_layer(name='data', size=2304)
conv = img_conv_layer(input=data,
filter_size = 3,
num_channels=1,
num_filters=16,
padding=1,
act=LinearActivation(),
bias_attr=True)
bilinear = bilinear_interp_layer(input=conv,
img_size_x=32,
img_size_y=32,
out_size_x=64,
out_size_y=64,
num_channels=16)
pool = img_pool_layer(input=bilinear,
num_channels=4,
pool_size=2,
stride=2,
pool_type=MaxPooling())
fc = fc_layer(input=pool, size=384, bias_attr=False)
outputs(fc)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册