提交 92498aba 编写于 作者: E eclipsess

optimize lrn kernel

上级 78fd1e66
......@@ -15,10 +15,14 @@ limitations under the License. */
#ifdef LRN_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
#ifdef __ARM_NEON
#include "arm_neon.h"
#include "operators/math/math_func_neon.h"
#endif
namespace paddle_mobile {
namespace operators {
......@@ -27,42 +31,138 @@ using namespace framework;
template <typename T>
struct LRNFunctor {
void operator()(const framework::Tensor &input, framework::Tensor *out, int N,
int C, int H, int W, int n, T k, T alpha, T beta) {
auto input_ptr = input.data<T>();
int C, int H, int W, int n, float k, float alpha,
float beta) {
const float *input_ptr = input.data<float>();
const int start = -(n - 1) / 2;
const int end = start + n;
auto out_ptr = out->data<T>();
const int stride0 = C * H * W;
const int stride1 = H * W;
const int stride2 = W;
const int stride3 = 1;
framework::Tensor sqr_buffer;
auto sqr_buffer_ptr = sqr_buffer.mutable_data<T>(input.dims());
std::fill(sqr_buffer_ptr, sqr_buffer_ptr + sqr_buffer.numel(), k);
auto sqr_buffer_ptr = sqr_buffer.mutable_data<float>(input.dims());
std::fill(sqr_buffer_ptr, sqr_buffer_ptr + sqr_buffer.numel(), 0.0);
for (int a = 0; a < N; a++) {
for (int b = 0; b < C; b++) {
for (int index = start; index < end; index++) {
int channel = b + index;
if (channel >= 0 && channel < C) {
int tmp_u = a * stride0 + b * stride1;
int tmp_i = a * stride0 + channel * stride1;
for (int c = 0; c < H; c++) {
for (int d = 0; d < W; d++) {
int tmp = c * stride2 + d;
int u = tmp_u + tmp;
int i = tmp_i + tmp;
sqr_buffer_ptr[u] += alpha * input_ptr[i] * input_ptr[i];
}
int tmp_s = a * stride0 + b * stride1;
int tmp_c = a * stride0 + channel * stride1;
#ifdef __ARM_NEON
int n4 = stride1 / 4;
int m4 = stride1 % 4;
float32x4_t sqr0;
float32x4_t in0;
float32x4_t res0;
for (int i = 0; i < n4; i++) {
sqr0 = vld1q_f32(sqr_buffer_ptr+tmp_s);
in0 = vld1q_f32(input_ptr+tmp_c);
res0 = vmlaq_f32(sqr0, in0, in0);
vst1q_f32(sqr_buffer_ptr+tmp_s, res0);
tmp_s+=4;
tmp_c+=4;
}
for (int i = 0; i < m4; i++) {
int s_i = tmp_s + i;
int c_i = tmp_c + i;
sqr_buffer_ptr[s_i] += input_ptr[c_i] * input_ptr[c_i];
}
#else
for (int tmp = 0; tmp < stride1; tmp++) {
int s_i = tmp_s + tmp;
int c_i = tmp_c + tmp;
sqr_buffer_ptr[s_i] += input_ptr[c_i] * input_ptr[c_i];
}
#endif
}
}
}
}
auto out_ptr = out->data<T>();
#ifdef __ARM_NEON
float32x4_t sqr1, sqr2, sqr3, sqr4;
float32x4_t alpha4;
float32x4_t k4;
float32x4_t beta4;
float32x4_t res1, res2, res3, res4;
float32x4_t in1, in2, in3, in4;
beta4 = vdupq_n_f32(beta);
alpha4 = vdupq_n_f32(alpha);
k4 = vdupq_n_f32(k);
auto out_tmp_ptr = out_ptr;
int n16 = input.numel() / 16;
int m16 = input.numel() % 16;
int m16n4 = m16 / 4;
int m16m4 = m16 % 4;
for (int i = 0; i < n16; i++) {
sqr1 = vld1q_f32(sqr_buffer_ptr);
sqr2 = vld1q_f32(sqr_buffer_ptr + 4);
sqr3 = vld1q_f32(sqr_buffer_ptr + 8);
sqr4 = vld1q_f32(sqr_buffer_ptr + 12);
in1 = vld1q_f32(input_ptr);
in2 = vld1q_f32(input_ptr + 4);
in3 = vld1q_f32(input_ptr + 8);
in4 = vld1q_f32(input_ptr + 12);
sqr1 = vmlaq_f32(k4, sqr1, alpha4);
sqr2 = vmlaq_f32(k4, sqr2, alpha4);
sqr3 = vmlaq_f32(k4, sqr3, alpha4);
sqr4 = vmlaq_f32(k4, sqr4, alpha4);
sqr1 = pow_ps(sqr1, -beta4);
sqr2 = pow_ps(sqr2, -beta4);
sqr3 = pow_ps(sqr3, -beta4);
sqr4 = pow_ps(sqr4, -beta4);
sqr1 = vmulq_f32(sqr1, in1);
sqr2 = vmulq_f32(sqr2, in2);
sqr3 = vmulq_f32(sqr3, in3);
sqr4 = vmulq_f32(sqr4, in4);
vst1q_f32(out_tmp_ptr, sqr1);
vst1q_f32(out_tmp_ptr + 4, sqr2);
vst1q_f32(out_tmp_ptr + 8, sqr3);
vst1q_f32(out_tmp_ptr + 12, sqr4);
sqr_buffer_ptr += 4 * 4;
input_ptr += 4 * 4;
out_tmp_ptr += 4 * 4;
}
for (int i = 0; i < m16n4; i++) {
sqr4 = vld1q_f32(sqr_buffer_ptr);
in4 = vld1q_f32(input_ptr);
sqr4 = vmlaq_f32(k4, sqr4, alpha4);
sqr4 = pow_ps(sqr4, -beta4);
sqr4 = vmulq_f32(sqr4, in4);
vst1q_f32(out_tmp_ptr, sqr4);
sqr_buffer_ptr += 4;
input_ptr += 4;
out_tmp_ptr += 4;
}
for (int i = 0; i < m16m4; i++) {
out_tmp_ptr[i] = input_ptr[i] / pow(k + alpha * sqr_buffer_ptr[i], beta);
}
#else
for (int i = 0; i < input.numel(); i++) {
out_ptr[i] = input_ptr[i] / pow(sqr_buffer_ptr[i], beta);
out_ptr[i] = input_ptr[i] / pow(k + alpha * sqr_buffer_ptr[i], beta);
}
#endif
}
};
......
#!/usr/bin/env sh
push_fn () {
sh build.sh android googlenet
MODELS_PATH="../test/models/*"
MODELS_SRC="../../test/models"
IMAGE_PATH="../test/images/*"
EXE_FILE="../test/build/*"
EXE_DIR="data/local/tmp/bin"
adb shell mkdir ${EXE_DIR}
MODELS_DIR="data/local/tmp/models"
adb shell mkdir ${MODELS_DIR}
for file in `ls ${MODELS_SRC}`
do
adb shell mkdir ${MODELS_DIR}"/"${file}
done
IMAGES_DIR="data/local/tmp/images"
adb shell mkdir ${IMAGES_DIR}
LIB_PATH="../build/release/arm-v7a/build/*"
adb push ${EXE_FILE} ${EXE_DIR}
adb push ${LIB_PATH} ${EXE_DIR}
adb push ${IMAGE_PATH} ${IMAGES_DIR}
adb push ${MODELS_PATH} ${MODELS_DIR}
echo "test-op or test-net below : "
adb shell ls /data/local/tmp/bin
echo "**** choose OP or NET to test ****"
adb shell "cd /data/local/tmp/bin; LD_LIBRARY_PATH=. ./test-lrn-op"
}
push_fn
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册