提交 256e69ec 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add conv+relu6/leakyRelu in conv_activation (#2833)

* fix con+relu6/leakyRelu fusion in Fp32, test=develop

* note m=397 in sgemv_int8 ut, test=develop

* fix ios build error. test=develop
上级 37ddb91a
......@@ -109,7 +109,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size);
float* pre_out = pre_din + pre_in_size;
#else
float pre_din = tmp_din;
float* pre_din = tmp_din;
float* pre_out = pre_din + pre_in_size;
#endif
prepack_input_nxwc4_dw(
......
......@@ -46,6 +46,7 @@ void fp32_to_int8(const float* din,
float inv_scale = 1.f / scale[j % axis_size];
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vscale = vdupq_n_f32(inv_scale);
float32x4_t vmax = vdupq_n_f32(-127.f);
float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f);
const float* din_c = din + j * inner_size;
......@@ -63,6 +64,14 @@ void fp32_to_int8(const float* din,
"fmul v5.4s, v1.4s, %[scale].4s \n"
"fmul v6.4s, v2.4s, %[scale].4s \n"
"fmul v7.4s, v3.4s, %[scale].4s \n"
"fcmge v8.4s, v4.4s, %[vmax].4s \n"
"fcmge v9.4s, v5.4s, %[vmax].4s \n"
"fcmge v10.4s, v6.4s, %[vmax].4s \n"
"fcmge v11.4s, v7.4s, %[vmax].4s \n"
"bif v4.16b, %[vmax].16b, v8.16b \n"
"bif v5.16b, %[vmax].16b, v9.16b \n"
"bif v6.16b, %[vmax].16b, v10.16b \n"
"bif v7.16b, %[vmax].16b, v11.16b \n"
"ldp q0, q1, [%[in]], #32 \n"
"subs %[cnt], %[cnt], #1 \n"
"FCVTAS v8.4s, v4.4s \n"
......@@ -79,7 +88,7 @@ void fp32_to_int8(const float* din,
"str q8, [%[out]], #16 \n"
"bne 0b \n"
: [in] "+r"(din_ptr), [out] "+r"(dout_ptr), [cnt] "+r"(cnt_loop)
: [scale] "w"(vscale)
: [scale] "w"(vscale), [vmax] "w"(vmax)
: "v0",
"v1",
"v2",
......@@ -104,15 +113,23 @@ void fp32_to_int8(const float* din,
"vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n"
"vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n"
"vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n"
"vbif.f32 q4, %q[vnoff], q8 @ get right offset\n"
"vcgt.f32 q8, q3, %q[vzero] @ get mask > 0, in3\n"
"vbif.f32 q5, %q[vnoff], q9 @ get right offset\n"
"vbif.f32 q6, %q[vnoff], q10 @ get right offset\n"
"vbif.f32 q7, %q[vnoff], q11 @ get right offset\n"
"vbif.f32 q7, %q[vnoff], q8 @ get right offset\n"
"vmla.f32 q4, q0, %q[vscale] @ mul scale\n"
"vmla.f32 q5, q1, %q[vscale] @ mul scale\n"
"vmla.f32 q6, q2, %q[vscale] @ mul scale\n"
"vmla.f32 q7, q3, %q[vscale] @ mul scale\n"
"vcge.f32 q8, q4, %q[vmax] @ q4 >= vmax \n"
"vcge.f32 q9, q5, %q[vmax] @ q4 >= vmax \n"
"vcge.f32 q10, q6, %q[vmax] @ q4 >= vmax \n"
"vbif q4, %q[vmax], q8 @ choose \n"
"vcge.f32 q8, q7, %q[vmax] @ q4 >= vmax \n"
"vbif q5, %q[vmax], q9 @ choose \n"
"vbif q6, %q[vmax], q10 @ choose \n"
"vbif q7, %q[vmax], q8 @ choose \n"
"vcvt.s32.f32 q0, q4 @ cvt to int32\n"
"vcvt.s32.f32 q1, q5 @ cvt to int32\n"
"vcvt.s32.f32 q2, q6 @ cvt to int32\n"
......@@ -133,25 +150,16 @@ void fp32_to_int8(const float* din,
: [vscale] "w"(vscale),
[vpoff] "w"(vpoff),
[vnoff] "w"(vnoff),
[vzero] "w"(vzero)
: "q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11");
[vzero] "w"(vzero),
[vmax] "w"(vmax)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10");
#endif
}
const float* din_r = din_c + 16 * cnt;
signed char* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int8_t>(roundf(inv_scale * din_r[i]));
dout_r[i] = dout_r[i] < -127 ? -127 : dout_r[i];
}
}
}
......
......@@ -29,6 +29,11 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
act_types.push_back("leaky_relu");
break;
}
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
act_types.push_back("relu6");
act_types.push_back("leaky_relu");
break;
}
}
for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) {
for (auto act_type : act_types) {
......
......@@ -285,7 +285,7 @@ TEST(TestLiteGemvInt8, gemv_prepacked_int8) {
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm test";
for (auto& m : {1, 3, 8, 32, 397}) {
for (auto& m : {1, 3, 8, 32}) { // ,397
for (auto& n : {1, 3, 13, 141, 512, 789}) {
for (auto& tra : {false}) {
for (auto& has_bias : {false, true}) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册