From a5e90f3ee20326daa69a03e1bd8299d6db1373a8 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 11 Oct 2017 20:11:05 +0800 Subject: [PATCH] Add omp to batch norm kernel --- mace/kernels/batch_norm.h | 1 + mace/kernels/neon/batch_norm_neon.cc | 1 + mace/ops/batch_norm_benchmark.cc | 8 +++++--- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 84312a03..be50df0f 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -33,6 +33,7 @@ struct BatchNormFunctor { // new_offset = \offset - mean * common_val; // Y = new_scale * X + new_offset; T new_scale, new_offset; +#pragma omp parallel for for (index_t c = 0; c < channel; ++c) { new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); new_offset = offset[c] - mean[c] * new_scale; diff --git a/mace/kernels/neon/batch_norm_neon.cc b/mace/kernels/neon/batch_norm_neon.cc index 0b121a70..cba69533 100644 --- a/mace/kernels/neon/batch_norm_neon.cc +++ b/mace/kernels/neon/batch_norm_neon.cc @@ -31,6 +31,7 @@ void BatchNormFunctor::operator()( float new_scale, new_offset; index_t count = sample_size >> 2; index_t remain_count = sample_size - (count << 2); +#pragma omp parallel for for (index_t c = 0; c < channel; ++c) { new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); new_offset = offset[c] - mean[c] * new_scale; diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index 079ad6f1..16763322 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -57,14 +57,16 @@ static void BatchNorm( BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); BM_BATCH_NORM(1, 1, 512, 512, float); -BM_BATCH_NORM(1, 1, 1024, 1024, float); BM_BATCH_NORM(1, 3, 128, 128, float); BM_BATCH_NORM(1, 3, 512, 512, float); -BM_BATCH_NORM(1, 3, 1024, 1024, float); +BM_BATCH_NORM(1, 32, 112, 112, float); BM_BATCH_NORM(1, 64, 256, 256, float); BM_BATCH_NORM(1, 64, 512, 512, float); +BM_BATCH_NORM(1, 128, 56, 56, float); BM_BATCH_NORM(1, 128, 256, 256, float); -BM_BATCH_NORM(1, 128, 512, 512, float); +BM_BATCH_NORM(1, 256, 14, 14, float); +BM_BATCH_NORM(1, 512, 14, 14, float); +BM_BATCH_NORM(1, 1024, 7, 7, float); BM_BATCH_NORM(32, 1, 256, 256, float); BM_BATCH_NORM(32, 3, 256, 256, float); } // namespace mace \ No newline at end of file -- GitLab