提交 7da1db05 编写于 作者: D dangqingqing

update cuda kernel.

上级 da7b9a5e
......@@ -25,11 +25,11 @@ __global__ void batchNormInference(real* output,
size_t channel,
size_t height,
size_t width) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int tid = threadIdx.x;
const int num = channel * height * width;
const int batch = blockIdx.y;
const int batch = blockIdx.x;
for (int i = tid; i < num; i += blockDim.x) {
const int c = (i / (height * width)) % channel;
const int c = i / (height * width);
const int id = batch * num + i;
real val = input[id] - estimatedMean[c];
val /= sqrt(estimatedVar[c] + epsilon);
......@@ -50,19 +50,17 @@ void hl_batch_norm_cuda_inference(const real* input,
size_t channel,
size_t height,
size_t width) {
dim3 block(256, 1);
dim3 grid(1, batchSize);
batchNormInference<<<grid, block, 0, STREAM_DEFAULT>>>(output,
input,
scale,
bias,
estimatedMean,
estimatedVar,
epsilon,
batchSize,
channel,
height,
width);
batchNormInference<<<batchSize, 256, 0, STREAM_DEFAULT>>>(output,
input,
scale,
bias,
estimatedMean,
estimatedVar,
epsilon,
batchSize,
channel,
height,
width);
CHECK_SYNC("hl_batch_norm_cuda_inference failed!");
}
......@@ -80,9 +80,21 @@ void CudnnBatchNormLayer::forward(PassType passType) {
savedInvVar);
} else {
// used movingMean and movingVar in testing
if (batchSize > 1024) {
// there is a bug in cudnn library when the batch size
// is larger than 1024.
if (batchSize <= 1024) {
hl_batch_norm_forward_inference(ioDesc_,
input,
ioDesc_,
output,
bnParamDesc_,
gamma,
beta,
movingMean,
movingVar,
EPS);
} else {
// There is a limitation in cudnn library.
// When the batch size is larger than 1024 in cuDNN v5.1,
// the cudnnBatchNormalizationForwardInference will fail.
hl_batch_norm_cuda_inference(input,
output,
gamma,
......@@ -94,17 +106,6 @@ void CudnnBatchNormLayer::forward(PassType passType) {
channels_,
imageH_,
imageW_);
} else {
hl_batch_norm_forward_inference(ioDesc_,
input,
ioDesc_,
output,
bnParamDesc_,
gamma,
beta,
movingMean,
movingVar,
EPS);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册