提交 53e1629a 编写于 作者: W wangyang59

Refactored imageSize in ConvBaseLayer to MathUtil

上级 03f4b1d4
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "ConvBaseLayer.h" #include "ConvBaseLayer.h"
#include "paddle/math/MathUtils.h"
namespace paddle { namespace paddle {
bool ConvBaseLayer::init(const LayerMap& layerMap, bool ConvBaseLayer::init(const LayerMap& layerMap,
...@@ -95,18 +96,22 @@ size_t ConvBaseLayer::calOutputSize() { ...@@ -95,18 +96,22 @@ size_t ConvBaseLayer::calOutputSize() {
if (inW[i] == 0) if (inW[i] == 0)
inW[i] = config_.inputs(i).conv_conf().output_x(); inW[i] = config_.inputs(i).conv_conf().output_x();
outH.push_back( outH.push_back(
imageSize(inH[i], filterSizeY_[i], paddingY_[i], strideY_[i])); imageSize(inH[i], filterSizeY_[i], paddingY_[i], strideY_[i],
caffeMode_));
outW.push_back( outW.push_back(
imageSize(inW[i], filterSize_[i], padding_[i], stride_[i])); imageSize(inW[i], filterSize_[i], padding_[i], stride_[i],
caffeMode_));
} else { } else {
if (inH[i] == 0) if (inH[i] == 0)
inH[i] = config_.inputs(i).conv_conf().img_size(); inH[i] = config_.inputs(i).conv_conf().img_size();
if (inW[i] == 0) if (inW[i] == 0)
inW[i] = config_.inputs(i).conv_conf().img_size(); inW[i] = config_.inputs(i).conv_conf().img_size();
outH.push_back( outH.push_back(
outputSize(inH[i], filterSizeY_[i], paddingY_[i], strideY_[i])); outputSize(inH[i], filterSizeY_[i], paddingY_[i], strideY_[i],
caffeMode_));
outW.push_back( outW.push_back(
outputSize(inW[i], filterSize_[i], padding_[i], stride_[i])); outputSize(inW[i], filterSize_[i], padding_[i], stride_[i],
caffeMode_));
} }
CHECK_EQ(outH[i], outH[0]); CHECK_EQ(outH[i], outH[0]);
CHECK_EQ(outW[i], outW[0]); CHECK_EQ(outW[i], outW[0]);
......
...@@ -91,43 +91,6 @@ public: ...@@ -91,43 +91,6 @@ public:
virtual size_t calOutputSize(); virtual size_t calOutputSize();
Weight& getWeight(int idx) { return *weights_[idx]; } Weight& getWeight(int idx) { return *weights_[idx]; }
/**
* Calculate output size based on caffeMode_.
* - input(+padding): 0123456789
* - imageSize(+padding) = 10;
* - filterSize = 3;
* - stride = 2;
* - caffeMode_ is true:
- output: (012), (234), (456), (678)
- outputSize = 4;
* - caffeMode_ is false:
* - output: (012), (234), (456), (678), (9)
* - outputSize = 5;
*/
int outputSize(int imageSize, int filterSize, int padding, int stride) {
int outputSize;
if (!caffeMode_) {
outputSize =
(imageSize - filterSize + 2 * padding + stride - 1) / stride + 1;
} else {
outputSize = (imageSize - filterSize + 2 * padding) / stride + 1;
}
CHECK_GE(outputSize, 1);
return outputSize;
}
int imageSize(int outputSize, int filterSize, int padding, int stride) {
int imageSize;
if (!caffeMode_) {
imageSize =
(outputSize - 1) * stride + filterSize - 2 * padding - stride + 1;
} else {
imageSize = (outputSize - 1) * stride + filterSize - 2 * padding;
}
CHECK_GE(imageSize, 1);
return imageSize;
}
}; };
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/trainer/Trainer.h" #include "paddle/trainer/Trainer.h"
#include "paddle/utils/GlobalConstants.h" #include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/layers/ExpandConvTransLayer.h" #include "paddle/gserver/layers/ExpandConvTransLayer.h"
#include "paddle/math/MathUtils.h"
#include "TestUtil.h" #include "TestUtil.h"
#include "LayerGradUtil.h" #include "LayerGradUtil.h"
...@@ -56,11 +57,9 @@ TEST(Layer, convTransLayerFwd) { ...@@ -56,11 +57,9 @@ TEST(Layer, convTransLayerFwd) {
conv->set_groups(1); conv->set_groups(1);
conv->set_filter_channels(3 / conv->groups()); conv->set_filter_channels(3 / conv->groups());
conv->set_img_size(16); conv->set_img_size(16);
conv->set_output_x( conv->set_output_x(outputSize(conv->img_size(), conv->filter_size(),
(2 * conv->padding() + conv->img_size() - conv->filter_size()) / conv->padding(), conv->stride(),
((float)conv->stride()) + /* caffeMode */ true));
1.5);
configt.layerConfig.set_size(conv->img_size() * conv->img_size() * configt.layerConfig.set_size(conv->img_size() * conv->img_size() *
configt.layerConfig.num_filters()); configt.layerConfig.num_filters());
configt.layerConfig.set_name("convTrans"); configt.layerConfig.set_name("convTrans");
...@@ -99,10 +98,9 @@ TEST(Layer, convTransLayerFwd) { ...@@ -99,10 +98,9 @@ TEST(Layer, convTransLayerFwd) {
conv->set_groups(1); conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups()); conv->set_filter_channels(conv->channels() / conv->groups());
conv->set_img_size(16); conv->set_img_size(16);
conv->set_output_x( conv->set_output_x(outputSize(conv->img_size(), conv->filter_size(),
(2 * conv->padding() + conv->img_size() - conv->filter_size()) / conv->padding(), conv->stride(),
((float)conv->stride()) + /* caffeMode */ true));
1.5);
config.layerConfig.set_size(conv->output_x() * conv->output_x() * config.layerConfig.set_size(conv->output_x() * conv->output_x() *
config.layerConfig.num_filters()); config.layerConfig.num_filters());
config.layerConfig.set_name("conv"); config.layerConfig.set_name("conv");
......
...@@ -336,10 +336,9 @@ void testConvTransLayer(const string& type, bool trans, bool useGpu) { ...@@ -336,10 +336,9 @@ void testConvTransLayer(const string& type, bool trans, bool useGpu) {
conv->set_groups(1); conv->set_groups(1);
conv->set_filter_channels(3 / conv->groups()); conv->set_filter_channels(3 / conv->groups());
conv->set_img_size(16); conv->set_img_size(16);
conv->set_output_x( conv->set_output_x(outputSize(conv->img_size(), conv->filter_size(),
(2 * conv->padding() + conv->img_size() - conv->filter_size()) / conv->padding(), conv->stride(),
((float)conv->stride()) + /* caffeMode */ true));
1.5);
config.layerConfig.set_size(conv->img_size() * conv->img_size() * config.layerConfig.set_size(conv->img_size() * conv->img_size() *
config.layerConfig.num_filters()); config.layerConfig.num_filters());
......
...@@ -80,4 +80,17 @@ int outputSize(int imageSize, int filterSize, int padding, int stride, ...@@ -80,4 +80,17 @@ int outputSize(int imageSize, int filterSize, int padding, int stride,
return outputSize; return outputSize;
} }
int imageSize(int outputSize, int filterSize, int padding, int stride,
bool caffeMode) {
int imageSize;
if (!caffeMode) {
imageSize =
(outputSize - 1) * stride + filterSize - 2 * padding - stride + 1;
} else {
imageSize = (outputSize - 1) * stride + filterSize - 2 * padding;
}
CHECK_GE(imageSize, 1);
return imageSize;
}
} // namespace paddle } // namespace paddle
...@@ -60,4 +60,7 @@ void sparseRand(int* major, int* minor, int nnz, int majorLen, int minorMax, ...@@ -60,4 +60,7 @@ void sparseRand(int* major, int* minor, int nnz, int majorLen, int minorMax,
int outputSize(int imageSize, int filterSize, int padding, int stride, int outputSize(int imageSize, int filterSize, int padding, int stride,
bool caffeMode); bool caffeMode);
int imageSize(int outputSize, int filterSize, int padding, int stride,
bool caffeMode);
} // namespace paddle } // namespace paddle
...@@ -1107,14 +1107,10 @@ def parse_conv(conv, input_layer_name, conv_conf, trans=False): ...@@ -1107,14 +1107,10 @@ def parse_conv(conv, input_layer_name, conv_conf, trans=False):
("Input layer %s: Incorrect input image size %d for input " ("Input layer %s: Incorrect input image size %d for input "
+ "image pixels %d") + "image pixels %d")
% (input_layer_name, conv_conf.img_size, img_pixels)) % (input_layer_name, conv_conf.img_size, img_pixels))
if conv.caffe_mode:
conv_conf.output_x = \ conv_conf.output_x = cnn_output_size(
1 + int(math.floor((2 * conv.padding + conv_conf.img_size \ conv_conf.img_size, conv_conf.filter_size,
- conv.filter_size) / float(conv.stride))) conv_conf.padding, conv_conf.stride, conv_conf.caffe_mode)
else:
conv_conf.output_x = \
1 + int(math.ceil((2 * conv.padding + conv_conf.img_size \
- conv.filter_size) / float(conv.stride)))
else: else:
outputSize = g_layer_map[input_layer_name].size / conv.channels outputSize = g_layer_map[input_layer_name].size / conv.channels
print('channels=%d size=%d'%(conv.channels, print('channels=%d size=%d'%(conv.channels,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册