提交 95da095d 编写于 作者: D dangqingqing

fix cudnn conv bug which occurs in image classfication demo in GTX GPU

上级 7eb29f26
......@@ -85,6 +85,7 @@ bool CudnnConvLayer::init(const LayerMap &layerMap,
biasOffset_ = numFilters_ / groups_[0];
}
batchNum_ = 0;
isSelectAlgo_ = false;
return true;
}
......@@ -132,6 +133,9 @@ void CudnnConvLayer::reshape(int batchSize) {
getOutput().setFrameHeight(outputH_);
getOutput().setFrameWidth(outputW_);
isSelectAlgo_ = (batchSize == batchNum_);
batchNum_ = batchSize;
size_t maxWorkSpace = 0;
for (size_t i = 0; i < inputLayers_.size(); i++) {
CHECK_EQ(inputLayers_[i]->getOutput().value->getWidth(),
......@@ -160,6 +164,10 @@ void CudnnConvLayer::reshape(int batchSize) {
maxWorkSpace = std::max(fwdLimitBytes_[i], bwdDataLimitBytes_[i]);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_[i]);
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_[i]
<< " / " << bwdDataAlgo_[i]
<< " / " << bwdFilterAlgo_[i];
}
}
......
......@@ -87,6 +87,10 @@ protected:
/// Is or not select conv algorihtm.
bool isSelectAlgo_;
/// batchNum is used to record batch size. If the batch size is changed,
/// the selection algorithm will be called.
int batchNum_;
public:
explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册