diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp index 9b0ae20f089e34a719883bc65e88e33ab9334e39..ed3887cbf653878623764a310c9f364f4d8be27f 100644 --- a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp @@ -119,7 +119,7 @@ void MKLDNNBatchNormLayer::reshape( int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { reshapeInput(bs, ih, iw); oh = ih; - ow = ow; + ow = iw; // ic_ and oc can not be changed CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) << "Input channel can not be changed"; diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp index 3960d699ac8dc08316ee413116878ee3eda65793..a0e039c2a33b586e21775ad06c1278a10804d654 100644 --- a/paddle/gserver/tests/test_MKLDNN.cpp +++ b/paddle/gserver/tests/test_MKLDNN.cpp @@ -269,6 +269,7 @@ void testBatchNormLayer(const testBatchNormDesc& pm) { TEST(MKLDNNLayer, BatchNormLayer) { testBatchNormLayer({4, 10, 6, 6}); testBatchNormLayer({16, 32, 16, 16}); + testBatchNormLayer({4, 16, 8, 10}); } struct testImageDesc {