diff --git a/paddle/cuda/src/hl_batch_norm.cu b/paddle/cuda/src/hl_batch_norm.cu index 57474ee2f741442c262c9caab6f3fd025489d0d3..5828ecb8e049c2f0573ab8547164794bef6db1ca 100644 --- a/paddle/cuda/src/hl_batch_norm.cu +++ b/paddle/cuda/src/hl_batch_norm.cu @@ -25,11 +25,11 @@ __global__ void batchNormInference(real* output, size_t channel, size_t height, size_t width) { - const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int tid = threadIdx.x; const int num = channel * height * width; - const int batch = blockIdx.y; + const int batch = blockIdx.x; for (int i = tid; i < num; i += blockDim.x) { - const int c = (i / (height * width)) % channel; + const int c = i / (height * width); const int id = batch * num + i; real val = input[id] - estimatedMean[c]; val /= sqrt(estimatedVar[c] + epsilon); @@ -50,19 +50,17 @@ void hl_batch_norm_cuda_inference(const real* input, size_t channel, size_t height, size_t width) { - dim3 block(256, 1); - dim3 grid(1, batchSize); - batchNormInference<<>>(output, - input, - scale, - bias, - estimatedMean, - estimatedVar, - epsilon, - batchSize, - channel, - height, - width); + batchNormInference<<>>(output, + input, + scale, + bias, + estimatedMean, + estimatedVar, + epsilon, + batchSize, + channel, + height, + width); CHECK_SYNC("hl_batch_norm_cuda_inference failed!"); } diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index cc2cc21cdfd20e26f1b4f708535c3bd863f7871d..44ba2c4b7d1562d2ce839b5f4b4de1af35e6925f 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -80,9 +80,21 @@ void CudnnBatchNormLayer::forward(PassType passType) { savedInvVar); } else { // used movingMean and movingVar in testing - if (batchSize > 1024) { - // there is a bug in cudnn library when the batch size - // is larger than 1024. + if (batchSize <= 1024) { + hl_batch_norm_forward_inference(ioDesc_, + input, + ioDesc_, + output, + bnParamDesc_, + gamma, + beta, + movingMean, + movingVar, + EPS); + } else { + // There is a limitation in cudnn library. + // When the batch size is larger than 1024 in cuDNN v5.1, + // the cudnnBatchNormalizationForwardInference will fail. hl_batch_norm_cuda_inference(input, output, gamma, @@ -94,17 +106,6 @@ void CudnnBatchNormLayer::forward(PassType passType) { channels_, imageH_, imageW_); - } else { - hl_batch_norm_forward_inference(ioDesc_, - input, - ioDesc_, - output, - bnParamDesc_, - gamma, - beta, - movingMean, - movingVar, - EPS); } }