提交 48a6168c 编写于 作者: W wangyang59

following comments from qingqing01

上级 b16c0a84
...@@ -49,49 +49,7 @@ ConvBaseOperator::ConvBaseOperator(const OperatorConfig &config, bool useGpu) ...@@ -49,49 +49,7 @@ ConvBaseOperator::ConvBaseOperator(const OperatorConfig &config, bool useGpu)
isSelectAlgo_ = false; isSelectAlgo_ = false;
} }
void ConvBaseOperator::allocConvWorkSpace(size_t maxWorkSpace) { void ConvBaseOperator::allocConvWorkSpace() {
if (maxWorkSpace > workSpaceInBytes_) {
if (workSpaceInBytes_ != 0) {
hl_free_mem_device(workSpace_);
}
// total amount of storage needed
workSpace_ = hl_malloc_device(maxWorkSpace);
workSpaceInBytes_ = maxWorkSpace;
}
}
void ConvBaseOperator::reshape(int batchSize) {
if (isDeconv_) {
outputH_ = ins_[0]->getFrameHeight();
outputW_ = ins_[0]->getFrameWidth();
if (outputH_ == 0) outputH_ = outputY_;
if (outputW_ == 0) outputW_ = outputX_;
imageH_ =
imageSize(outputH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
imageW_ = imageSize(outputW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the imageSizes are consistent with config
CHECK_EQ(imageH_, imgSizeY_);
CHECK_EQ(imageW_, imgSize_);
out_->setFrameHeight(imageH_);
out_->setFrameWidth(imageW_);
} else {
imageH_ = ins_[0]->getFrameHeight();
imageW_ = ins_[0]->getFrameWidth();
if (imageH_ == 0) imageH_ = imgSizeY_;
if (imageW_ == 0) imageW_ = imgSize_;
outputH_ =
outputSize(imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
outputW_ = outputSize(imageW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the outputSizes are consistent with config
CHECK_EQ(outputH_, outputY_);
CHECK_EQ(outputW_, outputX_);
out_->setFrameHeight(outputH_);
out_->setFrameWidth(outputW_);
}
reshapeImageDescriptors();
if (!isSelectAlgo_) {
hl_conv_workspace(imageDesc_, hl_conv_workspace(imageDesc_,
outputDesc_, outputDesc_,
filterDesc_, filterDesc_,
...@@ -107,10 +65,14 @@ void ConvBaseOperator::reshape(int batchSize) { ...@@ -107,10 +65,14 @@ void ConvBaseOperator::reshape(int batchSize) {
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_); maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
allocConvWorkSpace(maxWorkSpace); if (maxWorkSpace > workSpaceInBytes_) {
if (workSpaceInBytes_ != 0) {
hl_free_mem_device(workSpace_);
}
// total amount of storage needed
workSpace_ = hl_malloc_device(maxWorkSpace);
workSpaceInBytes_ = maxWorkSpace;
} }
isSelectAlgo_ = true;
} }
void ConvBaseOperator::computeConvSizes() { void ConvBaseOperator::computeConvSizes() {
...@@ -153,15 +115,6 @@ void ConvBaseOperator::reshapeImageDescriptors() { ...@@ -153,15 +115,6 @@ void ConvBaseOperator::reshapeImageDescriptors() {
padding_, padding_,
strideY_, strideY_,
stride_); stride_);
if (isDeconv_) {
inputOffset_ = numFilters_ * outputH_ * outputW_;
outputOffset_ = channels_ * imageH_ * imageW_;
} else {
inputOffset_ = channels_ * imageH_ * imageW_;
outputOffset_ = numFilters_ * outputH_ * outputW_;
}
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
} }
void ConvBaseOperator::getConvParams() { void ConvBaseOperator::getConvParams() {
......
...@@ -56,7 +56,7 @@ protected: ...@@ -56,7 +56,7 @@ protected:
/** /**
* Allocate Gpu Memory for cudnn convolution algorithms. * Allocate Gpu Memory for cudnn convolution algorithms.
*/ */
void allocConvWorkSpace(size_t maxWorkSpace); void allocConvWorkSpace();
/** /**
* Create cudnn tensor descriptor for convolution operation. * Create cudnn tensor descriptor for convolution operation.
...@@ -71,7 +71,7 @@ protected: ...@@ -71,7 +71,7 @@ protected:
/** /**
* Reshape cudnn tensor descriptor. * Reshape cudnn tensor descriptor.
*/ */
void reshape(int batchSize); virtual void reshape(int batchSize) = 0;
/** /**
* Check filter size is equal to the size calculated by parameters from * Check filter size is equal to the size calculated by parameters from
......
...@@ -140,19 +140,7 @@ void ConvBaseProjection::reshapeTensorDesc(int batchSize) { ...@@ -140,19 +140,7 @@ void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
void ConvBaseProjection::reshape(int batchSize) { void ConvBaseProjection::reshape(int batchSize) {
size_t width = calOutputSize(); size_t width = calOutputSize();
CHECK_EQ(width, out_->value->getWidth()); CHECK_EQ(width, out_->value->getWidth());
if (isDeconv_) { CHECK_EQ(calInputSize(), in_->value->getWidth());
CHECK_EQ(static_cast<size_t>(configChannels_ * outputH_ * outputW_),
in_->value->getWidth())
<< "Wrong input size for convolution transpose"
<< " channels=" << configChannels_ << " outputH=" << outputH_
<< " outputW=" << outputW_ << " inputSize=" << in_->value->getWidth();
} else {
CHECK_EQ(static_cast<size_t>(configChannels_ * imageH_ * imageW_),
in_->value->getWidth())
<< "Wrong input size for convolution"
<< " channels=" << configChannels_ << " imageH=" << imageH_
<< " imageW=" << imageW_ << " inputSize=" << in_->value->getWidth();
}
isSelectAlgo_ = (batchSize == batchNum_); isSelectAlgo_ = (batchSize == batchNum_);
batchNum_ = batchSize; batchNum_ = batchSize;
......
...@@ -40,54 +40,8 @@ protected: ...@@ -40,54 +40,8 @@ protected:
void reshapeTensorDesc(int batchSize); void reshapeTensorDesc(int batchSize);
void reshape(int batchSize); void reshape(int batchSize);
size_t calOutputSize() { virtual size_t calOutputSize() = 0;
if (isDeconv_) { virtual size_t calInputSize() = 0;
outputH_ = in_->getFrameHeight();
outputW_ = in_->getFrameWidth();
if (outputH_ == 0) outputH_ = configOutH_;
if (outputW_ == 0) outputW_ = configOutW_;
imageH_ = imageSize(outputH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
imageW_ = imageSize(outputW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument*>(out_)->setFrameHeight(imageH_);
const_cast<Argument*>(out_)->setFrameWidth(imageW_);
inputOffset_ = (configChannels_ / groups_) * outputH_ * outputW_;
outputOffset_ = (configNumFilters_ / groups_) * imageH_ * imageW_;
return imageH_ * imageW_ * configNumFilters_;
} else {
imageH_ = in_->getFrameHeight();
imageW_ = in_->getFrameWidth();
if (imageH_ == 0) imageH_ = configImgH_;
if (imageW_ == 0) imageW_ = configImgW_;
outputH_ = outputSize(imageH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
outputW_ = outputSize(imageW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument*>(out_)->setFrameHeight(outputH_);
const_cast<Argument*>(out_)->setFrameWidth(outputW_);
inputOffset_ = (configChannels_ / groups_) * imageH_ * imageW_;
outputOffset_ = (configNumFilters_ / groups_) * outputH_ * outputW_;
return outputH_ * outputW_ * configNumFilters_;
}
}
static void* getSpaceBytes(size_t size); static void* getSpaceBytes(size_t size);
......
...@@ -29,6 +29,32 @@ namespace paddle { ...@@ -29,6 +29,32 @@ namespace paddle {
REGISTER_OPERATOR(conv, ConvOperator); REGISTER_OPERATOR(conv, ConvOperator);
void ConvOperator::reshape(int batchSize) {
imageH_ = ins_[0]->getFrameHeight();
imageW_ = ins_[0]->getFrameWidth();
if (imageH_ == 0) imageH_ = imgSizeY_;
if (imageW_ == 0) imageW_ = imgSize_;
outputH_ = outputSize(imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
outputW_ = outputSize(imageW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the outputSizes are consistent with config
CHECK_EQ(outputH_, outputY_);
CHECK_EQ(outputW_, outputX_);
out_->setFrameHeight(outputH_);
out_->setFrameWidth(outputW_);
reshapeImageDescriptors();
inputOffset_ = channels_ * imageH_ * imageW_;
outputOffset_ = numFilters_ * outputH_ * outputW_;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
if (!isSelectAlgo_) {
allocConvWorkSpace();
}
isSelectAlgo_ = true;
}
void ConvOperator::forward() { void ConvOperator::forward() {
size_t batchSize = ins_[0]->value->getHeight(); size_t batchSize = ins_[0]->value->getHeight();
reshape(batchSize); reshape(batchSize);
......
...@@ -38,6 +38,7 @@ public: ...@@ -38,6 +38,7 @@ public:
virtual ~ConvOperator() {} virtual ~ConvOperator() {}
void forward() override; void forward() override;
void backward() override; void backward() override;
void reshape(int batchSize) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -19,6 +19,34 @@ namespace paddle { ...@@ -19,6 +19,34 @@ namespace paddle {
REGISTER_PROJECTION(conv, ConvProjection); REGISTER_PROJECTION(conv, ConvProjection);
size_t ConvProjection::calOutputSize() {
imageH_ = in_->getFrameHeight();
imageW_ = in_->getFrameWidth();
if (imageH_ == 0) imageH_ = configImgH_;
if (imageW_ == 0) imageW_ = configImgW_;
outputH_ = outputSize(imageH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
outputW_ = outputSize(imageW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument *>(out_)->setFrameHeight(outputH_);
const_cast<Argument *>(out_)->setFrameWidth(outputW_);
inputOffset_ = (configChannels_ / groups_) * imageH_ * imageW_;
outputOffset_ = (configNumFilters_ / groups_) * outputH_ * outputW_;
return outputH_ * outputW_ * configNumFilters_;
}
size_t ConvProjection::calInputSize() {
return static_cast<size_t>(configChannels_ * imageH_ * imageW_);
}
void ConvProjection::forward() { void ConvProjection::forward() {
int batchSize = in_->value->getHeight(); int batchSize = in_->value->getHeight();
reshape(batchSize); reshape(batchSize);
......
...@@ -36,6 +36,8 @@ public: ...@@ -36,6 +36,8 @@ public:
virtual void forward(); virtual void forward();
virtual void backward(const UpdateCallback& callback); virtual void backward(const UpdateCallback& callback);
virtual size_t calOutputSize();
virtual size_t calInputSize();
}; };
} // namespace paddle } // namespace paddle
...@@ -29,6 +29,32 @@ namespace paddle { ...@@ -29,6 +29,32 @@ namespace paddle {
REGISTER_OPERATOR(convt, ConvTransOperator); REGISTER_OPERATOR(convt, ConvTransOperator);
void ConvTransOperator::reshape(int batchSize) {
outputH_ = ins_[0]->getFrameHeight();
outputW_ = ins_[0]->getFrameWidth();
if (outputH_ == 0) outputH_ = outputY_;
if (outputW_ == 0) outputW_ = outputX_;
imageH_ = imageSize(outputH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
imageW_ = imageSize(outputW_, filterSize_, padding_, stride_, caffeMode_);
/// Check that the imageSizes are consistent with config
CHECK_EQ(imageH_, imgSizeY_);
CHECK_EQ(imageW_, imgSize_);
out_->setFrameHeight(imageH_);
out_->setFrameWidth(imageW_);
reshapeImageDescriptors();
inputOffset_ = numFilters_ * outputH_ * outputW_;
outputOffset_ = channels_ * imageH_ * imageW_;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
if (!isSelectAlgo_) {
allocConvWorkSpace();
}
isSelectAlgo_ = true;
}
void ConvTransOperator::forward() { void ConvTransOperator::forward() {
size_t batchSize = ins_[0]->value->getHeight(); size_t batchSize = ins_[0]->value->getHeight();
reshape(batchSize); reshape(batchSize);
......
...@@ -38,6 +38,7 @@ public: ...@@ -38,6 +38,7 @@ public:
virtual ~ConvTransOperator() {} virtual ~ConvTransOperator() {}
void forward() override; void forward() override;
void backward() override; void backward() override;
void reshape(int batchSize) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,34 @@ limitations under the License. */ ...@@ -18,6 +18,34 @@ limitations under the License. */
namespace paddle { namespace paddle {
REGISTER_PROJECTION(convt, ConvTransProjection); REGISTER_PROJECTION(convt, ConvTransProjection);
size_t ConvTransProjection::calOutputSize() {
outputH_ = in_->getFrameHeight();
outputW_ = in_->getFrameWidth();
if (outputH_ == 0) outputH_ = configOutH_;
if (outputW_ == 0) outputW_ = configOutW_;
imageH_ = imageSize(outputH_,
filterH_,
paddingH_,
strideH_,
/* caffeMode */ true);
imageW_ = imageSize(outputW_,
filterW_,
paddingW_,
strideW_,
/* caffeMode */ true);
const_cast<Argument *>(out_)->setFrameHeight(imageH_);
const_cast<Argument *>(out_)->setFrameWidth(imageW_);
inputOffset_ = (configChannels_ / groups_) * outputH_ * outputW_;
outputOffset_ = (configNumFilters_ / groups_) * imageH_ * imageW_;
return imageH_ * imageW_ * configNumFilters_;
}
size_t ConvTransProjection::calInputSize() {
return static_cast<size_t>(configChannels_ * outputH_ * outputW_);
}
void ConvTransProjection::forward() { void ConvTransProjection::forward() {
int batchSize = in_->value->getHeight(); int batchSize = in_->value->getHeight();
......
...@@ -36,6 +36,8 @@ public: ...@@ -36,6 +36,8 @@ public:
virtual void forward(); virtual void forward();
virtual void backward(const UpdateCallback& callback); virtual void backward(const UpdateCallback& callback);
virtual size_t calOutputSize();
virtual size_t calInputSize();
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册