提交 5d7f6dde 编写于 作者: C chengduoZH

Add depth dimension information to ConvBaseLayer

上级 8cc0eb9c
...@@ -21,9 +21,11 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, ...@@ -21,9 +21,11 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
/* Initialize the basic parent class */ /* Initialize the basic parent class */
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv") isDeconv_ = (config_.type() == "exconv" ||
? false config_.type() == "cudnn_conv" ||
: true; config_.type() == "conv3d" ||
config_.type() == "deconv3d" )
? false : true;
/* Initialize the convolutional layer parameter */ /* Initialize the convolutional layer parameter */
numFilters_ = config_.num_filters(); numFilters_ = config_.num_filters();
...@@ -36,7 +38,6 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, ...@@ -36,7 +38,6 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
paddingY_.push_back(conf.padding_y()); paddingY_.push_back(conf.padding_y());
strideY_.push_back(conf.stride_y()); strideY_.push_back(conf.stride_y());
filterSizeY_.push_back(conf.filter_size_y()); filterSizeY_.push_back(conf.filter_size_y());
filterPixels_.push_back(filterSize_.back() * filterSizeY_.back());
channels_.push_back(conf.channels()); channels_.push_back(conf.channels());
imgSizeH_.push_back(conf.has_img_size_y() ? conf.img_size_y() imgSizeH_.push_back(conf.has_img_size_y() ? conf.img_size_y()
: conf.img_size()); : conf.img_size());
...@@ -45,6 +46,14 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, ...@@ -45,6 +46,14 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
filterChannels_.push_back(conf.filter_channels()); filterChannels_.push_back(conf.filter_channels());
outputH_.push_back(conf.has_output_y() ? conf.output_y() : conf.output_x()); outputH_.push_back(conf.has_output_y() ? conf.output_y() : conf.output_x());
outputW_.push_back(conf.output_x()); outputW_.push_back(conf.output_x());
paddingZ_.push_back(conf.padding_z());
strideZ_.push_back(conf.stride_z());
filterSizeZ_.push_back(conf.filter_size_z());
imgSizeD_.push_back(conf.img_size_z());
outputD_.push_back(conf.output_z());
filterPixels_.push_back(
filterSize_.back() * filterSizeY_.back() * filterSizeZ_.back());
} }
CHECK(inputLayers_.size() == parameters_.size()); CHECK(inputLayers_.size() == parameters_.size());
......
...@@ -23,6 +23,7 @@ namespace paddle { ...@@ -23,6 +23,7 @@ namespace paddle {
* with learned filters and (optionally) adds biases. * with learned filters and (optionally) adds biases.
*/ */
class ConvBaseLayer : public Layer { class ConvBaseLayer : public Layer {
protected: protected:
typedef std::vector<int> IntV; typedef std::vector<int> IntV;
...@@ -58,6 +59,13 @@ protected: ...@@ -58,6 +59,13 @@ protected:
IntV outputH_; IntV outputH_;
/// The spatial dimensions of output feature map width. /// The spatial dimensions of output feature map width.
IntV outputW_; IntV outputW_;
IntV outputD_;
IntV imgSizeD_;
IntV filterSizeZ_;
IntV strideZ_;
IntV paddingZ_;
/// Group size, refer to grouped convolution in /// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the /// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels, /// filters are only connected to the first half of the input channels,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册