提交 c94b2755 编写于 作者: W wanghaoshuang

fix conv layer reshape

上级 0d26a158
...@@ -87,9 +87,6 @@ void ConvBaseProjection::initCudnn() { ...@@ -87,9 +87,6 @@ void ConvBaseProjection::initCudnn() {
bwdDataLimitBytes_ = 0; bwdDataLimitBytes_ = 0;
bwdFilterLimitBytes_ = 0; bwdFilterLimitBytes_ = 0;
workSpaceInBytes_ = 0; workSpaceInBytes_ = 0;
batchNum_ = 0;
isSelectAlgo_ = false;
} }
void ConvBaseProjection::reshapeTensorDesc(int batchSize) { void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
...@@ -142,10 +139,6 @@ void ConvBaseProjection::reshape(int batchSize) { ...@@ -142,10 +139,6 @@ void ConvBaseProjection::reshape(int batchSize) {
CHECK_EQ(width, out_->value->getWidth()); CHECK_EQ(width, out_->value->getWidth());
CHECK_EQ(calInputSize(), in_->value->getWidth()); CHECK_EQ(calInputSize(), in_->value->getWidth());
isSelectAlgo_ = (batchSize == batchNum_);
batchNum_ = batchSize;
if (!isSelectAlgo_) {
reshapeTensorDesc(batchSize); reshapeTensorDesc(batchSize);
hl_conv_workspace(imageDesc_, hl_conv_workspace(imageDesc_,
outputDesc_, outputDesc_,
...@@ -165,9 +158,6 @@ void ConvBaseProjection::reshape(int batchSize) { ...@@ -165,9 +158,6 @@ void ConvBaseProjection::reshape(int batchSize) {
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_ VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
<< " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_; << " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
}
isSelectAlgo_ = true;
} }
void *ConvBaseProjection::getSpaceBytes(size_t size) { void *ConvBaseProjection::getSpaceBytes(size_t size) {
......
...@@ -101,12 +101,6 @@ protected: ...@@ -101,12 +101,6 @@ protected:
size_t bwdFilterLimitBytes_; size_t bwdFilterLimitBytes_;
/// Size of total work space. /// Size of total work space.
size_t workSpaceInBytes_; size_t workSpaceInBytes_;
/// Whether to call cuDNN api to choose conv algorithm.
bool isSelectAlgo_;
/// batchNum is used to record batch size. If the batch size is changed,
/// the selection algorithm will be called.
int batchNum_;
bool bias_; bool bias_;
std::unique_ptr<Weight> weight_; std::unique_ptr<Weight> weight_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册