From d5768ebc89868431040e47e3db126263da385d70 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 18 Aug 2017 20:49:35 +0800 Subject: [PATCH] fix above comments --- paddle/cuda/include/hl_matrix.h | 58 ++++++++----- paddle/cuda/include/stub/hl_matrix_stub.h | 47 +++++++---- paddle/cuda/src/hl_cuda_matrix.cu | 84 +++++++++---------- paddle/gserver/layers/Conv3DLayer.cpp | 26 ++++-- paddle/gserver/layers/Conv3DLayer.h | 14 +--- paddle/gserver/layers/ConvBaseLayer.cpp | 26 +----- paddle/gserver/layers/ConvBaseLayer.h | 1 - paddle/gserver/layers/CudnnConvBaseLayer.cpp | 18 ++++ paddle/gserver/layers/DeConv3DLayer.cpp | 46 +++++----- paddle/gserver/layers/DeConv3DLayer.h | 44 +++++----- paddle/gserver/layers/ExpandConvBaseLayer.cpp | 21 ++++- paddle/gserver/tests/test_LayerGrad.cpp | 31 +++---- paddle/math/tests/test_matrixCompare.cpp | 28 ++----- proto/ModelConfig.proto | 4 +- 14 files changed, 247 insertions(+), 201 deletions(-) diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index da2ed8cabb7..a37921b7493 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -240,16 +240,25 @@ extern void hl_matrix_rotate( * @param[in] strideW stride in the width. * @param[in] paddingD padding in the depth. * @param[in] paddingH padding in the height. - * @param[in] paddingW padding in the width. + * @param[in] paddingW padding in the width. * @param[out] matDst output matrix. - * + * */ -extern void hl_matrix_vol2Col(real* matSrc, - int channel, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real* matDst); +extern void hl_matrix_vol2Col(const real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real* dataDst); /** * @brief Matrix col2Vol: Convert col matrix into 3D volume @@ -267,19 +276,28 @@ extern void hl_matrix_vol2Col(real* matSrc, * @param[in] strideW stride in the width. * @param[in] paddingD padding in the depth. * @param[in] paddingH padding in the height. - * @param[in] paddingW padding in the width. + * @param[in] paddingW padding in the width. * @param[in] matSrc input matrix. - * @param[in] beta input - * @param[in] alpha input - * + * @param[in] beta input + * @param[in] alpha input + * */ -extern void hl_matrix_col2Vol(real* matDst, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real* matSrc, - real alpha, real beta); - +extern void hl_matrix_col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + const real* dataSrc, + real alpha, + real beta); #endif /* HL_MATRIX_H_ */ diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index 0b73777812a..6ac332945c8 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -99,19 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d, inline void hl_matrix_rotate( real* mat, real* matRot, int dimM, int dimN, bool clockWise) {} -inline void hl_matrix_vol2Col(real* data, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real* data_col) {} - -inline void hl_matrix_col2Vol(real* data, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real* data_Im, - real alpha, real beta) {} +inline void hl_matrix_vol2Col(const real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real* dataDst) {} + +inline void hl_matrix_col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + const real* dataSrc, + real alpha, + real beta) {} #endif // HL_MATRIX_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 3bf1b0251f3..b41a3a1e06d 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -594,7 +594,7 @@ void hl_matrix_rotate( } __global__ void keMatrixVol2Col(int num_kernels, - real* dataSrc, + const real* dataSrc, real* dataDst, int depth, int height, @@ -643,7 +643,7 @@ __global__ void keMatrixVol2Col(int num_kernels, } } -void hl_matrix_vol2Col(real* dataSrc, +void hl_matrix_vol2Col(const real* dataSrc, int channels, int depth, int height, @@ -666,30 +666,30 @@ void hl_matrix_vol2Col(real* dataSrc, const int threads = 512; const int blocks = DIVUP(num_kernels, threads); - keMatrixVol2Col<<>>(num_kernels, - dataSrc, - dataDst, - depth, - height, - width, - filterD, - filterH, - filterW, - strideD, - strideH, - strideW, - paddingD, - paddingH, - paddingW, - depth_col, - height_col, - width_col); + keMatrixVol2Col<<>>(num_kernels, + dataSrc, + dataDst, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + depth_col, + height_col, + width_col); CHECK_SYNC("hl_matrix_vol2Col failed"); } __global__ void keMatrixCol2Vol(int num_kernels, real* dataDst, - real* dataSrc, + const real* dataSrc, int depth, int height, int width, @@ -759,7 +759,7 @@ void hl_matrix_col2Vol(real* dataDst, int paddingD, int paddingH, int paddingW, - real* dataSrc, + const real* dataSrc, real alpha, real beta) { int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; @@ -770,26 +770,26 @@ void hl_matrix_col2Vol(real* dataDst, const int threads = 512; const int blocks = DIVUP(num_kernels, threads); - keMatrixCol2Vol<<>>(num_kernels, - dataDst, - dataSrc, - depth, - height, - width, - filterD, - filterH, - filterW, - strideD, - strideH, - strideW, - paddingD, - paddingH, - paddingW, - depth_col, - height_col, - width_col, - alpha, - beta); + keMatrixCol2Vol<<>>(num_kernels, + dataDst, + dataSrc, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + depth_col, + height_col, + width_col, + alpha, + beta); CHECK_SYNC("hl_matrix_col2Vol failed"); } diff --git a/paddle/gserver/layers/Conv3DLayer.cpp b/paddle/gserver/layers/Conv3DLayer.cpp index 106909824df..db907bbab1c 100644 --- a/paddle/gserver/layers/Conv3DLayer.cpp +++ b/paddle/gserver/layers/Conv3DLayer.cpp @@ -28,16 +28,26 @@ bool Conv3DLayer::init(const LayerMap &layerMap, const ConvConfig &conf = inputConfig.conv_conf(); M_.push_back(numFilters_ / conf.groups()); K_.push_back(filterPixels_[index] * filterChannels_[index]); - if (nullptr != weights_[index]->getW()) - weights_[index]->getW()->reshape(weights_[index]->getW()->getWidth(), - weights_[index]->getW()->getHeight()); - if (nullptr != weights_[index]->getWGrad()) - weights_[index]->getWGrad()->reshape( - weights_[index]->getWGrad()->getWidth(), - weights_[index]->getWGrad()->getHeight()); + + // create a new weight + size_t height, width; + width = filterPixels_[index] * filterChannels_[index]; + height = numFilters_; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); ++index; } - CHECK(inputLayers_.size() == parameters_.size()); + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + } + } return true; } diff --git a/paddle/gserver/layers/Conv3DLayer.h b/paddle/gserver/layers/Conv3DLayer.h index 703671e5d0d..b622508d0ce 100644 --- a/paddle/gserver/layers/Conv3DLayer.h +++ b/paddle/gserver/layers/Conv3DLayer.h @@ -12,13 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once - +#include #include "ConvBaseLayer.h" -#include "paddle/math/Matrix.h" #include "paddle/math/MathUtils.h" -#include +#include "paddle/math/Matrix.h" namespace paddle { @@ -30,21 +28,17 @@ namespace paddle { class Conv3DLayer : public ConvBaseLayer { public: explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} - ~Conv3DLayer() {} - bool init(const LayerMap &layerMap, const ParameterMap ¶meterMap); - - size_t getSize(); + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void addBias(); - void backward(const UpdateCallback& callback); - void bpropBiases(); void bpropData(int i); void bpropWeights(int i); + size_t getSize(); protected: // Figure out the dimensions for individual gemms. diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index 6bcbe0ddb2d..8c637eaec93 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -21,8 +21,7 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { /* Initialize the basic parent class */ Layer::init(layerMap, parameterMap); - isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv" || - config_.type() == "conv3d" || config_.type() == "deconv3d") + isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv") ? false : true; @@ -56,28 +55,9 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, } CHECK(inputLayers_.size() == parameters_.size()); - for (size_t i = 0; i < inputLayers_.size(); i++) { - size_t height, width; - height = filterPixels_[i] * filterChannels_[i]; - width = (!isDeconv_) ? numFilters_ : channels_[i]; - - // create a new weight - CHECK_EQ(parameters_[i]->getSize(), width * height); - Weight* w = new Weight(height, width, parameters_[i]); - weights_.emplace_back(w); - } - /* initialize the biases_ */ - if (biasParameter_.get()) { - if (sharedBiases_) { - CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); - biases_ = - std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); - } else { - biases_ = - std::unique_ptr(new Weight(1, getSize(), biasParameter_)); - } - } + // create new weights_ in derived class + // create new biases_ in derived class // default caffe model caffeMode_ = true; diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h index 8d1fd989e83..629c462776d 100644 --- a/paddle/gserver/layers/ConvBaseLayer.h +++ b/paddle/gserver/layers/ConvBaseLayer.h @@ -23,7 +23,6 @@ namespace paddle { * with learned filters and (optionally) adds biases. */ - class ConvBaseLayer : public Layer { protected: typedef std::vector IntV; diff --git a/paddle/gserver/layers/CudnnConvBaseLayer.cpp b/paddle/gserver/layers/CudnnConvBaseLayer.cpp index c056bbe4d1d..9e954615cdd 100644 --- a/paddle/gserver/layers/CudnnConvBaseLayer.cpp +++ b/paddle/gserver/layers/CudnnConvBaseLayer.cpp @@ -46,8 +46,26 @@ bool CudnnConvBaseLayer::init(const LayerMap &layerMap, projConf_.emplace_back(conf); projections_.emplace_back( Projection::create(*projConf_[i], parameters_[i], useGpu_)); + + // create a new weight + size_t height, width; + height = filterPixels_[i] * filterChannels_[i]; + width = (!isDeconv_) ? numFilters_ : channels_[i]; + CHECK_EQ(parameters_[i]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[i]); + weights_.emplace_back(w); } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); + } + } if (biases_.get() && sharedBiases_) { hl_create_tensor_descriptor(&biasDesc_); hl_create_tensor_descriptor(&outputDesc_); diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp index 5a54a684471..b18c06e36c8 100644 --- a/paddle/gserver/layers/DeConv3DLayer.cpp +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -20,9 +20,6 @@ namespace paddle { REGISTER_LAYER(deconv3d, DeConv3DLayer); -#define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \ - (((IN_SIZE)-1) * (STRID)-2 * (PAD) + (KSIZE)) - bool DeConv3DLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; @@ -32,14 +29,25 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, for (int index = 0; index < config_.inputs().size(); ++index) { M_.push_back(filterChannels_[index]); K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index])); - if (weights_[index]->getW()) - weights_[index]->getW()->reshape(filterPixels_[index] * numFilters_, - filterChannels_[index]); - if (weights_[index]->getWGrad()) - weights_[index]->getWGrad()->reshape(filterPixels_[index] * numFilters_, - filterChannels_[index]); + + // create a new weight + size_t height, width; + height = filterPixels_[index] * numFilters_; + width = filterChannels_[index]; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + } } - CHECK(inputLayers_.size() == parameters_.size()); return true; } @@ -52,22 +60,22 @@ size_t DeConv3DLayer::getSize() { outputW_.clear(); outputD_.clear(); N_.clear(); - No_.clear(); + NOut_.clear(); size_t layerSize = 0; for (size_t i = 0; i < inputLayers_.size(); ++i) { // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); - outputW_.push_back(DECONV_OUTPUT_SIZE( - imgSizeW_[i], stride_[i], padding_[i], filterSize_[i])); - outputH_.push_back(DECONV_OUTPUT_SIZE( - imgSizeH_[i], strideY_[i], paddingY_[i], filterSizeY_[i])); - outputD_.push_back(DECONV_OUTPUT_SIZE( - imgSizeD_[i], strideZ_[i], paddingZ_[i], filterSizeZ_[i])); - No_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); + outputW_.push_back( + imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); + outputH_.push_back(imageSize( + imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true)); + outputD_.push_back(imageSize( + imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true)); + NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); - layerSize += No_[i] * numFilters_; + layerSize += NOut_[i] * numFilters_; } getOutput().setFrameHeight(outputH_[0]); getOutput().setFrameWidth(outputW_[0]); diff --git a/paddle/gserver/layers/DeConv3DLayer.h b/paddle/gserver/layers/DeConv3DLayer.h index 435807fe5de..a2a3d3f8273 100644 --- a/paddle/gserver/layers/DeConv3DLayer.h +++ b/paddle/gserver/layers/DeConv3DLayer.h @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once +#include #include "ConvBaseLayer.h" -#include "paddle/math/Matrix.h" #include "paddle/math/MathUtils.h" -#include +#include "paddle/math/Matrix.h" namespace paddle { @@ -29,30 +28,25 @@ namespace paddle { */ class DeConv3DLayer : public ConvBaseLayer { public: - explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} - - ~DeConv3DLayer() {} - - bool init(const LayerMap &layerMap, const ParameterMap ¶meterMap); - - size_t getSize(); - - void forward(PassType passType); - void addBias(); - - void backward(const UpdateCallback& callback); - - void bpropBiases(); - void bpropData(int i); - void bpropWeights(int i); + explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} + ~DeConv3DLayer() {} + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + + void forward(PassType passType); + void addBias(); + void backward(const UpdateCallback& callback); + void bpropBiases(); + void bpropData(int i); + void bpropWeights(int i); + size_t getSize(); protected: - // Figure out the dimensions for individual gemms. - IntV M_; /// numFilters_ / filter_group_; - IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ - IntV K_; /// outputD_ * outputH_ * outputW_ - IntV No_; - MatrixPtr colBuf_; + // Figure out the dimensions for individual gemms. + IntV M_; /// numFilters_ / filter_group_; + IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ + IntV K_; /// outputD_ * outputH_ * outputW_ + IntV NOut_; + MatrixPtr colBuf_; }; } // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.cpp b/paddle/gserver/layers/ExpandConvBaseLayer.cpp index 77736e78f93..2b7bef0a757 100644 --- a/paddle/gserver/layers/ExpandConvBaseLayer.cpp +++ b/paddle/gserver/layers/ExpandConvBaseLayer.cpp @@ -22,12 +22,31 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap, /* Initialize the basic convolutional parent class */ ConvBaseLayer::init(layerMap, parameterMap); + int index = 0; for (auto &inputConfig : config_.inputs()) { const ConvConfig &conf = inputConfig.conv_conf(); /* Consistent caffe mode for multiple input */ caffeMode_ = conf.caffe_mode(); - } + // create a new weight + size_t height, width; + height = filterPixels_[index] * filterChannels_[index]; + width = (!isDeconv_) ? numFilters_ : channels_[index]; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + index++; + } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); + } + } getOutputSize(); return true; diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 1e80e2c0ee0..d5724293bf8 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2019,7 +2019,7 @@ void test3DConvLayer(const string& type, bool trans, bool useGpu) { const int CHANNELS = 3; const int IMAGE_SIZE = 9; const int IMAGE_SIZE_Y = 9; - const int IMAGE_SIZE_Z = 9; // 2, 3, 5, 5, 5 + const int IMAGE_SIZE_Z = 9; TestConfig config; config.biasSize = NUM_FILTERS; @@ -2084,10 +2084,6 @@ TEST(Layer, test3DConvLayer) { #endif } -int deConvOutputSize(int inSize, int kSize, int pad, int stride) { - return (inSize - 1) * stride - 2 * pad + kSize; -} - void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { // filter size const int NUM_FILTERS = 6; @@ -2126,16 +2122,21 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { conv->set_img_size(IMAGE_SIZE); conv->set_img_size_y(IMAGE_SIZE_Y); conv->set_img_size_z(IMAGE_SIZE_Z); - conv->set_output_x(deConvOutputSize( - conv->img_size(), conv->filter_size(), conv->padding(), conv->stride())); - conv->set_output_y(deConvOutputSize(conv->img_size_y(), - conv->filter_size_y(), - conv->padding_y(), - conv->stride_y())); - conv->set_output_z(deConvOutputSize(conv->img_size_z(), - conv->filter_size_z(), - conv->padding_z(), - conv->stride_z())); + conv->set_output_x(imageSize(conv->img_size(), + conv->filter_size(), + conv->padding(), + conv->stride(), + true)); + conv->set_output_y(imageSize(conv->img_size_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y(), + true)); + conv->set_output_z(imageSize(conv->img_size_z(), + conv->filter_size_z(), + conv->padding_z(), + conv->stride_z(), + true)); config.layerConfig.set_size(conv->output_x() * conv->output_y() * conv->output_z() * NUM_FILTERS); conv->set_groups(1); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 1d41ec08702..3abe4484dbc 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "TensorCheck.h" +#include "paddle/math/MathUtils.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" #include "paddle/testing/TestUtil.h" @@ -1203,19 +1204,6 @@ TEST(Matrix, warpCTC) { } } -int outputSizeCol2Vol( - int imageSize, int filterSize, int padding, int stride, bool caffeMode) { - int outputSize; - if (!caffeMode) { - outputSize = - (imageSize - filterSize + 2 * padding + stride - 1) / stride + 1; - } else { - outputSize = (imageSize - filterSize + 2 * padding) / stride + 1; - } - CHECK_GE(outputSize, 1); - return outputSize; -} - void testMatrixCol2Vol(int depth, int height, int width) { int channel = 3; int filterX = 3, filterY = 4, filterZ = 5; @@ -1229,9 +1217,9 @@ void testMatrixCol2Vol(int depth, int height, int width) { cpuImage->randomizeUniform(); gpuImage->copyFrom(*cpuImage); - int outD = outputSizeCol2Vol(depth, filterZ, padZ, strideZ, true); - int outH = outputSizeCol2Vol(height, filterY, padZ, strideY, true); - int outW = outputSizeCol2Vol(width, filterX, padZ, strideX, true); + int outD = outputSize(depth, filterZ, padZ, strideZ, true); + int outH = outputSize(height, filterY, padY, strideY, true); + int outW = outputSize(width, filterX, padX, strideX, true); int colBufHeight = channel * filterZ * filterY * filterX; int colBufWidth = outD * outH * outW; @@ -1305,11 +1293,9 @@ void testMatrixCol2Vol(int depth, int height, int width) { } TEST(Matrix, col2Vol) { - for (auto depth : {9, 16, 64, 128}) { - for (auto height : {9, 11, 73, 128, 256}) { - for (auto width : { - 9, 32, 100, 512, - }) { + for (auto depth : {9, 16, 64}) { + for (auto height : {9, 11, 128}) { + for (auto width : {9, 32, 128}) { VLOG(3) << "depth=" << depth << " height=" << height << " width=" << width; testMatrixCol2Vol(depth, height, width); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 043ae502b02..8c6eb5b7e17 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -82,7 +82,7 @@ message ConvConfig { // if not set, use img_size optional uint32 img_size_y = 14; - + optional uint32 filter_size_z = 15 [ default = 1 ]; optional uint32 padding_z = 16 [ default = 1 ]; optional uint32 stride_z = 17 [ default = 1 ]; @@ -637,4 +637,4 @@ message ModelConfig { // For External Machine, defining how to split a neural network // into multiple parts. optional ExternalConfig external_config = 9; -}; \ No newline at end of file +}; -- GitLab