提交 48a6168c 编写于 作者: W wangyang59

following comments from qingqing01

上级 b16c0a84
......@@ -49,7 +49,22 @@ ConvBaseOperator::ConvBaseOperator(const OperatorConfig &config, bool useGpu)
isSelectAlgo_ = false;
}
void ConvBaseOperator::allocConvWorkSpace(size_t maxWorkSpace) {
void ConvBaseOperator::allocConvWorkSpace() {
hl_conv_workspace(imageDesc_,
outputDesc_,
filterDesc_,
convDesc_,
&fwdAlgo_,
&fwdLimitBytes_,
&bwdDataAlgo_,
&bwdDataLimitBytes_,
&bwdFilterAlgo_,
&bwdFilterLimitBytes_);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
if (maxWorkSpace > workSpaceInBytes_) {
if (workSpaceInBytes_ != 0) {
hl_free_mem_device(workSpace_);
......@@ -60,59 +75,6 @@ void ConvBaseOperator::allocConvWorkSpace(size_t maxWorkSpace) {
}
}
void ConvBaseOperator::reshape(int batchSize) {
if (isDeconv_) {
outputH_ = ins_[0]->getFrameHeight();
outputW_ = ins_[0]->getFrameWidth();
if (outputH_ == 0) outputH_ = outputY_;
if (outputW_ == 0) outputW_ = outputX_;
imageH_ =
imageSize(outputH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
imageW_ = imageSize(outputW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the imageSizes are consistent with config
CHECK_EQ(imageH_, imgSizeY_);
CHECK_EQ(imageW_, imgSize_);
out_->setFrameHeight(imageH_);
out_->setFrameWidth(imageW_);
} else {
imageH_ = ins_[0]->getFrameHeight();
imageW_ = ins_[0]->getFrameWidth();
if (imageH_ == 0) imageH_ = imgSizeY_;
if (imageW_ == 0) imageW_ = imgSize_;
outputH_ =
outputSize(imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
outputW_ = outputSize(imageW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the outputSizes are consistent with config
CHECK_EQ(outputH_, outputY_);
CHECK_EQ(outputW_, outputX_);
out_->setFrameHeight(outputH_);
out_->setFrameWidth(outputW_);
}
reshapeImageDescriptors();
if (!isSelectAlgo_) {
hl_conv_workspace(imageDesc_,
outputDesc_,
filterDesc_,
convDesc_,
&fwdAlgo_,
&fwdLimitBytes_,
&bwdDataAlgo_,
&bwdDataLimitBytes_,
&bwdFilterAlgo_,
&bwdFilterLimitBytes_);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
allocConvWorkSpace(maxWorkSpace);
}
isSelectAlgo_ = true;
}
void ConvBaseOperator::computeConvSizes() {
hl_create_filter_descriptor(
&filterDesc_, channels_, numFilters_, filterSizeY_, filterSize_);
......@@ -153,15 +115,6 @@ void ConvBaseOperator::reshapeImageDescriptors() {
padding_,
strideY_,
stride_);
if (isDeconv_) {
inputOffset_ = numFilters_ * outputH_ * outputW_;
outputOffset_ = channels_ * imageH_ * imageW_;
} else {
inputOffset_ = channels_ * imageH_ * imageW_;
outputOffset_ = numFilters_ * outputH_ * outputW_;
}
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
}
void ConvBaseOperator::getConvParams() {
......
......@@ -56,7 +56,7 @@ protected:
/**
* Allocate Gpu Memory for cudnn convolution algorithms.
*/
void allocConvWorkSpace(size_t maxWorkSpace);
void allocConvWorkSpace();
/**
* Create cudnn tensor descriptor for convolution operation.
......@@ -71,7 +71,7 @@ protected:
/**
* Reshape cudnn tensor descriptor.
*/
void reshape(int batchSize);
virtual void reshape(int batchSize) = 0;
/**
* Check filter size is equal to the size calculated by parameters from
......
......@@ -140,19 +140,7 @@ void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
void ConvBaseProjection::reshape(int batchSize) {
size_t width = calOutputSize();
CHECK_EQ(width, out_->value->getWidth());
if (isDeconv_) {
CHECK_EQ(static_cast<size_t>(configChannels_ * outputH_ * outputW_),
in_->value->getWidth())
<< "Wrong input size for convolution transpose"
<< " channels=" << configChannels_ << " outputH=" << outputH_
<< " outputW=" << outputW_ << " inputSize=" << in_->value->getWidth();
} else {
CHECK_EQ(static_cast<size_t>(configChannels_ * imageH_ * imageW_),
in_->value->getWidth())
<< "Wrong input size for convolution"
<< " channels=" << configChannels_ << " imageH=" << imageH_
<< " imageW=" << imageW_ << " inputSize=" << in_->value->getWidth();
}
CHECK_EQ(calInputSize(), in_->value->getWidth());
isSelectAlgo_ = (batchSize == batchNum_);
batchNum_ = batchSize;
......
......@@ -40,54 +40,8 @@ protected:
void reshapeTensorDesc(int batchSize);
void reshape(int batchSize);
size_t calOutputSize() {
if (isDeconv_) {
outputH_ = in_->getFrameHeight();
outputW_ = in_->getFrameWidth();
if (outputH_ == 0) outputH_ = configOutH_;
if (outputW_ == 0) outputW_ = configOutW_;
imageH_ = imageSize(outputH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
imageW_ = imageSize(outputW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument*>(out_)->setFrameHeight(imageH_);
const_cast<Argument*>(out_)->setFrameWidth(imageW_);
inputOffset_ = (configChannels_ / groups_) * outputH_ * outputW_;
outputOffset_ = (configNumFilters_ / groups_) * imageH_ * imageW_;
return imageH_ * imageW_ * configNumFilters_;
} else {
imageH_ = in_->getFrameHeight();
imageW_ = in_->getFrameWidth();
if (imageH_ == 0) imageH_ = configImgH_;
if (imageW_ == 0) imageW_ = configImgW_;
outputH_ = outputSize(imageH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
outputW_ = outputSize(imageW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument*>(out_)->setFrameHeight(outputH_);
const_cast<Argument*>(out_)->setFrameWidth(outputW_);
inputOffset_ = (configChannels_ / groups_) * imageH_ * imageW_;
outputOffset_ = (configNumFilters_ / groups_) * outputH_ * outputW_;
return outputH_ * outputW_ * configNumFilters_;
}
}
virtual size_t calOutputSize() = 0;
virtual size_t calInputSize() = 0;
static void* getSpaceBytes(size_t size);
......
......@@ -29,6 +29,32 @@ namespace paddle {
REGISTER_OPERATOR(conv, ConvOperator);
void ConvOperator::reshape(int batchSize) {
imageH_ = ins_[0]->getFrameHeight();
imageW_ = ins_[0]->getFrameWidth();
if (imageH_ == 0) imageH_ = imgSizeY_;
if (imageW_ == 0) imageW_ = imgSize_;
outputH_ = outputSize(imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
outputW_ = outputSize(imageW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the outputSizes are consistent with config
CHECK_EQ(outputH_, outputY_);
CHECK_EQ(outputW_, outputX_);
out_->setFrameHeight(outputH_);
out_->setFrameWidth(outputW_);
reshapeImageDescriptors();
inputOffset_ = channels_ * imageH_ * imageW_;
outputOffset_ = numFilters_ * outputH_ * outputW_;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
if (!isSelectAlgo_) {
allocConvWorkSpace();
}
isSelectAlgo_ = true;
}
void ConvOperator::forward() {
size_t batchSize = ins_[0]->value->getHeight();
reshape(batchSize);
......
......@@ -38,6 +38,7 @@ public:
virtual ~ConvOperator() {}
void forward() override;
void backward() override;
void reshape(int batchSize) override;
};
} // namespace paddle
......@@ -19,6 +19,34 @@ namespace paddle {
REGISTER_PROJECTION(conv, ConvProjection);
size_t ConvProjection::calOutputSize() {
imageH_ = in_->getFrameHeight();
imageW_ = in_->getFrameWidth();
if (imageH_ == 0) imageH_ = configImgH_;
if (imageW_ == 0) imageW_ = configImgW_;
outputH_ = outputSize(imageH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
outputW_ = outputSize(imageW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument *>(out_)->setFrameHeight(outputH_);
const_cast<Argument *>(out_)->setFrameWidth(outputW_);
inputOffset_ = (configChannels_ / groups_) * imageH_ * imageW_;
outputOffset_ = (configNumFilters_ / groups_) * outputH_ * outputW_;
return outputH_ * outputW_ * configNumFilters_;
}
size_t ConvProjection::calInputSize() {
return static_cast<size_t>(configChannels_ * imageH_ * imageW_);
}
void ConvProjection::forward() {
int batchSize = in_->value->getHeight();
reshape(batchSize);
......
......@@ -36,6 +36,8 @@ public:
virtual void forward();
virtual void backward(const UpdateCallback& callback);
virtual size_t calOutputSize();
virtual size_t calInputSize();
};
} // namespace paddle
......@@ -29,6 +29,32 @@ namespace paddle {
REGISTER_OPERATOR(convt, ConvTransOperator);
void ConvTransOperator::reshape(int batchSize) {
outputH_ = ins_[0]->getFrameHeight();
outputW_ = ins_[0]->getFrameWidth();
if (outputH_ == 0) outputH_ = outputY_;
if (outputW_ == 0) outputW_ = outputX_;
imageH_ = imageSize(outputH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
imageW_ = imageSize(outputW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the imageSizes are consistent with config
CHECK_EQ(imageH_, imgSizeY_);
CHECK_EQ(imageW_, imgSize_);
out_->setFrameHeight(imageH_);
out_->setFrameWidth(imageW_);
reshapeImageDescriptors();
inputOffset_ = numFilters_ * outputH_ * outputW_;
outputOffset_ = channels_ * imageH_ * imageW_;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
if (!isSelectAlgo_) {
allocConvWorkSpace();
}
isSelectAlgo_ = true;
}
void ConvTransOperator::forward() {
size_t batchSize = ins_[0]->value->getHeight();
reshape(batchSize);
......
......@@ -38,6 +38,7 @@ public:
virtual ~ConvTransOperator() {}
void forward() override;
void backward() override;
void reshape(int batchSize) override;
};
} // namespace paddle
......@@ -18,6 +18,34 @@ limitations under the License. */
namespace paddle {
REGISTER_PROJECTION(convt, ConvTransProjection);
size_t ConvTransProjection::calOutputSize() {
outputH_ = in_->getFrameHeight();
outputW_ = in_->getFrameWidth();
if (outputH_ == 0) outputH_ = configOutH_;
if (outputW_ == 0) outputW_ = configOutW_;
imageH_ = imageSize(outputH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
imageW_ = imageSize(outputW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument *>(out_)->setFrameHeight(imageH_);
const_cast<Argument *>(out_)->setFrameWidth(imageW_);
inputOffset_ = (configChannels_ / groups_) * outputH_ * outputW_;
outputOffset_ = (configNumFilters_ / groups_) * imageH_ * imageW_;
return imageH_ * imageW_ * configNumFilters_;
}
size_t ConvTransProjection::calInputSize() {
return static_cast<size_t>(configChannels_ * outputH_ * outputW_);
}
void ConvTransProjection::forward() {
int batchSize = in_->value->getHeight();
......
......@@ -36,6 +36,8 @@ public:
virtual void forward();
virtual void backward(const UpdateCallback& callback);
virtual size_t calOutputSize();
virtual size_t calInputSize();
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册