提交 0c2d8666 编写于 作者: Z zhaojiaying01

optimize norm_op

上级 443ee03b
...@@ -41,7 +41,6 @@ void NormCompute(const NormParam<CPU> &param) { ...@@ -41,7 +41,6 @@ void NormCompute(const NormParam<CPU> &param) {
int axis = param.Axis(); int axis = param.Axis();
const framework::Tensor *input = param.InputX(); const framework::Tensor *input = param.InputX();
framework::Tensor square;
framework::Tensor *norm = param.OutputNorm(); framework::Tensor *norm = param.OutputNorm();
framework::Tensor *out = param.Out(); framework::Tensor *out = param.Out();
...@@ -52,46 +51,40 @@ void NormCompute(const NormParam<CPU> &param) { ...@@ -52,46 +51,40 @@ void NormCompute(const NormParam<CPU> &param) {
int pre, n, post; int pre, n, post;
GetDims(x_dims, axis, &pre, &n, &post); GetDims(x_dims, axis, &pre, &n, &post);
square.Resize(input->dims());
const float *input_ptr = input->data<float>(); const float *input_ptr = input->data<float>();
float *square_ptr = square.mutable_data<float>();
float *norm_ptr = norm->mutable_data<float>(); float *norm_ptr = norm->mutable_data<float>();
float *out_ptr = out->mutable_data<float>(); float *out_ptr = out->mutable_data<float>();
// in_ch = 0; norm = epsilon + x * x
const float *in_tmp = input_ptr; const float *in_tmp = input_ptr;
float *square_tmp = square_ptr; float *norm_tmp = norm_ptr;
for (int i = 0; i < input->numel(); ++i) { for (int i = 0; i < post; ++i) {
float element = *in_tmp; *norm_tmp = epsilon;
*square_tmp = element * element; *norm_tmp += (*in_tmp) * (*in_tmp);
square_tmp++; norm_tmp++;
in_tmp++; in_tmp++;
} }
// const float *norm_tmp = norm_ptr; // in_ch >= 1; norm += x * x
// for (int i = 0; i < norm->numel(); ++i) { for (int j = 1; j < n; ++j) {
// *norm_tmp = 0; norm_tmp = norm_ptr;
// norm_tmp++; for (int i = 0; i < post; ++i) {
// } *norm_tmp += (*in_tmp) * (*in_tmp);
square_tmp = square_ptr;
float *norm_tmp = norm_ptr;
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < post; ++j) {
for (int k = 0; k < n; ++k) {
if (k == 0) {
*norm_tmp = *square_tmp;
} else {
*norm_tmp += *(square_tmp + k * post);
}
}
float sum = *norm_tmp + epsilon;
*norm_tmp = sqrtf(sum);
norm_tmp++; norm_tmp++;
square_tmp++; in_tmp++;
} }
} }
// norm = sqart(norm)
norm_tmp = norm_ptr;
for (int i = 0; i < post; ++i) {
float sqrt = sqrtf(*norm_tmp);
*norm_tmp = sqrt;
norm_tmp++;
}
// out = input / norm
in_tmp = input_ptr; in_tmp = input_ptr;
norm_tmp = norm_ptr; norm_tmp = norm_ptr;
float *out_tmp = out_ptr; float *out_tmp = out_ptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册