diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index eb454c59c1e58cf2b4817b4cb3230b9d75e320ac..c7f25109972195fb56b9e96c4b68d952363e6338 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -224,4 +224,80 @@ extern void hl_matrix_collect_shared_bias(real* B_d, extern void hl_matrix_rotate( real* mat, real* matRot, int dimM, int dimN, bool clockWise); +/** + * @brief Matrix vol2Col: Convert 3D volume into col matrix + * + * @param[in] matSrc input matrix. + * @param[in] channel channel of matSrc. + * @param[in] depth depth of matSrc. + * @param[in] height height of matSrc. + * @param[in] width width of matSrc. + * @param[in] filterD depth of filter. + * @param[in] filterH height of filter. + * @param[in] filterW width of filter. + * @param[in] strideD stride in the depth. + * @param[in] strideH stride in the height. + * @param[in] strideW stride in the width. + * @param[in] paddingD padding in the depth. + * @param[in] paddingH padding in the height. + * @param[in] paddingW padding in the width. + * @param[out] dataDst output matrix. + * + */ +extern void hl_matrix_vol2Col(const real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real* dataDst); + +/** + * @brief Matrix col2Vol: Convert col matrix into 3D volume + * + * @param[out] matDst output matrix. + * @param[in] channel channel of matDst. + * @param[in] depth depth of matDst. + * @param[in] height height of matDst. + * @param[in] width width of matDst. + * @param[in] filterD depth of filter. + * @param[in] filterH height of filter. + * @param[in] filterW width of filter. + * @param[in] strideD stride in the depth. + * @param[in] strideH stride in the height. + * @param[in] strideW stride in the width. + * @param[in] paddingD padding in the depth. + * @param[in] paddingH padding in the height. + * @param[in] paddingW padding in the width. + * @param[in] matSrc input matrix. + * @param[in] beta input + * @param[in] alpha input + * + */ +extern void hl_matrix_col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + const real* dataSrc, + real alpha, + real beta); + #endif /* HL_MATRIX_H_ */ diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index 127cb7e27983e8ff2c1ff6ef5108b5f8c5bd6ca5..6ac332945c8f09fef23f35680ba5bb1d9ba9f4fd 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -99,4 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d, inline void hl_matrix_rotate( real* mat, real* matRot, int dimM, int dimN, bool clockWise) {} +inline void hl_matrix_vol2Col(const real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real* dataDst) {} + +inline void hl_matrix_col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + const real* dataSrc, + real alpha, + real beta) {} + #endif // HL_MATRIX_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 39272456c394adc0509e60cf5972df832f7b3424..b41a3a1e06db7b2566acef19ce430645f79d486d 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -592,3 +592,204 @@ void hl_matrix_rotate( mat, matRot, dimM, dimN, clockWise); CHECK_SYNC("hl_matrix_rotate failed"); } + +__global__ void keMatrixVol2Col(int num_kernels, + const real* dataSrc, + real* dataDst, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + int depth_col, + int height_col, + int width_col) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int d_out = (index / width_col / height_col) % depth_col; + int channel_in = index / width_col / height_col / depth_col; + int channel_out = channel_in * filterD * filterH * filterW; + int w_in = w_out * strideW - paddingW; + int h_in = h_out * strideH - paddingH; + int d_in = d_out * strideD - paddingD; + + dataDst += + ((channel_out * depth_col + d_out) * height_col + h_out) * width_col + + w_out; + dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in; + for (int k = 0; k < filterD; ++k) { + for (int i = 0; i < filterH; ++i) { + for (int j = 0; j < filterW; ++j) { + int d = d_in + k; + int h = h_in + i; + int w = w_in + j; + *dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && + w < width) + ? dataSrc[(k * height + i) * width + j] + : 0; + dataDst += depth_col * height_col * width_col; + } + } + } + } +} + +void hl_matrix_vol2Col(const real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real* dataDst) { + int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; + int height_col = (height + 2 * paddingH - filterH) / strideH + 1; + int width_col = (width + 2 * paddingW - filterW) / strideW + 1; + int num_kernels = channels * depth_col * height_col * width_col; + + const int threads = 512; + const int blocks = DIVUP(num_kernels, threads); + + keMatrixVol2Col<<>>(num_kernels, + dataSrc, + dataDst, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + depth_col, + height_col, + width_col); + CHECK_SYNC("hl_matrix_vol2Col failed"); +} + +__global__ void keMatrixCol2Vol(int num_kernels, + real* dataDst, + const real* dataSrc, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + int depth_col, + int height_col, + int width_col, + real alpha, + real beta) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + real srcVal = 0; + real dstVal = dataDst[index]; + int w = index % width + paddingW; + int h = (index / width) % height + paddingH; + int d = (index / width / height) % depth + paddingD; + int c = index / width / height / depth; + // compute the start and end of the output + int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1; + int w_col_end = min(w / strideW + 1, width_col); + int h_col_start = (h < filterH) ? 0 : (h - filterH) / strideH + 1; + int h_col_end = min(h / strideH + 1, height_col); + int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1; + int d_col_end = min(d / strideD + 1, depth_col); + + int offset = (c * filterD * filterW * filterH + d * filterW * filterH + + h * filterW + w) * + depth_col * height_col * width_col; + + int coeff_d_col = + (1 - strideD * filterW * filterH * depth_col) * height_col * width_col; + int coeff_h_col = + (1 - strideH * filterW * depth_col * height_col) * width_col; + int coeff_w_col = (1 - strideW * depth_col * height_col * width_col); + + for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col + + w_col * coeff_w_col]; + } + } + } + dataDst[index] = alpha * srcVal + beta * dstVal; + } +} + +void hl_matrix_col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + const real* dataSrc, + real alpha, + real beta) { + int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; + int height_col = (height + 2 * paddingH - filterH) / strideH + 1; + int width_col = (width + 2 * paddingW - filterW) / strideW + 1; + int num_kernels = channels * depth * height * width; + + const int threads = 512; + const int blocks = DIVUP(num_kernels, threads); + + keMatrixCol2Vol<<>>(num_kernels, + dataDst, + dataSrc, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + depth_col, + height_col, + width_col, + alpha, + beta); + + CHECK_SYNC("hl_matrix_col2Vol failed"); +} diff --git a/paddle/gserver/layers/Conv3DLayer.cpp b/paddle/gserver/layers/Conv3DLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7cc9937cce37cbbc4640fbb88312841c23b757c0 --- /dev/null +++ b/paddle/gserver/layers/Conv3DLayer.cpp @@ -0,0 +1,244 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Conv3DLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(conv3d, Conv3DLayer); + +bool Conv3DLayer::init(const LayerMap &layerMap, + const ParameterMap ¶meterMap) { + if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; + int index = 0; + for (auto &inputConfig : config_.inputs()) { + const ConvConfig &conf = inputConfig.conv_conf(); + M_.push_back(numFilters_ / conf.groups()); + K_.push_back(filterPixels_[index] * filterChannels_[index]); + + // create a new weight + size_t height, width; + width = filterPixels_[index] * filterChannels_[index]; + height = numFilters_; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + ++index; + } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + } + } + return true; +} + +size_t Conv3DLayer::getSize() { + CHECK_NE(inputLayers_.size(), 0UL); + outputH_.clear(); + outputW_.clear(); + outputD_.clear(); + N_.clear(); + size_t layerSize = 0; + for (size_t i = 0; i < inputLayers_.size(); ++i) { + outputW_.push_back(outputSize( + imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); + outputH_.push_back(outputSize( + imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true)); + outputD_.push_back(outputSize( + imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true)); + + N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); + CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); + layerSize += N_[i] * numFilters_; + } + getOutput().setFrameHeight(outputH_[0]); + getOutput().setFrameWidth(outputW_[0]); + getOutput().setFrameDepth(outputD_[0]); + return layerSize; +} + +void Conv3DLayer::forward(PassType passType) { + Layer::forward(passType); + + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + int outWidth = getSize(); + resetOutput(batchSize, outWidth); + + for (size_t i = 0; i != inputLayers_.size(); ++i) { + REGISTER_TIMER_INFO("FwdConv3D", getName().c_str()); + const MatrixPtr &inMat = getInputValue(i); + const MatrixPtr &outMat = getOutputValue(); + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + MatrixPtr wMat = weights_[i]->getW(); + for (int n = 0; n < batchSize; ++n) { + colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(), + channels_[i], + imgSizeD_[i], + imgSizeH_[i], + imgSizeW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i]); + + real *outData = outMat->getData() + n * outMat->getStride(); + MatrixPtr outMatSub = + Matrix::create(outData, groups_[i] * M, N, false, useGpu_); + for (int g = 0; g < groups_[i]; g++) { + MatrixPtr wMatSub = wMat->subMatrix(g * M, M); + MatrixPtr in = colBuf_->subMatrix(g * K, K); + MatrixPtr out = outMatSub->subMatrix(g * M, M); + out->mul(*wMatSub, *in, 1.0, 1.0); + } + } + } + if (nullptr != this->biasParameter_) { + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + this->addBias(); + } + forwardActivation(); +} + +void Conv3DLayer::backward(const UpdateCallback &callback) { + backwardActivation(); + + if (biases_ && biases_->getWGrad()) { + bpropBiases(); + biases_->getParameterPtr()->incUpdate(callback); + } + + for (size_t i = 0; i != inputLayers_.size(); ++i) { + REGISTER_TIMER_INFO("BwdConv3D", getName().c_str()); + if (weights_[i]->getWGrad()) { + bpropWeights(i); + } + if (getInputGrad(i)) { + bpropData(i); + } + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + weights_[i]->getParameterPtr()->incUpdate(callback); + } +} + +void Conv3DLayer::bpropWeights(int i) { + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + const MatrixPtr &inMat = getInputValue(i); + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + MatrixPtr wGradMat = weights_[i]->getWGrad(); + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + for (int n = 0; n < batchSize; ++n) { + colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(), + channels_[i], + imgSizeD_[i], + imgSizeH_[i], + imgSizeW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i]); + + real *outGradData = + getOutputGrad()->getData() + n * getOutputGrad()->getStride(); + MatrixPtr outGradSub = + Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr inMatSub = colBuf_->subMatrix(g * K, K); + MatrixPtr outG = outGradSub->subMatrix(g * M, M); + MatrixPtr wGradSub = wGradMat->subMatrix(g * M, M); + wGradSub->mul(*outG, *(inMatSub->getTranspose()), 1.0, 1.0); + } + } +} + +void Conv3DLayer::bpropData(int i) { + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + MatrixPtr wMat = weights_[i]->getW(); + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + for (int n = 0; n < batchSize; ++n) { + real *outGradData = + getOutputGrad()->getData() + n * getOutputGrad()->getStride(); + real *preGradData = + getInputGrad(i)->getData() + n * getInputGrad(i)->getStride(); + MatrixPtr outGradSub = + Matrix::create(outGradData, M * groups_[i], N, false, useGpu_); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr wMatSub = wMat->subMatrix(g * M, M); + MatrixPtr outG = outGradSub->subMatrix(g * M, M); + MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K); + inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0); + } + colBuf_->col2Vol(preGradData, + channels_[i], + imgSizeD_[i], + imgSizeH_[i], + imgSizeW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i], + 1.0, + 1.0); + } +} + +void Conv3DLayer::bpropBiases() { + MatrixPtr outGradMat = getOutputGrad(); + if (this->sharedBiases_) { + biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); + } else { + biases_->getWGrad()->collectBias(*outGradMat, 1.0f); + } +} + +void Conv3DLayer::addBias() { + MatrixPtr outMat = getOutputValue(); + if (this->sharedBiases_) { + outMat->addSharedBias(*(biases_->getW()), 1.0f); + } else { + outMat->addBias(*(biases_->getW()), 1.0f); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/Conv3DLayer.h b/paddle/gserver/layers/Conv3DLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..b622508d0ce1b0938c44f5c7f1371a34c86b2c1d --- /dev/null +++ b/paddle/gserver/layers/Conv3DLayer.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "ConvBaseLayer.h" +#include "paddle/math/MathUtils.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +/** + * @brief A subclass of convolution layer. + * This layer expands input and use matrix multiplication to + * calculate convolution operation. + */ +class Conv3DLayer : public ConvBaseLayer { +public: + explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} + ~Conv3DLayer() {} + + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + + void forward(PassType passType); + void addBias(); + void backward(const UpdateCallback& callback); + void bpropBiases(); + void bpropData(int i); + void bpropWeights(int i); + size_t getSize(); + +protected: + // Figure out the dimensions for individual gemms. + IntV M_; /// numFilters_ / filter_group_; + IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ + IntV K_; /// outputD_ * outputH_ * outputW_ + MatrixPtr colBuf_; +}; + +} // namespace paddle diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index a5328ef8343e1050352fc48530e041fb6ce12a8b..b848ab6bdd44f8fe81cbbf63b35a321599fd93fe 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -38,7 +38,6 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, strideY_.push_back(conf.stride_y()); dilationY_.push_back(conf.dilation_y()); filterSizeY_.push_back(conf.filter_size_y()); - filterPixels_.push_back(filterSize_.back() * filterSizeY_.back()); channels_.push_back(conf.channels()); imgSizeH_.push_back(conf.has_img_size_y() ? conf.img_size_y() : conf.img_size()); @@ -47,31 +46,20 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, filterChannels_.push_back(conf.filter_channels()); outputH_.push_back(conf.has_output_y() ? conf.output_y() : conf.output_x()); outputW_.push_back(conf.output_x()); + + paddingZ_.push_back(conf.padding_z()); + strideZ_.push_back(conf.stride_z()); + filterSizeZ_.push_back(conf.filter_size_z()); + imgSizeD_.push_back(conf.img_size_z()); + outputD_.push_back(conf.output_z()); + filterPixels_.push_back(filterSize_.back() * filterSizeY_.back() * + filterSizeZ_.back()); } CHECK(inputLayers_.size() == parameters_.size()); - for (size_t i = 0; i < inputLayers_.size(); i++) { - size_t height, width; - height = filterPixels_[i] * filterChannels_[i]; - width = (!isDeconv_) ? numFilters_ : channels_[i]; - - // create a new weight - CHECK_EQ(parameters_[i]->getSize(), width * height); - Weight* w = new Weight(height, width, parameters_[i]); - weights_.emplace_back(w); - } - /* initialize the biases_ */ - if (biasParameter_.get()) { - if (sharedBiases_) { - CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); - biases_ = - std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); - } else { - biases_ = - std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); - } - } + // create new weights_ in derived class + // create new biases_ in derived class // default caffe model caffeMode_ = true; diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h index 223bce8e296d748c8e17eb105aa67e8a1c1219b6..ccd170d9d85f573dff7340c26b2038c17a548471 100644 --- a/paddle/gserver/layers/ConvBaseLayer.h +++ b/paddle/gserver/layers/ConvBaseLayer.h @@ -62,6 +62,13 @@ protected: IntV outputH_; /// The spatial dimensions of output feature map width. IntV outputW_; + + IntV outputD_; + IntV imgSizeD_; + IntV filterSizeZ_; + IntV strideZ_; + IntV paddingZ_; + /// Group size, refer to grouped convolution in /// Alex Krizhevsky's paper: when group=2, the first half of the /// filters are only connected to the first half of the input channels, diff --git a/paddle/gserver/layers/CudnnConvBaseLayer.cpp b/paddle/gserver/layers/CudnnConvBaseLayer.cpp index c056bbe4d1d354751d4f85f8d0743cf30486c087..9e954615cddf2566ea336d1c947985fd916e8cc4 100644 --- a/paddle/gserver/layers/CudnnConvBaseLayer.cpp +++ b/paddle/gserver/layers/CudnnConvBaseLayer.cpp @@ -46,8 +46,26 @@ bool CudnnConvBaseLayer::init(const LayerMap &layerMap, projConf_.emplace_back(conf); projections_.emplace_back( Projection::create(*projConf_[i], parameters_[i], useGpu_)); + + // create a new weight + size_t height, width; + height = filterPixels_[i] * filterChannels_[i]; + width = (!isDeconv_) ? numFilters_ : channels_[i]; + CHECK_EQ(parameters_[i]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[i]); + weights_.emplace_back(w); } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); + } + } if (biases_.get() && sharedBiases_) { hl_create_tensor_descriptor(&biasDesc_); hl_create_tensor_descriptor(&outputDesc_); diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d5c772c89d260264a59f4cc4439bb8a44c605a4 --- /dev/null +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -0,0 +1,212 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "DeConv3DLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(deconv3d, DeConv3DLayer); + +bool DeConv3DLayer::init(const LayerMap &layerMap, + const ParameterMap ¶meterMap) { + if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; + // for Deconv, the dimension of Kernel is + // channel * output * depth * height * weigth + // Matrix storage format: (output * depth * height * weigth) x channel + for (int index = 0; index < config_.inputs().size(); ++index) { + M_.push_back(filterChannels_[index]); + K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index])); + + // create a new weight + size_t height, width; + height = filterPixels_[index] * numFilters_; + width = filterChannels_[index]; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + } + } + return true; +} + +size_t DeConv3DLayer::getSize() { + CHECK_NE(inputLayers_.size(), 0UL); + outputH_.clear(); + outputW_.clear(); + outputD_.clear(); + N_.clear(); + NOut_.clear(); + size_t layerSize = 0; + for (size_t i = 0; i < inputLayers_.size(); ++i) { + outputW_.push_back( + imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); + outputH_.push_back(imageSize( + imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true)); + outputD_.push_back(imageSize( + imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true)); + NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); + N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); + CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); + layerSize += NOut_[i] * numFilters_; + } + getOutput().setFrameHeight(outputH_[0]); + getOutput().setFrameWidth(outputW_[0]); + getOutput().setFrameDepth(outputD_[0]); + return layerSize; +} + +void DeConv3DLayer::forward(PassType passType) { + Layer::forward(passType); + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + int outWidth = getSize(); + resetOutput(batchSize, outWidth); + const MatrixPtr outMat = getOutputValue(); + + for (size_t i = 0; i != inputLayers_.size(); ++i) { + REGISTER_TIMER_INFO("FwdDeConv3D", getName().c_str()); + const MatrixPtr &inMat = getInputValue(i); + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + MatrixPtr wMat = weights_[i]->getW(); + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + for (int n = 0; n < batchSize; ++n) { + real *inData = inMat->getData() + n * inMat->getStride(); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_); + MatrixPtr wMatSub = wMat->subMatrix(g * K, K); + MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K); + colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0); + inData += M * N; + } + colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(), + numFilters_, + outputD_[i], + outputH_[i], + outputW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i], + 1.0, + 1.0); + } + } + if (nullptr != this->biasParameter_) { + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + this->addBias(); + } + forwardActivation(); +} + +void DeConv3DLayer::backward(const UpdateCallback &callback) { + backwardActivation(); + int batchSize = getOutputGrad()->getHeight(); + if (biases_ && biases_->getWGrad()) { + bpropBiases(); + biases_->getParameterPtr()->incUpdate(callback); + } + for (size_t i = 0; i < inputLayers_.size(); ++i) { + if (weights_[i]->getWGrad() || this->needGradient_) { + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + REGISTER_TIMER_INFO("BwdDeConv3D", getName().c_str()); + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + const MatrixPtr &inMat = getInputValue(i); + for (int n = 0; n < batchSize; ++n) { + colBuf_->vol2Col( + getOutputGrad()->getData() + n * getOutputGrad()->getStride(), + numFilters_, + outputD_[i], + outputH_[i], + outputW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i]); + if (weights_[i]->getWGrad()) { + real *inData = inMat->getData() + n * inMat->getStride(); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K); + MatrixPtr wGradMatSub = + weights_[i]->getWGrad()->subMatrix(g * K, K); + MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_); + wGradMatSub->mul( + *colBufDataSub, *(inMatSub->getTranspose()), 1.0, 1.0); + inData += M * N; + } + } + if (getInputGrad(i)) { + real *preGrad = + getInputGrad(i)->getData() + n * getInputGrad(i)->getStride(); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K); + MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K); + MatrixPtr inGradMatSub = + Matrix::create(preGrad, M, N, false, useGpu_); + inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 1.0); + preGrad += M * N; + } + } + } + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + weights_[i]->getParameterPtr()->incUpdate(callback); + } + } +} +void DeConv3DLayer::bpropWeights(int i) {} +void DeConv3DLayer::bpropData(int i) {} + +void DeConv3DLayer::bpropBiases() { + const MatrixPtr &outGradMat = getOutputGrad(); + + if (this->sharedBiases_) { + biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); + } else { + biases_->getWGrad()->collectBias(*outGradMat, 1.0f); + } +} + +void DeConv3DLayer::addBias() { + MatrixPtr outMat = getOutputValue(); + if (this->sharedBiases_) { + outMat->addSharedBias(*(biases_->getW()), 1.0f); + } else { + outMat->addBias(*(biases_->getW()), 1.0f); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/DeConv3DLayer.h b/paddle/gserver/layers/DeConv3DLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..a2a3d3f8273ed77065224c27df6f711f09f34bbc --- /dev/null +++ b/paddle/gserver/layers/DeConv3DLayer.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "ConvBaseLayer.h" +#include "paddle/math/MathUtils.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +/** + * @brief A subclass of deconvolution3D layer. + * This layer expands input and use matrix multiplication to + * calculate deconvolution3D operation. + */ +class DeConv3DLayer : public ConvBaseLayer { +public: + explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} + ~DeConv3DLayer() {} + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + + void forward(PassType passType); + void addBias(); + void backward(const UpdateCallback& callback); + void bpropBiases(); + void bpropData(int i); + void bpropWeights(int i); + size_t getSize(); + +protected: + // Figure out the dimensions for individual gemms. + IntV M_; /// numFilters_ / filter_group_; + IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ + IntV K_; /// outputD_ * outputH_ * outputW_ + IntV NOut_; + MatrixPtr colBuf_; +}; + +} // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.cpp b/paddle/gserver/layers/ExpandConvBaseLayer.cpp index 77736e78f9349c0393e1e53ac700817a70893e53..2b7bef0a757d7c706be3815c539b036b094596cf 100644 --- a/paddle/gserver/layers/ExpandConvBaseLayer.cpp +++ b/paddle/gserver/layers/ExpandConvBaseLayer.cpp @@ -22,12 +22,31 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap, /* Initialize the basic convolutional parent class */ ConvBaseLayer::init(layerMap, parameterMap); + int index = 0; for (auto &inputConfig : config_.inputs()) { const ConvConfig &conf = inputConfig.conv_conf(); /* Consistent caffe mode for multiple input */ caffeMode_ = conf.caffe_mode(); - } + // create a new weight + size_t height, width; + height = filterPixels_[index] * filterChannels_[index]; + width = (!isDeconv_) ? numFilters_ : channels_[index]; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + index++; + } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); + } + } getOutputSize(); return true; diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 93b6e3cc5bd7a87aa854052277772904d70de802..c54f3ef965f4bed730cf5e0e82130b0416cb34c7 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2047,6 +2047,159 @@ TEST(Layer, RowL2NormLayer) { } } +void test3DConvLayer(const string& type, bool trans, bool useGpu) { + // filter size + const int NUM_FILTERS = 6; + // const int CHANNELS = 3; + const int FILTER_SIZE = 3; + const int FILTER_SIZE_Y = 3; + const int FILTER_SIZE_Z = 3; + + // input image + const int CHANNELS = 3; + const int IMAGE_SIZE = 9; + const int IMAGE_SIZE_Y = 9; + const int IMAGE_SIZE_Z = 9; + + TestConfig config; + config.biasSize = NUM_FILTERS; + config.layerConfig.set_type(type); + config.layerConfig.set_num_filters(NUM_FILTERS); + config.layerConfig.set_partial_sum(1); + config.layerConfig.set_shared_biases(true); + + // Setting up conv3D-trans layer + LayerInputConfig* input = config.layerConfig.add_inputs(); + ConvConfig* conv = input->mutable_conv_conf(); + + conv->set_channels(CHANNELS); + conv->set_filter_size(FILTER_SIZE); + conv->set_filter_size_y(FILTER_SIZE_Y); + conv->set_filter_size_z(FILTER_SIZE_Z); + conv->set_padding(0); + conv->set_padding_y(0); + conv->set_padding_z(0); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_stride_z(2); + conv->set_img_size(IMAGE_SIZE); + conv->set_img_size_y(IMAGE_SIZE_Y); + conv->set_img_size_z(IMAGE_SIZE_Z); + conv->set_output_x(outputSize(conv->img_size(), + conv->filter_size(), + conv->padding(), + conv->stride(), + /* caffeMode */ true)); + conv->set_output_y(outputSize(conv->img_size_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y(), + /* caffeMode */ true)); + conv->set_output_z(outputSize(conv->img_size_z(), + conv->filter_size_z(), + conv->padding_z(), + conv->stride_z(), + /* caffeMode */ true)); + + config.layerConfig.set_size(conv->output_x() * conv->output_y() * + conv->output_z() * NUM_FILTERS); + conv->set_groups(1); + conv->set_filter_channels(conv->channels() / conv->groups()); + config.inputDefs.push_back( + {INPUT_DATA, + "layer_0", + CHANNELS * IMAGE_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_Z, + conv->filter_channels() * FILTER_SIZE * FILTER_SIZE_Y * FILTER_SIZE_Z * + NUM_FILTERS}); + + testLayerGrad(config, "conv3D", 10, trans, useGpu); + // Use small batch_size and useWeight=true to test biasGrad + testLayerGrad(config, "conv3D", 2, trans, useGpu, true, 0.02); +} + +TEST(Layer, test3DConvLayer) { + test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ false); +#ifndef PADDLE_ONLY_CPU + test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ true); +#endif +} + +void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { + // filter size + const int NUM_FILTERS = 6; + // const int CHANNELS = 3; + const int FILTER_SIZE = 3; + const int FILTER_SIZE_Y = 3; + const int FILTER_SIZE_Z = 3; + + // input image + const int CHANNELS = 3; + const int IMAGE_SIZE = 4; + const int IMAGE_SIZE_Y = 6; + const int IMAGE_SIZE_Z = 6; + + // Setting up conv-trans layer + TestConfig config; + config.biasSize = NUM_FILTERS; + config.layerConfig.set_type("deconv3d"); + config.layerConfig.set_num_filters(NUM_FILTERS); + config.layerConfig.set_partial_sum(1); + config.layerConfig.set_shared_biases(true); + + LayerInputConfig* input = config.layerConfig.add_inputs(); + ConvConfig* conv = input->mutable_conv_conf(); + + conv->set_channels(CHANNELS); + conv->set_filter_size(FILTER_SIZE); + conv->set_filter_size_y(FILTER_SIZE_Y); + conv->set_filter_size_z(FILTER_SIZE_Z); + conv->set_padding(0); + conv->set_padding_y(0); + conv->set_padding_z(0); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_stride_z(2); + conv->set_img_size(IMAGE_SIZE); + conv->set_img_size_y(IMAGE_SIZE_Y); + conv->set_img_size_z(IMAGE_SIZE_Z); + conv->set_output_x(imageSize(conv->img_size(), + conv->filter_size(), + conv->padding(), + conv->stride(), + true)); + conv->set_output_y(imageSize(conv->img_size_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y(), + true)); + conv->set_output_z(imageSize(conv->img_size_z(), + conv->filter_size_z(), + conv->padding_z(), + conv->stride_z(), + true)); + config.layerConfig.set_size(conv->output_x() * conv->output_y() * + conv->output_z() * NUM_FILTERS); + conv->set_groups(1); + conv->set_filter_channels(conv->channels() / conv->groups()); + config.inputDefs.push_back( + {INPUT_DATA, + "layer_0", + CHANNELS * IMAGE_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_Z, + conv->filter_channels() * FILTER_SIZE * FILTER_SIZE_Y * FILTER_SIZE_Z * + NUM_FILTERS}); + + testLayerGrad(config, "deconv3D", 10, trans, useGpu); + // Use small batch_size and useWeight=true to test biasGrad + testLayerGrad(config, "deconv3D", 2, trans, useGpu, true, 0.02); +} + +TEST(Layer, test3DDeConvLayer) { + test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ false); +#ifndef PADDLE_ONLY_CPU + test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ true); +#endif +} + TEST(Layer, ScaleShiftLayer) { const size_t batchSize = 16; const size_t size = 32; diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 27f7d95b752d4a423bf99fa425b10b2816575d6a..579a0f3cf32b803c1d3ac2af57517ad6490f31ef 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1389,6 +1389,72 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) { output_d, grad_d, mat_d, height_, width_); } +void GpuMatrix::vol2Col(real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW) { + hl_matrix_vol2Col(dataSrc, + channels, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + getData()); +} + +void GpuMatrix::col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta) { + hl_matrix_col2Vol(dataDst, + channels, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + getData(), + alpha, + beta); +} + /** * CpuMatrix */ @@ -3975,6 +4041,95 @@ void CpuMatrix::bilinearBackward(const Matrix& out, } } +void CpuMatrix::vol2Col(real* data, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW) { + real* outData = getData(); + int outHeight = (height + 2 * paddingH - filterH) / strideH + 1; + int outWidth = (width + 2 * paddingW - filterW) / strideW + 1; + int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1; + + int channelsCol = channels * filterD * filterH * filterW; + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterW; + int hOffset = (c / filterW) % filterH; + int dOffset = (c / filterW / filterH) % filterD; + int cIn = c / filterW / filterH / filterD; + for (int d = 0; d < outDepth; ++d) { + for (int h = 0; h < outHeight; ++h) { + for (int w = 0; w < outWidth; ++w) { + int dPad = d * strideD - paddingD + dOffset; + int hPad = h * strideH - paddingH + hOffset; + int wPad = w * strideW - paddingW + wOffset; + + if (hPad >= 0 && hPad < height && wPad >= 0 && wPad < width && + dPad >= 0 && dPad < depth) + outData[((c * outDepth + d) * outHeight + h) * outWidth + w] = + data[((cIn * depth + dPad) * height + hPad) * width + wPad]; + else + outData[((c * outDepth + d) * outHeight + h) * outWidth + w] = 0; + } + } + } + } +} + +void CpuMatrix::col2Vol(real* trg, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta) { + real* src = getData(); + int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1; + int outHeight = (height + 2 * paddingH - filterH) / strideH + 1; + int outWidth = (width + 2 * paddingW - filterW) / strideW + 1; + int channelsCol = channels * filterD * filterH * filterW; + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterW; + int hOffset = (c / filterW) % filterH; + int dOffset = (c / filterW / filterH) % filterD; + int cIm = c / filterW / filterH / filterD; + for (int d = 0; d < outDepth; ++d) { + for (int h = 0; h < outHeight; ++h) { + for (int w = 0; w < outWidth; ++w) { + int dPad = d * strideD - paddingD + dOffset; + int hPad = h * strideH - paddingH + hOffset; + int wPad = w * strideW - paddingW + wOffset; + if (hPad >= 0 && hPad < height && wPad >= 0 && wPad < width && + dPad >= 0 && dPad < depth) + trg[((cIm * depth + dPad) * height + hPad) * width + wPad] = + alpha * + src[((c * outDepth + d) * outHeight + h) * outWidth + w] + + beta * + trg[((cIm * depth + dPad) * height + hPad) * width + wPad]; + } + } + } + } +} + //////////////////////////////////////////////////////////////// // functions executed via cpu // //////////////////////////////////////////////////////////////// diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index bb802bbb2c75289a45d987b22ad41ce8b1e95c98..cc3a56f279cc6a17104c681a51f1ca907143fc44 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1039,6 +1039,42 @@ public: LOG(FATAL) << "Not implemented"; } + virtual void vol2Col(real* data, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW) { + LOG(FATAL) << "Not implemeted"; + } + + virtual void col2Vol(real* trg, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta) { + LOG(FATAL) << "Not implemeted"; + } + virtual void bilinearForward(const Matrix& in, const size_t inImgH, const size_t inImgW, @@ -1374,6 +1410,38 @@ public: const real ratioH, const real ratioW); + void vol2Col(real* data, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW); + + void col2Vol(real* trg, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta); + void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label); @@ -1715,6 +1783,38 @@ public: const real ratioH, const real ratioW); + void vol2Col(real* data, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW); + + void col2Vol(real* trg, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta); + template void operator=(const ExpressionType& expr) { TensorCpuApply(*this, expr); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index d77478f345df97b37b214b5978f51ce47c1d791c..3abe4484dbc86711c5798b0900a47d09a0d47299 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "TensorCheck.h" +#include "paddle/math/MathUtils.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" #include "paddle/testing/TestUtil.h" @@ -1203,4 +1204,105 @@ TEST(Matrix, warpCTC) { } } +void testMatrixCol2Vol(int depth, int height, int width) { + int channel = 3; + int filterX = 3, filterY = 4, filterZ = 5; + int strideX = 2, strideY = 2, strideZ = 2; + int padX = 1, padY = 1, padZ = 1; + + MatrixPtr cpuImage = + std::make_shared(channel, depth * height * width); + MatrixPtr gpuImage = + std::make_shared(channel, depth * height * width); + cpuImage->randomizeUniform(); + gpuImage->copyFrom(*cpuImage); + + int outD = outputSize(depth, filterZ, padZ, strideZ, true); + int outH = outputSize(height, filterY, padY, strideY, true); + int outW = outputSize(width, filterX, padX, strideX, true); + + int colBufHeight = channel * filterZ * filterY * filterX; + int colBufWidth = outD * outH * outW; + MatrixPtr cpuColBuf = std::make_shared(colBufHeight, colBufWidth); + MatrixPtr gpuColBuf = std::make_shared(colBufHeight, colBufWidth); + cpuColBuf->vol2Col(cpuImage->getData(), + channel, + depth, + height, + width, + filterZ, + filterY, + filterX, + strideZ, + strideY, + strideX, + padZ, + padY, + padX); + gpuColBuf->vol2Col(gpuImage->getData(), + channel, + depth, + height, + width, + filterZ, + filterY, + filterX, + strideZ, + strideY, + strideX, + padZ, + padY, + padX); + TensorCheckEqual(*cpuColBuf, *gpuColBuf); + + cpuColBuf->randomizeUniform(); + gpuColBuf->copyFrom(*cpuColBuf); + cpuColBuf->col2Vol(cpuImage->getData(), + channel, + depth, + height, + width, + filterZ, + filterY, + filterX, + strideZ, + strideY, + strideX, + padZ, + padY, + padX, + 1.0, + 1.0); + gpuColBuf->col2Vol(gpuImage->getData(), + channel, + depth, + height, + width, + filterZ, + filterY, + filterX, + strideZ, + strideY, + strideX, + padZ, + padY, + padX, + 1.0, + 1.0); + TensorCheckErr(*cpuImage, *gpuImage); +} + +TEST(Matrix, col2Vol) { + for (auto depth : {9, 16, 64}) { + for (auto height : {9, 11, 128}) { + for (auto width : {9, 32, 128}) { + VLOG(3) << "depth=" << depth << " height=" << height + << " width=" << width; + testMatrixCol2Vol(depth, height, width); + } + } + } +} +/////// + #endif diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index b0e9e740c84e6445a59b1edff7272461090c2922..8dbef0b22e7b2f14c62586f86e686356b6e9c68e 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -186,6 +186,7 @@ void Argument::resizeAndCopyFrom(const Argument& src, resizeAndCopy(strs, src.strs, useGpu, stream); frameWidth = src.frameWidth; frameHeight = src.frameHeight; + frameDepth = src.frameDepth; } int32_t Argument::resizeAndCopyFrom(const Argument& src, @@ -206,6 +207,7 @@ int32_t Argument::resizeAndCopyFrom(const Argument& src, dataId = src.dataId; frameWidth = src.frameWidth; frameHeight = src.frameHeight; + frameDepth = src.frameDepth; if (!src.sequenceStartPositions) { // non-sequence input, copy samples directly diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 38797a76f55c311070192bd307103143d67cabca..7b59199dded5b3f5d030e389d8bfcac1668fd127 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -35,6 +32,7 @@ struct Argument { strs(nullptr), frameHeight(0), frameWidth(0), + frameDepth(0), sequenceStartPositions(nullptr), subSequenceStartPositions(nullptr), cpuSequenceDims(nullptr), @@ -64,6 +62,7 @@ struct Argument { allCount = argument.allCount; frameHeight = argument.frameHeight; frameWidth = argument.frameWidth; + frameDepth = argument.frameDepth; dataId = argument.dataId; } @@ -76,6 +75,7 @@ struct Argument { // A dataBatch includes batchSize frames, one frame maybe not only vector size_t frameHeight; size_t frameWidth; + size_t frameDepth; // If NULL, each position is treated independently. // Otherwise, its size should be #NumberOfSequences + 1. @@ -136,8 +136,10 @@ struct Argument { } size_t getFrameHeight() const { return frameHeight; } size_t getFrameWidth() const { return frameWidth; } + size_t getFrameDepth() const { return frameDepth; } void setFrameHeight(size_t h) { frameHeight = h; } void setFrameWidth(size_t w) { frameWidth = w; } + void setFrameDepth(size_t d) { frameDepth = d; } int64_t getNumSequences() const { return sequenceStartPositions ? sequenceStartPositions->getSize() - 1 diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 1113d5aded1eb06fc4bd35881530264059daa0cc..36b7803f073950437c938b54e6b5677b0c359151 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -85,6 +85,12 @@ message ConvConfig { optional uint32 dilation = 15 [ default = 1 ]; optional uint32 dilation_y = 16 [ default = 1 ]; + + optional uint32 filter_size_z = 17 [ default = 1 ]; + optional uint32 padding_z = 18 [ default = 1 ]; + optional uint32 stride_z = 19 [ default = 1 ]; + optional uint32 output_z = 20 [ default = 1 ]; + optional uint32 img_size_z = 21 [ default = 1 ]; } message PoolConfig { @@ -502,6 +508,8 @@ message LayerConfig { // for HuberRegressionLoss optional double delta = 57 [ default = 1.0 ]; + + optional uint64 depth = 58 [ default = 1 ]; } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 0788e3994ebd95993de2d47dd37d1ea23fbdf694..049efe6cfc1dcfa77504f7cd6a5fbc6bf610c3f0 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -886,6 +886,36 @@ class Conv(Cfg): config_assert(output_x <= 0) +# please refer to the comments in proto/ModelConfig.proto +@config_class +class Conv3D(Cfg): + def __init__(self, + filter_size, + channels, + padding=None, + stride=None, + groups=None, + filter_channels=None, + output_x=None, + img_size=None, + caffe_mode=True, + filter_size_y=None, + padding_y=None, + stride_y=None, + filter_size_z=None, + padding_z=None, + stride_z=None): + self.add_keys(locals()) + self.filter_size_y = filter_size_y if filter_size_y else filter_size + self.filter_size_z = filter_size_z if filter_size_z else filter_size + self.padding_y = padding_y if padding_y else padding + self.padding_z = padding_z if padding_z else padding + self.stride_y = stride_y if stride_y else stride + self.stride_z = stride_z if stride_z else stride + if output_x is not None: + config_assert(output_x <= 0) + + @config_class class BilinearInterp(Cfg): def __init__(self, out_size_x=None, out_size_y=None, channels=None): @@ -1172,6 +1202,20 @@ def get_img_size(input_layer_name, channels): return img_size, img_size_y +def get_img3d_size(input_layer_name, channels): + input = g_layer_map[input_layer_name] + img_pixels = input.size / channels + img_size = input.width + img_size_y = input.height + img_size_z = input.depth + + config_assert( + img_size * img_size_y * img_size_z == img_pixels, + "Input layer %s: Incorrect input image size %d * %d * %d for input image pixels %d" + % (input_layer_name, img_size, img_size_y, img_size_z, img_pixels)) + return img_size, img_size_y, img_size_z + + def parse_bilinear(bilinear, input_layer_name, bilinear_conf): parse_image(bilinear, input_layer_name, bilinear_conf.image_conf) bilinear_conf.out_size_x = bilinear.out_size_x @@ -1282,6 +1326,50 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False): conv_conf.stride_y, conv_conf.caffe_mode) +#caffe_mode: compute the output size using floor instead of ceil, +# which is consistent of caffe and CuDNN's convention. +def parse_conv3d(conv, input_layer_name, conv_conf, num_filters, trans=False): + conv_conf.filter_size = conv.filter_size + conv_conf.filter_size_y = conv.filter_size_y + conv_conf.filter_size_z = conv.filter_size_z + conv_conf.channels = conv.channels + conv_conf.padding = conv.padding + conv_conf.padding_y = conv.padding_y + conv_conf.padding_z = conv.padding_z + conv_conf.stride = conv.stride + conv_conf.stride_y = conv.stride_y + conv_conf.stride_z = conv.stride_z + conv_conf.groups = conv.groups + conv_conf.caffe_mode = conv.caffe_mode + + if not trans: + conv_conf.filter_channels = conv.channels / conv.groups + conv_conf.img_size, conv_conf.img_size_y, conv_conf.img_size_z = \ + get_img3d_size(input_layer_name, conv.channels) + conv_conf.output_x = cnn_output_size( + conv_conf.img_size, conv_conf.filter_size, conv_conf.padding, + conv_conf.stride, conv_conf.caffe_mode) + conv_conf.output_y = cnn_output_size( + conv_conf.img_size_y, conv_conf.filter_size_y, conv_conf.padding_y, + conv_conf.stride_y, conv_conf.caffe_mode) + conv_conf.output_z = cnn_output_size( + conv_conf.img_size_z, conv_conf.filter_size_z, conv_conf.padding_z, + conv_conf.stride_z, conv_conf.caffe_mode) + else: + conv_conf.filter_channels = num_filters / conv.groups + conv_conf.output_x, conv_conf.output_y, conv_conf.output_z = \ + get_img3d_size(input_layer_name, conv.channels) + conv_conf.img_size = cnn_image_size( + conv_conf.output_x, conv_conf.filter_size, conv_conf.padding, + conv_conf.stride, conv_conf.caffe_mode) + conv_conf.img_size_y = cnn_image_size( + conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y, + conv_conf.stride_y, conv_conf.caffe_mode) + conv_conf.img_size_z = cnn_image_size( + conv_conf.output_z, conv_conf.filter_size_z, conv_conf.padding_z, + conv_conf.stride_z, conv_conf.caffe_mode) + + def parse_block_expand(block_expand, input_layer_name, block_expand_conf): block_expand_conf.channels = block_expand.channels block_expand_conf.stride_x = block_expand.stride_x @@ -1585,6 +1673,9 @@ class LayerBase(object): self.config.height = height self.config.width = width + def set_layer_depth(self, depth): + self.config.depth = depth + def set_cnn_layer(self, input_layer_name, height, @@ -1803,11 +1894,19 @@ class DetectionOutputLayer(LayerBase): @config_layer('data') class DataLayer(LayerBase): - def __init__(self, name, size, height=None, width=None, device=None): + def __init__(self, + name, + size, + height=None, + width=None, + depth=None, + device=None): super(DataLayer, self).__init__( name, 'data', size, inputs=[], device=device) if height and width: self.set_layer_height_width(height, width) + if depth: + self.set_layer_depth(depth) ''' @@ -1922,7 +2021,7 @@ class ConvLayerBase(LayerBase): def calc_parameter_size(self, conv_conf): return self.config.num_filters * conv_conf.filter_channels \ - * (conv_conf.filter_size * conv_conf.filter_size_y) + * (conv_conf.filter_size * conv_conf.filter_size_y) @config_layer('exconv') @@ -2006,6 +2105,87 @@ class ConvTransLayer(ConvTransLayerBase): layer_type = 'cudnn_convt' +@config_layer('conv_3d') +class Conv3DLayerBase(LayerBase): + def __init__(self, + name, + inputs=[], + bias=True, + num_filters=None, + shared_biases=True, + **xargs): + super(Conv3DLayerBase, self).__init__( + name, self.layer_type, 0, inputs=inputs, **xargs) + + if num_filters is not None: + self.config.num_filters = num_filters + + # need to specify layer in config + self.config.type = self.layer_type + + trans = False + if self.config.type == "deconv3d": + trans = True + + if shared_biases is not None: + self.config.shared_biases = shared_biases + + for input_index in xrange(len(self.inputs)): + input_layer = self.get_input_layer(input_index) + conv_conf = self.config.inputs[input_index].conv_conf + parse_conv3d( + self.inputs[input_index].conv, + input_layer.name, + conv_conf, + num_filters, + trans=trans + ) # for z-axis pad:0, strid:1, filter_size:1, img_size:1 + psize = self.calc_parameter_size(conv_conf) + self.create_input_parameter(input_index, psize) + if trans: + self.set_cnn_layer(name, conv_conf.img_size_z, + conv_conf.img_size_y, conv_conf.img_size, + self.config.num_filters) + else: + self.set_cnn_layer(name, conv_conf.output_z, conv_conf.output_y, + conv_conf.output_x, self.config.num_filters) + + psize = self.config.size + if shared_biases: + psize = self.config.num_filters + self.create_bias_parameter(bias, psize, [psize, 1]) + + def calc_parameter_size(self, conv_conf): + return self.config.num_filters * conv_conf.filter_channels \ + * (conv_conf.filter_size * conv_conf.filter_size_y \ + * conv_conf.filter_size_z) + + def set_cnn_layer(self, + input_layer_name, + depth, + height, + width, + channels, + is_print=True): + size = depth * height * width * channels + self.set_layer_size(size) + self.set_layer_height_width(height, width) + self.set_layer_depth(depth) + if is_print: + print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" % + (input_layer_name, channels, depth, height, width, size)) + + +@config_layer('conv3d') +class Conv3DLayer(Conv3DLayerBase): + layer_type = 'conv3d' + + +@config_layer('deconv3d') +class Conv3DLayer(Conv3DLayerBase): + layer_type = 'deconv3d' + + @config_layer('norm') class NormLayer(LayerBase): def __init__(self, name, inputs, **xargs): diff --git a/python/paddle/trainer/recurrent_units.py b/python/paddle/trainer/recurrent_units.py old mode 100755 new mode 100644 diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py old mode 100755 new mode 100644 index b2ba16333b3280a1956a91d15224f47b9edae979..3de2f2f01da3778b3ae86f22e5c39f5193c7ccce --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -139,6 +139,7 @@ __all__ = [ 'seq_slice_layer', 'kmax_sequence_score_layer', 'scale_shift_layer', + 'img_conv3d_layer', ] @@ -220,6 +221,9 @@ class LayerType(object): CRF_DECODING_LAYER = 'crf_decoding' NCE_LAYER = 'nce' + CONV3D_LAYER = 'conv3d' + DECONV3D_LAYER = 'deconv3d' + RANK_COST = 'rank-cost' LAMBDA_COST = 'lambda_cost' HUBER_REGRESSION = 'huber_regression' @@ -896,7 +900,8 @@ def mixed_layer(size=0, @layer_support() -def data_layer(name, size, height=None, width=None, layer_attr=None): +def data_layer(name, size, height=None, width=None, depth=None, + layer_attr=None): """ Define DataLayer For NeuralNetwork. @@ -923,15 +928,18 @@ def data_layer(name, size, height=None, width=None, layer_attr=None): type=LayerType.DATA, name=name, size=size, + depth=depth, height=height, width=width, **ExtraLayerAttribute.to_kwargs(layer_attr)) + if depth is None: + depth = 1 num_filters = None if height is not None and width is not None: - num_filters = size / (width * height) - assert num_filters * width * height == size, \ - "size=%s width=%s height=%s" % (size, width, height) + num_filters = size / (width * height * depth) + assert num_filters * width * height*depth == size, \ + "size=%s width=%s height=%s depth=%s" % (size, width, height, depth) return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters) @@ -6483,6 +6491,149 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size) +@wrap_name_default("conv3d") +@wrap_param_attr_default() +@wrap_bias_attr_default() +@wrap_act_default(act=ReluActivation()) +@layer_support(DROPOUT) +def img_conv3d_layer(input, + filter_size, + num_filters, + name=None, + num_channels=None, + act=None, + groups=1, + stride=1, + padding=0, + bias_attr=None, + param_attr=None, + shared_biases=True, + layer_attr=None, + trans=False, + layer_type=None): + """ + + The example usage is: + + .. code-block:: python + + conv = img_conv3d_layer(input=data, filter_size=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) + + :param name: Layer name. + :type name: basestring + :param input: Layer Input. + :type input: LayerOutput + :param filter_size: The x dimension of a filter kernel. Or input a list. + :type filter_size: int|tuple|list + :param num_filters: Each filter group's number of filter + :param act: Activation type. Default is tanh + :type act: BaseActivation + :param groups: Group size of filters. + :type groups: int + :param stride: The x dimension of the stride. Or input a tuple for two image + dimension. + :type stride: int|tuple|list + :param padding: The x dimension of the padding. Or input a tuple for two + image dimension + :type padding: int|tuple|list + :param bias_attr: Convolution bias attribute. None means default bias. + False means no bias. + :type bias_attr: ParameterAttribute|False + :param num_channels: number of input channels. If None will be set + automatically from previous output. + :type num_channels: int + :param param_attr: Convolution param attribute. None means default attribute + :type param_attr: ParameterAttribute + :param shared_biases: Is biases will be shared between filters or not. + :type shared_biases: bool + :param layer_attr: Layer Extra Attribute. + :type layer_attr: ExtraLayerAttribute + :param trans: true if it is a convTransLayer, false if it is a convLayer + :type trans: bool + :param layer_type: specify the layer_type, default is None. If trans=True, + layer_type has to be "exconvt" or "cudnn_convt", + otherwise layer_type has to be either "exconv" or + "cudnn_conv" + :type layer_type: String + :return: LayerOutput object. + :rtype: LayerOutput + """ + if num_channels is None: + assert input.num_filters is not None + num_channels = input.num_filters + + if isinstance(filter_size, collections.Sequence): + assert len(filter_size) == 3 + filter_size, filter_size_y, filter_size_z = filter_size + else: + filter_size_y = filter_size + filter_size_z = filter_size + + if isinstance(stride, collections.Sequence): + assert len(stride) == 3 + stride, stride_y, stride_z = stride + else: + stride_y = stride + stride_z = stride + + if isinstance(padding, collections.Sequence): + assert len(padding) == 3 + padding, padding_y, padding_z = padding + else: + padding_y = padding + padding_z = padding + + if param_attr.attr.get('initial_smart'): + # special initial for conv layers. + init_w = (2.0 / (filter_size**2 * num_channels))**0.5 + param_attr.attr["initial_mean"] = 0.0 + param_attr.attr["initial_std"] = init_w + param_attr.attr["initial_strategy"] = 0 + param_attr.attr["initial_smart"] = False + + if layer_type: + if trans: + assert layer_type in ["deconv3d"] + lt = layer_type + else: + lt = LayerType.DECONV3D_LAYER if trans else LayerType.CONV3D_LAYER + + l = Layer( + name=name, + inputs=Input( + input.name, + conv=Conv3D( + filter_size=filter_size, + padding=padding, + stride=stride, + channels=num_channels, + groups=groups, + filter_size_y=filter_size_y, + padding_y=padding_y, + stride_y=stride_y, + filter_size_z=filter_size_z, + padding_z=padding_z, + stride_z=stride_z), + **param_attr.attr), + active_type=act.name, + num_filters=num_filters, + bias=ParamAttr.to_bias(bias_attr), + shared_biases=shared_biases, + type=lt, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name, + lt, + parents=[input], + activation=act, + num_filters=num_filters, + size=l.config.size) + + @wrap_name_default("scale_shift") @wrap_param_attr_default() @wrap_bias_attr_default() diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py old mode 100755 new mode 100644 diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index e87ee6a432592faf5e1501118e3eb2c4bca0a717..7982d34bc5c0d4d99b0a8468ab0db86134764eab 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -9,6 +9,6 @@ test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer -test_seq_slice_layer test_cross_entropy_over_beam) +test_seq_slice_layer test_cross_entropy_over_beam test_conv3d_layer test_deconv3d_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_conv3d_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_conv3d_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..9fe2bc29d3cd06231b67102e28f7a49c28306958 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_conv3d_layer.protostr @@ -0,0 +1,132 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 36288 + active_type: "" + height: 48 + width: 42 + depth: 6 +} +layers { + name: "conv3d_1" + type: "conv3d" + size: 24192 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "_conv3d_1.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 2 + padding: 1 + groups: 1 + filter_channels: 3 + output_x: 21 + img_size: 42 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 2 + output_y: 24 + img_size_y: 48 + filter_size_z: 3 + padding_z: 1 + stride_z: 2 + output_z: 3 + img_size_z: 6 + } + } + bias_parameter_name: "_conv3d_1.wbias" + num_filters: 16 + shared_biases: true + height: 24 + width: 21 + depth: 3 +} +layers { + name: "conv3d_2" + type: "conv3d" + size: 24192 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "_conv3d_2.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 2 + padding: 1 + groups: 1 + filter_channels: 3 + output_x: 21 + img_size: 42 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 2 + output_y: 24 + img_size_y: 48 + filter_size_z: 3 + padding_z: 1 + stride_z: 2 + output_z: 3 + img_size_z: 6 + } + } + bias_parameter_name: "_conv3d_2.wbias" + num_filters: 16 + shared_biases: true + height: 24 + width: 21 + depth: 3 +} +parameters { + name: "_conv3d_1.w0" + size: 1296 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_conv3d_1.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_conv3d_2.w0" + size: 1296 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_conv3d_2.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +output_layer_names: "conv3d_2" +sub_models { + name: "root" + layer_names: "data" + layer_names: "conv3d_1" + layer_names: "conv3d_2" + input_layer_names: "data" + output_layer_names: "conv3d_2" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_deconv3d_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_deconv3d_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..7bf409731cbf8d5d98341b03c7c09d91fa8328d9 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_deconv3d_layer.protostr @@ -0,0 +1,132 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 36288 + active_type: "" + height: 48 + width: 42 + depth: 6 +} +layers { + name: "deconv3d_1" + type: "deconv3d" + size: 1387760 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "_deconv3d_1.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 2 + padding: 1 + groups: 1 + filter_channels: 16 + output_x: 42 + img_size: 83 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 2 + output_y: 48 + img_size_y: 95 + filter_size_z: 3 + padding_z: 1 + stride_z: 2 + output_z: 6 + img_size_z: 11 + } + } + bias_parameter_name: "_deconv3d_1.wbias" + num_filters: 16 + shared_biases: true + height: 95 + width: 83 + depth: 11 +} +layers { + name: "deconv3d_2" + type: "deconv3d" + size: 1387760 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "_deconv3d_2.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 2 + padding: 1 + groups: 1 + filter_channels: 16 + output_x: 42 + img_size: 83 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 2 + output_y: 48 + img_size_y: 95 + filter_size_z: 3 + padding_z: 1 + stride_z: 2 + output_z: 6 + img_size_z: 11 + } + } + bias_parameter_name: "_deconv3d_2.wbias" + num_filters: 16 + shared_biases: true + height: 95 + width: 83 + depth: 11 +} +parameters { + name: "_deconv3d_1.w0" + size: 6912 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_deconv3d_1.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_deconv3d_2.w0" + size: 6912 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_deconv3d_2.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +output_layer_names: "deconv3d_2" +sub_models { + name: "root" + layer_names: "data" + layer_names: "deconv3d_1" + layer_names: "deconv3d_2" + input_layer_names: "data" + output_layer_names: "deconv3d_2" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_conv3d_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_conv3d_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0a2c0d5fe19b6c414acd708bb6e82d9fb6568f --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_conv3d_layer.py @@ -0,0 +1,49 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +num_channels = 3 +filter_size = 3 +filter_size_y = 3 +filter_size_z = 3 +stride = 2 +stride_y = 2 +stride_z = 2 +padding = 1 +padding_y = 1 +padding_z = 1 +groups = 1 + +data = data_layer( + name='data', size=12096 * num_channels, height=48, width=42, depth=6) +# first +conv3d_1 = img_conv3d_layer( + input=data, + name='conv3d_1', + num_filters=16, + num_channels=num_channels, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + bias_attr=True, + shared_biases=True, + trans=False, + layer_type="conv3d", + act=LinearActivation()) +# second +conv3d_2 = img_conv3d_layer( + input=data, + name='conv3d_2', + num_filters=16, + num_channels=num_channels, + filter_size=[filter_size, filter_size_y, filter_size_z], + stride=[stride, stride_y, stride_z], + padding=[padding, padding_y, padding_z], + groups=groups, + bias_attr=True, + shared_biases=True, + trans=False, + layer_type="conv3d", + act=LinearActivation()) +outputs(conv3d_2) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_deconv3d_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_deconv3d_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..a113279fc17b49ad01b8860b61180af0f35694fb --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_deconv3d_layer.py @@ -0,0 +1,50 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +num_channels = 3 +filter_size = 3 +filter_size_y = 3 +filter_size_z = 3 +stride = 2 +stride_y = 2 +stride_z = 2 +padding = 1 +padding_y = 1 +padding_z = 1 +groups = 1 + +data = data_layer( + name='data', size=12096 * num_channels, height=48, width=42, depth=6) + +# first +deconv3d_1 = img_conv3d_layer( + input=data, + name='deconv3d_1', + num_filters=16, + num_channels=num_channels, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + bias_attr=True, + shared_biases=True, + trans=True, + layer_type="deconv3d", + act=LinearActivation()) +# second +deconv3d_2 = img_conv3d_layer( + input=data, + name='deconv3d_2', + num_filters=16, + num_channels=num_channels, + filter_size=[filter_size, filter_size_y, filter_size_z], + stride=[stride, stride_y, stride_z], + padding=[padding, padding_y, padding_z], + groups=groups, + bias_attr=True, + shared_biases=True, + trans=True, + layer_type="deconv3d", + act=LinearActivation()) +outputs(deconv3d_2) diff --git a/python/paddle/trainer_config_helpers/tests/layers_test.py b/python/paddle/trainer_config_helpers/tests/layers_test.py index 05902ea293df5a3e9c10f6700930ca6a343603c2..b3dd8f8fc784754e749240e1b895b11ef6aba438 100644 --- a/python/paddle/trainer_config_helpers/tests/layers_test.py +++ b/python/paddle/trainer_config_helpers/tests/layers_test.py @@ -17,3 +17,4 @@ from paddle.trainer.config_parser import parse_config_and_serialize if __name__ == '__main__': parse_config_and_serialize( 'trainer_config_helpers/tests/layers_test_config.py', '') +# layers_test_config.py