提交 5500153a 编写于 作者: C chengduoZH

fix cudnnBatchNorm for 3D data

上级 4cb2966d
......@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
}
void CudnnBatchNormLayer::reshape(int batchSize) {
hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
}
void CudnnBatchNormLayer::forward(PassType passType) {
......@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
EPS,
batchSize,
channels_,
imageH_,
imageH_ * imageD_,
imageW_);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册