未验证 提交 7cea4e9e 编写于 作者: H HappyAngel 提交者: GitHub

[arm]add fc_relu implement, test=develop (#2765)

* fix, test=develop

* add fc_relu, test=develop
上级 d8143103
...@@ -21,19 +21,16 @@ namespace arm { ...@@ -21,19 +21,16 @@ namespace arm {
namespace math { namespace math {
template <> template <>
void fill_bias_fc<float>(float *out, const float *bias, int num, int channel) { void fill_bias_fc<float>(
float *out, const float *bias, int num, int channel, bool flag_relu) {
int cnt = channel >> 4; int cnt = channel >> 4;
int remain = channel & 15; int remain = channel & 15;
if (flag_relu) {
float32x4_t vzero = vdupq_n_f32(0.f);
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
const float *ptr_bias = bias; const float *ptr_bias = bias;
float *ptr_out = out + j * channel; float *ptr_out = out + j * channel;
float32x4_t vout1;
float32x4_t vout2;
float32x4_t vout3;
float32x4_t vout4;
for (int i = 0; i < cnt; ++i) { for (int i = 0; i < cnt; ++i) {
float32x4_t vin1 = vld1q_f32(ptr_out); float32x4_t vin1 = vld1q_f32(ptr_out);
float32x4_t vb1 = vld1q_f32(ptr_bias); float32x4_t vb1 = vld1q_f32(ptr_bias);
...@@ -47,10 +44,15 @@ void fill_bias_fc<float>(float *out, const float *bias, int num, int channel) { ...@@ -47,10 +44,15 @@ void fill_bias_fc<float>(float *out, const float *bias, int num, int channel) {
float32x4_t vin4 = vld1q_f32(ptr_out + 12); float32x4_t vin4 = vld1q_f32(ptr_out + 12);
float32x4_t vb4 = vld1q_f32(ptr_bias + 12); float32x4_t vb4 = vld1q_f32(ptr_bias + 12);
vout1 = vaddq_f32(vin1, vb1); float32x4_t vout1 = vaddq_f32(vin1, vb1);
vout2 = vaddq_f32(vin2, vb2); float32x4_t vout2 = vaddq_f32(vin2, vb2);
vout3 = vaddq_f32(vin3, vb3); float32x4_t vout3 = vaddq_f32(vin3, vb3);
vout4 = vaddq_f32(vin4, vb4); float32x4_t vout4 = vaddq_f32(vin4, vb4);
vout1 = vmaxq_f32(vout1, vzero);
vout2 = vmaxq_f32(vout2, vzero);
vout3 = vmaxq_f32(vout3, vzero);
vout4 = vmaxq_f32(vout4, vzero);
vst1q_f32(ptr_out, vout1); vst1q_f32(ptr_out, vout1);
vst1q_f32(ptr_out + 4, vout2); vst1q_f32(ptr_out + 4, vout2);
...@@ -60,34 +62,100 @@ void fill_bias_fc<float>(float *out, const float *bias, int num, int channel) { ...@@ -60,34 +62,100 @@ void fill_bias_fc<float>(float *out, const float *bias, int num, int channel) {
ptr_out += 16; ptr_out += 16;
ptr_bias += 16; ptr_bias += 16;
} }
#if 0 for (int i = 0; i < remain; ++i) {
if (cnt > 0) { *ptr_out += *(ptr_bias++);
asm( *ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f;
"1: \n" ptr_out++;
"vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" }
"vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" }
"vadd.f32 q2, q0, q1 @ add bias\n" } else {
"vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" for (int j = 0; j < num; ++j) {
"subs %[cnt], #1 @ loop count -1\n" const float *ptr_bias = bias;
"bne 1b @ jump to main loop\n" float *ptr_out = out + j * channel;
:[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \
[cnt] "+r"(cnt) for (int i = 0; i < cnt; ++i) {
: float32x4_t vin1 = vld1q_f32(ptr_out);
:"q0", "q1", "q2" float32x4_t vb1 = vld1q_f32(ptr_bias);
);
float32x4_t vin2 = vld1q_f32(ptr_out + 4);
float32x4_t vb2 = vld1q_f32(ptr_bias + 4);
float32x4_t vin3 = vld1q_f32(ptr_out + 8);
float32x4_t vb3 = vld1q_f32(ptr_bias + 8);
float32x4_t vin4 = vld1q_f32(ptr_out + 12);
float32x4_t vb4 = vld1q_f32(ptr_bias + 12);
float32x4_t vout1 = vaddq_f32(vin1, vb1);
float32x4_t vout2 = vaddq_f32(vin2, vb2);
float32x4_t vout3 = vaddq_f32(vin3, vb3);
float32x4_t vout4 = vaddq_f32(vin4, vb4);
vst1q_f32(ptr_out, vout1);
vst1q_f32(ptr_out + 4, vout2);
vst1q_f32(ptr_out + 8, vout3);
vst1q_f32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
} }
#endif
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
*(ptr_out++) += *(ptr_bias++); *(ptr_out++) += *(ptr_bias++);
} }
} }
}
} }
template <> template <>
void fill_bias_fc<int>(int *out, const int *bias, int num, int channel) { void fill_bias_fc<int>(
int *out, const int *bias, int num, int channel, bool flag_relu) {
int cnt = channel >> 4; int cnt = channel >> 4;
int remain = channel & 15; int remain = channel & 15;
if (flag_relu) {
for (int j = 0; j < num; ++j) {
const int *ptr_bias = bias;
int *ptr_out = out + j * channel;
int32x4_t vzero = vdupq_n_s32(0);
for (int i = 0; i < cnt; ++i) {
int32x4_t vin1 = vld1q_s32(ptr_out);
int32x4_t vb1 = vld1q_s32(ptr_bias);
int32x4_t vin2 = vld1q_s32(ptr_out + 4);
int32x4_t vb2 = vld1q_s32(ptr_bias + 4);
int32x4_t vin3 = vld1q_s32(ptr_out + 8);
int32x4_t vb3 = vld1q_s32(ptr_bias + 8);
int32x4_t vin4 = vld1q_s32(ptr_out + 12);
int32x4_t vb4 = vld1q_s32(ptr_bias + 12);
int32x4_t vout1 = vaddq_s32(vin1, vb1);
int32x4_t vout2 = vaddq_s32(vin2, vb2);
int32x4_t vout3 = vaddq_s32(vin3, vb3);
int32x4_t vout4 = vaddq_s32(vin4, vb4);
vout1 = vmaxq_s32(vout1, vzero);
vout2 = vmaxq_s32(vout2, vzero);
vout3 = vmaxq_s32(vout3, vzero);
vout4 = vmaxq_s32(vout4, vzero);
vst1q_s32(ptr_out, vout1);
vst1q_s32(ptr_out + 4, vout2);
vst1q_s32(ptr_out + 8, vout3);
vst1q_s32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
for (int i = 0; i < remain; ++i) {
*ptr_out += *(ptr_bias++);
*ptr_out = *ptr_out > 0 ? *ptr_out : 0;
ptr_out++;
}
}
} else {
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
const int *ptr_bias = bias; const int *ptr_bias = bias;
int *ptr_out = out + j * channel; int *ptr_out = out + j * channel;
...@@ -123,28 +191,11 @@ void fill_bias_fc<int>(int *out, const int *bias, int num, int channel) { ...@@ -123,28 +191,11 @@ void fill_bias_fc<int>(int *out, const int *bias, int num, int channel) {
ptr_out += 16; ptr_out += 16;
ptr_bias += 16; ptr_bias += 16;
} }
#if 0
if (cnt > 0) {
asm(
"1: \n"
"vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n"
"vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n"
"vadd.s32 q2, q0, q1 @ add bias\n"
"vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 1b @ jump to main loop\n"
:[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \
[cnt] "+r"(cnt)
:
:"q0", "q1", "q2"
);
}
#endif
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
*(ptr_out++) += *(ptr_bias++); *(ptr_out++) += *(ptr_bias++);
} }
} }
}
} }
} // namespace math } // namespace math
......
...@@ -356,7 +356,8 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { ...@@ -356,7 +356,8 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
} }
template <typename T> template <typename T>
void fill_bias_fc(T* tensor, const T* bias, int num, int channel); void fill_bias_fc(
T* tensor, const T* bias, int num, int channel, bool flag_relu);
template <lite_api::ActivationType Act = lite_api::ActivationType::kIndentity> template <lite_api::ActivationType Act = lite_api::ActivationType::kIndentity>
inline float32x4_t vactive_f32(const float32x4_t& x) { inline float32x4_t vactive_f32(const float32x4_t& x) {
......
...@@ -93,6 +93,10 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -93,6 +93,10 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
if (flag_trans_bias_) { if (flag_trans_bias_) {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
bool flag_relu = false;
if (param.activation_type == "relu") {
flag_relu = true;
}
if (flag_gemm_) { if (flag_gemm_) {
operators::ActivationParam act_param; operators::ActivationParam act_param;
act_param.has_active = false; act_param.has_active = false;
...@@ -115,7 +119,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -115,7 +119,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
&ctx); &ctx);
if (param.bias) { if (param.bias) {
CHECK_EQ(param.bias->numel(), n_); CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_); lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu);
} }
} else { } else {
for (int i = 0; i < m_; ++i) { for (int i = 0; i < m_; ++i) {
...@@ -129,7 +133,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -129,7 +133,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
k_, k_,
param.bias != nullptr, param.bias != nullptr,
b_data, b_data,
false, flag_relu,
&ctx); &ctx);
} }
} }
...@@ -148,6 +152,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -148,6 +152,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
if (flag_trans_bias_) { if (flag_trans_bias_) {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
bool flag_relu = false;
if (param.activation_type == "relu") {
flag_relu = true;
}
if (flag_gemm_) { if (flag_gemm_) {
lite::arm::math::gemm_s8(false, lite::arm::math::gemm_s8(false,
false, false,
...@@ -164,7 +172,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -164,7 +172,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
&ctx); &ctx);
if (param.bias) { if (param.bias) {
CHECK_EQ(param.bias->numel(), n_); CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_); lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu);
} }
} else { } else {
for (int i = 0; i < m_; ++i) { for (int i = 0; i < m_; ++i) {
...@@ -179,7 +187,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -179,7 +187,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
scale_.data(), scale_.data(),
param.bias != nullptr, param.bias != nullptr,
b_data, b_data,
false, flag_relu,
&ctx); &ctx);
} }
} }
...@@ -198,6 +206,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -198,6 +206,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
if (flag_trans_bias_) { if (flag_trans_bias_) {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
bool flag_relu = false;
if (param.activation_type == "relu") {
flag_relu = true;
}
if (flag_gemm_) { if (flag_gemm_) {
CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel " CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel "
"must not have bias"; "must not have bias";
...@@ -211,7 +223,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -211,7 +223,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
o_data, o_data,
nullptr, nullptr,
false, false,
false, flag_relu,
scale_.data(), scale_.data(),
&ctx); &ctx);
} else { } else {
...@@ -227,7 +239,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -227,7 +239,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
scale_.data(), scale_.data(),
param.bias != nullptr, param.bias != nullptr,
b_data, b_data,
false, flag_relu,
&ctx); &ctx);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册