提交 7da1db05 编写于 作者: D dangqingqing

update cuda kernel.

上级 da7b9a5e
......@@ -25,11 +25,11 @@ __global__ void batchNormInference(real* output,
size_t channel,
size_t height,
size_t width) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int tid = threadIdx.x;
const int num = channel * height * width;
const int batch = blockIdx.y;
const int batch = blockIdx.x;
for (int i = tid; i < num; i += blockDim.x) {
const int c = (i / (height * width)) % channel;
const int c = i / (height * width);
const int id = batch * num + i;
real val = input[id] - estimatedMean[c];
val /= sqrt(estimatedVar[c] + epsilon);
......@@ -50,9 +50,7 @@ void hl_batch_norm_cuda_inference(const real* input,
size_t channel,
size_t height,
size_t width) {
dim3 block(256, 1);
dim3 grid(1, batchSize);
batchNormInference<<<grid, block, 0, STREAM_DEFAULT>>>(output,
batchNormInference<<<batchSize, 256, 0, STREAM_DEFAULT>>>(output,
input,
scale,
bias,
......
......@@ -80,31 +80,32 @@ void CudnnBatchNormLayer::forward(PassType passType) {
savedInvVar);
} else {
// used movingMean and movingVar in testing
if (batchSize > 1024) {
// there is a bug in cudnn library when the batch size
// is larger than 1024.
hl_batch_norm_cuda_inference(input,
if (batchSize <= 1024) {
hl_batch_norm_forward_inference(ioDesc_,
input,
ioDesc_,
output,
bnParamDesc_,
gamma,
beta,
movingMean,
movingVar,
EPS,
batchSize,
channels_,
imageH_,
imageW_);
EPS);
} else {
hl_batch_norm_forward_inference(ioDesc_,
input,
ioDesc_,
// There is a limitation in cudnn library.
// When the batch size is larger than 1024 in cuDNN v5.1,
// the cudnnBatchNormalizationForwardInference will fail.
hl_batch_norm_cuda_inference(input,
output,
bnParamDesc_,
gamma,
beta,
movingMean,
movingVar,
EPS);
EPS,
batchSize,
channels_,
imageH_,
imageW_);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册