提交 ad98be44 编写于 作者: Z zhaojiaying01

fix norm_op when batch != 1

上级 172e03a8
...@@ -56,47 +56,46 @@ void NormCompute(const NormParam<CPU> &param) { ...@@ -56,47 +56,46 @@ void NormCompute(const NormParam<CPU> &param) {
float *norm_ptr = norm->mutable_data<float>(); float *norm_ptr = norm->mutable_data<float>();
float *out_ptr = out->mutable_data<float>(); float *out_ptr = out->mutable_data<float>();
// in_ch = 0; norm = epsilon + x * x for (int p = 0; p < pre; ++p) {
const float *in_tmp = input_ptr; const float *in_tmp = input_ptr + p * n * post;
float *norm_tmp = norm_ptr; float *norm_tmp = norm_ptr + p * post;
for (int i = 0; i < post; ++i) {
*norm_tmp = epsilon;
*norm_tmp += (*in_tmp) * (*in_tmp);
norm_tmp++;
in_tmp++;
}
// in_ch >= 1; norm += x * x // in_ch = 0; norm = epsilon + x * x
for (int j = 1; j < n; ++j) {
norm_tmp = norm_ptr;
for (int i = 0; i < post; ++i) { for (int i = 0; i < post; ++i) {
*norm_tmp = epsilon;
*norm_tmp += (*in_tmp) * (*in_tmp); *norm_tmp += (*in_tmp) * (*in_tmp);
norm_tmp++; norm_tmp++;
in_tmp++; in_tmp++;
} }
}
// norm = sqart(norm) // in_ch >= 1; norm += x * x
norm_tmp = norm_ptr; for (int c = 1; c < n; ++c) {
for (int i = 0; i < post; ++i) { norm_tmp = norm_ptr + p * post;
float sqrt = sqrtf(*norm_tmp); for (int i = 0; i < post; ++i) {
*norm_tmp = sqrt; *norm_tmp += (*in_tmp) * (*in_tmp);
norm_tmp++; norm_tmp++;
} in_tmp++;
}
}
// norm = sqart(norm)
norm_tmp = norm_ptr + p * post;
for (int i = 0; i < post; ++i) {
*norm_tmp = sqrtf(*norm_tmp);
norm_tmp++;
}
// out = input / norm // out = input / norm
in_tmp = input_ptr; in_tmp = input_ptr + p * n * post;
norm_tmp = norm_ptr; float *out_tmp = out_ptr + p * n * post;
float *out_tmp = out_ptr; for (int c = 0; c < n; ++c) {
for (int i = 0; i < pre; ++i) { norm_tmp = norm_ptr + p * post;
for (int k = 0; k < n; ++k) {
for (int j = 0; j < post; ++j) { for (int j = 0; j < post; ++j) {
*out_tmp = *in_tmp / *norm_tmp; *out_tmp = *in_tmp / *norm_tmp;
in_tmp++; in_tmp++;
norm_tmp++; norm_tmp++;
out_tmp++; out_tmp++;
} }
norm_tmp = norm_ptr + i * post;
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册