提交 7da1db05 编写于 作者: D dangqingqing

update cuda kernel.

上级 da7b9a5e
...@@ -25,11 +25,11 @@ __global__ void batchNormInference(real* output, ...@@ -25,11 +25,11 @@ __global__ void batchNormInference(real* output,
size_t channel, size_t channel,
size_t height, size_t height,
size_t width) { size_t width) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x; const int tid = threadIdx.x;
const int num = channel * height * width; const int num = channel * height * width;
const int batch = blockIdx.y; const int batch = blockIdx.x;
for (int i = tid; i < num; i += blockDim.x) { for (int i = tid; i < num; i += blockDim.x) {
const int c = (i / (height * width)) % channel; const int c = i / (height * width);
const int id = batch * num + i; const int id = batch * num + i;
real val = input[id] - estimatedMean[c]; real val = input[id] - estimatedMean[c];
val /= sqrt(estimatedVar[c] + epsilon); val /= sqrt(estimatedVar[c] + epsilon);
...@@ -50,9 +50,7 @@ void hl_batch_norm_cuda_inference(const real* input, ...@@ -50,9 +50,7 @@ void hl_batch_norm_cuda_inference(const real* input,
size_t channel, size_t channel,
size_t height, size_t height,
size_t width) { size_t width) {
dim3 block(256, 1); batchNormInference<<<batchSize, 256, 0, STREAM_DEFAULT>>>(output,
dim3 grid(1, batchSize);
batchNormInference<<<grid, block, 0, STREAM_DEFAULT>>>(output,
input, input,
scale, scale,
bias, bias,
......
...@@ -80,31 +80,32 @@ void CudnnBatchNormLayer::forward(PassType passType) { ...@@ -80,31 +80,32 @@ void CudnnBatchNormLayer::forward(PassType passType) {
savedInvVar); savedInvVar);
} else { } else {
// used movingMean and movingVar in testing // used movingMean and movingVar in testing
if (batchSize > 1024) { if (batchSize <= 1024) {
// there is a bug in cudnn library when the batch size hl_batch_norm_forward_inference(ioDesc_,
// is larger than 1024. input,
hl_batch_norm_cuda_inference(input, ioDesc_,
output, output,
bnParamDesc_,
gamma, gamma,
beta, beta,
movingMean, movingMean,
movingVar, movingVar,
EPS, EPS);
batchSize,
channels_,
imageH_,
imageW_);
} else { } else {
hl_batch_norm_forward_inference(ioDesc_, // There is a limitation in cudnn library.
input, // When the batch size is larger than 1024 in cuDNN v5.1,
ioDesc_, // the cudnnBatchNormalizationForwardInference will fail.
hl_batch_norm_cuda_inference(input,
output, output,
bnParamDesc_,
gamma, gamma,
beta, beta,
movingMean, movingMean,
movingVar, movingVar,
EPS); EPS,
batchSize,
channels_,
imageH_,
imageW_);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册