提交 1a5f1239 编写于 作者: H HappyAngel 提交者: GitHub

[arm]add fc_relu implement, test=develop (#2765)

* fix, test=develop

* add fc_relu, test=develop
上级 9b84dc91
......@@ -21,128 +21,179 @@ namespace arm {
namespace math {
template <>
void fill_bias_fc<float>(float *out, const float *bias, int num, int channel) {
void fill_bias_fc<float>(
float *out, const float *bias, int num, int channel, bool flag_relu) {
int cnt = channel >> 4;
int remain = channel & 15;
for (int j = 0; j < num; ++j) {
const float *ptr_bias = bias;
float *ptr_out = out + j * channel;
float32x4_t vout1;
float32x4_t vout2;
float32x4_t vout3;
float32x4_t vout4;
for (int i = 0; i < cnt; ++i) {
float32x4_t vin1 = vld1q_f32(ptr_out);
float32x4_t vb1 = vld1q_f32(ptr_bias);
float32x4_t vin2 = vld1q_f32(ptr_out + 4);
float32x4_t vb2 = vld1q_f32(ptr_bias + 4);
float32x4_t vin3 = vld1q_f32(ptr_out + 8);
float32x4_t vb3 = vld1q_f32(ptr_bias + 8);
float32x4_t vin4 = vld1q_f32(ptr_out + 12);
float32x4_t vb4 = vld1q_f32(ptr_bias + 12);
vout1 = vaddq_f32(vin1, vb1);
vout2 = vaddq_f32(vin2, vb2);
vout3 = vaddq_f32(vin3, vb3);
vout4 = vaddq_f32(vin4, vb4);
vst1q_f32(ptr_out, vout1);
vst1q_f32(ptr_out + 4, vout2);
vst1q_f32(ptr_out + 8, vout3);
vst1q_f32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
if (flag_relu) {
float32x4_t vzero = vdupq_n_f32(0.f);
for (int j = 0; j < num; ++j) {
const float *ptr_bias = bias;
float *ptr_out = out + j * channel;
for (int i = 0; i < cnt; ++i) {
float32x4_t vin1 = vld1q_f32(ptr_out);
float32x4_t vb1 = vld1q_f32(ptr_bias);
float32x4_t vin2 = vld1q_f32(ptr_out + 4);
float32x4_t vb2 = vld1q_f32(ptr_bias + 4);
float32x4_t vin3 = vld1q_f32(ptr_out + 8);
float32x4_t vb3 = vld1q_f32(ptr_bias + 8);
float32x4_t vin4 = vld1q_f32(ptr_out + 12);
float32x4_t vb4 = vld1q_f32(ptr_bias + 12);
float32x4_t vout1 = vaddq_f32(vin1, vb1);
float32x4_t vout2 = vaddq_f32(vin2, vb2);
float32x4_t vout3 = vaddq_f32(vin3, vb3);
float32x4_t vout4 = vaddq_f32(vin4, vb4);
vout1 = vmaxq_f32(vout1, vzero);
vout2 = vmaxq_f32(vout2, vzero);
vout3 = vmaxq_f32(vout3, vzero);
vout4 = vmaxq_f32(vout4, vzero);
vst1q_f32(ptr_out, vout1);
vst1q_f32(ptr_out + 4, vout2);
vst1q_f32(ptr_out + 8, vout3);
vst1q_f32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
for (int i = 0; i < remain; ++i) {
*ptr_out += *(ptr_bias++);
*ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f;
ptr_out++;
}
}
#if 0
if (cnt > 0) {
asm(
"1: \n"
"vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n"
"vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n"
"vadd.f32 q2, q0, q1 @ add bias\n"
"vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 1b @ jump to main loop\n"
:[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \
[cnt] "+r"(cnt)
:
:"q0", "q1", "q2"
);
}
#endif
for (int i = 0; i < remain; ++i) {
*(ptr_out++) += *(ptr_bias++);
} else {
for (int j = 0; j < num; ++j) {
const float *ptr_bias = bias;
float *ptr_out = out + j * channel;
for (int i = 0; i < cnt; ++i) {
float32x4_t vin1 = vld1q_f32(ptr_out);
float32x4_t vb1 = vld1q_f32(ptr_bias);
float32x4_t vin2 = vld1q_f32(ptr_out + 4);
float32x4_t vb2 = vld1q_f32(ptr_bias + 4);
float32x4_t vin3 = vld1q_f32(ptr_out + 8);
float32x4_t vb3 = vld1q_f32(ptr_bias + 8);
float32x4_t vin4 = vld1q_f32(ptr_out + 12);
float32x4_t vb4 = vld1q_f32(ptr_bias + 12);
float32x4_t vout1 = vaddq_f32(vin1, vb1);
float32x4_t vout2 = vaddq_f32(vin2, vb2);
float32x4_t vout3 = vaddq_f32(vin3, vb3);
float32x4_t vout4 = vaddq_f32(vin4, vb4);
vst1q_f32(ptr_out, vout1);
vst1q_f32(ptr_out + 4, vout2);
vst1q_f32(ptr_out + 8, vout3);
vst1q_f32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
for (int i = 0; i < remain; ++i) {
*(ptr_out++) += *(ptr_bias++);
}
}
}
}
template <>
void fill_bias_fc<int>(int *out, const int *bias, int num, int channel) {
void fill_bias_fc<int>(
int *out, const int *bias, int num, int channel, bool flag_relu) {
int cnt = channel >> 4;
int remain = channel & 15;
for (int j = 0; j < num; ++j) {
const int *ptr_bias = bias;
int *ptr_out = out + j * channel;
int32x4_t vout1;
int32x4_t vout2;
int32x4_t vout3;
int32x4_t vout4;
for (int i = 0; i < cnt; ++i) {
int32x4_t vin1 = vld1q_s32(ptr_out);
int32x4_t vb1 = vld1q_s32(ptr_bias);
int32x4_t vin2 = vld1q_s32(ptr_out + 4);
int32x4_t vb2 = vld1q_s32(ptr_bias + 4);
int32x4_t vin3 = vld1q_s32(ptr_out + 8);
int32x4_t vb3 = vld1q_s32(ptr_bias + 8);
int32x4_t vin4 = vld1q_s32(ptr_out + 12);
int32x4_t vb4 = vld1q_s32(ptr_bias + 12);
vout1 = vaddq_s32(vin1, vb1);
vout2 = vaddq_s32(vin2, vb2);
vout3 = vaddq_s32(vin3, vb3);
vout4 = vaddq_s32(vin4, vb4);
vst1q_s32(ptr_out, vout1);
vst1q_s32(ptr_out + 4, vout2);
vst1q_s32(ptr_out + 8, vout3);
vst1q_s32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
#if 0
if (cnt > 0) {
asm(
"1: \n"
"vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n"
"vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n"
"vadd.s32 q2, q0, q1 @ add bias\n"
"vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 1b @ jump to main loop\n"
:[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \
[cnt] "+r"(cnt)
:
:"q0", "q1", "q2"
);
if (flag_relu) {
for (int j = 0; j < num; ++j) {
const int *ptr_bias = bias;
int *ptr_out = out + j * channel;
int32x4_t vzero = vdupq_n_s32(0);
for (int i = 0; i < cnt; ++i) {
int32x4_t vin1 = vld1q_s32(ptr_out);
int32x4_t vb1 = vld1q_s32(ptr_bias);
int32x4_t vin2 = vld1q_s32(ptr_out + 4);
int32x4_t vb2 = vld1q_s32(ptr_bias + 4);
int32x4_t vin3 = vld1q_s32(ptr_out + 8);
int32x4_t vb3 = vld1q_s32(ptr_bias + 8);
int32x4_t vin4 = vld1q_s32(ptr_out + 12);
int32x4_t vb4 = vld1q_s32(ptr_bias + 12);
int32x4_t vout1 = vaddq_s32(vin1, vb1);
int32x4_t vout2 = vaddq_s32(vin2, vb2);
int32x4_t vout3 = vaddq_s32(vin3, vb3);
int32x4_t vout4 = vaddq_s32(vin4, vb4);
vout1 = vmaxq_s32(vout1, vzero);
vout2 = vmaxq_s32(vout2, vzero);
vout3 = vmaxq_s32(vout3, vzero);
vout4 = vmaxq_s32(vout4, vzero);
vst1q_s32(ptr_out, vout1);
vst1q_s32(ptr_out + 4, vout2);
vst1q_s32(ptr_out + 8, vout3);
vst1q_s32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
for (int i = 0; i < remain; ++i) {
*ptr_out += *(ptr_bias++);
*ptr_out = *ptr_out > 0 ? *ptr_out : 0;
ptr_out++;
}
}
#endif
for (int i = 0; i < remain; ++i) {
*(ptr_out++) += *(ptr_bias++);
} else {
for (int j = 0; j < num; ++j) {
const int *ptr_bias = bias;
int *ptr_out = out + j * channel;
int32x4_t vout1;
int32x4_t vout2;
int32x4_t vout3;
int32x4_t vout4;
for (int i = 0; i < cnt; ++i) {
int32x4_t vin1 = vld1q_s32(ptr_out);
int32x4_t vb1 = vld1q_s32(ptr_bias);
int32x4_t vin2 = vld1q_s32(ptr_out + 4);
int32x4_t vb2 = vld1q_s32(ptr_bias + 4);
int32x4_t vin3 = vld1q_s32(ptr_out + 8);
int32x4_t vb3 = vld1q_s32(ptr_bias + 8);
int32x4_t vin4 = vld1q_s32(ptr_out + 12);
int32x4_t vb4 = vld1q_s32(ptr_bias + 12);
vout1 = vaddq_s32(vin1, vb1);
vout2 = vaddq_s32(vin2, vb2);
vout3 = vaddq_s32(vin3, vb3);
vout4 = vaddq_s32(vin4, vb4);
vst1q_s32(ptr_out, vout1);
vst1q_s32(ptr_out + 4, vout2);
vst1q_s32(ptr_out + 8, vout3);
vst1q_s32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
for (int i = 0; i < remain; ++i) {
*(ptr_out++) += *(ptr_bias++);
}
}
}
}
......
......@@ -356,7 +356,8 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
}
template <typename T>
void fill_bias_fc(T* tensor, const T* bias, int num, int channel);
void fill_bias_fc(
T* tensor, const T* bias, int num, int channel, bool flag_relu);
template <lite_api::ActivationType Act = lite_api::ActivationType::kIndentity>
inline float32x4_t vactive_f32(const float32x4_t& x) {
......
......@@ -93,6 +93,10 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
bool flag_relu = false;
if (param.activation_type == "relu") {
flag_relu = true;
}
if (flag_gemm_) {
operators::ActivationParam act_param;
act_param.has_active = false;
......@@ -115,7 +119,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
&ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu);
}
} else {
for (int i = 0; i < m_; ++i) {
......@@ -129,7 +133,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
k_,
param.bias != nullptr,
b_data,
false,
flag_relu,
&ctx);
}
}
......@@ -148,6 +152,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
bool flag_relu = false;
if (param.activation_type == "relu") {
flag_relu = true;
}
if (flag_gemm_) {
lite::arm::math::gemm_s8(false,
false,
......@@ -164,7 +172,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
&ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu);
}
} else {
for (int i = 0; i < m_; ++i) {
......@@ -179,7 +187,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
scale_.data(),
param.bias != nullptr,
b_data,
false,
flag_relu,
&ctx);
}
}
......@@ -198,6 +206,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
bool flag_relu = false;
if (param.activation_type == "relu") {
flag_relu = true;
}
if (flag_gemm_) {
CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel "
"must not have bias";
......@@ -211,7 +223,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
o_data,
nullptr,
false,
false,
flag_relu,
scale_.data(),
&ctx);
} else {
......@@ -227,7 +239,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
scale_.data(),
param.bias != nullptr,
b_data,
false,
flag_relu,
&ctx);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册