提交 9dd588b4 编写于 作者: Q qijun

fix merge conflicts

......@@ -5,4 +5,6 @@ build/
.vscode
.idea
.project
.cproject
.pydevproject
Makefile
......@@ -287,6 +287,12 @@ interpolation_layer
:members: interpolation_layer
:noindex:
bilinear_interp_layer
----------------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: bilinear_interp_layer
:noindex:
power_layer
-----------
.. automodule:: paddle.trainer_config_helpers.layers
......
......@@ -246,6 +246,70 @@ extern void hl_CMRNorm_backward(
size_t channels, size_t height, size_t width, size_t sizeX,
real alpha, real beta);
/**
* @brief Bilinear interpolation forward.
*
* @param[in] inData input value.
* @param[in] inImgH input image height.
* @param[in] inImgW input image width.
* @param[in] inputH input batchSize.
* @param[in] inputW input image data dim.
* @param[out] outData output value.
* @param[in] outImgH output image height.
* @param[in] outImgW output image width.
* @param[in] outputH output batchSize.
* @param[in] outputW output image data dim.
* @param[in] numChannels number of channels.
* @param[in] ratioH inImgH / outImgH.
* @param[in] ratioW inImgW / outImgW.
*
*/
extern void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW);
/**
* @brief Bilinear interpolation backward.
*
* @param[out] inGrad input gradient.
* @param[in] inImgH input image height.
* @param[in] inImgW input image width.
* @param[in] inputH input batchSize.
* @param[in] inputW input image data dim.
* @param[in] outGrad output gradient.
* @param[in] outImgH output image height.
* @param[in] outImgW output image width.
* @param[in] outputH output batchSize.
* @param[in] outputW output image data dim.
* @param[in] numChannels number of channels.
* @param[in] ratioH inImgH / outImgH.
* @param[in] ratioW inImgW / outImgW.
*
*/
extern void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW);
/**
* @brief MaxOut forward.
*
......
......@@ -91,6 +91,34 @@ inline void hl_CMRNorm_backward(
size_t channels, size_t height, size_t width, size_t sizeX,
real alpha, real beta) {}
inline void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {}
inline void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {}
inline void hl_maxout_forward(
const real* inData, real* outData, int* idData,
size_t batchSize, size_t size, size_t featLen, size_t group) {}
......
......@@ -528,7 +528,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 -1) / 1024;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
......@@ -538,6 +538,138 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
CHECK_SYNC("hl_CMRNorm_backward");
}
__global__ void KeBilinearInterpFw(const real* in,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* out,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int nthreads = outputH * outputW;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int outIdH = tid / outputW;
int outIdW = tid % outputW;
int inImgSize = inputW / numChannels;
int outImgSize = outputW / numChannels;
int channelId = outIdW / outImgSize;
int outImgIdy = (outIdW % outImgSize) / outImgW;
int inImgIdy = ratioH * outImgIdy;
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
real h1lambda = ratioH * outImgIdy - inImgIdy;
real h2lambda = 1.f - h1lambda;
int outImgIdx = tid % outImgW;
int inImgIdx = ratioW * outImgIdx;
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
real w1lambda = ratioW * outImgIdx - inImgIdx;
real w2lambda = 1.f - w1lambda;
const real* inPos =
&in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
// bilinear interpolation
out[outIdH * outputW + outIdW] =
h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
h1lambda * (w2lambda * inPos[hId * inImgW] + w1lambda * inPos[hId * inImgW + wId]);
}
}
void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
real* outData,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int threadNum = outputH * outputW;
int blocks = (threadNum + 1024 - 1) / 1024;
KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
inData, inImgH, inImgW, inputH, inputW, outData, outImgH,
outImgW, outputH, outputW, numChannels, ratioH, ratioW);
CHECK_SYNC("hl_bilinear_forward failed");
}
__global__ void KeBilinearInterpBw(real* in,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* out,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int nthreads = outputH * outputW;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int outIdH = tid / outputW;
int outIdW = tid % outputW;
int inImgSize = inputW / numChannels;
int outImgSize = outputW / numChannels;
int channelId = outIdW / outImgSize;
int outImgIdy = (outIdW % outImgSize) / outImgW;
int inImgIdy = ratioH * outImgIdy;
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
real h1lambda = ratioH * outImgIdy - inImgIdy;
real h2lambda = 1.f - h1lambda;
int outImgIdx = tid % outImgW;
int inImgIdx = ratioW * outImgIdx;
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
real w1lambda = ratioW * outImgIdx - inImgIdx;
real w2lambda = 1.f - w1lambda;
real* inPos =
&in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
const real* outPos = &out[outIdH * outputW + outIdW];
atomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
atomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
atomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]);
atomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]);
}
}
void hl_bilinear_backward(real* inGrad,
const size_t inImgH,
const size_t inImgW,
const size_t inputH,
const size_t inputW,
const real* outGrad,
const size_t outImgH,
const size_t outImgW,
const size_t outputH,
const size_t outputW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
int threadNum = outputH * outputW;
int blocks = (threadNum + 1024 - 1) / 1024;
KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
inGrad, inImgH, inImgW, inputH, inputW, outGrad, outImgH,
outImgW, outputH, outputW, numChannels, ratioH, ratioW);
CHECK_SYNC("hl_bilinear_backward failed");
}
__global__ void maxoutFpCompute(size_t nthreads, const real * inData,
real * outData, int* idData,
size_t size, size_t featLen, size_t groups) {
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "BilinearInterpLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(bilinear_interp, BilinearInterpLayer);
size_t BilinearInterpLayer::getSize() {
inImgH_ = inputLayers_[0]->getOutput().getFrameHeight();
inImgW_ = inputLayers_[0]->getOutput().getFrameWidth();
const BilinearInterpConfig& conf = config_.inputs(0).bilinear_interp_conf();
if (inImgH_ == 0) {
inImgH_ = conf.img_size_y();
}
if (inImgW_ == 0) {
inImgW_ = conf.img_size_x();
}
outImgH_ = conf.out_size_y();
outImgW_ = conf.out_size_x();
numChannels_ = conf.num_channels();
CHECK(outImgH_ > 0 && outImgW_ > 0);
CHECK(inImgH_ > 0 && inImgW_ > 0);
CHECK(numChannels_);
ratioH_ = (outImgH_ > 1) ?
static_cast<real>(inImgH_ - 1) / (outImgH_ - 1) : 0.f;
ratioW_ = (outImgW_ > 1) ?
static_cast<real>(inImgW_ - 1) / (outImgW_ - 1) : 0.f;
getOutput().setFrameHeight(outImgH_);
getOutput().setFrameWidth(outImgW_);
return outImgH_ * outImgW_ * numChannels_;
}
bool BilinearInterpLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
CHECK_EQ(1, config_.inputs_size());
return true;
}
void BilinearInterpLayer::forward(PassType passType) {
Layer::forward(passType);
size_t batchSize = getInput(0).getBatchSize();
size_t size = getSize();
{
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
resetOutput(batchSize, size);
}
MatrixPtr inV = getInputValue(0);
MatrixPtr outV = getOutputValue();
{
REGISTER_TIMER_INFO("FwBilinearInterpTimer", getName().c_str());
outV->bilinearForward(*inV, inImgH_, inImgW_, outImgH_, outImgW_,
numChannels_, ratioH_, ratioW_);
}
}
void BilinearInterpLayer::backward(const UpdateCallback& callback) {
(void) callback;
MatrixPtr inputG = getInputGrad(0);
MatrixPtr outG = getOutputGrad();
{
REGISTER_TIMER_INFO("BwBilinearInterpTimer", getName().c_str());
if (inputG) {
inputG->bilinearBackward(*outG, outImgH_, outImgW_, inImgH_, inImgW_,
numChannels_, ratioH_, ratioW_);
}
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief A layer for bilinear interpolation which is
* used on conv layer output.
*
* @note The config file api is bilinear_interp_layer.
*/
class BilinearInterpLayer : public Layer {
protected:
size_t outImgH_, outImgW_;
size_t inImgH_, inImgW_;
real ratioH_, ratioW_;
size_t numChannels_;
public:
explicit BilinearInterpLayer(const LayerConfig& config) : Layer(config) {}
virtual ~BilinearInterpLayer() {}
size_t getSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
};
} // namespace paddle
......@@ -14,12 +14,15 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include "ConvBaseLayer.h"
#include "paddle/math/MathUtils.h"
namespace paddle {
bool ConvBaseLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv")
? false : true;
/* Initialize the convolutional layer parameter */
numFilters_ = config_.num_filters();
......@@ -42,8 +45,20 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
outputW_.push_back(conf.output_x());
}
CHECK(inputLayers_.size() == parameters_.size());
for (size_t i = 0; i < inputLayers_.size(); i++) {
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = (!isDeconv_) ? numFilters_ : channels_[i];
// create a new weight
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight* w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
}
/* initialize the biases_ */
if (biasParameter_.get() != NULL) {
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
......@@ -70,23 +85,48 @@ size_t ConvBaseLayer::calOutputSize() {
clearAndReserve(&outputH_);
clearAndReserve(&outputW_);
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); i++) {
imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight());
imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
if (imgSizeH_[i] == 0)
imgSizeH_[i] = config_.inputs(i).conv_conf().img_size();
if (imgSizeW_[i] == 0)
imgSizeW_[i] = config_.inputs(i).conv_conf().img_size();
outputH_.push_back(outputSize(imgSizeH_[i], filterSizeY_[i], paddingY_[i],
strideY_[i], caffeMode_));
outputW_.push_back(outputSize(imgSizeW_[i], filterSize_[i], padding_[i],
stride_[i], caffeMode_));
CHECK_EQ(outputH_[i], outputH_[0]);
CHECK_EQ(outputW_[i], outputW_[0]);
auto setLayerSize = [&](IntV& inH, IntV& inW, IntV& outH, IntV& outW) {
for (size_t i = 0; i < inputLayers_.size(); i++) {
inH.push_back(inputLayers_[i]->getOutput().getFrameHeight());
inW.push_back(inputLayers_[i]->getOutput().getFrameWidth());
if (isDeconv_) {
if (inH[i] == 0)
inH[i] = config_.inputs(i).conv_conf().output_x();
if (inW[i] == 0)
inW[i] = config_.inputs(i).conv_conf().output_x();
outH.push_back(
imageSize(inH[i], filterSizeY_[i], paddingY_[i], strideY_[i],
caffeMode_));
outW.push_back(
imageSize(inW[i], filterSize_[i], padding_[i], stride_[i],
caffeMode_));
} else {
if (inH[i] == 0)
inH[i] = config_.inputs(i).conv_conf().img_size();
if (inW[i] == 0)
inW[i] = config_.inputs(i).conv_conf().img_size();
outH.push_back(
outputSize(inH[i], filterSizeY_[i], paddingY_[i], strideY_[i],
caffeMode_));
outW.push_back(
outputSize(inW[i], filterSize_[i], padding_[i], stride_[i],
caffeMode_));
}
CHECK_EQ(outH[i], outH[0]);
CHECK_EQ(outW[i], outW[0]);
}
getOutput().setFrameHeight(outH[0]);
getOutput().setFrameWidth(outW[0]);
layerSize = outH[0] * outW[0] * size_t(numFilters_);
};
if (isDeconv_) {
setLayerSize(outputH_, outputW_, imgSizeH_, imgSizeW_);
} else {
setLayerSize(imgSizeH_, imgSizeW_, outputH_, outputW_);
}
getOutput().setFrameHeight(outputH_[0]);
getOutput().setFrameWidth(outputW_[0]);
layerSize = outputH_[0] * outputW_[0] * size_t(numFilters_);
return layerSize;
}
......
......@@ -28,6 +28,9 @@ class ConvBaseLayer : public Layer {
protected:
typedef std::vector<int> IntV;
/// True if it's deconv layer, false if it's convolution layer
bool isDeconv_;
/// The number of filters.
int numFilters_;
/// The x dimension of the padding.
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ExpandConvBaseLayer.h"
#include "paddle/utils/Logging.h"
namespace paddle {
bool ExpandConvBaseLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
/* Initialize the basic convolutional parent class */
ConvBaseLayer::init(layerMap, parameterMap);
/* The class fields channels_ and numFilters_ are the same as in the config
* i.e., channels_ is the for the input and numFilters_ is for the output
*
* But in order for the variables in convTrans having the same semantic
* meaning as in conv, we need to swap channels_ and numFilters here for
* convTrans, and in other functions too.
* */
int channel;
int numFilters;
/* Initialize the projection */
for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf();
numFilters = isDeconv_ ? conf.channels() : numFilters_;
subM_.push_back(numFilters / conf.groups());
subN_.push_back(conf.output_x() * conf.output_x());
channel = isDeconv_ ? numFilters_ : conf.channels();
subK_.push_back(channel * conf.filter_size() * conf.filter_size() /
conf.groups());
/* Consistent caffe mode for multiple input */
caffeMode_ = conf.caffe_mode();
}
getOutputSize();
return true;
}
size_t ExpandConvBaseLayer::getOutputSize() {
CHECK_NE(inputLayers_.size(), 0UL);
size_t layerSize = ConvBaseLayer::calOutputSize();
subN_.clear();
for (size_t i = 0; i < inputLayers_.size(); i++) {
subN_.push_back(outputH_[i] * outputW_[i]);
}
return layerSize;
}
void ExpandConvBaseLayer::resetExpandInput(size_t height, size_t width) {
Matrix::resizeOrCreate(expandInput_, height, width, false, useGpu_);
}
void ExpandConvBaseLayer::addSharedBias() {
size_t mapW = getOutputSize() / numFilters_;
size_t mapH = getOutputValue()->getElementCnt() / mapW;
MatrixPtr out =
Matrix::create(getOutputValue()->getData(), mapH, mapW, false, useGpu_);
Matrix::resizeOrCreate(transOutValue_, mapW, mapH, false, useGpu_);
out->transpose(transOutValue_, false); // false means no memory allocation
transOutValue_->reshape(transOutValue_->getElementCnt() / numFilters_,
numFilters_);
MatrixPtr bias =
Matrix::create(biases_->getW()->getData(), 1,
biases_->getW()->getElementCnt(), false, useGpu_);
transOutValue_->addBias(*bias, 1.0f);
transOutValue_->reshape(mapW, mapH);
transOutValue_->transpose(out, false); // false means no memory allocation
out->clear();
bias->clear();
}
void ExpandConvBaseLayer::addUnsharedBias() {
MatrixPtr outValue = getOutputValue();
MatrixPtr bias =
Matrix::create(biases_->getW()->getData(), 1,
biases_->getW()->getElementCnt(), false, useGpu_);
outValue->addBias(*bias, 1.0f);
}
void ExpandConvBaseLayer::expandOneFrame(MatrixPtr image, size_t startIdx,
int inIdx) {
int channel = isDeconv_ ? numFilters_ : channels_[inIdx];
resetExpandInput(subK_[inIdx] * groups_[inIdx], subN_[inIdx]);
real *imgData = image->getData() + startIdx * image->getWidth();
MatrixPtr imageTmp = Matrix::create(
imgData, 1, imgSizeH_[inIdx] * imgSizeW_[inIdx] * channel, false,
useGpu_);
expandInput_->convExpand(*imageTmp, imgSizeH_[inIdx], imgSizeW_[inIdx],
channel, filterSize_[inIdx],
filterSize_[inIdx], stride_[inIdx], stride_[inIdx],
padding_[inIdx], padding_[inIdx],
outputH_[inIdx], outputW_[inIdx]);
imageTmp->clear();
}
void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image, MatrixPtr out,
int inIdx, int startIdx) {
int subM = subM_[inIdx];
int subN = subN_[inIdx];
int subK = subK_[inIdx];
expandOneFrame(image, startIdx, inIdx);
int numFilters = isDeconv_ ? channels_[inIdx] : numFilters_;
real *outData =
out->getData() + startIdx * subN * numFilters;
real *wgtData = weights_[inIdx]->getW()->getData();
real *expInData = expandInput_->getData();
for (int g = 0; g < groups_[inIdx]; ++g) {
MatrixPtr A =
Matrix::create(wgtData, subK, subM, true, useGpu_); // mark transpose
MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_);
MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_);
C->mul(A, B, 1, 1);
A->clear();
B->clear();
C->clear();
wgtData += subK * subM;
expInData += subK * subN;
outData += subM * subN;
}
}
void ExpandConvBaseLayer::bpropActs(MatrixPtr out, MatrixPtr image,
int inpIdx) {
int channel = isDeconv_ ? numFilters_ : channels_[inpIdx];
int subM = subM_[inpIdx];
int subN = subN_[inpIdx];
int subK = subK_[inpIdx];
size_t batchSize = image->getHeight();
/* reset the expand-grad memory */
resetExpandInput(subK * groups_[inpIdx], subN);
real *localGradData = out->getData();
real *tgtGradData = image->getData();
for (size_t n = 0; n < batchSize; n++) {
real *wgtData = weights_[inpIdx]->getW()->getData();
real *expandInData = expandInput_->getData();
for (int g = 0; g < groups_[inpIdx]; g++) {
// create temporary matrix
MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_);
MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_);
MatrixPtr A = Matrix::create(wgtData, subK, subM, false, useGpu_);
C->mul(A, B); // mul
// clear the temporary matrix
A->clear();
B->clear();
C->clear();
expandInData += subK * subN;
localGradData += subM * subN;
wgtData += subK * subM;
}
// shrink one frame outGrad
MatrixPtr oneGradTmp = Matrix::create(
expandInput_->getData(), subK * groups_[inpIdx], subN, false, useGpu_);
MatrixPtr vTmp = Matrix::create(
tgtGradData, 1,
imgSizeH_[inpIdx] * imgSizeW_[inpIdx] * channel, false,
useGpu_);
vTmp->convShrink(*oneGradTmp, imgSizeH_[inpIdx], imgSizeW_[inpIdx],
channel, filterSize_[inpIdx],
filterSize_[inpIdx], stride_[inpIdx], stride_[inpIdx],
padding_[inpIdx], padding_[inpIdx],
outputH_[inpIdx], outputW_[inpIdx], 1.0f, 1.0f);
vTmp->clear();
oneGradTmp->clear();
// move the data-pointer
tgtGradData += imgSizeH_[inpIdx] * imgSizeW_[inpIdx] * channel;
}
}
void ExpandConvBaseLayer::bpropWeights(MatrixPtr image, MatrixPtr out,
int inpIdx) {
MatrixPtr weightGrad = weights_[inpIdx]->getWGrad();
int subM = subM_[inpIdx];
int subN = subN_[inpIdx];
int subK = subK_[inpIdx];
size_t batchSize = image->getHeight();
resetExpandInput(subK * groups_[inpIdx], subN);
real *gradData = out->getData();
for (size_t n = 0; n < batchSize; n++) { // frame by frame
// expand
expandOneFrame(image, n, inpIdx);
real *wGradData = weightGrad->getData();
real *expandInData = expandInput_->getData();
// expand-mul one-group by one
for (int g = 0; g < groups_[inpIdx]; g++) {
MatrixPtr A = Matrix::create(expandInData, subK, subN, false, useGpu_);
MatrixPtr B = Matrix::create(gradData, subM, subN, true, useGpu_);
MatrixPtr C = Matrix::create(wGradData, subK, subM, false, useGpu_);
C->mul(A, B, 1, 1);
A->clear();
B->clear();
C->clear();
gradData += subM * subN;
wGradData += subK * subM;
expandInData += subK * subN;
}
}
}
void ExpandConvBaseLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) {
size_t mapW = getOutputSize() / numFilters_;
size_t mapH = v->getElementCnt() / mapW;
MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_);
Matrix::resizeOrCreate(transOutValue_, mapW, mapH, false, useGpu_);
vTmp->transpose(transOutValue_, false); // false means no memory allocation
transOutValue_->reshape(transOutValue_->getElementCnt() / numFilters_,
numFilters_);
biases->collectBias(*transOutValue_, 1.0f);
}
void ExpandConvBaseLayer::bpropBiases(MatrixPtr v) {
MatrixPtr biases =
Matrix::create(biases_->getWGrad()->getData(), 1,
biases_->getWGrad()->getElementCnt(), false, useGpu_);
if (sharedBiases_) {
bpropSharedBias(biases, v);
} else {
biases->collectBias(*v, 1.0f);
}
biases->clear();
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ConvBaseLayer.h"
#include "paddle/math/Matrix.h"
#include <vector>
namespace paddle {
/**
* @brief A subclass of ConvBaseLayer that is a superclass of both
* ExpandConvLayer and ExpandConvTransLayer
*/
class ExpandConvBaseLayer : public ConvBaseLayer {
protected:
/// For expand convolution.
/// subM_ = numFilters_ / groups_.
IntV subM_;
/// subN_ = outputH_ * outputW_.
IntV subN_;
/// subK_ = channels_ * filterPixels_ * groups_.
IntV subK_;
/*The expandInput_ and transOutValue_ are used for CPU expand conv calc
* Expand one sample at a time. shape:
* (numChannels * filterPixels_, outputSizeH * outputSizeW)
* */
MatrixPtr expandInput_;
/// The transpose of output, which is an auxiliary matrix.
MatrixPtr transOutValue_;
public:
explicit ExpandConvBaseLayer(const LayerConfig& config)
: ConvBaseLayer(config) {}
~ExpandConvBaseLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
size_t getOutputSize();
/**
* Create or resize expandInput_.
*/
void resetExpandInput(size_t height, size_t width);
/**
* Add shared bias.
*/
void addSharedBias();
/**
* Add unshared bias.
*/
void addUnsharedBias();
/**
* Expand one input sample.
*/
void expandOneFrame(MatrixPtr image, size_t startIdx, int inIdx);
/**
* Expand one input sample and perform matrix multiplication.
*/
void expandFwdOnce(MatrixPtr image, MatrixPtr out, int inIdx, int startIdx);
void bpropSharedBias(MatrixPtr biases, MatrixPtr v);
void bpropBiases(MatrixPtr v);
void bpropWeights(MatrixPtr image, MatrixPtr out, int inpIdx);
void bpropActs(MatrixPtr image, MatrixPtr out, int inpIdx);
};
} // namespace paddle
......@@ -24,150 +24,29 @@ REGISTER_LAYER(exconv, ExpandConvLayer);
bool ExpandConvLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
/* Initialize the basic convolutional parent class */
ConvBaseLayer::init(layerMap, parameterMap);
/* Initialize the projection */
for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf();
subM_.push_back(numFilters_ / conf.groups());
subN_.push_back(conf.output_x() * conf.output_x());
subK_.push_back(conf.channels() * conf.filter_size() * conf.filter_size() /
conf.groups());
/* Consistent caffe mode for multiple input */
caffeMode_ = conf.caffe_mode();
}
/* initialize the weightList */
CHECK(inputLayers_.size() == parameters_.size());
for (size_t i = 0; i < inputLayers_.size(); i++) {
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = numFilters_;
// create a new weight
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight* w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
}
ExpandConvBaseLayer::init(layerMap, parameterMap);
return true;
}
size_t ExpandConvLayer::getOutputSize() {
CHECK_NE(inputLayers_.size(), 0UL);
size_t layerSize = ConvBaseLayer::calOutputSize();
subN_.clear();
for (size_t i = 0; i < inputLayers_.size(); i++) {
subN_.push_back(outputH_[i] * outputW_[i]);
}
return layerSize;
}
void ExpandConvLayer::resetExpandInput(size_t height, size_t width) {
Matrix::resizeOrCreate(expandInput_, height, width, false, useGpu_);
}
void ExpandConvLayer::resetConvOutput(size_t batchSize, int inIdx) {
Matrix::resizeOrCreate(transOutValue_, batchSize * numFilters_, subN_[inIdx],
false, useGpu_);
}
void ExpandConvLayer::expandOneFrame(MatrixPtr image, size_t startIdx,
int inIdx) {
resetExpandInput(subK_[inIdx] * groups_[inIdx], subN_[inIdx]);
real *imgData = image->getData() + startIdx * image->getWidth();
MatrixPtr imageTmp = Matrix::create(
imgData, 1, imgSizeH_[inIdx] * imgSizeW_[inIdx] * channels_[inIdx], false,
useGpu_);
expandInput_->convExpand(*imageTmp, imgSizeH_[inIdx], imgSizeW_[inIdx],
channels_[inIdx], filterSize_[inIdx],
filterSize_[inIdx], stride_[inIdx], stride_[inIdx],
padding_[inIdx], padding_[inIdx],
outputH_[inIdx], outputW_[inIdx]);
imageTmp->clear();
}
void ExpandConvLayer::expandFwdOnce(MatrixPtr image, int inIdx, int startIdx) {
int subM = subM_[inIdx];
int subN = subN_[inIdx];
int subK = subK_[inIdx];
expandOneFrame(image, startIdx, inIdx);
real *outData =
getOutputValue()->getData() + startIdx * subN * numFilters_;
real *wgtData = weights_[inIdx]->getW()->getData();
real *expInData = expandInput_->getData();
for (int g = 0; g < groups_[inIdx]; ++g) {
MatrixPtr A =
Matrix::create(wgtData, subK, subM, true, useGpu_); // mark transpose
MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_);
MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_);
C->mul(A, B, 1, 1);
A->clear();
B->clear();
C->clear();
wgtData += subK * subM;
expInData += subK * subN;
outData += subM * subN;
}
}
void ExpandConvLayer::addSharedBias() {
size_t mapW = getOutputValue()->getWidth() / numFilters_;
size_t mapH = getOutputValue()->getElementCnt() / mapW;
MatrixPtr out =
Matrix::create(getOutputValue()->getData(), mapH, mapW, false, useGpu_);
Matrix::resizeOrCreate(transOutValue_, mapW, mapH, false, useGpu_);
out->transpose(transOutValue_, false); // false means no memory allocation
transOutValue_->reshape(transOutValue_->getElementCnt() / numFilters_,
numFilters_);
MatrixPtr bias =
Matrix::create(biases_->getW()->getData(), 1,
biases_->getW()->getElementCnt(), false, useGpu_);
transOutValue_->addBias(*bias, 1.0f);
transOutValue_->reshape(mapW, mapH);
transOutValue_->transpose(out, false); // false means no memory allocation
out->clear();
bias->clear();
}
void ExpandConvLayer::addUnsharedBias() {
MatrixPtr outValue = getOutputValue();
MatrixPtr bias =
Matrix::create(biases_->getW()->getData(), 1,
biases_->getW()->getElementCnt(), false, useGpu_);
outValue->addBias(*bias, 1.0f);
}
void ExpandConvLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one colum, and the
* transOutValue correspond sample to one row */
int batchSize = inputLayers_[0]->getOutputValue()->getWidth();
batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
resetOutput(batchSize, getOutputSize());
MatrixPtr image = nullptr;
for (size_t i = 0; i != inputLayers_.size(); ++i) {
MatrixPtr outV = getOutputValue();
for (size_t i = 0; i < inputLayers_.size(); ++i) {
LayerPtr prevLayer = getPrev(i);
image = prevLayer->getOutputValue();
for (size_t off = 0; off < image->getHeight(); off++) {
REGISTER_TIMER_INFO("expandFwdOnce", getName().c_str());
expandFwdOnce(image, i, off);
expandFwdOnce(image, outV, i, off);
}
}
/* add the bias-vector */
if (biases_.get() != NULL) {
if (biases_.get()) {
if (sharedBiases_) {
addSharedBias();
} else {
......@@ -179,29 +58,6 @@ void ExpandConvLayer::forward(PassType passType) {
forwardActivation();
}
void ExpandConvLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) {
size_t mapW = v->getWidth() / numFilters_;
size_t mapH = v->getElementCnt() / mapW;
MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_);
Matrix::resizeOrCreate(transOutValue_, mapW, mapH, false, useGpu_);
vTmp->transpose(transOutValue_, false); // false means no memory allocation
vTmp->reshape(transOutValue_->getElementCnt() / numFilters_, numFilters_);
biases->collectBias(*vTmp, 1.0f);
}
void ExpandConvLayer::bpropBiases(MatrixPtr v) {
MatrixPtr biases =
Matrix::create(biases_->getWGrad()->getData(), 1,
biases_->getWGrad()->getElementCnt(), false, useGpu_);
if (sharedBiases_) {
bpropSharedBias(biases, v);
} else {
biases->collectBias(*v, 1.0f);
}
biases->clear();
}
void ExpandConvLayer::backward(const UpdateCallback &callback) {
backwardActivation();
......@@ -213,111 +69,18 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) {
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
for (size_t i = 0; i < inputLayers_.size(); ++i) {
/* First, calculate the input layers error */
bpropActs(outGrad, i);
if (getPrev(i)->getOutputGrad()) {
bpropActs(outGrad, getPrev(i)->getOutputGrad(), i);
}
if (weights_[i]->getWGrad()) {
/* Then, calculate the W-gradient for the current layer */
bpropWeights(outGrad, i);
bpropWeights(getPrev(i)->getOutputValue(), outGrad, i);
/* Increasing the number of gradient */
weights_[i]->getParameterPtr()->incUpdate(callback);
}
}
}
void ExpandConvLayer::bpropWeights(MatrixPtr v, int inpIdx) {
MatrixPtr weightGrad = weights_[inpIdx]->getWGrad();
MatrixPtr inputV = getPrev(inpIdx)->getOutputValue();
int subM = subM_[inpIdx];
int subN = subN_[inpIdx];
int subK = subK_[inpIdx];
size_t batchSize = inputV->getHeight();
resetExpandInput(subK * groups_[inpIdx], subN);
resetConvOutput(batchSize, inpIdx);
real *gradData = v->getData();
for (size_t n = 0; n < batchSize; n++) { // frame by frame
// expand
expandOneFrame(inputV, n, inpIdx);
real *wGradData = weightGrad->getData();
real *expandInData = expandInput_->getData();
// expand-mul one-group by one
for (int g = 0; g < groups_[inpIdx]; g++) {
MatrixPtr A = Matrix::create(expandInData, subK, subN, false, useGpu_);
MatrixPtr B = Matrix::create(gradData, subM, subN, true, useGpu_);
MatrixPtr C = Matrix::create(wGradData, subK, subM, false, useGpu_);
C->mul(A, B, 1, 1);
A->clear();
B->clear();
C->clear();
gradData += subM * subN;
wGradData += subK * subM;
expandInData += subK * subN;
}
}
}
void ExpandConvLayer::bpropActs(MatrixPtr v, int inpIdx) {
LayerPtr prevLayer = getPrev(inpIdx);
if (NULL == prevLayer->getOutputGrad()) {
return;
}
int subM = subM_[inpIdx];
int subN = subN_[inpIdx];
int subK = subK_[inpIdx];
size_t batchSize = v->getHeight();
MatrixPtr tgtGrad = prevLayer->getOutputGrad();
/* reset the expand-grad memory */
resetExpandInput(subK * groups_[inpIdx], subN);
resetConvOutput(batchSize, inpIdx);
real *localGradData = v->getData();
real *tgtGradData = tgtGrad->getData();
for (size_t n = 0; n < batchSize; n++) {
real *wgtData = weights_[inpIdx]->getW()->getData();
real *expandInData = expandInput_->getData();
for (int g = 0; g < groups_[inpIdx]; g++) {
// create temporary matrix
MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_);
MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_);
MatrixPtr A = Matrix::create(wgtData, subK, subM, false, useGpu_);
C->mul(A, B); // mul
// clear the temporary matrix
A->clear();
B->clear();
C->clear();
expandInData += subK * subN;
localGradData += subM * subN;
wgtData += subK * subM;
}
// shrink one frame outGrad
MatrixPtr oneGradTmp = Matrix::create(
expandInput_->getData(), subK * groups_[inpIdx], subN, false, useGpu_);
MatrixPtr vTmp = Matrix::create(
tgtGradData, 1,
imgSizeH_[inpIdx] * imgSizeW_[inpIdx] * channels_[inpIdx], false,
useGpu_);
vTmp->convShrink(*oneGradTmp, imgSizeH_[inpIdx], imgSizeW_[inpIdx],
channels_[inpIdx], filterSize_[inpIdx],
filterSize_[inpIdx], stride_[inpIdx], stride_[inpIdx],
padding_[inpIdx], padding_[inpIdx],
outputH_[inpIdx], outputW_[inpIdx], 1.0f, 1.0f);
vTmp->clear();
oneGradTmp->clear();
// move the data-pointer
tgtGradData += imgSizeH_[inpIdx] * imgSizeW_[inpIdx] * channels_[inpIdx];
}
}
} // namespace paddle
......@@ -15,9 +15,9 @@ limitations under the License. */
#pragma once
#include "ConvBaseLayer.h"
#include "paddle/math/Matrix.h"
#include <vector>
#include "ExpandConvBaseLayer.h"
namespace paddle {
......@@ -28,65 +28,18 @@ namespace paddle {
*
* The config file api is img_conv_layer.
*/
class ExpandConvLayer : public ConvBaseLayer {
protected:
/// For expand convolution.
/// subM_ = numFilters_ / groups_.
IntV subM_;
/// subN_ = outputH_ * outputW_.
IntV subN_;
/// subK_ = channels_ * filterPixels_ * groups_.
IntV subK_;
/// Expand one sample at a time. shape:
/// (numChannels * filterPixels_, outputSizeH * outputSizeW)
MatrixPtr expandInput_;
/// The transpose of output, which is an auxiliary matrix.
MatrixPtr transOutValue_;
class ExpandConvLayer : public ExpandConvBaseLayer {
public:
explicit ExpandConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
explicit ExpandConvLayer(const LayerConfig& config) :
ExpandConvBaseLayer(config) {}
~ExpandConvLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
size_t getOutputSize();
/**
* Create or resize expandInput_.
*/
void resetExpandInput(size_t height, size_t width);
/**
* Create or resize transOutValue_.
*/
void resetConvOutput(size_t batchSize, int inIdx);
/**
* Expand one input sample.
*/
void expandOneFrame(MatrixPtr image, size_t startIdx, int inIdx);
/**
* Expand one input sample and perform matrix multiplication.
*/
void expandFwdOnce(MatrixPtr image, int inIdx, int startIdx);
/**
* Add shared bias.
*/
void addSharedBias();
/**
* Add unshared bias.
*/
void addUnsharedBias();
void forward(PassType passType);
void bpropSharedBias(MatrixPtr biases, MatrixPtr v);
void bpropBiases(MatrixPtr v);
void backward(const UpdateCallback& callback);
void bpropWeights(MatrixPtr v, int inpIdx);
void bpropActs(MatrixPtr v, int inpIdx);
};
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
#include "ExpandConvTransLayer.h"
/* The implementation of the convTransLayer is basically a swap of forward and
* backward of the original convLayer.
* The variable naming follows the convention of the convLayer.
* */
namespace paddle {
REGISTER_LAYER(exconvt, ExpandConvTransLayer);
bool ExpandConvTransLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
/* Initialize the basic convolutional parent class */
ExpandConvBaseLayer::init(layerMap, parameterMap);
return true;
}
void ExpandConvTransLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
resetOutput(batchSize, getOutputSize());
MatrixPtr output = nullptr;
for (size_t i = 0; i < inputLayers_.size(); ++i) {
LayerPtr prevLayer = getPrev(i);
output = prevLayer->getOutputValue();
REGISTER_TIMER_INFO("shrinkFwd", getName().c_str());
bpropActs(output, getOutputValue(), i);
}
/* add the bias-vector */
if (biases_.get()) {
if (sharedBiases_) {
addSharedBias();
} else {
addUnsharedBias();
}
}
/* activation */
forwardActivation();
}
void ExpandConvTransLayer::backward(const UpdateCallback &callback) {
backwardActivation();
MatrixPtr imageGrad = getOutputGrad();
if (biases_ && biases_->getWGrad()) {
bpropBiases(imageGrad);
/* Increasing the number of gradient */
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i < inputLayers_.size(); ++i) {
/* First, calculate the input layers error */
for (size_t off = 0; off < imageGrad->getHeight(); off++) {
if (getPrev(i)->getOutputGrad()) {
expandFwdOnce(imageGrad, getPrev(i)->getOutputGrad(), i, off);
}
}
if (weights_[i]->getWGrad()) {
/* Then, calculate the W-gradient for the current layer */
bpropWeights(imageGrad, getPrev(i)->getOutputValue(), i);
/* Increasing the number of gradient */
weights_[i]->getParameterPtr()->incUpdate(callback);
}
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/math/Matrix.h"
#include <vector>
#include "ExpandConvBaseLayer.h"
namespace paddle {
/**
* @brief A subclass of convolution layer.
* This layer expands input and use matrix multiplication to
* calculate convolution transpose (deconv) operation.
*
* The config file api is img_conv_layer with flag trans=True.
*/
class ExpandConvTransLayer : public ExpandConvBaseLayer {
public:
explicit ExpandConvTransLayer(const LayerConfig& config) :
ExpandConvBaseLayer(config) {}
~ExpandConvTransLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback);
};
} // namespace paddle
......@@ -26,6 +26,14 @@ add_unittest_without_exec(test_ActivationGrad
TestUtil.cpp)
add_test(NAME test_ActivationGrad
COMMAND test_ActivationGrad)
################# test_ConvTrans #######################
add_unittest_without_exec(test_ConvTrans
test_ConvTrans.cpp
LayerGradUtil.cpp
TestUtil.cpp)
add_test(NAME test_ConvTrans
COMMAND test_ConvTrans)
################## test_Evaluator #######################
add_unittest(test_Evaluator
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <vector>
#include <string>
#include "paddle/gserver/layers/DataLayer.h"
#include "ModelConfig.pb.h"
#include "paddle/trainer/Trainer.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/layers/ExpandConvTransLayer.h"
#include "paddle/math/MathUtils.h"
#include "TestUtil.h"
#include "LayerGradUtil.h"
using namespace paddle; // NOLINT
using namespace std; // NOLINT
P_DECLARE_bool(use_gpu);
P_DECLARE_int32(gpu_id);
P_DECLARE_double(checkgrad_eps);
P_DECLARE_bool(thread_local_rand_use_global_seed);
P_DECLARE_bool(prev_batch_state);
// Test that the convTrans forward is the same as conv backward
TEST(Layer, convTransLayerFwd) {
// Setting up conv-trans layer
TestConfig configt;
configt.biasSize = 3;
configt.layerConfig.set_type("exconvt");
configt.layerConfig.set_num_filters(3);
configt.layerConfig.set_partial_sum(1);
configt.layerConfig.set_shared_biases(true);
configt.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 384});
LayerInputConfig* input = configt.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_filter_size(2);
conv->set_filter_size_y(4);
conv->set_channels(16);
conv->set_padding(0);
conv->set_padding_y(1);
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_groups(1);
conv->set_filter_channels(3 / conv->groups());
conv->set_img_size(16);
conv->set_output_x(outputSize(conv->img_size(), conv->filter_size(),
conv->padding(), conv->stride(),
/* caffeMode */ true));
configt.layerConfig.set_size(conv->img_size() * conv->img_size() *
configt.layerConfig.num_filters());
configt.layerConfig.set_name("convTrans");
// data layer initialize
std::vector<DataLayerPtr> dataLayers;
LayerMap layerMap;
vector<Argument> datas;
initDataLayer(configt, &dataLayers, &datas, &layerMap, "convTrans",
100, false, false);
// test layer initialize
std::vector<ParameterPtr> parameters;
LayerPtr convtLayer;
initTestLayer(configt, &layerMap, &parameters, &convtLayer);
convtLayer->getBiasParameter()->zeroMem();
convtLayer->forward(PASS_GC);
// Setting up conv-layer config
TestConfig config;
config.biasSize = 16;
config.layerConfig.set_type("exconv");
config.layerConfig.set_num_filters(16);
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
config.inputDefs.push_back({INPUT_DATA, "layer_1", 768, 384});
input = config.layerConfig.add_inputs();
conv = input->mutable_conv_conf();
conv->set_filter_size(2);
conv->set_filter_size_y(4);
conv->set_channels(3);
conv->set_padding(0);
conv->set_padding_y(1);
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
conv->set_img_size(16);
conv->set_output_x(outputSize(conv->img_size(), conv->filter_size(),
conv->padding(), conv->stride(),
/* caffeMode */ true));
config.layerConfig.set_size(conv->output_x() * conv->output_x() *
config.layerConfig.num_filters());
config.layerConfig.set_name("conv");
// data layer initialize
std::vector<DataLayerPtr> dataLayers2;
LayerMap layerMap2;
vector<Argument> datas2;
initDataLayer(config, &dataLayers2, &datas2, &layerMap2, "conv",
100, false, false);
// test layer initialize
std::vector<ParameterPtr> parameters2;
LayerPtr convLayer;
initTestLayer(config, &layerMap2, &parameters2, &convLayer);
// Sync convLayer and convtLayer parameter
convLayer->getBiasParameter()->zeroMem();
convLayer->getParameters()[0]->getBuf(PARAMETER_VALUE)->copyFrom(
*(convtLayer->getParameters()[0]->getBuf(PARAMETER_VALUE)));
// Set convLayer outputGrad as convTransLayer input value
convLayer->forward(PASS_GC);
convLayer->getOutput().grad->copyFrom(*(dataLayers[0]->getOutputValue()));
vector<int> callbackFlags(parameters2.size(), 0);
auto callback = [&](Parameter* para) { ++callbackFlags[para->getID()]; };
convLayer->backward(callback);
// Check that the convLayer backward is the same as convTransLayer forward
checkMatrixEqual(convtLayer->getOutputValue(),
dataLayers2[0]->getOutputGrad());
}
// Do one forward pass of convTrans layer and check to see if its output
// matches the given result
void doOneConvtTest(size_t imgSize, size_t output_x, size_t stride,
size_t padding, size_t filter_size, MatrixPtr& result) {
TestConfig configt;
configt.biasSize = 1;
configt.layerConfig.set_type("exconvt");
configt.layerConfig.set_num_filters(1);
configt.layerConfig.set_partial_sum(1);
configt.layerConfig.set_shared_biases(true);
configt.inputDefs.push_back({INPUT_DATA, "layer_0", output_x * output_x,
filter_size * filter_size});
LayerInputConfig* input = configt.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_filter_size(filter_size);
conv->set_filter_size_y(filter_size);
conv->set_channels(1);
conv->set_padding(padding);
conv->set_padding_y(padding);
conv->set_stride(stride);
conv->set_stride_y(stride);
conv->set_groups(1);
conv->set_filter_channels(1);
conv->set_img_size(imgSize);
conv->set_output_x(output_x);
configt.layerConfig.set_size(conv->img_size() * conv->img_size() *
configt.layerConfig.num_filters());
configt.layerConfig.set_name("convTrans");
std::vector<DataLayerPtr> dataLayers;
LayerMap layerMap;
vector<Argument> datas;
initDataLayer(configt, &dataLayers, &datas, &layerMap, "convTrans",
1, false, false);
dataLayers[0]->getOutputValue()->zeroMem();
dataLayers[0]->getOutputValue()->add(1.0);
// test layer initialize
std::vector<ParameterPtr> parameters;
LayerPtr convtLayer;
initTestLayer(configt, &layerMap, &parameters, &convtLayer);
convtLayer->getBiasParameter()->zeroMem();
convtLayer->getParameters()[0]->zeroMem();
convtLayer->getParameters()[0]->getBuf(PARAMETER_VALUE)->add(1.0);
convtLayer->forward(PASS_GC);
checkMatrixEqual(convtLayer->getOutputValue(), result);
}
TEST(Layer, convTransLayerFwd2) {
MatrixPtr result;
result = Matrix::create(1, 5 * 5, false, false);
result->zeroMem();
result->add(1.0);
doOneConvtTest(/* imgSize */ 5,
/* output_x */ 1,
/* stride */ 1,
/* padding */ 0,
/* filter_size */ 5,
result);
float resultData[] = {1, 2, 2, 2, 1,
2, 4, 4, 4, 2,
2, 4, 4, 4, 2,
2, 4, 4, 4, 2,
1, 2, 2, 2, 1};
result->setData(resultData);
doOneConvtTest(/* imgSize */ 5,
/* output_x */ 2,
/* stride */ 1,
/* padding */ 0,
/* filter_size */ 4,
result);
float resultData2[] = {1, 2, 2, 2, 1,
2, 4, 4, 4, 2,
2, 4, 4, 4, 2,
2, 4, 4, 4, 2,
1, 2, 2, 2, 1};
result->setData(resultData2);
doOneConvtTest(/* imgSize */ 5,
/* output_x */ 2,
/* stride */ 2,
/* padding */ 1,
/* filter_size */ 5,
result);
float resultData3[] = {1, 1, 2, 1, 1,
1, 1, 2, 1, 1,
2, 2, 4, 2, 2,
1, 1, 2, 1, 1,
1, 1, 2, 1, 1};
result->setData(resultData3);
doOneConvtTest(/* imgSize */ 5,
/* output_x */ 2,
/* stride */ 2,
/* padding */ 0,
/* filter_size */ 3,
result);}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
FLAGS_thread_local_rand_use_global_seed = true;
srand(1);
return RUN_ALL_TESTS();
}
......@@ -175,6 +175,27 @@ TEST(Projection, conv) {
}
#endif
TEST(Layer, BilinearInterpLayer) {
TestConfig config;
config.layerConfig.set_type("bilinear_interp");
config.biasSize = 0;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 4096, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
BilinearInterpConfig* bilinear = input->mutable_bilinear_interp_conf();
bilinear->set_img_size_x(32);
bilinear->set_img_size_y(32);
bilinear->set_num_channels(4);
for (auto useGpu : {false, true}) {
for (auto outSize : {32, 64}) {
bilinear->set_out_size_x(outSize);
bilinear->set_out_size_y(outSize);
testLayerGrad(config, "bilinear_interp", 10, false, useGpu);
}
}
}
TEST(Layer, concat) {
TestConfig config;
config.biasSize = 0;
......@@ -302,6 +323,8 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
config.layerConfig.num_filters());
testLayerGrad(config, "conv", 100, trans, useGpu);
// Use small batch_size and useWeight=true to test biasGrad
testLayerGrad(config, "conv", 2, trans, useGpu, true, 0.02);
}
TEST(Layer, convLayer) {
......@@ -312,6 +335,46 @@ TEST(Layer, convLayer) {
#endif
}
void testConvTransLayer(const string& type, bool trans, bool useGpu) {
TestConfig config;
config.biasSize = 3;
config.layerConfig.set_type(type);
config.layerConfig.set_num_filters(3);
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 288});
LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_filter_size(2);
conv->set_filter_size_y(3);
conv->set_channels(16);
conv->set_padding(0);
conv->set_padding_y(1);
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_groups(1);
conv->set_filter_channels(3 / conv->groups());
conv->set_img_size(16);
conv->set_output_x(outputSize(conv->img_size(), conv->filter_size(),
conv->padding(), conv->stride(),
/* caffeMode */ true));
config.layerConfig.set_size(conv->img_size() * conv->img_size() *
config.layerConfig.num_filters());
testLayerGrad(config, "convTrans", 100, trans, useGpu);
// Use small batch_size and useWeight=true to test biasGrad
testLayerGrad(config, "convTrans", 2, trans, useGpu, true, 0.02);
}
TEST(Layer, convTransLayer) {
for (auto useGpu : {false, true}) {
testConvTransLayer("exconvt", /* trans= */ false, /* useGpu= */ useGpu);
}
}
TEST(Layer, blockExpandLayer) {
TestConfig config;
config.biasSize = 0;
......
......@@ -80,4 +80,17 @@ int outputSize(int imageSize, int filterSize, int padding, int stride,
return outputSize;
}
int imageSize(int outputSize, int filterSize, int padding, int stride,
bool caffeMode) {
int imageSize;
if (!caffeMode) {
imageSize =
(outputSize - 1) * stride + filterSize - 2 * padding - stride + 1;
} else {
imageSize = (outputSize - 1) * stride + filterSize - 2 * padding;
}
CHECK_GE(imageSize, 1);
return imageSize;
}
} // namespace paddle
......@@ -60,4 +60,11 @@ void sparseRand(int* major, int* minor, int nnz, int majorLen, int minorMax,
int outputSize(int imageSize, int filterSize, int padding, int stride,
bool caffeMode);
/**
* Calculate image size based on output size and caffeMode_.
* It is the reverse function of outputSize()
*/
int imageSize(int outputSize, int filterSize, int padding, int stride,
bool caffeMode);
} // namespace paddle
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <cmath>
#include <string.h>
#include "hl_cnn.h"
#include "hl_gpu.h"
#include "hl_table_apply.h"
#include "hl_top_k.h"
......@@ -1211,6 +1212,62 @@ void GpuMatrix::addColumnVector(const Matrix& b) {
BaseMatrix::addColVector(const_cast<Matrix&>(b));
}
void GpuMatrix::bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
CHECK(dynamic_cast<const GpuMatrix*>(&in));
const size_t outputW = getWidth();
const size_t outputH = getHeight();
const size_t inputW = in.getWidth();
const size_t inputH = in.getHeight();
real* outData = getData();
const real* inData = in.getData();
if (inImgH == outImgW && inImgW == outImgW) {
this->copyFrom(in);
} else {
hl_bilinear_forward(
inData, inImgH, inImgW, inputH, inputW, outData,
outImgH, outImgW, outputH, outputW, numChannels,
ratioH, ratioW);
}
}
void GpuMatrix::bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
CHECK(dynamic_cast<const GpuMatrix*>(&out));
const size_t inputW = getWidth();
const size_t inputH = getHeight();
const size_t outputW = out.getWidth();
const size_t outputH = out.getHeight();
real* inGrad = getData();
const real* outGrad = out.getData();
if (outImgH == inImgH && outImgW == inImgW) {
this->add(const_cast<Matrix&>(out));
} else {
hl_bilinear_backward(
inGrad, inImgH, inImgW, inputH, inputW, outGrad,
outImgH, outImgW, outputH, outputW, numChannels,
ratioH, ratioW);
}
}
/**
* CpuMatrix
*/
......@@ -3838,6 +3895,112 @@ void CpuMatrix::classificationErrorMulti(Matrix& output, Matrix& label,
}
}
void CpuMatrix::bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
CHECK(dynamic_cast<const CpuMatrix*>(&in));
size_t outputW = getWidth();
size_t batchSize = getHeight();
size_t inputW = in.getWidth();
size_t inputH = in.getHeight();
size_t inPosOffset = inImgH * inImgW;
size_t outPosOffset = outImgH * outImgW;
(void)(inputH);
real* outData = getData();
const real* inData = in.getData();
if (inImgH == outImgH && inImgW == outImgW) {
this->copyFrom(in);
} else {
for (size_t k = 0; k < batchSize; ++k) { // loop for batches
for (size_t i = 0; i < outImgH; ++i) { // loop for images
size_t h = ratioH * i;
size_t hid = (h < inImgH - 1) ? 1 : 0;
real h1lambda = ratioH * i - h;
real h2lambda = 1 - h1lambda;
for (size_t j = 0; j < outImgW; ++j) {
size_t w = ratioW * j;
size_t wid = (w < inImgW - 1) ? 1 : 0;
real w1lambda = ratioW * j - w;
real w2lambda = 1 - w1lambda;
// calculate four position for bilinear interpolation
const real* inPos = &inData[k * inputW + h * inImgW + w];
real* outPos = &outData[k * outputW + i * outImgW + j];
for (size_t c = 0; c < numChannels; ++c) { // loop for channels
// bilinear interpolation
outPos[0] =
h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wid]) +
h1lambda * (w2lambda * inPos[hid * inImgW] +
w1lambda * inPos[hid * inImgW + wid]);
inPos += inPosOffset;
outPos += outPosOffset;
}
}
}
}
}
}
void CpuMatrix::bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
CHECK(dynamic_cast<const CpuMatrix*>(&out));
size_t inputW = getWidth();
size_t inputH = getHeight();
size_t outputW = out.getWidth();
size_t batchSize = out.getHeight();
size_t inPosOffset = inImgH * inImgW;
size_t outPosOffset = outImgH * outImgW;
(void)(inputH);
real* inGrad = getData();
const real* outGrad = out.getData();
if (inImgH == outImgH && inImgW == outImgW) {
this->add(const_cast<Matrix&>(out));
} else {
for (size_t k = 0; k < batchSize; ++k) { // loop for batches
for (size_t i = 0; i < outImgH; ++i) { // loop for images
size_t h = ratioH * i;
size_t hid = (h < inImgH - 1) ? 1 : 0;
real h1lambda = ratioH * i - h;
real h2lambda = 1 - h1lambda;
for (size_t j = 0; j < outImgW; ++j) {
size_t w = ratioW * j;
size_t wid = (w < inImgW - 1) ? 1 : 0;
real w1lambda = ratioW * j - w;
real w2lambda = 1 - w1lambda;
real* inPos = &inGrad[k * inputW + h * inImgW + w];
const real* outPos = &outGrad[k * outputW + i * outImgW + j];
for (size_t c = 0; c < numChannels; ++c) { // loop for channels
inPos[0] += h2lambda * w2lambda * outPos[0];
inPos[wid] += h2lambda * w1lambda * outPos[0];
inPos[hid * inImgW] += h1lambda * w2lambda * outPos[0];
inPos[hid * inImgW + wid] += h1lambda * w1lambda * outPos[0];
inPos += inPosOffset;
outPos += outPosOffset;
}
}
}
}
}
}
////////////////////////////////////////////////////////////////
// functions executed via cpu //
////////////////////////////////////////////////////////////////
......
......@@ -995,6 +995,26 @@ public:
virtual void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) {
LOG(FATAL) << "Not implemented";
}
virtual void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
LOG(FATAL) << "Not implemented";
}
virtual void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
LOG(FATAL) << "Not implemented";
}
};
inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) {
......@@ -1265,6 +1285,24 @@ public:
int contextLength,
int contextStart, int totalPad,
size_t beginPad);
void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
};
class CpuMatrix : public Matrix {
......@@ -1553,6 +1591,24 @@ public:
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
void classificationErrorMulti(Matrix& output, Matrix& label, real threshold);
void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
};
class SharedCpuMatrix : public CpuMatrix {
......
......@@ -90,6 +90,73 @@ void MatrixCheckErr(const Matrix& matrix1, const Matrix& matrix2) {
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
}
void testBilinearFwdBwd(int numSamples, int imgSizeH, int imgSizeW,
int channels) {
int inWidth = imgSizeH * imgSizeW * channels;
int outWidth = 2 * imgSizeH * 2 * imgSizeW * channels;
real ratioH = 0.5;
real ratioW = 0.5;
// forward
MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true);
MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true);
MatrixPtr targetCheck = CpuMatrix::create(numSamples, outWidth, false, false);
input->randomizeUniform();
inputGpu->copyFrom(*input);
target->bilinearForward(*input, imgSizeH, imgSizeW,
2 * imgSizeH, 2 * imgSizeW, channels, ratioH, ratioW);
targetGpu->bilinearForward(*inputGpu, imgSizeH, imgSizeW,
2 * imgSizeH, 2 * imgSizeW, channels, ratioH, ratioW);
// check
targetCheck->copyFrom(*targetGpu);
MatrixCheckErr(*target, *targetCheck);
// backward
MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true);
MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpuGrad = GpuMatrix::create(numSamples, outWidth, false,
true);
MatrixPtr targetCheckGrad =
CpuMatrix::create(numSamples, inWidth, false, false);
inputGrad->randomizeUniform();
targetGrad->randomizeUniform();
inputGpuGrad->copyFrom(*inputGrad);
targetGpuGrad->copyFrom(*targetGrad);
inputGrad->bilinearBackward(*targetGrad, 2 * imgSizeH, 2 * imgSizeW,
imgSizeH, imgSizeW, channels, ratioH, ratioW);
inputGpuGrad->bilinearBackward(*targetGpuGrad, 2 * imgSizeH, 2 * imgSizeW,
imgSizeH, imgSizeW, channels, ratioH, ratioW);
// check
targetCheckGrad->copyFrom(*inputGpuGrad);
MatrixCheckErr(*inputGrad, *targetCheckGrad);
}
TEST(Matrix, BilinearFwdBwd) {
for (auto numSamples : {5, 10}) {
for (auto channels : {8, 16}) {
for (auto imgSizeH : {14, 28}) {
for (auto imgSizeW : {16, 30}) {
VLOG(3) << " numSamples=" << numSamples
<< " channels=" << channels
<< " imgSizeH=" << imgSizeH
<< " imgSizeW=" << imgSizeW;
testBilinearFwdBwd(numSamples, imgSizeH, imgSizeW, channels);
}
}
}
}
}
void testMatrixProjectionForward(int contextStart, int contextLength,
bool padding, int batchSize, int inputDim) {
MatrixPtr cpuInput = std::make_shared<CpuMatrix>(batchSize, inputDim);
......
......@@ -223,6 +223,15 @@ message OperatorConfig {
optional int32 num_filters = 7;
}
message BilinearInterpConfig {
// The size of input feature map.
optional uint32 img_size_x = 1;
optional uint32 img_size_y = 2;
// The size of output feature map.
required uint32 out_size_x = 3;
required uint32 out_size_y = 4;
required uint32 num_channels = 5;
}
message ImageConfig {
// The image data dimensionality.
......@@ -245,8 +254,9 @@ message LayerInputConfig {
// If the input layer has multi-output.
// Set the argument name.
optional string input_layer_argument = 9;
optional MaxOutConfig maxout_conf = 10;
optional SppConfig spp_conf = 11;
optional BilinearInterpConfig bilinear_interp_conf = 10;
optional MaxOutConfig maxout_conf = 11;
optional SppConfig spp_conf = 12;
}
message LayerConfig {
......
......@@ -465,6 +465,7 @@ class Input(Cfg):
sparse_update=None,
gradient_clipping_threshold=None,
conv=None,
bilinear_interp=None,
norm=None,
pool=None,
image=None,
......@@ -650,7 +651,8 @@ class ConvProjection(Projection):
parse_conv(conv_conf,
input_layer_name,
self.proj_conf.conv_conf)
self.proj_conf.conv_conf,
num_filters)
# TODO: support rectangle input
self.proj_conf.output_size = (self.proj_conf.conv_conf.output_x ** 2) * num_filters
......@@ -730,7 +732,8 @@ class ConvOperator(Operator):
parse_conv(conv_conf,
MakeLayerNameInSubmodel(input_layer_names[0]),
self.operator_conf.conv_conf)
self.operator_conf.conv_conf,
num_filters)
self.operator_conf.output_size = (self.operator_conf.conv_conf.output_x ** 2) * num_filters
config_assert(len(input_layer_names) == 2, "Conv is binary operator")
......@@ -766,6 +769,16 @@ class Conv(Cfg):
if output_x is not None:
config_assert(output_x <= 0)
# please refer to the comments in proto/ModelConfig.proto
@config_class
class BilinearInterp(Cfg):
def __init__(
self,
out_size_x = None,
out_size_y = None,
num_channels = None):
self.add_keys(locals())
# please refer to the comments in proto/ModelConfig.proto
@config_class
class Pool(Cfg):
......@@ -1017,6 +1030,11 @@ def TestData(data_config, async_load_data=None):
" Data definition")
g_config.test_data_config.async_load_data = async_load_data
def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
bilinear_conf.out_size_x = bilinear.out_size_x;
bilinear_conf.out_size_y = bilinear.out_size_y;
bilinear_conf.num_channels = bilinear.num_channels;
'''
caffe_mode: compute the output size using floor instead of ceil,
which is consistent of caffe and CuDNN's convention.
......@@ -1028,6 +1046,17 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode):
else:
return 1 + int(math.ceil(output))
'''
calcualte image_size based on output_size for convolution.
It is the reverse function of cnn_output_size
'''
def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode):
if caffe_mode:
img_size = (output_size - 1) * stride + filter_size - 2 * padding
else:
img_size = (output_size - 2) * stride + filter_size - 2 * padding + 1
return img_size
def parse_pool(pool, input_layer_name, pool_conf):
pool_conf.pool_type = pool.pool_type
config_assert(pool.pool_type in ['max-projection', 'avg-projection',
......@@ -1109,7 +1138,11 @@ def parse_norm(norm, input_layer_name, norm_conf):
else:
norm_conf.scale /= norm.size ** 2
def parse_conv(conv, input_layer_name, conv_conf):
'''
caffe_mode: compute the output size using floor instead of ceil,
which is consistent of caffe and CuDNN's convention.
'''
def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
conv_conf.filter_size = conv.filter_size
conv_conf.filter_size_y = conv.filter_size_y
conv_conf.channels = conv.channels
......@@ -1118,20 +1151,37 @@ def parse_conv(conv, input_layer_name, conv_conf):
conv_conf.stride = conv.stride
conv_conf.stride_y = conv.stride_y
conv_conf.groups = conv.groups
conv_conf.filter_channels = conv.channels / conv.groups
conv_conf.caffe_mode = conv.caffe_mode
img_pixels = g_layer_map[input_layer_name].size / conv.channels
print('channels=%d size=%d'%(conv.channels,
g_layer_map[input_layer_name].size))
conv_conf.img_size = int(img_pixels ** 0.5)
config_assert((conv_conf.img_size ** 2) == img_pixels,
("Input layer %s: Incorrect input image size %d for input "
+ "image pixels %d")
% (input_layer_name, conv_conf.img_size, img_pixels))
conv_conf.output_x = cnn_output_size(conv_conf.img_size, conv_conf.filter_size,
conv_conf.padding, conv_conf.stride,
conv_conf.caffe_mode)
if not trans:
conv_conf.filter_channels = conv.channels / conv.groups
img_pixels = g_layer_map[input_layer_name].size / conv.channels
print('channels=%d size=%d'%(conv.channels,
g_layer_map[input_layer_name].size))
conv_conf.img_size = int(img_pixels ** 0.5)
config_assert((conv_conf.img_size ** 2) == img_pixels,
("Input layer %s: Incorrect input image size %d for input "
+ "image pixels %d")
% (input_layer_name, conv_conf.img_size, img_pixels))
conv_conf.output_x = cnn_output_size(
conv_conf.img_size, conv_conf.filter_size,
conv_conf.padding, conv_conf.stride, conv_conf.caffe_mode)
else:
conv_conf.filter_channels = num_filters / conv.groups
outputSize = g_layer_map[input_layer_name].size / conv.channels
print('channels=%d size=%d'%(conv.channels,
g_layer_map[input_layer_name].size))
conv_conf.output_x = int(outputSize ** 0.5)
config_assert((conv_conf.output_x ** 2) == outputSize,
("Input layer %s: Incorrect input image size %d for input "
+ "image pixels %d")
% (input_layer_name, conv_conf.output_x, outputSize))
conv_conf.img_size = cnn_image_size(
conv_conf.output_x, conv_conf.filter_size,
conv_conf.padding, conv_conf.stride, conv_conf.caffe_mode)
def parse_block_expand(block_expand, input_layer_name, block_expand_conf):
block_expand_conf.channels = block_expand.channels
......@@ -1614,7 +1664,8 @@ class ConvLayerBase(LayerBase):
parse_conv(
self.inputs[input_index].conv,
input_layer.name,
self.config.inputs[input_index].conv_conf)
self.config.inputs[input_index].conv_conf,
num_filters)
conv_conf = self.config.inputs[input_index].conv_conf
psize = self.calc_parameter_size(conv_conf)
print("output size for %s is %d " % (name, conv_conf.output_x))
......@@ -1639,6 +1690,63 @@ class ConvLayer(ConvLayerBase):
class ConvLayer(ConvLayerBase):
layer_type = 'cudnn_conv'
@config_layer('convt')
class ConvTransLayerBase(LayerBase):
layer_type = 'convt'
def __init__(
self,
name,
inputs=[],
bias=True,
num_filters=None,
shared_biases=False,
**xargs):
super(ConvTransLayerBase, self).__init__(
name, self.layer_type, 0, inputs=inputs, **xargs)
if num_filters is not None:
self.config.num_filters = num_filters
use_gpu = int(g_command_config_args.get("use_gpu", 0))
parallel_nn = int(g_command_config_args.get("parallel_nn", 0))
# cudnn_convt has not been implemented so use exconvt only
self.layer_type = "exconvt"
# need to specify layer in config
self.config.type = self.layer_type
if shared_biases is not None:
self.config.shared_biases = shared_biases
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
parse_conv(
self.inputs[input_index].conv,
input_layer.name,
self.config.inputs[input_index].conv_conf,
num_filters,
trans=True)
conv_conf = self.config.inputs[input_index].conv_conf
psize = self.calc_parameter_size(conv_conf)
print("output size for %s is %d " % (name, conv_conf.output_x))
self.create_input_parameter(input_index, psize)
self.set_layer_size(
(conv_conf.img_size ** 2) * self.config.num_filters)
psize = self.config.size
if shared_biases:
psize = self.config.num_filters
self.create_bias_parameter(bias, psize, [psize, 1])
def calc_parameter_size(self, conv_conf):
return conv_conf.channels * conv_conf.filter_channels \
* (conv_conf.filter_size * conv_conf.filter_size_y)
@config_layer('exconvt')
class ConvTransLayer(ConvTransLayerBase):
layer_type = 'exconvt'
@config_layer('norm')
class NormLayer(LayerBase):
def __init__(
......@@ -2424,6 +2532,22 @@ class InterpolationLayer(LayerBase):
config_assert(input_layer1.size == input_layer2.size,
'the two vector inputs should be of the same size')
@config_layer('bilinear_interp')
class BilinearInterpLayer(LayerBase):
def __init__(
self,
name,
inputs,
**xargs):
super(BilinearInterpLayer, self).__init__(
name, 'bilinear_interp', 0, inputs=inputs, **xargs)
input_layer = self.get_input_layer(0)
parse_bilinear(self.inputs[0].bilinear_interp,
input_layer.name,
self.config.inputs[0].bilinear_interp_conf);
conf = self.inputs[0].bilinear_interp
self.set_layer_size(conf.out_size_x * conf.out_size_y * conf.num_channels)
@config_layer('sum_to_one_norm')
class SumToOneNormLayer(LayerBase):
def __init__(
......
......@@ -40,8 +40,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'img_cmrnorm_layer', 'addto_layer',
'concat_layer', 'lstm_step_layer', 'recurrent_group',
'memory', 'StaticInput', 'expand_layer', 'scaling_layer',
'power_layer', 'interpolation_layer', 'trans_layer',
'sum_to_one_norm_layer',
'power_layer', 'interpolation_layer', 'bilinear_interp_layer',
'trans_layer', 'sum_to_one_norm_layer',
'get_output_layer', 'LayerType', 'context_projection',
'beam_search', 'maxid_layer', 'GeneratedInput', 'SubsequenceInput',
'gru_step_layer', 'recurrent_layer',
......@@ -79,6 +79,7 @@ class LayerType(object):
COSINE_SIM = 'cos'
HSIGMOID = 'hsigmoid'
CONV_LAYER = "conv"
CONVTRANS_LAYER = "convt"
POOL_LAYER = "pool"
BATCH_NORM_LAYER = 'batch_norm'
NORM_LAYER = 'norm'
......@@ -94,6 +95,7 @@ class LayerType(object):
EXPAND_LAYER = 'expand'
INTERPOLATION_LAYER = 'interpolation'
BILINEAR_INTERP_LAYER = 'bilinear_interp'
POWER_LAYER = 'power'
SCALING_LAYER = 'scaling'
TRANS_LAYER = 'trans'
......@@ -1261,6 +1263,52 @@ def interpolation_layer(input, weight, name=None, layer_attr=None):
size=input[0].size)
@wrap_name_default()
@layer_support()
def bilinear_interp_layer(input,
out_size_x=None,
out_size_y=None,
name=None,
layer_attr=None):
"""
This layer is to implement bilinear interpolation on conv layer output.
Please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation
The simple usage is:
.. code-block:: python
bilinear = bilinear_interp_layer(input=layer1, out_size_x=64, out_size_y=64)
:param input: A input layer.
:type input: LayerOutput.
:param out_size_x: bilinear interpolation output width.
:type out_size_x: int|None
:param out_size_y: bilinear interpolation output height.
:type out_size_y: int|None
:param name: The layer's name, which cna not be specified.
:type name: None|basestring
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert input.layer_type == LayerType.CONV_LAYER
assert isinstance(input.activation, LinearActivation)
assert out_size_x > 0 and out_size_y > 0
assert input.num_filters is not None
num_channels = input.num_filters
Layer(name=name,
inputs=Input(input.name,
bilinear_interp=BilinearInterp(out_size_x=out_size_x,
out_size_y=out_size_y,
num_channels=num_channels)),
type=LayerType.BILINEAR_INTERP_LAYER,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name, LayerType.BILINEAR_INTERP_LAYER, parents=[input],
num_filters=num_channels)
@wrap_name_default()
@layer_support()
def power_layer(input, weight, name=None, layer_attr=None):
......@@ -1520,7 +1568,8 @@ def img_conv_layer(input, filter_size, num_filters,
name=None, num_channels=None,
act=None, groups=1, stride=1, padding=0, bias_attr=None,
param_attr=None, shared_biases=True, layer_attr=None,
filter_size_y=None, stride_y=None, padding_y=None):
filter_size_y=None, stride_y=None, padding_y=None,
trans=False):
"""
Convolution layer for image. Paddle only support square input currently and
thus input image's width equals height.
......@@ -1528,7 +1577,14 @@ def img_conv_layer(input, filter_size, num_filters,
The details of convolution layer, please refer UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/
FeatureExtractionUsingConvolution/>`_ .
Convolution Transpose (deconv) layer for image. Paddle only support square
input currently and thus input image's width equals height.
The details of convolution transpose layer,
please refer to the following explanation and references therein
<http://datascience.stackexchange.com/questions/6107/
what-are-deconvolutional-layers/>`_ .
The num_channel means input image's channel number. It may be 1 or 3 when
input is raw pixels of image(mono or RGB), or it may be the previous layer's
num_filters * num_group.
......@@ -1578,6 +1634,8 @@ def img_conv_layer(input, filter_size, num_filters,
:type shared_biases: bool
:param layer_attr: Layer Extra Attribute.
:type layer_attr: ExtraLayerAttribute
:param trans: true if it is a convTransLayer, false if it is a convLayer
:type trans: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
......@@ -1613,6 +1671,9 @@ def img_conv_layer(input, filter_size, num_filters,
param_attr.attr["initial_std"] = init_w
param_attr.attr["initial_strategy"] = 0
param_attr.attr["initial_smart"] = False
lt = LayerType.CONVTRANS_LAYER if trans else LayerType.CONV_LAYER
Layer(
name=name,
inputs=Input(input.name, conv=Conv(
......@@ -1625,10 +1686,10 @@ def img_conv_layer(input, filter_size, num_filters,
num_filters=num_filters,
bias=ParamAttr.to_bias(bias_attr),
shared_biases=shared_biases,
type=LayerType.CONV_LAYER,
type=lt,
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
return LayerOutput(name, LayerType.CONV_LAYER, parents=[input],
return LayerOutput(name, lt, parents=[input],
activation=act, num_filters=num_filters)
......
......@@ -9,9 +9,9 @@ protostr=$PWD/protostr
configs=(test_fc layer_activations projections test_print_layer
test_sequence_pooling test_lstmemory_layer test_grumemory_layer
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers util_layers simple_rnn_layers unused_layers test_cost_layers
img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight
test_maxout test_bi_grumemory math_ops test_spp_layer)
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops)
for conf in ${configs[*]}
......
from paddle.trainer_config_helpers import *
settings(
learning_rate=1e-3,
batch_size=1000
)
img = data_layer(name='image', size=227*227)
# the parse_conv in config_parse.py is not strictly accurate when filter_size
# is not square. So here set square filter_size.
img_conv = img_conv_layer(input=img, num_channels=1, num_filters=64,
filter_size=(32, 32), padding=(1, 1), stride=(1, 1),
act=LinearActivation(), trans=True)
img_bn = batch_norm_layer(input=img_conv, act=ReluActivation())
img_norm = img_cmrnorm_layer(input=img_bn, size=32)
img_pool = img_pool_layer(input=img_conv, pool_size=32, pool_type=MaxPooling())
outputs(img_pool, img_norm)
type: "nn"
layers {
name: "image"
type: "data"
size: 51529
active_type: ""
}
layers {
name: "__conv_0__"
type: "exconvt"
size: 4194304
active_type: ""
inputs {
input_layer_name: "image"
input_parameter_name: "___conv_0__.w0"
conv_conf {
filter_size: 32
channels: 1
stride: 1
padding: 1
groups: 1
filter_channels: 64
output_x: 227
img_size: 256
caffe_mode: true
filter_size_y: 32
padding_y: 1
stride_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
num_filters: 64
shared_biases: true
}
layers {
name: "__batch_norm_0__"
type: "batch_norm"
size: 4194304
active_type: "relu"
inputs {
input_layer_name: "__conv_0__"
input_parameter_name: "___batch_norm_0__.w0"
image_conf {
channels: 64
img_size: 256
}
}
inputs {
input_layer_name: "__conv_0__"
input_parameter_name: "___batch_norm_0__.w1"
}
inputs {
input_layer_name: "__conv_0__"
input_parameter_name: "___batch_norm_0__.w2"
}
bias_parameter_name: "___batch_norm_0__.wbias"
moving_average_fraction: 0.9
}
layers {
name: "__crmnorm_0__"
type: "norm"
size: 4194304
active_type: ""
inputs {
input_layer_name: "__batch_norm_0__"
norm_conf {
norm_type: "cmrnorm-projection"
channels: 64
size: 32
scale: 0.0004
pow: 0.75
output_x: 256
img_size: 256
blocked: false
}
}
}
layers {
name: "__pool_0__"
type: "pool"
size: 3240000
active_type: ""
inputs {
input_layer_name: "__conv_0__"
pool_conf {
pool_type: "max-projection"
channels: 64
size_x: 32
stride: 1
output_x: 225
img_size: 256
padding: 0
size_y: 32
stride_y: 1
output_y: 225
img_size_y: 256
padding_y: 0
}
}
}
parameters {
name: "___conv_0__.w0"
size: 65536
initial_mean: 0.0
initial_std: 0.0441941738242
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___conv_0__.wbias"
size: 64
initial_mean: 0.0
initial_std: 0.0
dims: 64
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___batch_norm_0__.w0"
size: 64
initial_mean: 1.0
initial_std: 0.0
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___batch_norm_0__.w1"
size: 64
initial_mean: 0.0
initial_std: 0.0
dims: 1
dims: 64
initial_strategy: 0
initial_smart: false
is_static: true
is_shared: true
}
parameters {
name: "___batch_norm_0__.w2"
size: 64
initial_mean: 0.0
initial_std: 0.0
dims: 1
dims: 64
initial_strategy: 0
initial_smart: false
is_static: true
is_shared: true
}
parameters {
name: "___batch_norm_0__.wbias"
size: 64
initial_mean: 0.0
initial_std: 0.0
dims: 1
dims: 64
initial_strategy: 0
initial_smart: false
}
input_layer_names: "image"
output_layer_names: "__pool_0__"
output_layer_names: "__crmnorm_0__"
sub_models {
name: "root"
layer_names: "image"
layer_names: "__conv_0__"
layer_names: "__batch_norm_0__"
layer_names: "__crmnorm_0__"
layer_names: "__pool_0__"
input_layer_names: "image"
output_layer_names: "__pool_0__"
output_layer_names: "__crmnorm_0__"
is_recurrent_layer_group: false
}
type: "nn"
layers {
name: "data"
type: "data"
size: 2304
active_type: ""
}
layers {
name: "__conv_0__"
type: "exconv"
size: 36864
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "___conv_0__.w0"
conv_conf {
filter_size: 3
channels: 1
stride: 1
padding: 1
groups: 1
filter_channels: 1
output_x: 48
img_size: 48
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
num_filters: 16
shared_biases: true
}
layers {
name: "__bilinear_interp_layer_0__"
type: "bilinear_interp"
size: 65536
active_type: ""
inputs {
input_layer_name: "__conv_0__"
bilinear_interp_conf {
out_size_x: 64
out_size_y: 64
num_channels: 16
}
}
}
layers {
name: "__pool_0__"
type: "pool"
size: 16384
active_type: ""
inputs {
input_layer_name: "__bilinear_interp_layer_0__"
pool_conf {
pool_type: "max-projection"
channels: 4
size_x: 2
stride: 2
output_x: 64
img_size: 128
padding: 0
size_y: 2
stride_y: 2
output_y: 64
img_size_y: 128
padding_y: 0
}
}
}
layers {
name: "__fc_layer_0__"
type: "fc"
size: 384
active_type: "tanh"
inputs {
input_layer_name: "__pool_0__"
input_parameter_name: "___fc_layer_0__.w0"
}
}
parameters {
name: "___conv_0__.w0"
size: 144
initial_mean: 0.0
initial_std: 0.471404520791
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___conv_0__.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___fc_layer_0__.w0"
size: 6291456
initial_mean: 0.0
initial_std: 0.0078125
dims: 16384
dims: 384
initial_strategy: 0
initial_smart: true
}
input_layer_names: "data"
output_layer_names: "__fc_layer_0__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "__conv_0__"
layer_names: "__bilinear_interp_layer_0__"
layer_names: "__pool_0__"
layer_names: "__fc_layer_0__"
input_layer_names: "data"
output_layer_names: "__fc_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
settings(
batch_size=1000,
learning_rate=1e-5
)
data = data_layer(name='data', size=2304)
conv = img_conv_layer(input=data,
filter_size = 3,
num_channels=1,
num_filters=16,
padding=1,
act=LinearActivation(),
bias_attr=True)
bilinear = bilinear_interp_layer(input=conv,
out_size_x=64,
out_size_y=64)
pool = img_pool_layer(input=bilinear,
num_channels=4,
pool_size=2,
stride=2,
pool_type=MaxPooling())
fc = fc_layer(input=pool, size=384, bias_attr=False)
outputs(fc)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册