提交 5e4cc241 编写于 作者: W wangyang59

Revised deconv implementations according to luotao1

上级 5fff96f5
......@@ -22,9 +22,9 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
Layer::init(layerMap, parameterMap);
if (config_.type() == "exconv" || config_.type() == "cudnn_conv") {
isConv_ = true;
isDeconv_ = false;
} else {
isConv_ = false;
isDeconv_ = true;
}
/* Initialize the convolutional layer parameter */
......
......@@ -28,8 +28,8 @@ class ConvBaseLayer : public Layer {
protected:
typedef std::vector<int> IntV;
/// True if it's convolution layer, false if it's deconv layer
bool isConv_;
/// True if it's deconv layer, false if it's convolution layer
bool isDeconv_;
/// The number of filters.
int numFilters_;
......
......@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "ExpandConvBaseLayer.h"
#include "paddle/utils/Logging.h"
#include "ConvBaseLayerCpu.h"
namespace paddle {
bool ConvBaseLayerCpu::init(const LayerMap &layerMap,
bool ExpandConvBaseLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
/* Initialize the basic convolutional parent class */
ConvBaseLayer::init(layerMap, parameterMap);
......@@ -34,10 +35,10 @@ bool ConvBaseLayerCpu::init(const LayerMap &layerMap,
/* Initialize the projection */
for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf();
nf = isConv_ ? numFilters_ : conf.channels();
nf = (!isDeconv_) ? numFilters_ : conf.channels();
subM_.push_back(nf / conf.groups());
subN_.push_back(conf.output_x() * conf.output_x());
channel = isConv_ ? conf.channels() : numFilters_;
channel = (!isDeconv_) ? conf.channels() : numFilters_;
subK_.push_back(channel * conf.filter_size() * conf.filter_size() /
conf.groups());
/* Consistent caffe mode for multiple input */
......@@ -47,11 +48,11 @@ bool ConvBaseLayerCpu::init(const LayerMap &layerMap,
return true;
}
void ConvBaseLayerCpu::resetExpandInput(size_t height, size_t width) {
void ExpandConvBaseLayer::resetExpandInput(size_t height, size_t width) {
Matrix::resizeOrCreate(expandInput_, height, width, false, useGpu_);
}
void ConvBaseLayerCpu::addSharedBias() {
void ExpandConvBaseLayer::addSharedBias() {
size_t mapW = getSize() / numFilters_;
size_t mapH = getOutputValue()->getElementCnt() / mapW;
MatrixPtr out =
......@@ -75,7 +76,7 @@ void ConvBaseLayerCpu::addSharedBias() {
bias->clear();
}
void ConvBaseLayerCpu::addUnsharedBias() {
void ExpandConvBaseLayer::addUnsharedBias() {
MatrixPtr outValue = getOutputValue();
MatrixPtr bias =
Matrix::create(biases_->getW()->getData(), 1,
......@@ -84,9 +85,9 @@ void ConvBaseLayerCpu::addUnsharedBias() {
}
void ConvBaseLayerCpu::expandOneFrame(MatrixPtr image, size_t startIdx,
void ExpandConvBaseLayer::expandOneFrame(MatrixPtr image, size_t startIdx,
int inIdx) {
int channel = isConv_ ? channels_[inIdx] : numFilters_;
int channel = (!isDeconv_) ? channels_[inIdx] : numFilters_;
resetExpandInput(subK_[inIdx] * groups_[inIdx], subN_[inIdx]);
real *imgData = image->getData() + startIdx * image->getWidth();
......@@ -101,7 +102,7 @@ void ConvBaseLayerCpu::expandOneFrame(MatrixPtr image, size_t startIdx,
imageTmp->clear();
}
void ConvBaseLayerCpu::expandFwdOnce(MatrixPtr image, MatrixPtr out,
void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image, MatrixPtr out,
int inIdx, int startIdx) {
int subM = subM_[inIdx];
int subN = subN_[inIdx];
......@@ -109,7 +110,7 @@ void ConvBaseLayerCpu::expandFwdOnce(MatrixPtr image, MatrixPtr out,
expandOneFrame(image, startIdx, inIdx);
int nf = isConv_ ? numFilters_ : channels_[inIdx];
int nf = (!isDeconv_) ? numFilters_ : channels_[inIdx];
real *outData =
out->getData() + startIdx * subN * nf;
......@@ -132,8 +133,9 @@ void ConvBaseLayerCpu::expandFwdOnce(MatrixPtr image, MatrixPtr out,
}
}
void ConvBaseLayerCpu::bpropActs(MatrixPtr out, MatrixPtr image, int inpIdx) {
int channel = isConv_ ? channels_[inpIdx] : numFilters_;
void ExpandConvBaseLayer::bpropActs(MatrixPtr out, MatrixPtr image,
int inpIdx) {
int channel = (!isDeconv_) ? channels_[inpIdx] : numFilters_;
int subM = subM_[inpIdx];
int subN = subN_[inpIdx];
......@@ -186,7 +188,7 @@ void ConvBaseLayerCpu::bpropActs(MatrixPtr out, MatrixPtr image, int inpIdx) {
}
}
void ConvBaseLayerCpu::bpropWeights(MatrixPtr image, MatrixPtr out,
void ExpandConvBaseLayer::bpropWeights(MatrixPtr image, MatrixPtr out,
int inpIdx) {
MatrixPtr weightGrad = weights_[inpIdx]->getWGrad();
......@@ -221,7 +223,7 @@ void ConvBaseLayerCpu::bpropWeights(MatrixPtr image, MatrixPtr out,
}
}
void ConvBaseLayerCpu::bpropSharedBias(MatrixPtr biases, MatrixPtr v) {
void ExpandConvBaseLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) {
size_t mapW = getSize() / numFilters_;
size_t mapH = v->getElementCnt() / mapW;
MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_);
......@@ -234,7 +236,7 @@ void ConvBaseLayerCpu::bpropSharedBias(MatrixPtr biases, MatrixPtr v) {
biases->collectBias(*transOutValue_, 1.0f);
}
void ConvBaseLayerCpu::bpropBiases(MatrixPtr v) {
void ExpandConvBaseLayer::bpropBiases(MatrixPtr v) {
MatrixPtr biases =
Matrix::create(biases_->getWGrad()->getData(), 1,
biases_->getWGrad()->getElementCnt(), false, useGpu_);
......
......@@ -25,7 +25,7 @@ namespace paddle {
* @brief A subclass of ConvBaseLayer that is a superclass of both
* ExpandConvLayer and ExpandConvTransLayer
*/
class ConvBaseLayerCpu : public ConvBaseLayer {
class ExpandConvBaseLayer : public ConvBaseLayer {
protected:
/// For expand convolution.
/// subM_ = numFilters_ / groups_.
......@@ -43,18 +43,19 @@ protected:
/// The spatial dimensions of width of output feature map.
IntV outputW_;
/*The expandInput_ and transOutValue_ are used for CPU expand conv calc*/
/// Expand one sample at a time. shape:
/// (numChannels * filterPixels_, outputSizeH * outputSizeW)
/*The expandInput_ and transOutValue_ are used for CPU expand conv calc
* Expand one sample at a time. shape:
* (numChannels * filterPixels_, outputSizeH * outputSizeW)
* */
MatrixPtr expandInput_;
/// The transpose of output, which is an auxiliary matrix.
MatrixPtr transOutValue_;
public:
explicit ConvBaseLayerCpu(const LayerConfig& config)
explicit ExpandConvBaseLayer(const LayerConfig& config)
: ConvBaseLayer(config) {}
~ConvBaseLayerCpu() {}
~ExpandConvBaseLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
......
......@@ -24,7 +24,7 @@ REGISTER_LAYER(exconv, ExpandConvLayer);
bool ExpandConvLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
/* Initialize the basic convolutional parent class */
ConvBaseLayerCpu::init(layerMap, parameterMap);
ExpandConvBaseLayer::init(layerMap, parameterMap);
return true;
}
......@@ -49,16 +49,17 @@ void ExpandConvLayer::forward(PassType passType) {
resetOutput(batchSize, getOutputSize());
MatrixPtr image = nullptr;
for (size_t i = 0; i != inputLayers_.size(); ++i) {
MatrixPtr outV = getOutputValue();
for (size_t i = 0; i < inputLayers_.size(); ++i) {
LayerPtr prevLayer = getPrev(i);
image = prevLayer->getOutputValue();
for (size_t off = 0; off < image->getHeight(); off++) {
REGISTER_TIMER_INFO("expandFwdOnce", getName().c_str());
expandFwdOnce(image, getOutputValue(), i, off);
expandFwdOnce(image, outV, i, off);
}
}
/* add the bias-vector */
if (biases_.get() != NULL) {
if (biases_.get()) {
if (sharedBiases_) {
addSharedBias();
} else {
......@@ -81,9 +82,9 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) {
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
for (size_t i = 0; i < inputLayers_.size(); ++i) {
/* First, calculate the input layers error */
if (NULL != getPrev(i)->getOutputGrad()) {
if (getPrev(i)->getOutputGrad()) {
bpropActs(outGrad, getPrev(i)->getOutputGrad(), i);
}
if (weights_[i]->getWGrad()) {
......
......@@ -15,9 +15,9 @@ limitations under the License. */
#pragma once
#include "ConvBaseLayerCpu.h"
#include "paddle/math/Matrix.h"
#include <vector>
#include "ExpandConvBaseLayer.h"
namespace paddle {
......@@ -29,10 +29,10 @@ namespace paddle {
* The config file api is img_conv_layer.
*/
class ExpandConvLayer : public ConvBaseLayerCpu {
class ExpandConvLayer : public ExpandConvBaseLayer {
public:
explicit ExpandConvLayer(const LayerConfig& config) :
ConvBaseLayerCpu(config) {}
ExpandConvBaseLayer(config) {}
~ExpandConvLayer() {}
......
......@@ -29,7 +29,7 @@ REGISTER_LAYER(exconvt, ExpandConvTransLayer);
bool ExpandConvTransLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
/* Initialize the basic convolutional parent class */
ConvBaseLayerCpu::init(layerMap, parameterMap);
ExpandConvBaseLayer::init(layerMap, parameterMap);
return true;
}
......@@ -72,7 +72,7 @@ void ExpandConvTransLayer::forward(PassType passType) {
resetOutput(batchSize, getSize());
MatrixPtr output = nullptr;
for (size_t i = 0; i != inputLayers_.size(); ++i) {
for (size_t i = 0; i < inputLayers_.size(); ++i) {
LayerPtr prevLayer = getPrev(i);
output = prevLayer->getOutputValue();
REGISTER_TIMER_INFO("shrinkFwd", getName().c_str());
......@@ -80,7 +80,7 @@ void ExpandConvTransLayer::forward(PassType passType) {
}
/* add the bias-vector */
if (biases_.get() != NULL) {
if (biases_.get()) {
if (sharedBiases_) {
addSharedBias();
} else {
......@@ -102,10 +102,10 @@ void ExpandConvTransLayer::backward(const UpdateCallback &callback) {
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
for (size_t i = 0; i < inputLayers_.size(); ++i) {
/* First, calculate the input layers error */
for (size_t off = 0; off < imageGrad->getHeight(); off++) {
if (NULL != getPrev(i)->getOutputGrad()) {
if (getPrev(i)->getOutputGrad()) {
expandFwdOnce(imageGrad, getPrev(i)->getOutputGrad(), i, off);
}
}
......
......@@ -15,9 +15,9 @@ limitations under the License. */
#pragma once
#include "ConvBaseLayerCpu.h"
#include "paddle/math/Matrix.h"
#include <vector>
#include "ExpandConvBaseLayer.h"
namespace paddle {
......@@ -28,10 +28,10 @@ namespace paddle {
*
* The config file api is img_convTrans_layer.
*/
class ExpandConvTransLayer : public ConvBaseLayerCpu {
class ExpandConvTransLayer : public ExpandConvBaseLayer {
public:
explicit ExpandConvTransLayer(const LayerConfig& config) :
ConvBaseLayerCpu(config) {}
ExpandConvBaseLayer(config) {}
~ExpandConvTransLayer() {}
......
......@@ -1107,7 +1107,7 @@ def parse_conv(conv, input_layer_name, conv_conf):
conv_conf.caffe_mode)
def parse_convt(conv, input_layer_name, conv_conf, num_filters):
def parse_conv_trans(conv, input_layer_name, conv_conf, num_filters):
conv_conf.filter_size = conv.filter_size
conv_conf.filter_size_y = conv.filter_size_y
conv_conf.channels = conv.channels
......@@ -1683,7 +1683,7 @@ class ConvTransLayerBase(LayerBase):
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
parse_convt(
parse_conv_trans(
self.inputs[input_index].conv,
input_layer.name,
self.config.inputs[input_index].conv_conf, num_filters)
......
......@@ -1515,7 +1515,8 @@ def img_conv_layer(input, filter_size, num_filters,
name=None, num_channels=None,
act=None, groups=1, stride=1, padding=0, bias_attr=None,
param_attr=None, shared_biases=True, layer_attr=None,
filter_size_y=None, stride_y=None, padding_y=None):
filter_size_y=None, stride_y=None, padding_y=None,
trans=False):
"""
Convolution layer for image. Paddle only support square input currently and
thus input image's width equals height.
......@@ -1523,120 +1524,7 @@ def img_conv_layer(input, filter_size, num_filters,
The details of convolution layer, please refer UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/
FeatureExtractionUsingConvolution/>`_ .
The num_channel means input image's channel number. It may be 1 or 3 when
input is raw pixels of image(mono or RGB), or it may be the previous layer's
num_filters * num_group.
There are several group of filter in PaddlePaddle implementation.
Each group will process some channel of the inputs. For example, if an input
num_channel = 256, group = 4, num_filter=32, the PaddlePaddle will create
32*4 = 128 filters to process inputs. The channels will be split into 4
pieces. First 256/4 = 64 channels will process by first 32 filters. The
rest channels will be processed by rest group of filters.
:param name: Layer name.
:type name: basestring
:param input: Layer Input.
:type input: LayerOutput
:param filter_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension.
:type filter_size: int|tuple|list
:param filter_size_y: The y dimension of a filter kernel. Since PaddlePaddle
currently supports rectangular filters, the filter's
shape will be (filter_size, filter_size_y).
:type filter_size_y: int|None
:param num_filters: Each filter group's number of filter
:param act: Activation type. Default is tanh
:type act: BaseActivation
:param groups: Group size of filters.
:type groups: int
:param stride: The x dimension of the stride. Or input a tuple for two image
dimension.
:type stride: int|tuple|list
:param stride_y: The y dimension of the stride.
:type stride_y: int
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension
:type padding: int|tuple|list
:param padding_y: The y dimension of the padding.
:type padding_y: int
:param bias_attr: Convolution bias attribute. None means default bias.
False means no bias.
:type bias_attr: ParameterAttribute|False
:param num_channels: number of input channels. If None will be set
automatically from previous output.
:type num_channels: int
:param param_attr: Convolution param attribute. None means default attribute
:type param_attr: ParameterAttribute
:param shared_biases: Is biases will be shared between filters or not.
:type shared_biases: bool
:param layer_attr: Layer Extra Attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
if num_channels is None:
assert input.num_filters is not None
num_channels = input.num_filters
if filter_size_y is None:
if isinstance(filter_size, collections.Sequence):
assert len(filter_size) == 2
filter_size, filter_size_y = filter_size
else:
filter_size_y = filter_size
if stride_y is None:
if isinstance(stride, collections.Sequence):
assert len(stride) == 2
stride, stride_y = stride
else:
stride_y = stride
if padding_y is None:
if isinstance(padding, collections.Sequence):
assert len(padding) == 2
padding, padding_y = padding
else:
padding_y = padding
if param_attr.attr.get('initial_smart'):
# special initial for conv layers.
init_w = (2.0 / (filter_size ** 2 * num_channels)) ** 0.5
param_attr.attr["initial_mean"] = 0.0
param_attr.attr["initial_std"] = init_w
param_attr.attr["initial_strategy"] = 0
param_attr.attr["initial_smart"] = False
Layer(
name=name,
inputs=Input(input.name, conv=Conv(
filter_size=filter_size, padding=padding, stride=stride,
channels=num_channels, groups=groups,
filter_size_y=filter_size_y, padding_y=padding_y,
stride_y=stride_y),
**param_attr.attr),
active_type=act.name,
num_filters=num_filters,
bias=ParamAttr.to_bias(bias_attr),
shared_biases=shared_biases,
type=LayerType.CONV_LAYER,
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
return LayerOutput(name, LayerType.CONV_LAYER, parents=[input],
activation=act, num_filters=num_filters)
@wrap_name_default("convt")
@wrap_param_attr_default()
@wrap_bias_attr_default()
@wrap_act_default(act=ReluActivation())
@layer_support(DROPOUT)
def img_convTrans_layer(input, filter_size, num_filters,
name=None, num_channels=None,
act=None, groups=1, stride=1, padding=0, bias_attr=None,
param_attr=None, shared_biases=True, layer_attr=None,
filter_size_y=None, stride_y=None, padding_y=None):
"""
Convolution Transpose (deconv) layer for image. Paddle only support square
input currently and thus input image's width equals height.
......@@ -1644,7 +1532,6 @@ def img_convTrans_layer(input, filter_size, num_filters,
please refer to the following explanation and references therein
<http://datascience.stackexchange.com/questions/6107/
what-are-deconvolutional-layers/>`_ .
The num_channel means input image's channel number. It may be 1 or 3 when
input is raw pixels of image(mono or RGB), or it may be the previous layer's
num_filters * num_group.
......@@ -1694,6 +1581,8 @@ def img_convTrans_layer(input, filter_size, num_filters,
:type shared_biases: bool
:param layer_attr: Layer Extra Attribute.
:type layer_attr: ExtraLayerAttribute
:param trans: true if it is a convTransLayer, false if it is a convLayer
:type trans: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
......@@ -1729,6 +1618,12 @@ def img_convTrans_layer(input, filter_size, num_filters,
param_attr.attr["initial_std"] = init_w
param_attr.attr["initial_strategy"] = 0
param_attr.attr["initial_smart"] = False
if trans:
lt = LayerType.CONVTRANS_LAYER
else:
lt = LayerType.CONV_LAYER
Layer(
name=name,
inputs=Input(input.name, conv=Conv(
......@@ -1741,14 +1636,13 @@ def img_convTrans_layer(input, filter_size, num_filters,
num_filters=num_filters,
bias=ParamAttr.to_bias(bias_attr),
shared_biases=shared_biases,
type=LayerType.CONVTRANS_LAYER,
type=lt,
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
return LayerOutput(name, LayerType.CONVTRANS_LAYER, parents=[input],
return LayerOutput(name, lt, parents=[input],
activation=act, num_filters=num_filters)
@wrap_name_default("pool")
@layer_support()
def img_pool_layer(input, pool_size, name=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册