From 3fe1d6a917edf2f909b422795db6012a4db650d4 Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Fri, 11 Jan 2019 20:22:04 +0800 Subject: [PATCH] fix norm_op when batch != 1 --- .../kernel/central-arm-func/norm_arm_func.h | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/operators/kernel/central-arm-func/norm_arm_func.h b/src/operators/kernel/central-arm-func/norm_arm_func.h index 6f3cf055f6..71b4c5515e 100644 --- a/src/operators/kernel/central-arm-func/norm_arm_func.h +++ b/src/operators/kernel/central-arm-func/norm_arm_func.h @@ -56,47 +56,46 @@ void NormCompute(const NormParam ¶m) { float *norm_ptr = norm->mutable_data(); float *out_ptr = out->mutable_data(); - // in_ch = 0; norm = epsilon + x * x - const float *in_tmp = input_ptr; - float *norm_tmp = norm_ptr; - for (int i = 0; i < post; ++i) { - *norm_tmp = epsilon; - *norm_tmp += (*in_tmp) * (*in_tmp); - norm_tmp++; - in_tmp++; - } + for (int p = 0; p < pre; ++p) { + const float *in_tmp = input_ptr + p * n * post; + float *norm_tmp = norm_ptr + p * post; - // in_ch >= 1; norm += x * x - for (int j = 1; j < n; ++j) { - norm_tmp = norm_ptr; + // in_ch = 0; norm = epsilon + x * x for (int i = 0; i < post; ++i) { + *norm_tmp = epsilon; *norm_tmp += (*in_tmp) * (*in_tmp); norm_tmp++; in_tmp++; } - } - // norm = sqart(norm) - norm_tmp = norm_ptr; - for (int i = 0; i < post; ++i) { - float sqrt = sqrtf(*norm_tmp); - *norm_tmp = sqrt; - norm_tmp++; - } + // in_ch >= 1; norm += x * x + for (int c = 1; c < n; ++c) { + norm_tmp = norm_ptr + p * post; + for (int i = 0; i < post; ++i) { + *norm_tmp += (*in_tmp) * (*in_tmp); + norm_tmp++; + in_tmp++; + } + } + + // norm = sqart(norm) + norm_tmp = norm_ptr + p * post; + for (int i = 0; i < post; ++i) { + *norm_tmp = sqrtf(*norm_tmp); + norm_tmp++; + } - // out = input / norm - in_tmp = input_ptr; - norm_tmp = norm_ptr; - float *out_tmp = out_ptr; - for (int i = 0; i < pre; ++i) { - for (int k = 0; k < n; ++k) { + // out = input / norm + in_tmp = input_ptr + p * n * post; + float *out_tmp = out_ptr + p * n * post; + for (int c = 0; c < n; ++c) { + norm_tmp = norm_ptr + p * post; for (int j = 0; j < post; ++j) { *out_tmp = *in_tmp / *norm_tmp; in_tmp++; norm_tmp++; out_tmp++; } - norm_tmp = norm_ptr + i * post; } } } -- GitLab