From 5500153a6d4fd4a70f0aa3adc773131b2d539d7a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 6 Sep 2017 11:34:02 +0800 Subject: [PATCH] fix cudnnBatchNorm for 3D data --- paddle/gserver/layers/CudnnBatchNormLayer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index 44ba2c4b7..49a9540c0 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap, } void CudnnBatchNormLayer::reshape(int batchSize) { - hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_); + hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_); } void CudnnBatchNormLayer::forward(PassType passType) { @@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { EPS, batchSize, channels_, - imageH_, + imageH_ * imageD_, imageW_); } } -- GitLab