diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp index 1b59ed60c57fe3bbfa814befa8a63408a2621715..3eea638649e8ebfdd7efa18615977a9e1344c695 100644 --- a/paddle/gserver/layers/DeConv3DLayer.cpp +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -53,27 +53,27 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, size_t DeConv3DLayer::getSize() { CHECK_NE(inputLayers_.size(), 0UL); - outputH_.clear(); - outputW_.clear(); - outputD_.clear(); + imgSizeW_.clear(); + imgSizeH_.clear(); + imgSizeD_.clear(); N_.clear(); NOut_.clear(); size_t layerSize = 0; for (size_t i = 0; i < inputLayers_.size(); ++i) { - outputW_.push_back( - imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); - outputH_.push_back(imageSize( - imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true)); - outputD_.push_back(imageSize( - imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true)); - NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); - N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); + imgSizeW_.push_back( + imageSize(outputW_[i], filterSize_[i], padding_[i], stride_[i], true)); + imgSizeH_.push_back(imageSize( + outputH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true)); + imgSizeD_.push_back(imageSize( + outputD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true)); + NOut_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); + N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); layerSize += NOut_[i] * numFilters_; } - getOutput().setFrameHeight(outputH_[0]); - getOutput().setFrameWidth(outputW_[0]); - getOutput().setFrameDepth(outputD_[0]); + getOutput().setFrameHeight(imgSizeH_[0]); + getOutput().setFrameWidth(imgSizeW_[0]); + getOutput().setFrameDepth(imgSizeD_[0]); return layerSize; } @@ -103,9 +103,9 @@ void DeConv3DLayer::forward(PassType passType) { } colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(), numFilters_, - outputD_[i], - outputH_[i], - outputW_[i], + imgSizeD_[i], + imgSizeH_[i], + imgSizeW_[i], filterSizeZ_[i], filterSizeY_[i], filterSize_[i], @@ -144,9 +144,9 @@ void DeConv3DLayer::backward(const UpdateCallback &callback) { colBuf_->vol2Col( getOutputGrad()->getData() + n * getOutputGrad()->getStride(), numFilters_, - outputD_[i], - outputH_[i], - outputW_[i], + imgSizeD_[i], + imgSizeH_[i], + imgSizeW_[i], filterSizeZ_[i], filterSizeY_[i], filterSize_[i], diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 0e6be2df9ef5f0fae8ed2b0c65ac6c032fe45ab1..090bde7b203652e3ffb1662b8f5b8937885d2608 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2302,26 +2302,27 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { conv->set_stride(2); conv->set_stride_y(2); conv->set_stride_z(2); - conv->set_img_size(IMAGE_SIZE); - conv->set_img_size_y(IMAGE_SIZE_Y); - conv->set_img_size_z(IMAGE_SIZE_Z); - conv->set_output_x(imageSize(conv->img_size(), + conv->set_output_x(IMAGE_SIZE); + conv->set_output_y(IMAGE_SIZE_Y); + conv->set_output_z(IMAGE_SIZE_Z); + + conv->set_img_size(imageSize(conv->output_x(), conv->filter_size(), conv->padding(), conv->stride(), true)); - conv->set_output_y(imageSize(conv->img_size_y(), - conv->filter_size_y(), - conv->padding_y(), - conv->stride_y(), - true)); - conv->set_output_z(imageSize(conv->img_size_z(), - conv->filter_size_z(), - conv->padding_z(), - conv->stride_z(), - true)); - config.layerConfig.set_size(conv->output_x() * conv->output_y() * - conv->output_z() * NUM_FILTERS); + conv->set_img_size_y(imageSize(conv->output_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y(), + true)); + conv->set_img_size_z(imageSize(conv->output_z(), + conv->filter_size_z(), + conv->padding_z(), + conv->stride_z(), + true)); + config.layerConfig.set_size(conv->img_size() * conv->img_size_y() * + conv->img_size_z() * NUM_FILTERS); conv->set_groups(1); conv->set_filter_channels(conv->channels() / conv->groups()); config.inputDefs.push_back(