提交 c94b2755 编写于 作者: W wanghaoshuang

fix conv layer reshape

上级 0d26a158
...@@ -87,9 +87,6 @@ void ConvBaseProjection::initCudnn() { ...@@ -87,9 +87,6 @@ void ConvBaseProjection::initCudnn() {
bwdDataLimitBytes_ = 0; bwdDataLimitBytes_ = 0;
bwdFilterLimitBytes_ = 0; bwdFilterLimitBytes_ = 0;
workSpaceInBytes_ = 0; workSpaceInBytes_ = 0;
batchNum_ = 0;
isSelectAlgo_ = false;
} }
void ConvBaseProjection::reshapeTensorDesc(int batchSize) { void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
...@@ -142,32 +139,25 @@ void ConvBaseProjection::reshape(int batchSize) { ...@@ -142,32 +139,25 @@ void ConvBaseProjection::reshape(int batchSize) {
CHECK_EQ(width, out_->value->getWidth()); CHECK_EQ(width, out_->value->getWidth());
CHECK_EQ(calInputSize(), in_->value->getWidth()); CHECK_EQ(calInputSize(), in_->value->getWidth());
isSelectAlgo_ = (batchSize == batchNum_); reshapeTensorDesc(batchSize);
batchNum_ = batchSize; hl_conv_workspace(imageDesc_,
outputDesc_,
if (!isSelectAlgo_) { filterDesc_,
reshapeTensorDesc(batchSize); convDesc_,
hl_conv_workspace(imageDesc_, &fwdAlgo_,
outputDesc_, &fwdLimitBytes_,
filterDesc_, &bwdDataAlgo_,
convDesc_, &bwdDataLimitBytes_,
&fwdAlgo_, &bwdFilterAlgo_,
&fwdLimitBytes_, &bwdFilterLimitBytes_);
&bwdDataAlgo_,
&bwdDataLimitBytes_, size_t maxWorkSpace = 0;
&bwdFilterAlgo_, maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
&bwdFilterLimitBytes_); maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
workSpaceInBytes_ = maxWorkSpace;
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_); << " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
workSpaceInBytes_ = maxWorkSpace;
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
<< " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
}
isSelectAlgo_ = true;
} }
void *ConvBaseProjection::getSpaceBytes(size_t size) { void *ConvBaseProjection::getSpaceBytes(size_t size) {
......
...@@ -101,12 +101,6 @@ protected: ...@@ -101,12 +101,6 @@ protected:
size_t bwdFilterLimitBytes_; size_t bwdFilterLimitBytes_;
/// Size of total work space. /// Size of total work space.
size_t workSpaceInBytes_; size_t workSpaceInBytes_;
/// Whether to call cuDNN api to choose conv algorithm.
bool isSelectAlgo_;
/// batchNum is used to record batch size. If the batch size is changed,
/// the selection algorithm will be called.
int batchNum_;
bool bias_; bool bias_;
std::unique_ptr<Weight> weight_; std::unique_ptr<Weight> weight_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册