diff --git a/lite/backends/arm/math/funcs.cc b/lite/backends/arm/math/funcs.cc index e4425ade2efebdaad9136f75c39493f2bd3df4ca..8d20e5242e556c86a1283a64ff9ccf51e2efa247 100644 --- a/lite/backends/arm/math/funcs.cc +++ b/lite/backends/arm/math/funcs.cc @@ -21,128 +21,179 @@ namespace arm { namespace math { template <> -void fill_bias_fc(float *out, const float *bias, int num, int channel) { +void fill_bias_fc( + float *out, const float *bias, int num, int channel, bool flag_relu) { int cnt = channel >> 4; int remain = channel & 15; - - for (int j = 0; j < num; ++j) { - const float *ptr_bias = bias; - float *ptr_out = out + j * channel; - - float32x4_t vout1; - float32x4_t vout2; - float32x4_t vout3; - float32x4_t vout4; - - for (int i = 0; i < cnt; ++i) { - float32x4_t vin1 = vld1q_f32(ptr_out); - float32x4_t vb1 = vld1q_f32(ptr_bias); - - float32x4_t vin2 = vld1q_f32(ptr_out + 4); - float32x4_t vb2 = vld1q_f32(ptr_bias + 4); - - float32x4_t vin3 = vld1q_f32(ptr_out + 8); - float32x4_t vb3 = vld1q_f32(ptr_bias + 8); - - float32x4_t vin4 = vld1q_f32(ptr_out + 12); - float32x4_t vb4 = vld1q_f32(ptr_bias + 12); - - vout1 = vaddq_f32(vin1, vb1); - vout2 = vaddq_f32(vin2, vb2); - vout3 = vaddq_f32(vin3, vb3); - vout4 = vaddq_f32(vin4, vb4); - - vst1q_f32(ptr_out, vout1); - vst1q_f32(ptr_out + 4, vout2); - vst1q_f32(ptr_out + 8, vout3); - vst1q_f32(ptr_out + 12, vout4); - - ptr_out += 16; - ptr_bias += 16; + if (flag_relu) { + float32x4_t vzero = vdupq_n_f32(0.f); + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = out + j * channel; + + for (int i = 0; i < cnt; ++i) { + float32x4_t vin1 = vld1q_f32(ptr_out); + float32x4_t vb1 = vld1q_f32(ptr_bias); + + float32x4_t vin2 = vld1q_f32(ptr_out + 4); + float32x4_t vb2 = vld1q_f32(ptr_bias + 4); + + float32x4_t vin3 = vld1q_f32(ptr_out + 8); + float32x4_t vb3 = vld1q_f32(ptr_bias + 8); + + float32x4_t vin4 = vld1q_f32(ptr_out + 12); + float32x4_t vb4 = vld1q_f32(ptr_bias + 12); + + float32x4_t vout1 = vaddq_f32(vin1, vb1); + float32x4_t vout2 = vaddq_f32(vin2, vb2); + float32x4_t vout3 = vaddq_f32(vin3, vb3); + float32x4_t vout4 = vaddq_f32(vin4, vb4); + + vout1 = vmaxq_f32(vout1, vzero); + vout2 = vmaxq_f32(vout2, vzero); + vout3 = vmaxq_f32(vout3, vzero); + vout4 = vmaxq_f32(vout4, vzero); + + vst1q_f32(ptr_out, vout1); + vst1q_f32(ptr_out + 4, vout2); + vst1q_f32(ptr_out + 8, vout3); + vst1q_f32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + for (int i = 0; i < remain; ++i) { + *ptr_out += *(ptr_bias++); + *ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f; + ptr_out++; + } } -#if 0 - if (cnt > 0) { - asm( - "1: \n" - "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" - "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" - "vadd.f32 q2, q0, q1 @ add bias\n" - "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" - "subs %[cnt], #1 @ loop count -1\n" - "bne 1b @ jump to main loop\n" - :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ - [cnt] "+r"(cnt) - : - :"q0", "q1", "q2" - ); - } -#endif - for (int i = 0; i < remain; ++i) { - *(ptr_out++) += *(ptr_bias++); + } else { + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = out + j * channel; + + for (int i = 0; i < cnt; ++i) { + float32x4_t vin1 = vld1q_f32(ptr_out); + float32x4_t vb1 = vld1q_f32(ptr_bias); + + float32x4_t vin2 = vld1q_f32(ptr_out + 4); + float32x4_t vb2 = vld1q_f32(ptr_bias + 4); + + float32x4_t vin3 = vld1q_f32(ptr_out + 8); + float32x4_t vb3 = vld1q_f32(ptr_bias + 8); + + float32x4_t vin4 = vld1q_f32(ptr_out + 12); + float32x4_t vb4 = vld1q_f32(ptr_bias + 12); + + float32x4_t vout1 = vaddq_f32(vin1, vb1); + float32x4_t vout2 = vaddq_f32(vin2, vb2); + float32x4_t vout3 = vaddq_f32(vin3, vb3); + float32x4_t vout4 = vaddq_f32(vin4, vb4); + + vst1q_f32(ptr_out, vout1); + vst1q_f32(ptr_out + 4, vout2); + vst1q_f32(ptr_out + 8, vout3); + vst1q_f32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + for (int i = 0; i < remain; ++i) { + *(ptr_out++) += *(ptr_bias++); + } } } } template <> -void fill_bias_fc(int *out, const int *bias, int num, int channel) { +void fill_bias_fc( + int *out, const int *bias, int num, int channel, bool flag_relu) { int cnt = channel >> 4; int remain = channel & 15; - - for (int j = 0; j < num; ++j) { - const int *ptr_bias = bias; - int *ptr_out = out + j * channel; - - int32x4_t vout1; - int32x4_t vout2; - int32x4_t vout3; - int32x4_t vout4; - - for (int i = 0; i < cnt; ++i) { - int32x4_t vin1 = vld1q_s32(ptr_out); - int32x4_t vb1 = vld1q_s32(ptr_bias); - - int32x4_t vin2 = vld1q_s32(ptr_out + 4); - int32x4_t vb2 = vld1q_s32(ptr_bias + 4); - - int32x4_t vin3 = vld1q_s32(ptr_out + 8); - int32x4_t vb3 = vld1q_s32(ptr_bias + 8); - - int32x4_t vin4 = vld1q_s32(ptr_out + 12); - int32x4_t vb4 = vld1q_s32(ptr_bias + 12); - - vout1 = vaddq_s32(vin1, vb1); - vout2 = vaddq_s32(vin2, vb2); - vout3 = vaddq_s32(vin3, vb3); - vout4 = vaddq_s32(vin4, vb4); - - vst1q_s32(ptr_out, vout1); - vst1q_s32(ptr_out + 4, vout2); - vst1q_s32(ptr_out + 8, vout3); - vst1q_s32(ptr_out + 12, vout4); - - ptr_out += 16; - ptr_bias += 16; - } - -#if 0 - if (cnt > 0) { - asm( - "1: \n" - "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" - "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" - "vadd.s32 q2, q0, q1 @ add bias\n" - "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" - "subs %[cnt], #1 @ loop count -1\n" - "bne 1b @ jump to main loop\n" - :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ - [cnt] "+r"(cnt) - : - :"q0", "q1", "q2" - ); + if (flag_relu) { + for (int j = 0; j < num; ++j) { + const int *ptr_bias = bias; + int *ptr_out = out + j * channel; + + int32x4_t vzero = vdupq_n_s32(0); + + for (int i = 0; i < cnt; ++i) { + int32x4_t vin1 = vld1q_s32(ptr_out); + int32x4_t vb1 = vld1q_s32(ptr_bias); + + int32x4_t vin2 = vld1q_s32(ptr_out + 4); + int32x4_t vb2 = vld1q_s32(ptr_bias + 4); + + int32x4_t vin3 = vld1q_s32(ptr_out + 8); + int32x4_t vb3 = vld1q_s32(ptr_bias + 8); + + int32x4_t vin4 = vld1q_s32(ptr_out + 12); + int32x4_t vb4 = vld1q_s32(ptr_bias + 12); + + int32x4_t vout1 = vaddq_s32(vin1, vb1); + int32x4_t vout2 = vaddq_s32(vin2, vb2); + int32x4_t vout3 = vaddq_s32(vin3, vb3); + int32x4_t vout4 = vaddq_s32(vin4, vb4); + + vout1 = vmaxq_s32(vout1, vzero); + vout2 = vmaxq_s32(vout2, vzero); + vout3 = vmaxq_s32(vout3, vzero); + vout4 = vmaxq_s32(vout4, vzero); + + vst1q_s32(ptr_out, vout1); + vst1q_s32(ptr_out + 4, vout2); + vst1q_s32(ptr_out + 8, vout3); + vst1q_s32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + for (int i = 0; i < remain; ++i) { + *ptr_out += *(ptr_bias++); + *ptr_out = *ptr_out > 0 ? *ptr_out : 0; + ptr_out++; + } } -#endif - for (int i = 0; i < remain; ++i) { - *(ptr_out++) += *(ptr_bias++); + } else { + for (int j = 0; j < num; ++j) { + const int *ptr_bias = bias; + int *ptr_out = out + j * channel; + + int32x4_t vout1; + int32x4_t vout2; + int32x4_t vout3; + int32x4_t vout4; + + for (int i = 0; i < cnt; ++i) { + int32x4_t vin1 = vld1q_s32(ptr_out); + int32x4_t vb1 = vld1q_s32(ptr_bias); + + int32x4_t vin2 = vld1q_s32(ptr_out + 4); + int32x4_t vb2 = vld1q_s32(ptr_bias + 4); + + int32x4_t vin3 = vld1q_s32(ptr_out + 8); + int32x4_t vb3 = vld1q_s32(ptr_bias + 8); + + int32x4_t vin4 = vld1q_s32(ptr_out + 12); + int32x4_t vb4 = vld1q_s32(ptr_bias + 12); + + vout1 = vaddq_s32(vin1, vb1); + vout2 = vaddq_s32(vin2, vb2); + vout3 = vaddq_s32(vin3, vb3); + vout4 = vaddq_s32(vin4, vb4); + + vst1q_s32(ptr_out, vout1); + vst1q_s32(ptr_out + 4, vout2); + vst1q_s32(ptr_out + 8, vout3); + vst1q_s32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + for (int i = 0; i < remain; ++i) { + *(ptr_out++) += *(ptr_bias++); + } } } } diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index 6fb64138221ea4ca4d70ddf04f53b5bd4cdf4a92..e975160c97b6e7396ab208805a4d685586ac00c8 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -356,7 +356,8 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { } template -void fill_bias_fc(T* tensor, const T* bias, int num, int channel); +void fill_bias_fc( + T* tensor, const T* bias, int num, int channel, bool flag_relu); template inline float32x4_t vactive_f32(const float32x4_t& x) { diff --git a/lite/kernels/arm/fc_compute.cc b/lite/kernels/arm/fc_compute.cc index 3b61e0b0135fbede8f3322e9ab486d351ad466e1..cc119d3802ef1b3a92002767e96845e4ddfba500 100644 --- a/lite/kernels/arm/fc_compute.cc +++ b/lite/kernels/arm/fc_compute.cc @@ -93,6 +93,10 @@ void FcCompute::Run() { if (flag_trans_bias_) { b_data = bias_.data(); } + bool flag_relu = false; + if (param.activation_type == "relu") { + flag_relu = true; + } if (flag_gemm_) { operators::ActivationParam act_param; act_param.has_active = false; @@ -115,7 +119,7 @@ void FcCompute::Run() { &ctx); if (param.bias) { CHECK_EQ(param.bias->numel(), n_); - lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_); + lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu); } } else { for (int i = 0; i < m_; ++i) { @@ -129,7 +133,7 @@ void FcCompute::Run() { k_, param.bias != nullptr, b_data, - false, + flag_relu, &ctx); } } @@ -148,6 +152,10 @@ void FcCompute::Run() { if (flag_trans_bias_) { b_data = bias_.data(); } + bool flag_relu = false; + if (param.activation_type == "relu") { + flag_relu = true; + } if (flag_gemm_) { lite::arm::math::gemm_s8(false, false, @@ -164,7 +172,7 @@ void FcCompute::Run() { &ctx); if (param.bias) { CHECK_EQ(param.bias->numel(), n_); - lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_); + lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu); } } else { for (int i = 0; i < m_; ++i) { @@ -179,7 +187,7 @@ void FcCompute::Run() { scale_.data(), param.bias != nullptr, b_data, - false, + flag_relu, &ctx); } } @@ -198,6 +206,10 @@ void FcCompute::Run() { if (flag_trans_bias_) { b_data = bias_.data(); } + bool flag_relu = false; + if (param.activation_type == "relu") { + flag_relu = true; + } if (flag_gemm_) { CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel " "must not have bias"; @@ -211,7 +223,7 @@ void FcCompute::Run() { o_data, nullptr, false, - false, + flag_relu, scale_.data(), &ctx); } else { @@ -227,7 +239,7 @@ void FcCompute::Run() { scale_.data(), param.bias != nullptr, b_data, - false, + flag_relu, &ctx); } }