diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 84312a03d2e59d10fd76eec93f9e4cff2199696a..be50df0fc1172fb413d2954c2cee6e49efbe2d53 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -33,6 +33,7 @@ struct BatchNormFunctor { // new_offset = \offset - mean * common_val; // Y = new_scale * X + new_offset; T new_scale, new_offset; +#pragma omp parallel for for (index_t c = 0; c < channel; ++c) { new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); new_offset = offset[c] - mean[c] * new_scale; diff --git a/mace/kernels/neon/batch_norm_neon.cc b/mace/kernels/neon/batch_norm_neon.cc index 0b121a70dc9b3b99892e086b525830364c7040a8..cba69533648499e19abb99886e26f06110d2c187 100644 --- a/mace/kernels/neon/batch_norm_neon.cc +++ b/mace/kernels/neon/batch_norm_neon.cc @@ -31,6 +31,7 @@ void BatchNormFunctor::operator()( float new_scale, new_offset; index_t count = sample_size >> 2; index_t remain_count = sample_size - (count << 2); +#pragma omp parallel for for (index_t c = 0; c < channel; ++c) { new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); new_offset = offset[c] - mean[c] * new_scale; diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index 079ad6f1a15c82b98487ec3850b21ee29accb19e..16763322c0418ef5cf4618ed0492402fdc08ec4b 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -57,14 +57,16 @@ static void BatchNorm( BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); BM_BATCH_NORM(1, 1, 512, 512, float); -BM_BATCH_NORM(1, 1, 1024, 1024, float); BM_BATCH_NORM(1, 3, 128, 128, float); BM_BATCH_NORM(1, 3, 512, 512, float); -BM_BATCH_NORM(1, 3, 1024, 1024, float); +BM_BATCH_NORM(1, 32, 112, 112, float); BM_BATCH_NORM(1, 64, 256, 256, float); BM_BATCH_NORM(1, 64, 512, 512, float); +BM_BATCH_NORM(1, 128, 56, 56, float); BM_BATCH_NORM(1, 128, 256, 256, float); -BM_BATCH_NORM(1, 128, 512, 512, float); +BM_BATCH_NORM(1, 256, 14, 14, float); +BM_BATCH_NORM(1, 512, 14, 14, float); +BM_BATCH_NORM(1, 1024, 7, 7, float); BM_BATCH_NORM(32, 1, 256, 256, float); BM_BATCH_NORM(32, 3, 256, 256, float); } // namespace mace \ No newline at end of file