diff --git a/src/operators/kernel/central-arm-func/norm_arm_func.h b/src/operators/kernel/central-arm-func/norm_arm_func.h index e43c03484712cfb1f8baf96a9ca8cccc062672ca..6f3cf055f67e8f5106cde5499ecb045ec8ed7d6c 100644 --- a/src/operators/kernel/central-arm-func/norm_arm_func.h +++ b/src/operators/kernel/central-arm-func/norm_arm_func.h @@ -41,7 +41,6 @@ void NormCompute(const NormParam ¶m) { int axis = param.Axis(); const framework::Tensor *input = param.InputX(); - framework::Tensor square; framework::Tensor *norm = param.OutputNorm(); framework::Tensor *out = param.Out(); @@ -52,46 +51,40 @@ void NormCompute(const NormParam ¶m) { int pre, n, post; GetDims(x_dims, axis, &pre, &n, &post); - square.Resize(input->dims()); const float *input_ptr = input->data(); - float *square_ptr = square.mutable_data(); float *norm_ptr = norm->mutable_data(); float *out_ptr = out->mutable_data(); + // in_ch = 0; norm = epsilon + x * x const float *in_tmp = input_ptr; - float *square_tmp = square_ptr; - for (int i = 0; i < input->numel(); ++i) { - float element = *in_tmp; - *square_tmp = element * element; - square_tmp++; + float *norm_tmp = norm_ptr; + for (int i = 0; i < post; ++i) { + *norm_tmp = epsilon; + *norm_tmp += (*in_tmp) * (*in_tmp); + norm_tmp++; in_tmp++; } - // const float *norm_tmp = norm_ptr; - // for (int i = 0; i < norm->numel(); ++i) { - // *norm_tmp = 0; - // norm_tmp++; - // } - - square_tmp = square_ptr; - float *norm_tmp = norm_ptr; - for (int i = 0; i < pre; ++i) { - for (int j = 0; j < post; ++j) { - for (int k = 0; k < n; ++k) { - if (k == 0) { - *norm_tmp = *square_tmp; - } else { - *norm_tmp += *(square_tmp + k * post); - } - } - float sum = *norm_tmp + epsilon; - *norm_tmp = sqrtf(sum); + // in_ch >= 1; norm += x * x + for (int j = 1; j < n; ++j) { + norm_tmp = norm_ptr; + for (int i = 0; i < post; ++i) { + *norm_tmp += (*in_tmp) * (*in_tmp); norm_tmp++; - square_tmp++; + in_tmp++; } } + // norm = sqart(norm) + norm_tmp = norm_ptr; + for (int i = 0; i < post; ++i) { + float sqrt = sqrtf(*norm_tmp); + *norm_tmp = sqrt; + norm_tmp++; + } + + // out = input / norm in_tmp = input_ptr; norm_tmp = norm_ptr; float *out_tmp = out_ptr;