未验证 提交 1a64347a 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add scale+relu/relu6/leakyrelu fusion (#3461)

* add scale+relu/relu6/leakyrelu test=develop
* fix format, test=develop
上级 4d495329
......@@ -52,6 +52,7 @@ USE_MIR_PASS(mlu_postprocess_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass);
USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
......
......@@ -27,29 +27,421 @@ void scale<float>(
int remain = num % 16;
float32x4_t vscale = vdupq_n_f32(scale);
float32x4_t vbias = vdupq_n_f32(bias);
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(
"1: \n"
"ld1 {v4.4s}, [%[din]], #16 \n"
"and v8.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v5.4s}, [%[din]], #16 \n"
"and v9.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v6.4s}, [%[din]], #16 \n"
"and v10.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v7.4s}, [%[din]], #16 \n"
"and v11.16b, %[vbias].16b, %[vbias].16b \n"
"fmla v8.4s, v4.4s, %[vscale].4s \n"
"fmla v9.4s, v5.4s, %[vscale].4s \n"
"fmla v10.4s, v6.4s, %[vscale].4s \n"
"fmla v11.4s, v7.4s, %[vscale].4s \n"
"stp q8, q9, [%[dout]], #32 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"stp q10, q11, [%[dout]], #32 \n"
"bne 1b \n"
"0: \n"
: [dout] "+r"(dout), [din] "+r"(din), [cnt] "+r"(cnt)
: [vscale] "w"(vscale), [vbias] "w"(vbias)
: "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
#else
asm volatile(
"1: @ loop header \n"
"vld1.32 {d8-d11}, [%[din]]! @ load din 0 \n"
"vand.32 q8, %q[vbias], %q[vbias] @ out bias \n"
"vand.32 q9, %q[vbias], %q[vbias] @ out bias \n"
"vld1.32 {d12-d15}, [%[din]]! @ load din 0 \n"
"vand.32 q10, %q[vbias], %q[vbias] @ out bias \n"
"vand.32 q11, %q[vbias], %q[vbias] @ out bias \n"
"vmla.f32 q8, q4, %q[vscale] @ mla \n"
"vmla.f32 q9, q5, %q[vscale] @ mla \n"
"vmla.f32 q10, q6, %q[vscale] @ mla \n"
"vmla.f32 q11, q7, %q[vscale] @ mla \n"
"vst1.32 {d16-d19}, [%[dout]]! @ store result, add pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"vst1.32 {d20-d23}, [%[dout]]! @ store result, add pointer\n"
"bne 1b @ jump to main loop start "
"2: \n"
: [dout] "+r"(dout), [din] "+r"(din), [cnt] "+r"(cnt)
: [vscale] "w"(vscale), [vbias] "w"(vbias)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
#endif
}
if (remain > 0) {
for (int i = 0; i < remain; i++) {
*dout = *din * scale + bias;
dout++;
din++;
}
}
}
template <>
void scale_relu<float>(
const float* din, float* dout, int num, float scale, float bias) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vscale = vdupq_n_f32(scale);
float32x4_t vbias = vdupq_n_f32(bias);
float32x4_t vzero = vdupq_n_f32(0.f);
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(
"1: \n"
"ld1 {v4.4s}, [%[din]], #16 \n"
"and v8.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v5.4s}, [%[din]], #16 \n"
"and v9.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v6.4s}, [%[din]], #16 \n"
"and v10.16b, %[vbias].16b, %[vbias].16b\n"
"ld1 {v7.4s}, [%[din]], #16 \n"
"and v11.16b, %[vbias].16b, %[vbias].16b\n"
"fmla v8.4s, v4.4s, %[vscale].4s \n"
"fmla v9.4s, v5.4s, %[vscale].4s \n"
"fmla v10.4s, v6.4s, %[vscale].4s \n"
"fmla v11.4s, v7.4s, %[vscale].4s \n"
"fmax v8.4s, v8.4s, %[vzero].4s \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"fmax v10.4s, v10.4s, %[vzero].4s \n"
"fmax v11.4s, v11.4s, %[vzero].4s \n"
"stp q8, q9, [%[dout]], #32 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"stp q10, q11, [%[dout]], #32 \n"
"bne 1b \n"
"0: \n"
: [dout] "+r"(dout), [din] "+r"(din), [cnt] "+r"(cnt)
: [vscale] "w"(vscale), [vbias] "w"(vbias), [vzero] "w"(vzero)
: "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
#else
asm volatile(
"1: @ loop header \n"
"vld1.32 {d8-d11}, [%[din]]! @ load din 0 \n"
"vand.32 q8, %q[vbias], %q[vbias] @ out bias \n"
"vand.32 q9, %q[vbias], %q[vbias] @ out bias \n"
"vld1.32 {d12-d15}, [%[din]]! @ load din 0 \n"
"vand.32 q10, %q[vbias], %q[vbias] @ out bias \n"
"vand.32 q11, %q[vbias], %q[vbias] @ out bias \n"
"vmla.f32 q8, q4, %q[vscale] @ mla \n"
"vmla.f32 q9, q5, %q[vscale] @ mla \n"
"vmla.f32 q10, q6, %q[vscale] @ mla \n"
"vmla.f32 q11, q7, %q[vscale] @ mla \n"
"vmax.f32 q8, q8, %q[vzero] @ relu \n"
"vmax.f32 q9, q9, %q[vzero] @ relu \n"
"vmax.f32 q10, q10, %q[vzero] @ relu \n"
"vmax.f32 q11, q11, %q[vzero] @ relu \n"
"vst1.32 {d16-d19}, [%[dout]]! @ store result, add pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"vst1.32 {d20-d23}, [%[dout]]! @ store result, add pointer\n"
"bne 1b @ jump to main loop start "
"2: \n"
: [dout] "+r"(dout), [din] "+r"(din), [cnt] "+r"(cnt)
: [vscale] "w"(vscale), [vbias] "w"(vbias), [vzero] "w"(vzero)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
#endif
}
if (remain > 0) {
for (int i = 0; i < remain; i++) {
*dout = *din * scale + bias;
*dout = *dout > 0.f ? *dout : 0.f;
dout++;
din++;
}
}
}
template <>
void scale_relu6<float>(const float* din,
float* dout,
int num,
float scale,
float bias,
float alpha) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vscale = vdupq_n_f32(scale);
float32x4_t vbias = vdupq_n_f32(bias);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t valpha = vdupq_n_f32(alpha);
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(
"1: \n"
"ld1 {v4.4s}, [%[din]], #16 \n"
"and v8.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v5.4s}, [%[din]], #16 \n"
"and v9.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v6.4s}, [%[din]], #16 \n"
"and v10.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v7.4s}, [%[din]], #16 \n"
"and v11.16b, %[vbias].16b, %[vbias].16b \n"
"fmla v8.4s, v4.4s, %[vscale].4s \n"
"fmla v9.4s, v5.4s, %[vscale].4s \n"
"fmla v10.4s, v6.4s, %[vscale].4s \n"
"fmla v11.4s, v7.4s, %[vscale].4s \n"
"fmax v8.4s, v8.4s, %[vzero].4s \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"fmax v10.4s, v10.4s, %[vzero].4s \n"
"fmax v11.4s, v11.4s, %[vzero].4s \n"
"fmin v8.4s, v8.4s, %[valpha].4s \n"
"fmin v9.4s, v9.4s, %[valpha].4s \n"
"fmin v10.4s, v10.4s, %[valpha].4s \n"
"fmin v11.4s, v11.4s, %[valpha].4s \n"
"stp q8, q9, [%[dout]], #32 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"stp q10, q11, [%[dout]], #32 \n"
"bne 1b \n"
"0: \n"
: [dout] "+r"(dout), [din] "+r"(din), [cnt] "+r"(cnt)
: [vscale] "w"(vscale),
[vbias] "w"(vbias),
[vzero] "w"(vzero),
[valpha] "w"(valpha)
: "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
#else
asm volatile(
"1: @ loop header \n"
"vld1.32 {d8-d11}, [%[din]]! @ load din 0 \n"
"vand.32 q8, %q[vbias], %q[vbias] @ out bias \n"
"vand.32 q9, %q[vbias], %q[vbias] @ out bias \n"
"vld1.32 {d12-d15}, [%[din]]! @ load din 0 \n"
"vand.32 q10, %q[vbias], %q[vbias] @ out bias \n"
"vand.32 q11, %q[vbias], %q[vbias] @ out bias \n"
"vmla.f32 q8, q4, %q[vscale] @ mla \n"
"vmla.f32 q9, q5, %q[vscale] @ mla \n"
"vmla.f32 q10, q6, %q[vscale] @ mla \n"
"vmla.f32 q11, q7, %q[vscale] @ mla \n"
"vmax.f32 q8, q8, %q[vzero] @ relu \n"
"vmax.f32 q9, q9, %q[vzero] @ relu \n"
"vmax.f32 q10, q10, %q[vzero] @ relu \n"
"vmax.f32 q11, q11, %q[vzero] @ relu \n"
"vmin.f32 q8, q8, %q[valpha] @ relu \n"
"vmin.f32 q9, q9, %q[valpha] @ relu \n"
"vmin.f32 q10, q10, %q[valpha] @ relu \n"
"vmin.f32 q11, q11, %q[valpha] @ relu \n"
"vst1.32 {d16-d19}, [%[dout]]! @ store result, add pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"vst1.32 {d20-d23}, [%[dout]]! @ store result, add pointer\n"
"bne 1b @ jump to main loop start "
"2: \n"
: [dout] "+r"(dout), [din] "+r"(din), [cnt] "+r"(cnt)
: [vscale] "w"(vscale),
[vbias] "w"(vbias),
[vzero] "w"(vzero),
[valpha] "w"(valpha)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
#endif
}
if (remain > 0) {
for (int i = 0; i < remain; i++) {
*dout = *din * scale + bias;
*dout = *dout > 0.f ? (*dout < alpha ? *dout : alpha) : 0.f;
dout++;
din++;
}
}
}
template <>
void scale_leaky_relu<float>(const float* din,
float* dout,
int num,
float scale,
float bias,
float alpha) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vscale = vdupq_n_f32(scale);
float32x4_t vbias = vdupq_n_f32(bias);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t valpha = vdupq_n_f32(alpha);
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(
"1: \n"
"ld1 {v4.4s}, [%[din]], #16 \n"
"and v8.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v5.4s}, [%[din]], #16 \n"
"and v9.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v6.4s}, [%[din]], #16 \n"
"and v10.16b, %[vbias].16b, %[vbias].16b \n"
"ld1 {v7.4s}, [%[din]], #16 \n"
"and v11.16b, %[vbias].16b, %[vbias].16b \n"
"fmla v8.4s, v4.4s, %[vscale].4s \n"
"fmla v9.4s, v5.4s, %[vscale].4s \n"
"fmla v10.4s, v6.4s, %[vscale].4s \n"
"fmla v11.4s, v7.4s, %[vscale].4s \n"
"fcmge v12.4s, v8.4s, %[vzero].4s \n"
"fmul v16.4s, v8.4s, %[valpha].4s \n"
"fcmge v13.4s, v9.4s, %[vzero].4s \n"
"fmul v17.4s, v9.4s, %[valpha].4s \n"
"fcmge v14.4s, v10.4s, %[vzero].4s \n"
"fmul v18.4s, v10.4s, %[valpha].4s \n"
"fcmge v15.4s, v11.4s, %[vzero].4s \n"
"fmul v19.4s, v11.4s, %[valpha].4s \n"
"bif v8.16b, v16.16b, v12.16b \n" /* choose*/
"bif v9.16b, v17.16b, v13.16b \n" /* choose*/
"bif v10.16b, v18.16b, v14.16b \n" /* choose*/
"bif v11.16b, v19.16b, v15.16b \n" /* choose*/
"stp q8, q9, [%[dout]], #32 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"stp q10, q11, [%[dout]], #32 \n"
"bne 1b \n"
"0: \n"
: [dout] "+r"(dout), [din] "+r"(din), [cnt] "+r"(cnt)
: [vscale] "w"(vscale),
[vbias] "w"(vbias),
[vzero] "w"(vzero),
[valpha] "w"(valpha)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(
"1: @ loop header \n"
"vld1.32 {d8-d11}, [%[din]]! @ load din 0 \n"
"vand.32 q8, %q[vbias], %q[vbias] @ out bias \n"
"vand.32 q9, %q[vbias], %q[vbias] @ out bias \n"
"vld1.32 {d12-d15}, [%[din]]! @ load din 0 \n"
"vand.32 q10, %q[vbias], %q[vbias] @ out bias \n"
"vand.32 q11, %q[vbias], %q[vbias] @ out bias \n"
"vmla.f32 q8, q4, %q[vscale] @ mla \n"
"vmla.f32 q9, q5, %q[vscale] @ mla \n"
"vmla.f32 q10, q6, %q[vscale] @ mla \n"
"vmla.f32 q11, q7, %q[vscale] @ mla \n"
"vcge.f32 q12, q8, %q[vzero] @ relu \n"
"vmul.f32 q14, q8, %q[valpha] @ mul \n"
"vcge.f32 q13, q9, %q[vzero] @ relu \n"
"vmul.f32 q15, q9, %q[valpha] @ mul \n"
"vbif q8, q14, q12 @ choose \n"
"vbif q9, q15, q13 @ choose \n"
"vcge.f32 q12, q10, %q[vzero] @ relu \n"
"vmul.f32 q14, q10, %q[valpha] @ mul \n"
"vcge.f32 q13, q11, %q[vzero] @ relu \n"
"vmul.f32 q15, q11, %q[valpha] @ mul \n"
"vst1.32 {d16-d19}, [%[dout]]! @ store result, add pointer\n"
"vbif q10, q14, q12 @ choose \n"
"vbif q11, q15, q13 @ choose \n"
"subs %[cnt], #1 @ loop count minus 1\n"
"vst1.32 {d20-d23}, [%[dout]]! @ store result, add pointer\n"
"bne 1b @ jump to main loop start "
"2: \n"
: [dout] "+r"(dout), [din] "+r"(din), [cnt] "+r"(cnt)
: [vscale] "w"(vscale),
[vbias] "w"(vbias),
[vzero] "w"(vzero),
[valpha] "w"(valpha)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
if (remain > 0) {
for (int i = 0; i < remain; i++) {
*dout = *din * scale + bias;
*dout = *dout > 0.f ? *dout : (*dout * alpha);
dout++;
din++;
}
}
}
template <>
void scale<int>(const int* din, int* dout, int num, int scale, int bias) {
int cnt = num >> 4;
int remain = num % 16;
int32x4_t vscale = vdupq_n_s32(scale);
int32x4_t vbias = vdupq_n_s32(bias);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
const int* din_ptr = din + (i << 4);
int* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
int32x4_t din0 = vld1q_s32(din_ptr);
int32x4_t din1 = vld1q_s32(din_ptr + 4);
int32x4_t din2 = vld1q_s32(din_ptr + 8);
int32x4_t din3 = vld1q_s32(din_ptr + 12);
float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale);
float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale);
float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale);
float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale);
int32x4_t vsum1 = vmlaq_s32(vbias, din0, vscale);
int32x4_t vsum2 = vmlaq_s32(vbias, din1, vscale);
int32x4_t vsum3 = vmlaq_s32(vbias, din2, vscale);
int32x4_t vsum4 = vmlaq_s32(vbias, din3, vscale);
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
vst1q_s32(dout_ptr, vsum1);
vst1q_s32(dout_ptr + 4, vsum2);
vst1q_s32(dout_ptr + 8, vsum3);
vst1q_s32(dout_ptr + 12, vsum4);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
const int* din_ptr = din + (cnt << 4);
int* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr * scale + bias;
dout_ptr++;
......@@ -59,11 +451,110 @@ void scale<float>(
}
template <>
void scale<int>(const int* din, int* dout, int num, int scale, int bias) {
void scale_relu<int>(const int* din, int* dout, int num, int scale, int bias) {
int cnt = num >> 4;
int remain = num % 16;
int32x4_t vscale = vdupq_n_s32(scale);
int32x4_t vbias = vdupq_n_s32(bias);
int32x4_t vzero = vdupq_n_s32(0);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const int* din_ptr = din + (i << 4);
int* dout_ptr = dout + (i << 4);
int32x4_t din0 = vld1q_s32(din_ptr);
int32x4_t din1 = vld1q_s32(din_ptr + 4);
int32x4_t din2 = vld1q_s32(din_ptr + 8);
int32x4_t din3 = vld1q_s32(din_ptr + 12);
int32x4_t vsum1 = vmlaq_s32(vbias, din0, vscale);
int32x4_t vsum2 = vmlaq_s32(vbias, din1, vscale);
int32x4_t vsum3 = vmlaq_s32(vbias, din2, vscale);
int32x4_t vsum4 = vmlaq_s32(vbias, din3, vscale);
vsum1 = vmaxq_s32(vsum1, vzero);
vsum2 = vmaxq_s32(vsum2, vzero);
vsum3 = vmaxq_s32(vsum3, vzero);
vsum4 = vmaxq_s32(vsum4, vzero);
vst1q_s32(dout_ptr, vsum1);
vst1q_s32(dout_ptr + 4, vsum2);
vst1q_s32(dout_ptr + 8, vsum3);
vst1q_s32(dout_ptr + 12, vsum4);
}
if (remain > 0) {
const int* din_ptr = din + (cnt << 4);
int* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr * scale + bias;
*dout_ptr = *dout_ptr > 0 ? *dout_ptr : 0;
dout_ptr++;
din_ptr++;
}
}
}
template <>
void scale_relu6<int>(
const int* din, int* dout, int num, int scale, int bias, int alpha) {
int cnt = num >> 4;
int remain = num % 16;
int32x4_t vscale = vdupq_n_s32(scale);
int32x4_t vbias = vdupq_n_s32(bias);
int32x4_t vzero = vdupq_n_s32(0);
int32x4_t valpha = vdupq_n_s32(alpha);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const int* din_ptr = din + (i << 4);
int* dout_ptr = dout + (i << 4);
int32x4_t din0 = vld1q_s32(din_ptr);
int32x4_t din1 = vld1q_s32(din_ptr + 4);
int32x4_t din2 = vld1q_s32(din_ptr + 8);
int32x4_t din3 = vld1q_s32(din_ptr + 12);
int32x4_t vsum1 = vmlaq_s32(vbias, din0, vscale);
int32x4_t vsum2 = vmlaq_s32(vbias, din1, vscale);
int32x4_t vsum3 = vmlaq_s32(vbias, din2, vscale);
int32x4_t vsum4 = vmlaq_s32(vbias, din3, vscale);
vsum1 = vmaxq_s32(vsum1, vzero);
vsum2 = vmaxq_s32(vsum2, vzero);
vsum3 = vmaxq_s32(vsum3, vzero);
vsum4 = vmaxq_s32(vsum4, vzero);
vsum1 = vminq_s32(vsum1, valpha);
vsum2 = vminq_s32(vsum2, valpha);
vsum3 = vminq_s32(vsum3, valpha);
vsum4 = vminq_s32(vsum4, valpha);
vst1q_s32(dout_ptr, vsum1);
vst1q_s32(dout_ptr + 4, vsum2);
vst1q_s32(dout_ptr + 8, vsum3);
vst1q_s32(dout_ptr + 12, vsum4);
}
if (remain > 0) {
const int* din_ptr = din + (cnt << 4);
int* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr * scale + bias;
*dout_ptr = *dout_ptr > 0 ? (*dout_ptr > alpha ? alpha : *dout_ptr) : 0;
dout_ptr++;
din_ptr++;
}
}
}
template <>
void scale_leaky_relu<int>(
const int* din, int* dout, int num, int scale, int bias, int alpha) {
int cnt = num >> 4;
int remain = num % 16;
int32x4_t vscale = vdupq_n_s32(scale);
int32x4_t vbias = vdupq_n_s32(bias);
int32x4_t vzero = vdupq_n_s32(0);
int32x4_t valpha = vdupq_n_s32(alpha);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const int* din_ptr = din + (i << 4);
......@@ -79,16 +570,33 @@ void scale<int>(const int* din, int* dout, int num, int scale, int bias) {
int32x4_t vsum3 = vmlaq_s32(vbias, din2, vscale);
int32x4_t vsum4 = vmlaq_s32(vbias, din3, vscale);
uint32x4_t v1 = vcgeq_s32(vsum1, vzero);
uint32x4_t v2 = vcgeq_s32(vsum2, vzero);
uint32x4_t v3 = vcgeq_s32(vsum3, vzero);
uint32x4_t v4 = vcgeq_s32(vsum4, vzero);
int32x4_t v11 = vmulq_s32(vsum1, valpha);
int32x4_t v21 = vmulq_s32(vsum1, valpha);
int32x4_t v31 = vmulq_s32(vsum1, valpha);
int32x4_t v41 = vmulq_s32(vsum1, valpha);
vsum1 = vbslq_s32(v1, vsum1, v11);
vsum2 = vbslq_s32(v2, vsum2, v21);
vsum3 = vbslq_s32(v3, vsum3, v31);
vsum4 = vbslq_s32(v4, vsum4, v41);
vst1q_s32(dout_ptr, vsum1);
vst1q_s32(dout_ptr + 4, vsum2);
vst1q_s32(dout_ptr + 8, vsum3);
vst1q_s32(dout_ptr + 12, vsum4);
}
if (remain > 0) {
const int* din_ptr = din + (cnt << 4);
int* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr * scale + bias;
*dout_ptr = *dout_ptr > 0 ? *dout_ptr : (*dout_ptr) * alpha;
dout_ptr++;
din_ptr++;
}
......
......@@ -40,6 +40,15 @@ void scale_compute_basic(const operators::ScaleParam& param) {
template <typename T>
void scale(const T* din, T* dout, int num, T scale, T bias);
template <typename T>
void scale_relu(const T* din, T* dout, int num, T scale, T bias);
template <typename T>
void scale_relu6(const T* din, T* dout, int num, T scale, T bias, T alpha);
template <typename T>
void scale_leaky_relu(const T* din, T* dout, int num, T scale, T bias, T alpha);
template <typename T>
void scale(const T* din,
T* dout,
......
......@@ -21,6 +21,7 @@ lite_cc_library(mir_passes
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc
fusion/scale_activation_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
......
......@@ -31,6 +31,9 @@ lite_cc_library(fuse_interpolate
lite_cc_library(fuse_sequence_pool_concat
SRCS sequence_pool_concat_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_scale_activation
SRCS scale_activation_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
......@@ -44,6 +47,7 @@ set(mir_fusers
fuse_transpose_softmax_transpose
fuse_interpolate
fuse_sequence_pool_concat
fuse_scale_activation
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scale_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/scale_activation_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ScaleActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for (auto act_type : {"relu", "relu6", "leaky_relu"}) {
fusion::ScaleActivationFuser fuser(act_type);
fuser(graph.get());
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_scale_activation_fuse_pass,
paddle::lite::mir::ScaleActivationFusePass)
.BindTargets({TARGET(kARM)})
.BindKernel("scale");
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ScaleActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scale_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ScaleActivationFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("scale", "X")->AsInput();
// create op nodes
auto* scale =
OpNode("scale", "scale")->assert_is_op("scale")->AsIntermediate();
auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes
auto* scale_out = VarNode("scale_out")
->assert_is_op_output("scale", "Out")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
// create output node
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
*x >> *scale >> *scale_out;
*scale_out >> *act >> *out;
}
void ScaleActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto scale_op = LiteOpRegistry::Global().Create("scale");
auto scale = matched.at("scale")->stmt()->op();
auto* scope = scale->scope();
auto& valid_places = scale->valid_places();
scale_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(scale_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ScaleActivationFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("scale")->stmt()->op_info();
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info();
op_desc.SetAttr("activation_type", act_type_);
if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true);
} else if (act_type_ == "relu6") {
float alpha = act_op_desc.GetAttr<float>("threshold");
op_desc.SetAttr("alpha", alpha);
} else if (act_type_ == "leaky_relu") {
float alpha = act_op_desc.GetAttr<float>("alpha");
op_desc.SetAttr("alpha", alpha);
}
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ScaleActivationFuser : public FuseBase {
public:
explicit ScaleActivationFuser(const std::string& act_type) {
act_type_ = act_type;
}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string act_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -71,6 +71,7 @@ class Optimizer {
"identity_scale_eliminate_pass", //
"elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", //
......
......@@ -31,7 +31,18 @@ void ScaleCompute<T, PType>::Run() {
if (!param.bias_after_scale) {
bias *= scale;
}
T alpha = param.alpha;
if (param.activation_type == "") { // no act
lite::arm::math::scale<T>(x_data, output_data, num, scale, bias);
} else if (param.activation_type == "relu") { // do relu
lite::arm::math::scale_relu<T>(x_data, output_data, num, scale, bias);
} else if (param.activation_type == "relu6") { // do relu6
lite::arm::math::scale_relu6<T>(
x_data, output_data, num, scale, bias, alpha);
} else if (param.activation_type == "leaky_relu") { // do leaky_relu
lite::arm::math::scale_leaky_relu<T>(
x_data, output_data, num, scale, bias, alpha);
}
if (!param.x->lod().empty()) {
param.output->set_lod(param.x->lod());
}
......
......@@ -244,6 +244,9 @@ struct ScaleParam : ParamBase {
float scale{1.};
float bias{};
bool bias_after_scale{true};
std::string activation_type{""};
bool fuse_relu{false};
float alpha{6.};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
......
......@@ -38,6 +38,20 @@ bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.scale = op_desc.GetAttr<float>("scale");
param_.bias = op_desc.GetAttr<float>("bias");
param_.bias_after_scale = op_desc.GetAttr<bool>("bias_after_scale");
if (op_desc.HasAttr("activation_type")) {
auto act_type = op_desc.GetAttr<std::string>("activation_type");
param_.activation_type = act_type;
if (act_type == "relu") {
param_.fuse_relu = true;
} else if (act_type == "relu6") {
param_.alpha = op_desc.GetAttr<float>("alpha"); // 6.f
} else if (act_type == "leaky_relu") {
param_.alpha = op_desc.GetAttr<float>("alpha");
} else {
CHECK(false)
<< "The fused conv only supports fuse with relu and leaky relu";
}
}
CHECK(param_.x);
CHECK(param_.output);
return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册