提交 96f42d8e 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #1455 from wangyang59/cudnnDeconv

Cudnn deconv
......@@ -25,12 +25,16 @@ filter_test(GSERVER_HEADER)
filter_test(GSERVER_SOURCES)
if(NOT WITH_GPU)
list(REMOVE_ITEM GSERVER_HEADER
layers/CudnnConvBaseLayer.h
layers/CudnnConvLayer.h
layers/CudnnConvTransLayer.h
layers/CudnnPoolLayer.h
layers/CudnnBatchNormLayer.h)
list(REMOVE_ITEM GSERVER_SOURCES
layers/CudnnConvBaseLayer.cpp
layers/CudnnConvLayer.cpp
layers/CudnnConvTransLayer.cpp
layers/CudnnPoolLayer.cpp
layers/CudnnBatchNormLayer.cpp)
compile_cu_as_cpp(layers/LstmCompute.cu)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvBaseOperator.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief ConvBaseOperator takes two inputs to perform the convolution.
* The first input is the image, and the second input is the convolution kernel.
* The height of data for two inputs are the same. Each data of the first input
* is convolved with each data of the second input indepedently.
*
* The config file api is conv_operator.
*/
ConvBaseOperator::ConvBaseOperator(const OperatorConfig &config, bool useGpu)
: Operator(config, useGpu) {
CHECK(useGpu);
CHECK_EQ(config_.input_indices_size(), 2L);
caffeMode_ = true;
getConvParams();
computeConvSizes();
// initialize all to default algorithms
fwdAlgo_ = 0;
bwdFilterAlgo_ = 0;
bwdDataAlgo_ = 0;
fwdLimitBytes_ = 0;
bwdDataLimitBytes_ = 0;
bwdFilterLimitBytes_ = 0;
workSpaceInBytes_ = 0;
workSpace_ = nullptr;
isSelectAlgo_ = false;
}
void ConvBaseOperator::allocConvWorkSpace() {
hl_conv_workspace(imageDesc_,
outputDesc_,
filterDesc_,
convDesc_,
&fwdAlgo_,
&fwdLimitBytes_,
&bwdDataAlgo_,
&bwdDataLimitBytes_,
&bwdFilterAlgo_,
&bwdFilterLimitBytes_);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
if (maxWorkSpace > workSpaceInBytes_) {
if (workSpaceInBytes_ != 0) {
hl_free_mem_device(workSpace_);
}
// total amount of storage needed
workSpace_ = hl_malloc_device(maxWorkSpace);
workSpaceInBytes_ = maxWorkSpace;
}
}
void ConvBaseOperator::computeConvSizes() {
hl_create_filter_descriptor(
&filterDesc_, channels_, numFilters_, filterSizeY_, filterSize_);
hl_create_tensor_descriptor(&imageDesc_);
hl_create_tensor_descriptor(&outputDesc_);
hl_create_convolution_descriptor(&convDesc_,
imageDesc_,
filterDesc_,
paddingY_,
padding_,
strideY_,
stride_);
}
void ConvBaseOperator::reshapeImageDescriptors() {
hl_tensor_reshape(imageDesc_,
1,
channels_,
imageH_,
imageW_,
channels_ * imageH_ * imageW_,
imageH_ * imageW_,
imageW_,
1);
hl_tensor_reshape(outputDesc_,
1,
numFilters_,
outputH_,
outputW_,
numFilters_ * outputH_ * outputW_,
outputH_ * outputW_,
outputW_,
1);
hl_reset_convolution_descriptor(convDesc_,
imageDesc_,
filterDesc_,
paddingY_,
padding_,
strideY_,
stride_);
}
void ConvBaseOperator::getConvParams() {
configNumFilters_ = config_.num_filters();
const ConvConfig &conf = config_.conv_conf();
padding_ = conf.padding();
stride_ = conf.stride();
filterSize_ = conf.filter_size();
paddingY_ = conf.padding_y();
strideY_ = conf.stride_y();
filterSizeY_ = conf.filter_size_y();
filterPixels_ = filterSize_ * filterSizeY_;
configChannels_ = conf.channels();
imgSize_ = conf.img_size();
imgSizeY_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
imgPixels_ = imgSize_ * imgSizeY_;
CHECK_EQ(conf.groups(), 1U);
filterChannels_ = conf.filter_channels();
outputX_ = conf.output_x();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
outputs_ = outputX_ * outputX_;
isDeconv_ = (config_.type() == "conv") ? false : true;
if (isDeconv_) {
channels_ = configNumFilters_;
numFilters_ = configChannels_;
} else {
channels_ = configChannels_;
numFilters_ = configNumFilters_;
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Operator.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief ConvOperator takes two inputs to perform the convolution.
* The first input is the image, and the second input is the convolution kernel.
* The height of data for two inputs are the same. Each data of the first input
* is convolved with each data of the second input indepedently.
*
* The config file api is conv_operator.
*/
class ConvBaseOperator : public Operator {
public:
ConvBaseOperator(const OperatorConfig &config, bool useGpu);
/**
* Free workspace in device and destroy cudnn tensor descriptor.
*/
virtual ~ConvBaseOperator() {
if (workSpaceInBytes_ != 0) {
hl_free_mem_device(workSpace_);
workSpaceInBytes_ = 0;
}
hl_destroy_tensor_descriptor(imageDesc_);
hl_destroy_tensor_descriptor(outputDesc_);
hl_destroy_filter_descriptor(filterDesc_);
hl_destroy_convolution_descriptor(convDesc_);
}
protected:
/**
* Get convolution parameters from layer config and
* initialize member variables.
*/
void getConvParams();
/**
* Allocate Gpu Memory for cudnn convolution algorithms.
*/
void allocConvWorkSpace();
/**
* Create cudnn tensor descriptor for convolution operation.
*/
void computeConvSizes();
/**
* Reshape cudnn tensor descriptor.
*/
void reshapeImageDescriptors();
/**
* Reshape cudnn tensor descriptor.
*/
virtual void reshape(int batchSize) = 0;
/**
* Check filter size is equal to the size calculated by parameters from
* layer config.
*/
void checkFilterSize(const MatrixPtr &filter) {
CHECK_EQ(static_cast<int>(filter->getWidth()),
filterSize_ * filterSizeY_ * channels_ * numFilters_);
}
/// Most of member variables are same with CudnnConvLayer.
/// There is no explanation here.
bool isDeconv_;
int imageH_, imageW_, outputH_, outputW_;
hl_tensor_descriptor imageDesc_;
hl_tensor_descriptor outputDesc_;
hl_filter_descriptor filterDesc_;
hl_convolution_descriptor convDesc_;
bool caffeMode_;
int inputOffset_, outputOffset_, weightOffset_;
int numFilters_, channels_;
/// from parsing config
int configNumFilters_, configChannels_;
int padding_, stride_, filterSize_, imgSize_, imgSizeY_;
int paddingY_, strideY_, filterSizeY_;
int imgPixels_, filterPixels_, filterChannels_, outputX_, outputY_, outputs_;
/// Following member variables are same with CudnnConvLayer.
/// There is no explanation here.
int fwdAlgo_, bwdFilterAlgo_, bwdDataAlgo_;
size_t fwdLimitBytes_, bwdDataLimitBytes_, bwdFilterLimitBytes_;
size_t workSpaceInBytes_;
void *workSpace_;
bool isSelectAlgo_;
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvBaseProjection.h"
#include "paddle/utils/Stat.h"
namespace paddle {
ThreadLocalD<std::vector<MemoryHandle *>> ConvBaseProjection::convMem_;
ConvBaseProjection::ConvBaseProjection(const ProjectionConfig &config,
ParameterPtr parameter,
bool useGpu)
: Projection(config, parameter, useGpu) {
CHECK(useGpu); // only support GPU
getConvParams();
initCudnn();
size_t height = filterH_ * filterW_ * channels_ / groups_;
size_t width = numFilters_;
weight_.reset(new Weight(height, width, parameter));
weightOffset_ = height * width / groups_;
}
void ConvBaseProjection::getConvParams() {
const ConvConfig &conf = config_.conv_conf();
paddingH_ = conf.padding_y();
paddingW_ = conf.padding();
strideH_ = conf.stride_y();
strideW_ = conf.stride();
filterH_ = conf.filter_size_y();
filterW_ = conf.filter_size();
configImgH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
configImgW_ = conf.img_size();
configOutH_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
configOutW_ = conf.output_x();
configChannels_ = conf.channels();
configNumFilters_ = config_.num_filters();
isDeconv_ = (config_.type() == "conv") ? false : true;
channels_ = (isDeconv_) ? configNumFilters_ : configChannels_;
numFilters_ = (isDeconv_) ? configChannels_ : configNumFilters_;
groups_ = conf.groups();
CHECK_EQ(channels_ % groups_, 0);
CHECK_EQ(numFilters_ % groups_, 0);
}
void ConvBaseProjection::initCudnn() {
hl_create_filter_descriptor(&filterDesc_,
channels_ / groups_,
numFilters_ / groups_,
filterH_,
filterW_);
hl_create_tensor_descriptor(&imageDesc_);
hl_create_tensor_descriptor(&outputDesc_);
hl_create_convolution_descriptor(&convDesc_,
imageDesc_,
filterDesc_,
paddingH_,
paddingW_,
strideH_,
strideW_);
// initialize all to default algorithms
fwdAlgo_ = 0;
bwdFilterAlgo_ = 0;
bwdDataAlgo_ = 0;
fwdLimitBytes_ = 0;
bwdDataLimitBytes_ = 0;
bwdFilterLimitBytes_ = 0;
workSpaceInBytes_ = 0;
batchNum_ = 0;
isSelectAlgo_ = false;
}
void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
// The stride between two consecutive samples in the output of ConvProjection
// may not be numFilters_ * outputH_ * outputW_ (conv) or
// channels_ * imageH_ * imageW_ (deconv)
// for example, in the case of layer ConcatenateLayer2 with two
// ConvProjection, the stride is the output_size of layer ConcatenateLayer2.
// So the calculation of nStride is different from CudnnConvLayer.
size_t nStrideImage, nStrideOutput;
if (isDeconv_) {
nStrideImage = out_->value->getStride();
nStrideOutput = numFilters_ * outputH_ * outputW_;
} else {
nStrideImage = channels_ * imageH_ * imageW_;
nStrideOutput = out_->value->getStride();
}
hl_tensor_reshape(imageDesc_,
batchSize,
channels_ / groups_,
imageH_,
imageW_,
nStrideImage,
imageH_ * imageW_,
imageW_,
1);
hl_tensor_reshape(outputDesc_,
batchSize,
numFilters_ / groups_,
outputH_,
outputW_,
nStrideOutput,
outputH_ * outputW_,
outputW_,
1);
hl_reset_convolution_descriptor(convDesc_,
imageDesc_,
filterDesc_,
paddingH_,
paddingW_,
strideH_,
strideW_);
}
void ConvBaseProjection::reshape(int batchSize) {
size_t width = calOutputSize();
CHECK_EQ(width, out_->value->getWidth());
CHECK_EQ(calInputSize(), in_->value->getWidth());
isSelectAlgo_ = (batchSize == batchNum_);
batchNum_ = batchSize;
if (!isSelectAlgo_) {
reshapeTensorDesc(batchSize);
hl_conv_workspace(imageDesc_,
outputDesc_,
filterDesc_,
convDesc_,
&fwdAlgo_,
&fwdLimitBytes_,
&bwdDataAlgo_,
&bwdDataLimitBytes_,
&bwdFilterAlgo_,
&bwdFilterLimitBytes_);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
workSpaceInBytes_ = maxWorkSpace;
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
<< " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
}
isSelectAlgo_ = true;
}
void *ConvBaseProjection::getSpaceBytes(size_t size) {
std::vector<MemoryHandle *> &convMem = *convMem_;
if (convMem.empty()) {
int numDevices = hl_get_device_count();
convMem.resize(numDevices);
}
int devId = hl_get_device();
MemoryHandle **localMem = &(convMem[devId]);
if (NULL == *localMem || size > (*localMem)->getAllocSize()) {
*localMem = new GpuMemoryHandle(size);
}
return (*localMem)->getBuf();
}
ConvBaseProjection::~ConvBaseProjection() {
hl_destroy_tensor_descriptor(imageDesc_);
hl_destroy_tensor_descriptor(outputDesc_);
hl_destroy_filter_descriptor(filterDesc_);
hl_destroy_convolution_descriptor(convDesc_);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Projection.h"
#include "paddle/math/MathUtils.h"
namespace paddle {
/**
* @brief Base class for ConvProjection and ConvTransProjection.
*/
class ConvBaseProjection : public Projection {
public:
/**
* Constructor.
*/
ConvBaseProjection(const ProjectionConfig& config,
ParameterPtr parameter,
bool useGpu);
~ConvBaseProjection();
protected:
void getConvParams();
void initCudnn();
void reshapeTensorDesc(int batchSize);
void reshape(int batchSize);
virtual size_t calOutputSize() = 0;
virtual size_t calInputSize() = 0;
static void* getSpaceBytes(size_t size);
/// True if it's deconv projection layer, false if it's ConvProjection layer
bool isDeconv_;
/// imageH_ and imageW_ / outputH_ and outputW_
/// is calculated from the input layer.
int imageH_, imageW_;
int outputH_, outputW_;
/// configImgH_ and configImgW_ / configOutH_ and configOutW_
/// is obtained from config.
int configImgH_, configImgW_;
int configOutH_, configOutW_;
/// channels_ and numFilters_ are defined in terms of convolution semantics
int channels_, numFilters_;
/// configChannels and configNumFilters_ are obtained from config
/// For Conv they are the same as channels_ and numFilters
/// For ConvTrans they are opposite to channels_ and numFilters
int configChannels_, configNumFilters_;
int paddingH_, paddingW_;
int strideH_, strideW_;
int filterH_, filterW_;
/// One group offset of input data.
int inputOffset_;
/// One group offset of output data.
int outputOffset_;
/// One group offset of weight.
int weightOffset_;
int groups_;
/// Cudnn tensor descriptor for input.
hl_tensor_descriptor imageDesc_;
/// Cudnn tensor descriptor for output.
hl_tensor_descriptor outputDesc_;
/// Cudnn tensor descriptor for filter.
hl_filter_descriptor filterDesc_;
/// Cudnn tensor descriptor for a convolution operation.
hl_convolution_descriptor convDesc_;
/// Record the algorithm for forward convolution, which is obtained by cudnn
/// api to search the best suited algorithm.
int fwdAlgo_;
/// Record the algorithm for computing convolution gradient with respect to
/// filter coefficients.
int bwdFilterAlgo_;
/// Record the algorithm for computing convolution gradient with respect to
/// the output.
int bwdDataAlgo_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// forward convolution with the specified algo.
size_t fwdLimitBytes_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// backwardFilter with the specified algo.
size_t bwdDataLimitBytes_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// backwardData with the specified algo.
size_t bwdFilterLimitBytes_;
/// Size of total work space.
size_t workSpaceInBytes_;
/// Whether to call cuDNN api to choose conv algorithm.
bool isSelectAlgo_;
/// batchNum is used to record batch size. If the batch size is changed,
/// the selection algorithm will be called.
int batchNum_;
bool bias_;
std::unique_ptr<Weight> weight_;
static ThreadLocalD<std::vector<MemoryHandle*>> convMem_;
};
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Operator.h"
#include "ConvOperator.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
......@@ -27,120 +27,8 @@ namespace paddle {
* The config file api is conv_operator.
*/
class ConvOperator : public Operator {
public:
ConvOperator(const OperatorConfig &config, bool useGpu);
/**
* Free workspace in device and destroy cudnn tensor descriptor.
*/
virtual ~ConvOperator() {
if (workSpaceInBytes_ != 0) {
hl_free_mem_device(workSpace_);
workSpaceInBytes_ = 0;
}
hl_destroy_tensor_descriptor(inputDesc_);
hl_destroy_tensor_descriptor(outputDesc_);
hl_destroy_filter_descriptor(filterDesc_);
hl_destroy_convolution_descriptor(convDesc_);
}
virtual void forward();
virtual void backward();
private:
/**
* Get convolution parameters from layer config and
* initialize member variables.
*/
void getConvParams();
/**
* Allocate Gpu Memory for cudnn convolution algorithms.
*/
void allocConvWorkSpace(size_t maxWorkSpace);
/**
* Create cudnn tensor descriptor for convolution operation.
*/
void computeConvSizes();
/**
* Reshape cudnn tensor descriptor.
*/
void reshapeImageDescriptors();
/**
* Reshape cudnn tensor descriptor.
*/
void reshape(int batchSize);
/**
* Check filter size is equal to the size calculated by parameters from
* layer config.
*/
void checkFilterSize(const MatrixPtr &filter) {
CHECK_EQ(static_cast<int>(filter->getWidth()),
filterSize_ * filterSizeY_ * channels_ * numFilters_);
}
/// Most of member variables are same with CudnnConvLayer.
/// There is no explanation here.
int imageH_, imageW_, outputH_, outputW_;
hl_tensor_descriptor inputDesc_;
hl_tensor_descriptor outputDesc_;
hl_filter_descriptor filterDesc_;
hl_convolution_descriptor convDesc_;
bool caffeMode_;
int inputOffset_, outputOffset_, weightOffset_;
int numFilters_;
int padding_, stride_, filterSize_, channels_, imgSize_, imgSizeY_;
int paddingY_, strideY_, filterSizeY_;
int imgPixels_, filterPixels_, filterChannels_, outputX_, outputY_, outputs_;
/// Following member variables are same with CudnnConvLayer.
/// There is no explanation here.
int fwdAlgo_, bwdFilterAlgo_, bwdDataAlgo_;
size_t fwdLimitBytes_, bwdDataLimitBytes_, bwdFilterLimitBytes_;
size_t workSpaceInBytes_;
void *workSpace_;
bool isSelectAlgo_;
};
REGISTER_OPERATOR(conv, ConvOperator);
ConvOperator::ConvOperator(const OperatorConfig &config, bool useGpu)
: Operator(config, useGpu) {
CHECK(useGpu);
CHECK_EQ(config_.input_indices_size(), 2L);
caffeMode_ = true;
getConvParams();
computeConvSizes();
// initialize all to default algorithms
fwdAlgo_ = 0;
bwdFilterAlgo_ = 0;
bwdDataAlgo_ = 0;
fwdLimitBytes_ = 0;
bwdDataLimitBytes_ = 0;
bwdFilterLimitBytes_ = 0;
workSpaceInBytes_ = 0;
workSpace_ = nullptr;
isSelectAlgo_ = false;
}
void ConvOperator::allocConvWorkSpace(size_t maxWorkSpace) {
if (maxWorkSpace > workSpaceInBytes_) {
if (workSpaceInBytes_ != 0) {
hl_free_mem_device(workSpace_);
}
// total amount of storage needed
workSpace_ = hl_malloc_device(maxWorkSpace);
workSpaceInBytes_ = maxWorkSpace;
}
}
void ConvOperator::reshape(int batchSize) {
imageH_ = ins_[0]->getFrameHeight();
imageW_ = ins_[0]->getFrameWidth();
......@@ -148,106 +36,25 @@ void ConvOperator::reshape(int batchSize) {
if (imageW_ == 0) imageW_ = imgSize_;
outputH_ = outputSize(imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
outputW_ = outputSize(imageW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the outputSizes are consistent with config
CHECK_EQ(outputH_, outputY_);
CHECK_EQ(outputW_, outputX_);
out_->setFrameHeight(outputH_);
out_->setFrameWidth(outputW_);
reshapeImageDescriptors();
if (!isSelectAlgo_) {
hl_conv_workspace(inputDesc_,
outputDesc_,
filterDesc_,
convDesc_,
&fwdAlgo_,
&fwdLimitBytes_,
&bwdDataAlgo_,
&bwdDataLimitBytes_,
&bwdFilterAlgo_,
&bwdFilterLimitBytes_);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
inputOffset_ = channels_ * imageH_ * imageW_;
outputOffset_ = numFilters_ * outputH_ * outputW_;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
allocConvWorkSpace(maxWorkSpace);
if (!isSelectAlgo_) {
allocConvWorkSpace();
}
isSelectAlgo_ = true;
}
void ConvOperator::computeConvSizes() {
hl_create_filter_descriptor(
&filterDesc_, channels_, numFilters_, filterSizeY_, filterSize_);
hl_create_tensor_descriptor(&inputDesc_);
int outputX =
outputSize(imgSize_, filterSize_, padding_, stride_, caffeMode_);
int outputY =
outputSize(imgSizeY_, filterSizeY_, paddingY_, strideY_, caffeMode_);
CHECK_EQ(outputX, outputX_);
CHECK_EQ(outputY, outputY_);
hl_create_tensor_descriptor(&outputDesc_);
hl_create_convolution_descriptor(&convDesc_,
inputDesc_,
filterDesc_,
paddingY_,
padding_,
strideY_,
stride_);
}
void ConvOperator::reshapeImageDescriptors() {
hl_tensor_reshape(inputDesc_,
1,
channels_,
imageH_,
imageW_,
channels_ * imageH_ * imageW_,
imageH_ * imageW_,
imageW_,
1);
hl_tensor_reshape(outputDesc_,
1,
numFilters_,
outputH_,
outputW_,
numFilters_ * outputH_ * outputW_,
outputH_ * outputW_,
outputW_,
1);
hl_reset_convolution_descriptor(convDesc_,
inputDesc_,
filterDesc_,
paddingY_,
padding_,
strideY_,
stride_);
inputOffset_ = channels_ * imageH_ * imageW_;
outputOffset_ = numFilters_ * outputH_ * outputW_;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSize_;
}
void ConvOperator::getConvParams() {
numFilters_ = config_.num_filters();
const ConvConfig &conf = config_.conv_conf();
padding_ = conf.padding();
stride_ = conf.stride();
filterSize_ = conf.filter_size();
paddingY_ = conf.padding_y();
strideY_ = conf.stride_y();
filterSizeY_ = conf.filter_size_y();
filterPixels_ = filterSize_ * filterSizeY_;
channels_ = conf.channels();
imgSize_ = conf.img_size();
imgSizeY_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
imgPixels_ = imgSize_ * imgSizeY_;
CHECK_EQ(conf.groups(), 1U);
filterChannels_ = conf.filter_channels();
outputX_ = conf.output_x();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
outputs_ = outputX_ * outputX_;
}
void ConvOperator::forward() {
size_t batchSize = ins_[0]->value->getHeight();
reshape(batchSize);
......@@ -264,7 +71,7 @@ void ConvOperator::forward() {
real *inputData = ins_[0]->value->getData() + inputOffset_ * batchId;
real *wgtData = ins_[1]->value->getData() + weightOffset_ * batchId;
real *outData = out_->value->getData() + outputOffset_ * batchId;
hl_convolution_forward(inputDesc_,
hl_convolution_forward(imageDesc_,
inputData,
outputDesc_,
outData,
......@@ -287,7 +94,7 @@ void ConvOperator::backward() {
if (ins_[1]->grad) {
real *inputData = ins_[0]->value->getData() + inputOffset_ * batchId;
real *weightGrad = ins_[1]->grad->getData() + weightOffset_ * batchId;
hl_convolution_backward_filter(inputDesc_,
hl_convolution_backward_filter(imageDesc_,
inputData,
outputDesc_,
outGrad,
......@@ -303,7 +110,7 @@ void ConvOperator::backward() {
if (NULL != preGrad) {
real *inputGrad = preGrad->getData() + inputOffset_ * batchId;
real *wgtData = ins_[1]->value->getData() + weightOffset_ * batchId;
hl_convolution_backward_data(inputDesc_,
hl_convolution_backward_data(imageDesc_,
inputGrad,
outputDesc_,
outGrad,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ConvBaseOperator.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief ConvOperator takes two inputs to perform the convolution.
* The first input is the image, and the second input is the convolution kernel.
* The height of data for two inputs are the same. Each data of the first input
* is convolved with each data of the second input indepedently.
*
* The config file api is conv_operator.
*/
class ConvOperator : public ConvBaseOperator {
public:
ConvOperator(const OperatorConfig &config, bool useGpu)
: ConvBaseOperator(config, useGpu) {}
/**
* Free workspace in device and destroy cudnn tensor descriptor.
*/
virtual ~ConvOperator() {}
void forward() override;
void backward() override;
void reshape(int batchSize) override;
};
} // namespace paddle
......@@ -19,149 +19,32 @@ namespace paddle {
REGISTER_PROJECTION(conv, ConvProjection);
ThreadLocalD<std::vector<MemoryHandle *>> ConvProjection::convMem_;
ConvProjection::ConvProjection(const ProjectionConfig &config,
ParameterPtr parameter,
bool useGpu)
: Projection(config, parameter, useGpu) {
CHECK(useGpu); // only support GPU
getConvParams();
initCudnn();
size_t height = filterH_ * filterW_ * channels_ / groups_;
size_t width = numFilters_;
weight_.reset(new Weight(height, width, parameter));
weightOffset_ = height * width / groups_;
}
void ConvProjection::getConvParams() {
const ConvConfig &conf = config_.conv_conf();
paddingH_ = conf.padding_y();
paddingW_ = conf.padding();
strideH_ = conf.stride_y();
strideW_ = conf.stride();
filterH_ = conf.filter_size_y();
filterW_ = conf.filter_size();
configImgH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
configImgW_ = conf.img_size();
channels_ = conf.channels();
numFilters_ = config_.num_filters();
groups_ = conf.groups();
CHECK_EQ(channels_ % groups_, 0);
CHECK_EQ(numFilters_ % groups_, 0);
}
void ConvProjection::initCudnn() {
hl_create_filter_descriptor(&filterDesc_,
channels_ / groups_,
numFilters_ / groups_,
filterH_,
filterW_);
hl_create_tensor_descriptor(&inputDesc_);
hl_create_tensor_descriptor(&outputDesc_);
hl_create_convolution_descriptor(&convDesc_,
inputDesc_,
filterDesc_,
paddingH_,
paddingW_,
strideH_,
strideW_);
// initialize all to default algorithms
fwdAlgo_ = 0;
bwdFilterAlgo_ = 0;
bwdDataAlgo_ = 0;
fwdLimitBytes_ = 0;
bwdDataLimitBytes_ = 0;
bwdFilterLimitBytes_ = 0;
workSpaceInBytes_ = 0;
batchNum_ = 0;
isSelectAlgo_ = false;
}
void ConvProjection::reshapeTensorDesc(int batchSize) {
hl_tensor_reshape(inputDesc_,
batchSize,
channels_ / groups_,
imageH_,
imageW_,
channels_ * imageH_ * imageW_,
imageH_ * imageW_,
imageW_,
1);
hl_reset_convolution_descriptor(convDesc_,
inputDesc_,
filterDesc_,
paddingH_,
paddingW_,
strideH_,
strideW_);
// The stride between two consecutive images in ConvProjection may not be 1,
// for example, in the case of layer ConcatenateLayer2 with two
// ConvProjection, the stride is the output_size of layer ConcatenateLayer2.
// So the calculation of nStride is different from CudnnConvLayer.
// In fact, only "nStride = out_->value->getStride()" is ok.
size_t nStride = numFilters_ * outputH_ * outputW_;
if (out_->value->isContiguous()) {
CHECK_EQ(nStride, out_->value->getWidth());
} else {
nStride = out_->value->getStride();
}
hl_tensor_reshape(outputDesc_,
batchSize,
numFilters_ / groups_,
outputH_,
outputW_,
nStride,
outputH_ * outputW_,
outputW_,
1);
size_t ConvProjection::calOutputSize() {
imageH_ = in_->getFrameHeight();
imageW_ = in_->getFrameWidth();
if (imageH_ == 0) imageH_ = configImgH_;
if (imageW_ == 0) imageW_ = configImgW_;
outputH_ = outputSize(imageH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
outputW_ = outputSize(imageW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument *>(out_)->setFrameHeight(outputH_);
const_cast<Argument *>(out_)->setFrameWidth(outputW_);
inputOffset_ = (configChannels_ / groups_) * imageH_ * imageW_;
outputOffset_ = (configNumFilters_ / groups_) * outputH_ * outputW_;
return outputH_ * outputW_ * configNumFilters_;
}
void ConvProjection::reshape(int batchSize) {
size_t width = calOutputSize();
CHECK_EQ(width, out_->value->getWidth());
CHECK_EQ(static_cast<size_t>(channels_ * imageH_ * imageW_),
in_->value->getWidth())
<< "Wrong input size for convolution"
<< " channels=" << channels_ << " imageH=" << imageH_
<< " imageW=" << imageW_ << " inputSize=" << in_->value->getWidth();
isSelectAlgo_ = (batchSize == batchNum_);
batchNum_ = batchSize;
if (!isSelectAlgo_) {
reshapeTensorDesc(batchSize);
hl_conv_workspace(inputDesc_,
outputDesc_,
filterDesc_,
convDesc_,
&fwdAlgo_,
&fwdLimitBytes_,
&bwdDataAlgo_,
&bwdDataLimitBytes_,
&bwdFilterAlgo_,
&bwdFilterLimitBytes_);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
workSpaceInBytes_ = maxWorkSpace;
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
<< " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
}
isSelectAlgo_ = true;
size_t ConvProjection::calInputSize() {
return static_cast<size_t>(configChannels_ * imageH_ * imageW_);
}
void ConvProjection::forward() {
......@@ -179,7 +62,7 @@ void ConvProjection::forward() {
real *inputData = in_->value->getData() + g * inputOffset_;
real *wgtData = weight_->getW()->getData() + g * weightOffset_;
real *outData = out_->value->getData() + g * outputOffset_;
hl_convolution_forward(inputDesc_,
hl_convolution_forward(imageDesc_,
inputData,
outputDesc_,
outData,
......@@ -205,7 +88,7 @@ void ConvProjection::backward(const UpdateCallback &callback) {
if (weight_->getWGrad()) {
real *inputData = in_->value->getData() + g * inputOffset_;
real *weightGrad = weight_->getWGrad()->getData() + g * weightOffset_;
hl_convolution_backward_filter(inputDesc_,
hl_convolution_backward_filter(imageDesc_,
inputData,
outputDesc_,
outGrad,
......@@ -221,7 +104,7 @@ void ConvProjection::backward(const UpdateCallback &callback) {
if (NULL != preGrad) {
real *inputGrad = preGrad->getData() + g * inputOffset_;
real *wgtData = weight_->getW()->getData() + g * weightOffset_;
hl_convolution_backward_data(inputDesc_,
hl_convolution_backward_data(imageDesc_,
inputGrad,
outputDesc_,
outGrad,
......@@ -237,26 +120,4 @@ void ConvProjection::backward(const UpdateCallback &callback) {
weight_->getParameterPtr()->incUpdate(callback);
}
void *ConvProjection::getSpaceBytes(size_t size) {
std::vector<MemoryHandle *> &convMem = *convMem_;
if (convMem.empty()) {
int numDevices = hl_get_device_count();
convMem.resize(numDevices);
}
int devId = hl_get_device();
MemoryHandle **localMem = &(convMem[devId]);
if (NULL == *localMem || size > (*localMem)->getAllocSize()) {
*localMem = new GpuMemoryHandle(size);
}
return (*localMem)->getBuf();
}
ConvProjection::~ConvProjection() {
hl_destroy_tensor_descriptor(inputDesc_);
hl_destroy_tensor_descriptor(outputDesc_);
hl_destroy_filter_descriptor(filterDesc_);
hl_destroy_convolution_descriptor(convDesc_);
}
} // namespace paddle
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include "Projection.h"
#include "ConvBaseProjection.h"
#include "paddle/math/MathUtils.h"
namespace paddle {
......@@ -22,109 +22,22 @@ namespace paddle {
/**
* @brief Convolution projection do the same calculation with CudnnConvLayer.
*/
class ConvProjection : public Projection {
class ConvProjection : public ConvBaseProjection {
public:
/**
* Constructor.
*/
ConvProjection(const ProjectionConfig& config,
ParameterPtr parameter,
bool useGpu);
bool useGpu)
: ConvBaseProjection(config, parameter, useGpu) {}
~ConvProjection();
~ConvProjection() {}
virtual void forward();
virtual void backward(const UpdateCallback& callback);
protected:
void getConvParams();
void initCudnn();
void reshapeTensorDesc(int batchSize);
void reshape(int batchSize);
size_t calOutputSize() {
imageH_ = in_->getFrameHeight();
imageW_ = in_->getFrameWidth();
if (imageH_ == 0) imageH_ = configImgH_;
if (imageW_ == 0) imageW_ = configImgW_;
outputH_ = outputSize(imageH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
outputW_ = outputSize(imageW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument*>(out_)->setFrameHeight(outputH_);
const_cast<Argument*>(out_)->setFrameWidth(outputW_);
inputOffset_ = (channels_ / groups_) * imageH_ * imageW_;
outputOffset_ = (numFilters_ / groups_) * outputH_ * outputW_;
return outputH_ * outputW_ * numFilters_;
}
static void* getSpaceBytes(size_t size);
/// imageH_ and imageW_ is calculated from the input layer.
int imageH_, imageW_;
/// configImgH_ and configImgW_ is obtained from config.
int configImgH_, configImgW_;
int outputH_, outputW_;
int channels_, numFilters_;
int paddingH_, paddingW_;
int strideH_, strideW_;
int filterH_, filterW_;
/// One group offset of input data.
int inputOffset_;
/// One group offset of output data.
int outputOffset_;
/// One group offset of weight.
int weightOffset_;
int groups_;
/// Cudnn tensor descriptor for input.
hl_tensor_descriptor inputDesc_;
/// Cudnn tensor descriptor for output.
hl_tensor_descriptor outputDesc_;
/// Cudnn tensor descriptor for filter.
hl_filter_descriptor filterDesc_;
/// Cudnn tensor descriptor for a convolution operation.
hl_convolution_descriptor convDesc_;
/// Record the algorithm for forward convolution, which is obtained by cudnn
/// api to search the best suited algorithm.
int fwdAlgo_;
/// Record the algorithm for computing convolution gradient with respect to
/// filter coefficients.
int bwdFilterAlgo_;
/// Record the algorithm for computing convolution gradient with respect to
/// the output.
int bwdDataAlgo_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// forward convolution with the specified algo.
size_t fwdLimitBytes_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// backwardFilter with the specified algo.
size_t bwdDataLimitBytes_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// backwardData with the specified algo.
size_t bwdFilterLimitBytes_;
/// Size of total work space.
size_t workSpaceInBytes_;
/// Whether to call cuDNN api to choose conv algorithm.
bool isSelectAlgo_;
/// batchNum is used to record batch size. If the batch size is changed,
/// the selection algorithm will be called.
int batchNum_;
bool bias_;
std::unique_ptr<Weight> weight_;
static ThreadLocalD<std::vector<MemoryHandle*>> convMem_;
virtual size_t calOutputSize();
virtual size_t calInputSize();
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvTransOperator.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief ConvTransOperator takes two inputs to perform the convolution.
* The first input is the image, and the second input is the convolution kernel.
* The height of data for two inputs are the same. Each data of the first input
* is convolved with each data of the second input indepedently.
*
* The config file api is conv_operator.
*/
REGISTER_OPERATOR(convt, ConvTransOperator);
void ConvTransOperator::reshape(int batchSize) {
outputH_ = ins_[0]->getFrameHeight();
outputW_ = ins_[0]->getFrameWidth();
if (outputH_ == 0) outputH_ = outputY_;
if (outputW_ == 0) outputW_ = outputX_;
imageH_ = imageSize(outputH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
imageW_ = imageSize(outputW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the imageSizes are consistent with config
CHECK_EQ(imageH_, imgSizeY_);
CHECK_EQ(imageW_, imgSize_);
out_->setFrameHeight(imageH_);
out_->setFrameWidth(imageW_);
reshapeImageDescriptors();
inputOffset_ = numFilters_ * outputH_ * outputW_;
outputOffset_ = channels_ * imageH_ * imageW_;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
if (!isSelectAlgo_) {
allocConvWorkSpace();
}
isSelectAlgo_ = true;
}
void ConvTransOperator::forward() {
size_t batchSize = ins_[0]->value->getHeight();
reshape(batchSize);
CHECK_EQ(ins_[1]->value->getHeight(), batchSize);
checkFilterSize(ins_[1]->value);
Matrix::resizeOrCreate(
out_->value, batchSize, imageH_ * imageW_ * channels_, false, useGpu_);
{
AsyncGpuBlock block;
for (size_t batchId = 0; batchId < batchSize; ++batchId) {
real *inputData = ins_[0]->value->getData() + inputOffset_ * batchId;
real *wgtData = ins_[1]->value->getData() + weightOffset_ * batchId;
real *outData = out_->value->getData() + outputOffset_ * batchId;
hl_convolution_backward_data(imageDesc_,
outData,
outputDesc_,
inputData,
filterDesc_,
wgtData,
convDesc_,
workSpace_,
workSpaceInBytes_,
bwdDataAlgo_);
}
}
}
void ConvTransOperator::backward() {
size_t batchSize = ins_[0]->value->getHeight();
{
AsyncGpuBlock block;
for (size_t batchId = 0; batchId < batchSize; ++batchId) {
real *outGrad = out_->grad->getData() + outputOffset_ * batchId;
if (ins_[1]->grad) {
real *inputData = ins_[0]->value->getData() + inputOffset_ * batchId;
real *weightGrad = ins_[1]->grad->getData() + weightOffset_ * batchId;
hl_convolution_backward_filter(imageDesc_,
outGrad,
outputDesc_,
inputData,
filterDesc_,
weightGrad,
convDesc_,
workSpace_,
workSpaceInBytes_,
bwdFilterAlgo_);
}
MatrixPtr preGrad = ins_[0]->grad;
if (NULL != preGrad) {
real *inputGrad = preGrad->getData() + inputOffset_ * batchId;
real *wgtData = ins_[1]->value->getData() + weightOffset_ * batchId;
hl_convolution_forward(imageDesc_,
outGrad,
outputDesc_,
inputGrad,
filterDesc_,
wgtData,
convDesc_,
workSpace_,
workSpaceInBytes_,
fwdAlgo_);
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ConvBaseOperator.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief ConvTransOperator takes two inputs to perform the convolution.
* The first input is the image, and the second input is the convolution kernel.
* The height of data for two inputs are the same. Each data of the first input
* is convolved with each data of the second input indepedently.
*
* The config file api is conv_operator.
*/
class ConvTransOperator : public ConvBaseOperator {
public:
ConvTransOperator(const OperatorConfig &config, bool useGpu)
: ConvBaseOperator(config, useGpu) {}
/**
* Free workspace in device and destroy cudnn tensor descriptor.
*/
virtual ~ConvTransOperator() {}
void forward() override;
void backward() override;
void reshape(int batchSize) override;
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvTransProjection.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_PROJECTION(convt, ConvTransProjection);
size_t ConvTransProjection::calOutputSize() {
outputH_ = in_->getFrameHeight();
outputW_ = in_->getFrameWidth();
if (outputH_ == 0) outputH_ = configOutH_;
if (outputW_ == 0) outputW_ = configOutW_;
imageH_ = imageSize(outputH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
imageW_ = imageSize(outputW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument *>(out_)->setFrameHeight(imageH_);
const_cast<Argument *>(out_)->setFrameWidth(imageW_);
inputOffset_ = (configChannels_ / groups_) * outputH_ * outputW_;
outputOffset_ = (configNumFilters_ / groups_) * imageH_ * imageW_;
return imageH_ * imageW_ * configNumFilters_;
}
size_t ConvTransProjection::calInputSize() {
return static_cast<size_t>(configChannels_ * outputH_ * outputW_);
}
void ConvTransProjection::forward() {
int batchSize = in_->value->getHeight();
reshape(batchSize);
void *workSpace = NULL;
if (workSpaceInBytes_ > 0) {
workSpace = getSpaceBytes(workSpaceInBytes_);
}
for (int g = 0; g < groups_; ++g) {
REGISTER_TIMER_INFO("CudnnConvTransFwTimer", getName().c_str());
real *inData = in_->value->getData() + g * inputOffset_;
real *wgtData = weight_->getW()->getData() + g * weightOffset_;
real *outData = out_->value->getData() + g * outputOffset_;
hl_convolution_backward_data(imageDesc_,
outData,
outputDesc_,
inData,
filterDesc_,
wgtData,
convDesc_,
workSpace,
bwdDataLimitBytes_,
bwdDataAlgo_);
}
}
void ConvTransProjection::backward(const UpdateCallback &callback) {
REGISTER_TIMER_INFO("CudnnConvTransBpTimer", getName().c_str());
void *workSpace = NULL;
if (workSpaceInBytes_ > 0) {
workSpace = getSpaceBytes(workSpaceInBytes_);
}
for (int g = 0; g < groups_; ++g) {
real *outGrad = out_->grad->getData() + g * outputOffset_;
if (weight_->getWGrad()) {
real *inData = in_->value->getData() + g * inputOffset_;
real *weightGrad = weight_->getWGrad()->getData() + g * weightOffset_;
hl_convolution_backward_filter(imageDesc_,
outGrad,
outputDesc_,
inData,
filterDesc_,
weightGrad,
convDesc_,
workSpace,
bwdFilterLimitBytes_,
bwdFilterAlgo_);
}
MatrixPtr preGrad = in_->grad;
if (NULL != preGrad) {
real *inGrad = preGrad->getData() + g * inputOffset_;
real *wgtData = weight_->getW()->getData() + g * weightOffset_;
hl_convolution_forward(imageDesc_,
outGrad,
outputDesc_,
inGrad,
filterDesc_,
wgtData,
convDesc_,
workSpace,
fwdLimitBytes_,
fwdAlgo_);
}
}
weight_->getParameterPtr()->incUpdate(callback);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ConvBaseProjection.h"
#include "paddle/math/MathUtils.h"
namespace paddle {
/**
* @brief Convolution projection do the same calculation with CudnnConvLayer.
*/
class ConvTransProjection : public ConvBaseProjection {
public:
/**
* Constructor.
*/
ConvTransProjection(const ProjectionConfig& config,
ParameterPtr parameter,
bool useGpu)
: ConvBaseProjection(config, parameter, useGpu) {}
~ConvTransProjection() {}
virtual void forward();
virtual void backward(const UpdateCallback& callback);
virtual size_t calOutputSize();
virtual size_t calInputSize();
};
} // namespace paddle
......@@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CudnnConvLayer.h"
#include "CudnnConvBaseLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(cudnn_conv, CudnnConvBaseLayer);
REGISTER_LAYER(cudnn_convt, CudnnConvBaseLayer);
REGISTER_LAYER(cudnn_conv, CudnnConvLayer);
bool CudnnConvLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
bool CudnnConvBaseLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
CHECK(useGpu_) << "CudnnConvLayer only support gpu";
......@@ -33,7 +33,11 @@ bool CudnnConvLayer::init(const LayerMap &layerMap,
CHECK(config_.shared_biases());
for (size_t i = 0; i < inputLayers_.size(); i++) {
ProjectionConfig *conf = new ProjectionConfig();
conf->set_type("conv");
if (isDeconv_) {
conf->set_type("convt");
} else {
conf->set_type("conv");
}
conf->set_num_filters(numFilters_);
ConvConfig *convConf = conf->mutable_conv_conf();
*convConf = *(config_.mutable_inputs(i)->mutable_conv_conf());
......@@ -47,14 +51,13 @@ bool CudnnConvLayer::init(const LayerMap &layerMap,
if (biases_.get() && sharedBiases_) {
hl_create_tensor_descriptor(&biasDesc_);
hl_create_tensor_descriptor(&outputDesc_);
hl_tensor_reshape(biasDesc_, 1, numFilters_ / groups_[0], 1, 1);
biasOffset_ = numFilters_ / groups_[0];
hl_tensor_reshape(biasDesc_, 1, numFilters_, 1, 1);
}
return true;
}
void CudnnConvLayer::forward(PassType passType) {
void CudnnConvBaseLayer::forward(PassType passType) {
Layer::forward(passType);
int batchSize = getInput(0).getBatchSize();
......@@ -67,37 +70,41 @@ void CudnnConvLayer::forward(PassType passType) {
if (biases_) {
REGISTER_TIMER_INFO("CudnnConvBiasTimer", getName().c_str());
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int outH, outW;
if (isDeconv_) {
outH = imgSizeH_[0];
outW = imgSizeW_[0];
} else {
outH = outputH_[0];
outW = outputW_[0];
}
hl_tensor_reshape(outputDesc_,
batchSize,
numFilters_ / groups_[0],
outputH_[0],
outputW_[0],
numFilters_ * outputH_[0] * outputW_[0],
outputH_[0] * outputW_[0],
outputW_[0],
numFilters_,
outH,
outW,
numFilters_ * outH * outW,
outH * outW,
outW,
1);
outputOffset_ = getOutputValue()->getWidth() / groups_[0];
for (int g = 0; g < groups_[0]; ++g) {
real *biasData = biases_->getW()->getData() + biasOffset_ * g;
real *outData = getOutputValue()->getData() + outputOffset_ * g;
hl_convolution_forward_add_bias(
biasDesc_, biasData, outputDesc_, outData);
}
real *outData = getOutputValue()->getData();
real *biasData = biases_->getW()->getData();
hl_convolution_forward_add_bias(biasDesc_, biasData, outputDesc_, outData);
}
forwardActivation();
}
void CudnnConvLayer::backward(const UpdateCallback &callback) {
void CudnnConvBaseLayer::backward(const UpdateCallback &callback) {
backwardActivation();
if (biases_ && biases_->getWGrad()) {
REGISTER_TIMER_INFO("CudnnConvBpBiasTimer", getName().c_str());
for (int g = 0; g < groups_[0]; ++g) {
real *biasGrad = biases_->getWGrad()->getData() + biasOffset_ * g;
real *outGrad = getOutputGrad()->getData() + outputOffset_ * g;
hl_convolution_backward_bias(biasDesc_, biasGrad, outputDesc_, outGrad);
}
real *biasGrad = biases_->getWGrad()->getData();
real *outGrad = getOutputGrad()->getData();
hl_convolution_backward_bias(biasDesc_, biasGrad, outputDesc_, outGrad);
biases_->getParameterPtr()->incUpdate(callback);
}
......@@ -106,7 +113,7 @@ void CudnnConvLayer::backward(const UpdateCallback &callback) {
}
}
CudnnConvLayer::~CudnnConvLayer() {
CudnnConvBaseLayer::~CudnnConvBaseLayer() {
if (biases_) {
hl_destroy_tensor_descriptor(biasDesc_);
hl_destroy_tensor_descriptor(outputDesc_);
......
......@@ -30,27 +30,24 @@ namespace paddle {
*
* The config file api is img_conv_layer.
*/
class CudnnConvLayer : public ConvBaseLayer {
class CudnnConvBaseLayer : public ConvBaseLayer {
protected:
std::vector<std::unique_ptr<ProjectionConfig>> projConf_;
std::vector<std::unique_ptr<Projection>> projections_;
hl_tensor_descriptor biasDesc_;
hl_tensor_descriptor outputDesc_;
int biasOffset_;
int outputOffset_;
public:
explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
explicit CudnnConvBaseLayer(const LayerConfig& config)
: ConvBaseLayer(config) {}
~CudnnConvLayer();
~CudnnConvBaseLayer();
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
void addBiases();
void bpropBiases();
};
} // namespace paddle
......@@ -34,8 +34,7 @@ DECLARE_double(checkgrad_eps);
DECLARE_bool(thread_local_rand_use_global_seed);
DECLARE_bool(prev_batch_state);
// Do one forward pass of convTrans layer and check to see if its output
// matches the given result
// Do one forward pass of ConvLayer using either exconv or cudnn_conv
MatrixPtr doOneConvTest(size_t imgSize,
size_t output_x,
size_t stride,
......@@ -46,22 +45,35 @@ MatrixPtr doOneConvTest(size_t imgSize,
size_t groups,
MatrixPtr& inputData,
real* param,
bool useGpu) {
bool useGpu,
bool isDeconv = false) {
TestConfig config;
config.biasSize = numfilters;
string layerType;
if (useGpu) {
config.layerConfig.set_type("cudnn_conv");
layerType = (isDeconv) ? "cudnn_convt" : "cudnn_conv";
} else {
config.layerConfig.set_type("exconv");
layerType = (isDeconv) ? "exconvt" : "exconv";
}
config.layerConfig.set_type(layerType);
config.layerConfig.set_num_filters(numfilters);
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
size_t weightSize = channel * filter_size * filter_size *
config.layerConfig.num_filters() / groups;
config.inputDefs.push_back(
{INPUT_DATA, "layer_0", imgSize * imgSize * channel, weightSize});
if (isDeconv) {
config.inputDefs.push_back(
{INPUT_DATA, "layer_0", output_x * output_x * channel, weightSize});
config.layerConfig.set_size(imgSize * imgSize *
config.layerConfig.num_filters());
} else {
config.inputDefs.push_back(
{INPUT_DATA, "layer_0", imgSize * imgSize * channel, weightSize});
config.layerConfig.set_size(output_x * output_x *
config.layerConfig.num_filters());
}
LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_filter_size(filter_size);
......@@ -72,12 +84,15 @@ MatrixPtr doOneConvTest(size_t imgSize,
conv->set_stride(stride);
conv->set_stride_y(stride);
conv->set_groups(groups);
conv->set_filter_channels(channel / groups);
conv->set_img_size(imgSize);
conv->set_output_x(output_x);
config.layerConfig.set_size(conv->output_x() * conv->output_x() *
config.layerConfig.num_filters());
if (isDeconv) {
conv->set_filter_channels(numfilters / groups);
} else {
conv->set_filter_channels(channel / groups);
}
config.layerConfig.set_name("conv");
std::vector<DataLayerPtr> dataLayers;
......@@ -105,6 +120,8 @@ MatrixPtr doOneConvTest(size_t imgSize,
TEST(Layer, convParaUnified) {
#ifndef PADDLE_ONLY_CPU
MatrixPtr input, resultCpu, resultGpu;
/// TEST1 for conv ///
input = Matrix::create(1, 4 * 4, false, false);
real inputData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
real param[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 8, 7, 6, 5, 4, 3, 2, 1};
......@@ -121,7 +138,7 @@ TEST(Layer, convParaUnified) {
/*groups*/ 1,
input,
param,
false);
/*useGpu*/ false);
resultGpu = doOneConvTest(/* imgSize */ 4,
/* output_x */ 2,
......@@ -133,9 +150,42 @@ TEST(Layer, convParaUnified) {
/*groups*/ 1,
input,
param,
true);
/*useGpu*/ true);
checkMatrixEqual(resultCpu, resultGpu);
/// TEST1 for deconv ///
input = Matrix::create(1, 2 * 2, false, false);
real inputDataT[] = {1, 2, 3, 4};
input->setData(inputDataT);
resultCpu = doOneConvTest(/* imgSize */ 4,
/* output_x */ 2,
/* stride */ 1,
/* padding */ 0,
/* filter_size */ 3,
/*channel*/ 1,
/*numfilters*/ 2,
/*groups*/ 1,
input,
param,
/*useGpu*/ false,
/*isDeconv*/ true);
resultGpu = doOneConvTest(/* imgSize */ 4,
/* output_x */ 2,
/* stride */ 1,
/* padding */ 0,
/* filter_size */ 3,
/*channel*/ 1,
/*numfilters*/ 2,
/*groups*/ 1,
input,
param,
/*useGpu*/ true,
/*isDeconv*/ true);
checkMatrixEqual(resultCpu, resultGpu);
/// TEST2 for conv ///
input = Matrix::create(1, 3 * 3 * 2, false, false);
real inputData2[] = {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
......@@ -153,7 +203,7 @@ TEST(Layer, convParaUnified) {
/*groups*/ 1,
input,
param2,
false);
/*useGpu*/ false);
resultGpu = doOneConvTest(/* imgSize */ 3,
/* output_x */ 2,
......@@ -165,9 +215,10 @@ TEST(Layer, convParaUnified) {
/*groups*/ 1,
input,
param2,
true);
/*useGpu*/ true);
checkMatrixEqual(resultCpu, resultGpu);
/// TEST3 for conv ///
real param3[] = {1, 2, 3, 4, 4, 3, 2, 1};
resultCpu = doOneConvTest(/* imgSize */ 3,
......@@ -180,7 +231,66 @@ TEST(Layer, convParaUnified) {
/*groups*/ 2,
input,
param3,
false);
/*useGpu*/ false);
resultGpu = doOneConvTest(/* imgSize */ 3,
/* output_x */ 2,
/* stride */ 1,
/* padding */ 0,
/* filter_size */ 2,
/*channel*/ 2,
/*numfilters*/ 2,
/*groups*/ 2,
input,
param3,
/*useGpu*/ true);
checkMatrixEqual(resultCpu, resultGpu);
/// TEST2 for deconv ///
input = Matrix::create(1, 2 * 2 * 2, false, false);
real inputData2T[] = {1, 2, 3, 4, 5, 6, 7, 8};
input->setData(inputData2T);
resultCpu = doOneConvTest(/* imgSize */ 3,
/* output_x */ 2,
/* stride */ 1,
/* padding */ 0,
/* filter_size */ 2,
/*channel*/ 2,
/*numfilters*/ 2,
/*groups*/ 1,
input,
param2,
/*useGpu*/ false,
/*isDeconv*/ true);
resultGpu = doOneConvTest(/* imgSize */ 3,
/* output_x */ 2,
/* stride */ 1,
/* padding */ 0,
/* filter_size */ 2,
/*channel*/ 2,
/*numfilters*/ 2,
/*groups*/ 1,
input,
param2,
/*useGpu*/ true,
/*isDeconv*/ true);
checkMatrixEqual(resultCpu, resultGpu);
/// TEST3 for deconv ///
resultCpu = doOneConvTest(/* imgSize */ 3,
/* output_x */ 2,
/* stride */ 1,
/* padding */ 0,
/* filter_size */ 2,
/*channel*/ 2,
/*numfilters*/ 2,
/*groups*/ 2,
input,
param3,
/*useGpu*/ false,
/*isDeconv*/ true);
resultGpu = doOneConvTest(/* imgSize */ 3,
/* output_x */ 2,
......@@ -192,7 +302,8 @@ TEST(Layer, convParaUnified) {
/*groups*/ 2,
input,
param3,
true);
/*useGpu*/ true,
/*isDeconv*/ true);
checkMatrixEqual(resultCpu, resultGpu);
#endif
}
......
......@@ -166,15 +166,19 @@ TEST(Projection, scaling) {
}
}
void testProjectionConv(size_t groups) {
void testProjectionConv(size_t groups, bool isDeconv) {
const int NUM_FILTERS = 18;
const int FILTER_SIZE = 2;
const int FILTER_SIZE_Y = 3;
const int FILTER_SIZE_Y = 4;
const int CHANNELS = 3;
const int IMAGE_SIZE = 16;
ProjectionConfig conf;
conf.set_type("conv");
if (isDeconv) {
conf.set_type("convt");
} else {
conf.set_type("conv");
}
conf.set_num_filters(NUM_FILTERS);
ConvConfig* conv = conf.mutable_conv_conf();
......@@ -186,7 +190,11 @@ void testProjectionConv(size_t groups) {
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_groups(groups);
conv->set_filter_channels(conv->channels() / conv->groups());
if (isDeconv) {
conv->set_filter_channels(NUM_FILTERS / conv->groups());
} else {
conv->set_filter_channels(conv->channels() / conv->groups());
}
conv->set_img_size(IMAGE_SIZE);
int output_x = outputSize(conv->img_size(),
conv->filter_size(),
......@@ -199,8 +207,14 @@ void testProjectionConv(size_t groups) {
conv->stride_y(),
/* caffeMode */ true);
conv->set_output_x(output_x);
conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS);
conf.set_output_size(output_x * output_y * NUM_FILTERS);
conv->set_output_y(output_y);
if (isDeconv) {
conf.set_input_size(output_x * output_y * CHANNELS);
conf.set_output_size(IMAGE_SIZE * IMAGE_SIZE * NUM_FILTERS);
} else {
conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS);
conf.set_output_size(output_x * output_y * NUM_FILTERS);
}
testProjectionGrad(conf,
INPUT_DATA,
......@@ -215,8 +229,12 @@ void testProjectionConv(size_t groups) {
#ifndef PADDLE_ONLY_CPU
TEST(Projection, conv) {
testProjectionConv(1);
testProjectionConv(3);
/// test ConvProjection
testProjectionConv(1, false);
testProjectionConv(3, false);
/// test ConvTransProjection
testProjectionConv(1, true);
testProjectionConv(3, true);
}
#endif
......@@ -385,11 +403,11 @@ void testConvTransLayer(const string& type, bool trans, bool useGpu) {
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 288});
config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 384});
LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_filter_size(2);
conv->set_filter_size_y(3);
conv->set_filter_size_y(4);
conv->set_channels(16);
conv->set_padding(0);
conv->set_padding_y(1);
......@@ -416,6 +434,9 @@ TEST(Layer, convTransLayer) {
for (auto useGpu : {false, true}) {
testConvTransLayer("exconvt", /* trans= */ false, /* useGpu= */ useGpu);
}
#ifndef PADDLE_ONLY_CPU
testConvTransLayer("cudnn_convt", /* trans= */ false, /* useGpu= */ true);
#endif
}
TEST(Layer, blockExpandLayer) {
......@@ -1482,16 +1503,20 @@ TEST(Layer, BatchNormalizationLayer) {
#endif
}
TEST(Operator, conv) {
void testConvOperator(bool isDeconv) {
TestConfig config;
const int NUM_FILTERS = 16;
const int FILTER_SIZE = 2;
const int FILTER_SIZE_Y = 3;
const int CHANNELS = 3;
const int IMAGE_SIZE = 16;
const int IMAGE_SIZE_Y = 8;
const int IMAGE_SIZE_Y = 9;
OperatorConfig& operatorConf = *config.layerConfig.add_operator_confs();
operatorConf.set_type("conv");
if (isDeconv) {
operatorConf.set_type("convt");
} else {
operatorConf.set_type("conv");
}
ConvConfig* conv = operatorConf.mutable_conv_conf();
operatorConf.set_num_filters(NUM_FILTERS);
conv->set_filter_size(FILTER_SIZE);
......@@ -1502,7 +1527,6 @@ TEST(Operator, conv) {
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
conv->set_img_size(IMAGE_SIZE);
conv->set_img_size_y(IMAGE_SIZE_Y);
conv->set_output_x(outputSize(conv->img_size(),
......@@ -1515,11 +1539,22 @@ TEST(Operator, conv) {
conv->padding_y(),
conv->stride_y(),
/* caffeMode */ true));
config.layerConfig.set_size(conv->output_x() * conv->output_y() *
NUM_FILTERS);
config.inputDefs.push_back(
{INPUT_DATA, "layer_0", IMAGE_SIZE * IMAGE_SIZE_Y * CHANNELS, 0});
if (isDeconv) {
conv->set_filter_channels(NUM_FILTERS / conv->groups());
config.inputDefs.push_back({INPUT_DATA,
"layer_0",
conv->output_x() * conv->output_y() * CHANNELS,
0});
config.layerConfig.set_size(IMAGE_SIZE * IMAGE_SIZE_Y * NUM_FILTERS);
} else {
conv->set_filter_channels(conv->channels() / conv->groups());
config.inputDefs.push_back(
{INPUT_DATA, "layer_0", IMAGE_SIZE * IMAGE_SIZE_Y * CHANNELS, 0});
config.layerConfig.set_size(conv->output_x() * conv->output_y() *
NUM_FILTERS);
}
config.inputDefs.push_back(
{INPUT_DATA,
"layer_1",
......@@ -1531,6 +1566,11 @@ TEST(Operator, conv) {
testOperatorGrad(config, operatorConf, 100, /*useGpu*/ true, false);
}
TEST(Operator, conv) {
testConvOperator(/*isDeconv*/ true);
testConvOperator(/*isDeconv*/ false);
}
TEST(Layer, FeatureMapExpandLayer) {
TestConfig config;
config.layerConfig.set_type("featmap_expand");
......
......@@ -686,25 +686,17 @@ class ContextProjection(Projection):
@config_class
class ConvProjection(Projection):
type = 'conv'
class ConvBaseProjection(Projection):
def __init__(self,
input_layer_name,
num_filters=None,
conv_conf=None,
**xargs):
super(ConvProjection, self).__init__(input_layer_name, **xargs)
super(ConvBaseProjection, self).__init__(input_layer_name, **xargs)
if num_filters is not None:
self.proj_conf.num_filters = num_filters
parse_conv(conv_conf, input_layer_name, self.proj_conf.conv_conf,
num_filters)
self.proj_conf.output_size = self.proj_conf.conv_conf.output_x * \
self.proj_conf.conv_conf.output_y * \
num_filters
def calc_output_size(self, input_layer_config):
return self.proj_conf.output_size
......@@ -723,6 +715,46 @@ class ConvProjection(Projection):
return None
@config_class
class ConvProjection(ConvBaseProjection):
type = 'conv'
def __init__(self,
input_layer_name,
num_filters=None,
conv_conf=None,
**xargs):
super(ConvProjection, self).__init__(input_layer_name, **xargs)
parse_conv(conv_conf, self.input_layer_name, self.proj_conf.conv_conf,
num_filters)
self.proj_conf.output_size = self.proj_conf.conv_conf.output_x * \
self.proj_conf.conv_conf.output_y * \
num_filters
@config_class
class ConvTransProjection(ConvBaseProjection):
type = 'convt'
def __init__(self,
input_layer_name,
num_filters=None,
conv_conf=None,
**xargs):
super(ConvTransProjection, self).__init__(input_layer_name, **xargs)
parse_conv(
conv_conf,
self.input_layer_name,
self.proj_conf.conv_conf,
num_filters,
trans=True)
self.proj_conf.output_size = self.proj_conf.conv_conf.img_size_y * \
self.proj_conf.conv_conf.img_size * \
num_filters
# Define a operator for mixed layer
@config_class
class Operator(Cfg):
......@@ -789,6 +821,36 @@ class ConvOperator(Operator):
return self.operator_conf.output_size
@config_class
class ConvTransOperator(Operator):
type = 'convt'
def __init__(self,
input_layer_names,
num_filters=None,
conv_conf=None,
**xargs):
super(ConvTransOperator, self).__init__(input_layer_names, **xargs)
if num_filters is not None:
self.operator_conf.num_filters = num_filters
parse_conv(
conv_conf,
MakeLayerNameInSubmodel(input_layer_names[0]),
self.operator_conf.conv_conf,
num_filters,
trans=True)
self.operator_conf.output_size = \
self.operator_conf.conv_conf.img_size * \
self.operator_conf.conv_conf.img_size_y * \
num_filters
config_assert(len(input_layer_names) == 2, "Conv is binary operator")
def calc_output_size(self, input_sizes):
return self.operator_conf.output_size
# please refer to the comments in proto/ModelConfig.proto
@config_class
class Conv(Cfg):
......@@ -1772,8 +1834,17 @@ class ConvTransLayerBase(LayerBase):
use_gpu = int(g_command_config_args.get("use_gpu", 0))
parallel_nn = int(g_command_config_args.get("parallel_nn", 0))
# cudnn_convt has not been implemented so use exconvt only
self.layer_type = "exconvt"
# Automatically select cudnn_type for GPU and exconvt for CPU
# if set type=exconvt, but still reserve the way user specify
# exconvt or cudnn_convt manually.
if self.layer_type == "cudnn_convt":
config_assert(use_gpu, "cudnn_convt only support GPU")
if (use_gpu == 1 and self.layer_type != "exconvt" and
(parallel_nn == 0 or self.config.device > -1)):
self.layer_type = "cudnn_convt"
else:
self.layer_type = "exconvt"
# need to specify layer in config
self.config.type = self.layer_type
......@@ -1790,10 +1861,9 @@ class ConvTransLayerBase(LayerBase):
trans=True)
conv_conf = self.config.inputs[input_index].conv_conf
psize = self.calc_parameter_size(conv_conf)
print("output size for %s is %d " % (name, conv_conf.output_x))
self.create_input_parameter(input_index, psize)
self.set_layer_size(
(conv_conf.img_size**2) * self.config.num_filters)
self.set_cnn_layer(name, conv_conf.img_size_y, conv_conf.img_size,
self.config.num_filters)
psize = self.config.size
if shared_biases:
......@@ -1810,6 +1880,11 @@ class ConvTransLayer(ConvTransLayerBase):
layer_type = 'exconvt'
@config_layer('cudnn_convt')
class ConvTransLayer(ConvTransLayerBase):
layer_type = 'cudnn_convt'
@config_layer('norm')
class NormLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
......
......@@ -712,8 +712,9 @@ class MixedLayerType(LayerOutput):
assert len(self.inputs) == 0
return self
def __exit__(self, *args, **kwargs):
del args, kwargs # unused parameter to suppress warning
def __exit__(self, exc_type, exc_value, tb):
if exc_value is not None:
raise exc_value
assert len(self.inputs) != 0
ml = MixedLayer(
name=self.name,
......@@ -2044,8 +2045,9 @@ def img_conv_layer(input,
:param trans: true if it is a convTransLayer, false if it is a convLayer
:type trans: bool
:param layer_type: specify the layer_type, default is None. If trans=True,
layer_type has to be "exconvt", otherwise layer_type
has to be either "exconv" or "cudnn_conv"
layer_type has to be "exconvt" or "cudnn_convt",
otherwise layer_type has to be either "exconv" or
"cudnn_conv"
:type layer_type: String
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -2085,7 +2087,7 @@ def img_conv_layer(input,
if layer_type:
if trans:
assert layer_type in ["exconvt"]
assert layer_type in ["exconvt", "cudnn_convt"]
else:
assert layer_type in ["exconv", "cudnn_conv"]
lt = layer_type
......@@ -3715,7 +3717,8 @@ def conv_operator(img,
padding=0,
filter_size_y=None,
stride_y=None,
padding_y=None):
padding_y=None,
trans=False):
"""
Different from img_conv_layer, conv_op is an Operator, which can be used
in mixed_layer. And conv_op takes two inputs to perform convolution.
......@@ -3771,7 +3774,9 @@ def conv_operator(img,
if filter.size is not None:
filter.size = filter_size * filter_size_y * num_filters * num_channels
op = ConvOperator(
opCls = ConvTransOperator if trans else ConvOperator
op = opCls(
input_layer_names=[img.name, filter.name],
num_filters=num_filters,
conv_conf=Conv(
......@@ -3783,6 +3788,7 @@ def conv_operator(img,
padding_y=padding_y,
stride_y=stride_y,
groups=1))
op.origin = [img, filter]
return op
......@@ -3798,7 +3804,8 @@ def conv_projection(input,
stride_y=None,
padding_y=None,
groups=1,
param_attr=None):
param_attr=None,
trans=False):
"""
Different from img_conv_layer and conv_op, conv_projection is an Projection,
which can be used in mixed_layer and conat_layer. It use cudnn to implement
......@@ -3837,6 +3844,8 @@ def conv_projection(input,
:type groups: int
:param param_attr: Convolution param attribute. None means default attribute
:type param_attr: ParameterAttribute
:param trans: whether it is convTrans or conv
:type trans: boolean
:return: A DotMulProjection Object.
:rtype: DotMulProjection
"""
......@@ -3873,7 +3882,9 @@ def conv_projection(input,
param_attr.attr["initial_strategy"] = 0
param_attr.attr["initial_smart"] = False
proj = ConvProjection(
projCls = ConvTransProjection if trans else ConvProjection
proj = projCls(
input_layer_name=input.name,
num_filters=num_filters,
conv_conf=Conv(
......
......@@ -34,11 +34,31 @@ flt = data_layer(name='filter', size=3 * 3 * 1 * 64)
with mixed_layer() as m7:
m7 += conv_operator(
img=img, filter=flt, num_filters=64, num_channels=1, filter_size=3)
m7 += conv_projection(img, filter_size=3, num_filters=64, num_channels=1)
with mixed_layer() as m8:
m8 += conv_operator(
img=img,
filter=flt,
num_filters=64,
num_channels=1,
filter_size=3,
stride=2,
padding=1,
trans=True)
m8 += conv_projection(
img,
filter_size=3,
num_filters=64,
num_channels=1,
stride=2,
padding=1,
trans=True)
end = mixed_layer(
input=[
full_matrix_projection(input=m5),
trans_full_matrix_projection(input=m6), full_matrix_projection(input=m7)
trans_full_matrix_projection(input=m6),
full_matrix_projection(input=m7), full_matrix_projection(input=m8)
],
size=100,
layer_attr=ExtraAttr(
......
......@@ -33,6 +33,8 @@ layers {
bias_parameter_name: "___conv_0__.wbias"
num_filters: 64
shared_biases: true
height: 256
width: 256
}
layers {
name: "__batch_norm_0__"
......@@ -58,6 +60,8 @@ layers {
}
bias_parameter_name: "___batch_norm_0__.wbias"
moving_average_fraction: 0.9
height: 256
width: 256
}
layers {
name: "__crmnorm_0__"
......
......@@ -154,13 +154,38 @@ layers {
inputs {
input_layer_name: "img"
}
inputs {
input_layer_name: "img"
proj_conf {
type: "conv"
name: "___mixed_6__.w1"
input_size: 1024
output_size: 57600
conv_conf {
filter_size: 3
channels: 1
stride: 1
padding: 0
groups: 1
filter_channels: 1
output_x: 30
img_size: 32
caffe_mode: true
filter_size_y: 3
padding_y: 0
stride_y: 1
output_y: 30
img_size_y: 32
}
}
}
inputs {
input_layer_name: "filter"
}
operator_confs {
type: "conv"
input_indices: 0
input_indices: 1
input_indices: 2
input_sizes: 1024
input_sizes: 576
output_size: 57600
......@@ -186,38 +211,110 @@ layers {
layers {
name: "__mixed_7__"
type: "mixed"
size: 254016
active_type: ""
inputs {
input_layer_name: "img"
}
inputs {
input_layer_name: "img"
proj_conf {
type: "convt"
name: "___mixed_7__.w1"
input_size: 1024
output_size: 254016
conv_conf {
filter_size: 3
channels: 1
stride: 2
padding: 1
groups: 1
filter_channels: 64
output_x: 32
img_size: 63
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 32
img_size_y: 63
}
}
}
inputs {
input_layer_name: "filter"
}
operator_confs {
type: "convt"
input_indices: 0
input_indices: 2
input_sizes: 1024
input_sizes: 576
output_size: 254016
conv_conf {
filter_size: 3
channels: 1
stride: 2
padding: 1
groups: 1
filter_channels: 64
output_x: 32
img_size: 63
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 32
img_size_y: 63
}
num_filters: 64
}
}
layers {
name: "__mixed_8__"
type: "mixed"
size: 100
active_type: ""
inputs {
input_layer_name: "__mixed_4__"
input_parameter_name: "___mixed_7__.w0"
input_parameter_name: "___mixed_8__.w0"
proj_conf {
type: "fc"
name: "___mixed_7__.w0"
name: "___mixed_8__.w0"
input_size: 300
output_size: 100
}
}
inputs {
input_layer_name: "__mixed_5__"
input_parameter_name: "___mixed_7__.w1"
input_parameter_name: "___mixed_8__.w1"
proj_conf {
type: "trans_fc"
name: "___mixed_7__.w1"
name: "___mixed_8__.w1"
input_size: 100
output_size: 100
}
}
inputs {
input_layer_name: "__mixed_6__"
input_parameter_name: "___mixed_7__.w2"
input_parameter_name: "___mixed_8__.w2"
proj_conf {
type: "fc"
name: "___mixed_7__.w2"
name: "___mixed_8__.w2"
input_size: 57600
output_size: 100
}
}
inputs {
input_layer_name: "__mixed_7__"
input_parameter_name: "___mixed_8__.w3"
proj_conf {
type: "fc"
name: "___mixed_8__.w3"
input_size: 254016
output_size: 100
}
}
drop_rate: 0.5
}
parameters {
......@@ -281,7 +378,7 @@ parameters {
initial_smart: true
}
parameters {
name: "___mixed_7__.w0"
name: "___mixed_8__.w0"
size: 30000
initial_mean: 0.0
initial_std: 0.057735026919
......@@ -291,7 +388,7 @@ parameters {
initial_smart: true
}
parameters {
name: "___mixed_7__.w1"
name: "___mixed_8__.w1"
size: 10000
initial_mean: 0.0
initial_std: 0.1
......@@ -301,7 +398,7 @@ parameters {
initial_smart: true
}
parameters {
name: "___mixed_7__.w2"
name: "___mixed_8__.w2"
size: 5760000
initial_mean: 0.0
initial_std: 0.00416666666667
......@@ -310,10 +407,20 @@ parameters {
initial_strategy: 0
initial_smart: true
}
parameters {
name: "___mixed_8__.w3"
size: 25401600
initial_mean: 0.0
initial_std: 0.00198412698413
dims: 254016
dims: 100
initial_strategy: 0
initial_smart: true
}
input_layer_names: "test"
input_layer_names: "img"
input_layer_names: "filter"
output_layer_names: "__mixed_7__"
output_layer_names: "__mixed_8__"
sub_models {
name: "root"
layer_names: "test"
......@@ -328,10 +435,11 @@ sub_models {
layer_names: "filter"
layer_names: "__mixed_6__"
layer_names: "__mixed_7__"
layer_names: "__mixed_8__"
input_layer_names: "test"
input_layer_names: "img"
input_layer_names: "filter"
output_layer_names: "__mixed_7__"
output_layer_names: "__mixed_8__"
is_recurrent_layer_group: false
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册