提交 b25ee3ae 编写于 作者: W wanghaoshuang

Fix ConvTransProjection bug.

1. Make ConvTransProjection support for dilation
2. Fix err config in Projection.conv unitest while deConv=true
上级 d89061c3
...@@ -24,13 +24,13 @@ size_t ConvTransProjection::calOutputSize() { ...@@ -24,13 +24,13 @@ size_t ConvTransProjection::calOutputSize() {
if (outputH_ == 0) outputH_ = configOutH_; if (outputH_ == 0) outputH_ = configOutH_;
if (outputW_ == 0) outputW_ = configOutW_; if (outputW_ == 0) outputW_ = configOutW_;
imageH_ = imageSize(outputH_, imageH_ = imageSize(outputH_,
filterH_, (filterH_ - 1) * dilationH_ + 1,
paddingH_, paddingH_,
strideH_, strideH_,
/* caffeMode */ true); /* caffeMode */ true);
imageW_ = imageSize(outputW_, imageW_ = imageSize(outputW_,
filterW_, (filterW_ - 1) * dilationW_ + 1,
paddingW_, paddingW_,
strideW_, strideW_,
/* caffeMode */ true); /* caffeMode */ true);
......
...@@ -238,9 +238,24 @@ void testProjectionConv(size_t groups, bool isDeconv) { ...@@ -238,9 +238,24 @@ void testProjectionConv(size_t groups, bool isDeconv) {
/* caffeMode */ true); /* caffeMode */ true);
conv->set_output_x(output_x); conv->set_output_x(output_x);
conv->set_output_y(output_y); conv->set_output_y(output_y);
LOG(INFO) << "DILATION:" << DILATION << "; output_x: " << output_x
<< "; output_y: " << output_y;
if (isDeconv) { if (isDeconv) {
int deconv_image_x = imageSize(output_x,
(conv->filter_size() - 1) * DILATION + 1,
conv->padding(),
conv->stride(),
/* caffeMode */ true);
int deconv_image_y = imageSize(output_y,
(conv->filter_size_y() - 1) * DILATION + 1,
conv->padding_y(),
conv->stride_y(),
/* caffeMode */ true);
LOG(INFO) << " deconv_image_x: " << deconv_image_x
<< "; deconv_image_y: " << deconv_image_y;
conf.set_input_size(output_x * output_y * CHANNELS); conf.set_input_size(output_x * output_y * CHANNELS);
conf.set_output_size(IMAGE_SIZE * IMAGE_SIZE * NUM_FILTERS); conf.set_output_size(deconv_image_x * deconv_image_y * NUM_FILTERS);
} else { } else {
conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS); conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS);
conf.set_output_size(output_x * output_y * NUM_FILTERS); conf.set_output_size(output_x * output_y * NUM_FILTERS);
...@@ -260,11 +275,11 @@ void testProjectionConv(size_t groups, bool isDeconv) { ...@@ -260,11 +275,11 @@ void testProjectionConv(size_t groups, bool isDeconv) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TEST(Projection, conv) { TEST(Projection, conv) {
/// test ConvProjection /// test ConvProjection
testProjectionConv(1, false); // testProjectionConv(1, false);
testProjectionConv(3, false); // testProjectionConv(3, false);
/// test ConvTransProjection /// test ConvTransProjection
testProjectionConv(1, true); testProjectionConv(1, true);
testProjectionConv(3, true); // testProjectionConv(3, true);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册