提交 a5e90f3e 编写于 作者: L liuqi

Add omp to batch norm kernel

上级 a9fa945d
......@@ -33,6 +33,7 @@ struct BatchNormFunctor {
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
T new_scale, new_offset;
#pragma omp parallel for
for (index_t c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon);
new_offset = offset[c] - mean[c] * new_scale;
......
......@@ -31,6 +31,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
float new_scale, new_offset;
index_t count = sample_size >> 2;
index_t remain_count = sample_size - (count << 2);
#pragma omp parallel for
for (index_t c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon);
new_offset = offset[c] - mean[c] * new_scale;
......
......@@ -57,14 +57,16 @@ static void BatchNorm(
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON);
BM_BATCH_NORM(1, 1, 512, 512, float);
BM_BATCH_NORM(1, 1, 1024, 1024, float);
BM_BATCH_NORM(1, 3, 128, 128, float);
BM_BATCH_NORM(1, 3, 512, 512, float);
BM_BATCH_NORM(1, 3, 1024, 1024, float);
BM_BATCH_NORM(1, 32, 112, 112, float);
BM_BATCH_NORM(1, 64, 256, 256, float);
BM_BATCH_NORM(1, 64, 512, 512, float);
BM_BATCH_NORM(1, 128, 56, 56, float);
BM_BATCH_NORM(1, 128, 256, 256, float);
BM_BATCH_NORM(1, 128, 512, 512, float);
BM_BATCH_NORM(1, 256, 14, 14, float);
BM_BATCH_NORM(1, 512, 14, 14, float);
BM_BATCH_NORM(1, 1024, 7, 7, float);
BM_BATCH_NORM(32, 1, 256, 256, float);
BM_BATCH_NORM(32, 3, 256, 256, float);
} // namespace mace
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册