提交 c0af965c 编写于 作者: H HappyAngel 提交者: xiaogang

[arm]add gemm + relu6/leakyrelu fusion (#2674)

add gemm + relu6/leakyrelu fusion
上级 7a8118b0
...@@ -924,58 +924,58 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -924,58 +924,58 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\ \
"st1 {v15.4s}, [%[doutr3]], #16 \n" "st1 {v15.4s}, [%[doutr3]], #16 \n"
#define RIGHT_RESULT_S1_RELU6 \ #define RIGHT_RESULT_S1_RELU6 \
"fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \
\ \
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\ \
"fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \
\ \
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\ \
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"bif v12.16b, v22.16b, v18.16b \n" \ "bif v12.16b, v22.16b, v18.16b \n" \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \
\ \
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" \ "st1 {v12.4s}, [%[doutr0]], #16 \n" \
\ \
"fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \
\ \
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\ \
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \
"bif v13.16b, v23.16b, v18.16b \n" \ "bif v13.16b, v23.16b, v18.16b \n" \
\ \
"fmla v15.4s , v10.4s, v20.s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\ \
"fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \
\ \
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\ \
"fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \
\ \
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\ \
"bif v14.16b, v24.16b, v18.16b \n" \ "bif v14.16b, v24.16b, v18.16b \n" \
"fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \
\ \
"st1 {v14.4s}, [%[doutr2]], #16 \n" \ "st1 {v14.4s}, [%[doutr2]], #16 \n" \
\ \
"fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \
"bif v15.16b, v25.16b, v18.16b \n" \ "bif v15.16b, v25.16b, v18.16b \n" \
\ \
"st1 {v15.4s}, [%[doutr3]], #16 \n" "st1 {v15.4s}, [%[doutr3]], #16 \n"
#define RIGHT_RESULT_S1_LEAKY_RELU \ #define RIGHT_RESULT_S1_LEAKY_RELU \
...@@ -1586,7 +1586,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -1586,7 +1586,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
/* r3 */ \ /* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\ \
"vld1.32 {d28-d29}, [%[six_ptr]]! @ load din r0\n" \ "vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \ "vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\ \
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
...@@ -1617,7 +1617,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -1617,7 +1617,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
/* r3 */ \ /* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\ \
"vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \ "vld1.32 {d28-d29}, [%[scale_ptr]] @ load din r0\n" \
\ \
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\ \
...@@ -1694,7 +1694,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -1694,7 +1694,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
/* r3 */ \ /* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\ \
"vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \ "vld1.32 {d28-d29}, [%[scale_ptr]] @ load din r0\n" \
\ \
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\ \
......
...@@ -2237,7 +2237,7 @@ inline void act_switch_process(float* src, ...@@ -2237,7 +2237,7 @@ inline void act_switch_process(float* src,
int cnt = size >> 4; int cnt = size >> 4;
int remain = size % 16; int remain = size % 16;
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
if (act_param != nullptr && act_param->has_active) { if (act_param != nullptr) {
float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef); float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha); float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha);
if (cnt > 0) { if (cnt > 0) {
...@@ -2327,6 +2327,7 @@ inline void act_switch_process(float* src, ...@@ -2327,6 +2327,7 @@ inline void act_switch_process(float* src,
src++; src++;
dst++; dst++;
} }
break;
case lite_api::ActivationType::kRelu6: case lite_api::ActivationType::kRelu6:
for (int i = 0; i < remain; i++) { for (int i = 0; i < remain; i++) {
float tmp = *src >= 0.f ? *src : 0.f; float tmp = *src >= 0.f ? *src : 0.f;
...@@ -2336,6 +2337,7 @@ inline void act_switch_process(float* src, ...@@ -2336,6 +2337,7 @@ inline void act_switch_process(float* src,
src++; src++;
dst++; dst++;
} }
break;
case lite_api::ActivationType::kLeakyRelu: case lite_api::ActivationType::kLeakyRelu:
for (int i = 0; i < remain; i++) { for (int i = 0; i < remain; i++) {
if (*src >= 0.f) { if (*src >= 0.f) {
......
...@@ -180,6 +180,8 @@ void conv1x1s1_gemm(const float* i_data, ...@@ -180,6 +180,8 @@ void conv1x1s1_gemm(const float* i_data,
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
int hblock = get_hblock(ctx); int hblock = get_hblock(ctx);
int m_roundup = hblock * ((m + hblock - 1) / hblock); int m_roundup = hblock * ((m + hblock - 1) / hblock);
int weights_size_per_group = m * k; int weights_size_per_group = m * k;
...@@ -223,7 +225,7 @@ void conv1x1s1_gemm(const float* i_data, ...@@ -223,7 +225,7 @@ void conv1x1s1_gemm(const float* i_data,
n, n,
bias_group, bias_group,
flag_bias, flag_bias,
flag_relu, act_param,
ctx); ctx);
} }
} }
...@@ -361,6 +363,8 @@ void conv_im2col_gemm(const float* i_data, ...@@ -361,6 +363,8 @@ void conv_im2col_gemm(const float* i_data,
int hblock = get_hblock(ctx); int hblock = get_hblock(ctx);
int m_roundup = hblock * ((m + hblock - 1) / hblock); int m_roundup = hblock * ((m + hblock - 1) / hblock);
int weights_size_per_group = m * k; int weights_size_per_group = m * k;
auto act_param = param.activation_param;
if (n > 1) { if (n > 1) {
weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; weights_size_per_group = ((m_roundup * k + 15) / 16) * 16;
} }
...@@ -422,7 +426,7 @@ void conv_im2col_gemm(const float* i_data, ...@@ -422,7 +426,7 @@ void conv_im2col_gemm(const float* i_data,
n, n,
bias_group, bias_group,
flag_bias, flag_bias,
flag_relu, act_param,
ctx); ctx);
} }
} }
......
...@@ -44,6 +44,8 @@ void conv_winograd3x3(const float* din, ...@@ -44,6 +44,8 @@ void conv_winograd3x3(const float* din,
int size_out_channel = wout * hout; int size_out_channel = wout * hout;
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
act_param.has_active = false;
//! transform input //! transform input
int tile_w = (wout + 5) / 6; int tile_w = (wout + 5) / 6;
...@@ -127,7 +129,7 @@ void conv_winograd3x3(const float* din, ...@@ -127,7 +129,7 @@ void conv_winograd3x3(const float* din,
size_tile, size_tile,
nullptr, nullptr,
false, false,
false, act_param,
ctx); ctx);
} }
......
...@@ -115,7 +115,241 @@ void fill_bias_relu<int>(int* tensor, ...@@ -115,7 +115,241 @@ void fill_bias_relu<int>(int* tensor,
} }
} }
} }
#ifdef __aarch64__
#define FILL_BIAS \
"1: \n" \
"ld1 {v0.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v1.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v2.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v3.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"add v0.4s, v0.4s, %[vbias].4s \n" \
"add v1.4s, v1.4s, %[vbias].4s \n" \
"add v2.4s, v2.4s, %[vbias].4s \n" \
"add v3.4s, v3.4s, %[vbias].4s \n"
#define FILL_RELU \
"fmax v0.4s, v0.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v1.4s, v1.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v2.4s, v2.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v3.4s, v3.4s, %[vzero].4s \n" /* vmaxq_f32() */
#define FILL_RELU6 \
"fmin v0.4s, v0.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */
#define FILL_LEAKY_RELU \
"cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v10.16b \n" /* choose*/
#define FILL_STORE \
"subs %w[cnt], %w[cnt], #1 \n" \
"st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v1.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v2.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v3.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"bne 1b \n"
#else
#define FILL_BIAS \
"1: \n" \
"vld1.32 {d6-d7}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d8-d9}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d10-d11}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d12-d13}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vadd.f32 q3, q3, %q[vbias] @ add \n" \
"vadd.f32 q4, q4, %q[vbias] @ add \n" \
"vadd.f32 q5, q5, %q[vbias] @ add \n" \
"vadd.f32 q6, q6, %q[vbias] @ add \n"
#define FILL_RELU \
"vmax.f32 q3, q3, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q4, q4, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q5, q5, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q6, q6, %q[vzero] @ vmaxq_f32() \n"
#define FILL_RELU6 \
"vmin.f32 q3, q3, %q[vsix] @ vminq_f32() \n" \
"vmin.f32 q4, q4, %q[vsix] @ vmaxq_f32() \n" \
"vmin.f32 q5, q5, %q[vsix] @ vmaxq_f32() \n" \
"vmin.f32 q6, q6, %q[vsix] @ vmaxq_f32() \n"
#define FILL_LEAKY_RELU \
"vcge.f32 q7, q3, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q8, q3, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q9, q4, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q10, q4, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q11, q5, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q12, q5, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q13, q6, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q14, q6, %q[vscale] @ vmulq_f32 \n" \
"vbif q3, q8, q7 @ choose \n" \
"vbif q4, q10, q9 @ choose \n" \
"vbif q5, q12, q11 @ choose \n" \
"vbif q6, q14, q13 @ choose \n"
#define FILL_STORE \
"subs %[cnt], #1 \n" \
"vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d8-d9}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d10-d11}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d12-d13}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"bne 1b \n"
#endif
template <>
void fill_bias_act<float>(float* tensor,
const float* bias,
int channel,
int channel_size,
bool flag_bias,
const operators::ActivationParam* act_param) {
float* data = tensor;
int cnt = channel_size >> 4;
int remain = channel_size % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
if (act_param != nullptr && act_param->has_active) {
float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha);
for (int j = 0; j < channel; j++) {
float bias_data = flag_bias ? bias[j] : 0.f;
float* src = data + j * channel_size;
float* dst = data + j * channel_size;
float32x4_t vbias = vdupq_n_f32(bias_data);
if (cnt > 0) {
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(
FILL_BIAS FILL_RELU FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vbias] "w"(vbias)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(
FILL_BIAS FILL_RELU FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vbias] "w"(vbias)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
break;
case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__
asm volatile(
FILL_BIAS FILL_RELU FILL_RELU6 FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vsix] "w"(vsix), [vbias] "w"(vbias)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(
FILL_BIAS FILL_RELU FILL_RELU6 FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vsix] "w"(vsix), [vbias] "w"(vbias)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
#ifdef __aarch64__
asm volatile(
FILL_BIAS FILL_LEAKY_RELU FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vscale] "w"(vscale), [vbias] "w"(vbias)
: "memory",
"cc",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11");
#else
asm volatile(
FILL_BIAS FILL_LEAKY_RELU FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vscale] "w"(vscale), [vbias] "w"(vbias)
: "memory",
"cc",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
}
// remain
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
for (int i = 0; i < remain; i++) {
*dst = *src >= 0.f ? *src : 0.f;
src++;
dst++;
}
case lite_api::ActivationType::kRelu6:
for (int i = 0; i < remain; i++) {
float tmp = *src >= 0.f ? *src : 0.f;
*dst = tmp <= act_param->Relu_clipped_coef
? tmp
: act_param->Relu_clipped_coef;
src++;
dst++;
}
case lite_api::ActivationType::kLeakyRelu:
for (int i = 0; i < remain; i++) {
if (*src >= 0.f) {
*dst = *src;
} else {
*dst = *src * act_param->Leaky_relu_alpha;
}
src++;
dst++;
}
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
}
} else {
for (int j = 0; j < channel; ++j) {
float bias_data = flag_bias ? bias[j] : 0.f;
float32x4_t vbias = vdupq_n_f32(bias_data);
float* src = data + j * channel_size;
float* dst = data + j * channel_size;
#ifdef __aarch64__
asm volatile(FILL_BIAS FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vbias] "w"(vbias)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(FILL_BIAS FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vbias] "w"(vbias)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
}
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -37,7 +37,22 @@ void fill_bias_relu(Dtype* tensor, ...@@ -37,7 +37,22 @@ void fill_bias_relu(Dtype* tensor,
int channel_size, int channel_size,
bool flag_bias, bool flag_bias,
bool flag_relu); bool flag_relu);
/**
* * \brief neon implementation to add bias and activation(relu, relu6,
* leakyrelu)
* * @param tensor
* * @param bias
* * @param channel
* * @param channel_size
*
*/
template <typename Dtype>
void fill_bias_act(Dtype* tensor,
const Dtype* bias,
int channel,
int channel_size,
bool flag_bias,
const operators::ActivationParam* act_param);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -383,6 +383,8 @@ struct GRUUnitFunctor { ...@@ -383,6 +383,8 @@ struct GRUUnitFunctor {
const lite_api::ActivationType active_gate, const lite_api::ActivationType active_gate,
bool origin_mode, bool origin_mode,
ARMContext* ctx) { ARMContext* ctx) {
operators::ActivationParam act_param;
act_param.has_active = false;
if (value.prev_out_value) { if (value.prev_out_value) {
sgemm(false, sgemm(false,
false, false,
...@@ -399,7 +401,7 @@ struct GRUUnitFunctor { ...@@ -399,7 +401,7 @@ struct GRUUnitFunctor {
frame_size * 3, frame_size * 3,
nullptr, nullptr,
false, false,
false, act_param,
ctx); ctx);
} }
gru_unit_reset_act(active_gate, value, frame_size, batch_size); gru_unit_reset_act(active_gate, value, frame_size, batch_size);
...@@ -420,7 +422,7 @@ struct GRUUnitFunctor { ...@@ -420,7 +422,7 @@ struct GRUUnitFunctor {
frame_size * 3, frame_size * 3,
nullptr, nullptr,
false, false,
false, act_param,
ctx); ctx);
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "lite/backends/arm/math/packed_sgemm.h" #include "lite/backends/arm/math/packed_sgemm.h"
#include <arm_neon.h> #include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -51,7 +52,7 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -51,7 +52,7 @@ void sgemm_prepacked_8x12(bool is_transB,
int ldc, int ldc,
const float *bias, const float *bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext *ctx); ARMContext *ctx);
void pack_m4(float *out, void pack_m4(float *out,
...@@ -83,7 +84,7 @@ void sgemm_prepacked_4x4(bool is_transB, ...@@ -83,7 +84,7 @@ void sgemm_prepacked_4x4(bool is_transB,
int ldc, int ldc,
const float *bias, const float *bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext *ctx); ARMContext *ctx);
#else #else
// for kA72 // for kA72
...@@ -136,7 +137,7 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -136,7 +137,7 @@ void sgemm_prepacked_6x8(bool is_transB,
int ldc, int ldc,
const float *bias, const float *bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext *ctx); ARMContext *ctx);
// for kA73, 4x8 // for kA73, 4x8
void sgemm_prepacked_4x8(bool is_transB, void sgemm_prepacked_4x8(bool is_transB,
...@@ -151,7 +152,7 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -151,7 +152,7 @@ void sgemm_prepacked_4x8(bool is_transB,
int ldc, int ldc,
const float *bias, const float *bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext *ctx); ARMContext *ctx);
#endif // __aarch64__ #endif // __aarch64__
...@@ -249,7 +250,7 @@ void sgemm_prepack(bool is_transB, ...@@ -249,7 +250,7 @@ void sgemm_prepack(bool is_transB,
int ldc, int ldc,
const float *bias, const float *bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext *ctx) { ARMContext *ctx) {
#ifdef __aarch64__ #ifdef __aarch64__
if (M <= 4) { if (M <= 4) {
...@@ -265,7 +266,7 @@ void sgemm_prepack(bool is_transB, ...@@ -265,7 +266,7 @@ void sgemm_prepack(bool is_transB,
ldc, ldc,
bias, bias,
has_bias, has_bias,
has_relu, act_param,
ctx); ctx);
} else { } else {
sgemm_prepacked_8x12(is_transB, sgemm_prepacked_8x12(is_transB,
...@@ -280,7 +281,7 @@ void sgemm_prepack(bool is_transB, ...@@ -280,7 +281,7 @@ void sgemm_prepack(bool is_transB,
ldc, ldc,
bias, bias,
has_bias, has_bias,
has_relu, act_param,
ctx); ctx);
} }
#else // armv7 #else // armv7
...@@ -297,7 +298,7 @@ void sgemm_prepack(bool is_transB, ...@@ -297,7 +298,7 @@ void sgemm_prepack(bool is_transB,
ldc, ldc,
bias, bias,
has_bias, has_bias,
has_relu, act_param,
ctx); ctx);
} else { } else {
sgemm_prepacked_6x8(is_transB, sgemm_prepacked_6x8(is_transB,
...@@ -312,7 +313,7 @@ void sgemm_prepack(bool is_transB, ...@@ -312,7 +313,7 @@ void sgemm_prepack(bool is_transB,
ldc, ldc,
bias, bias,
has_bias, has_bias,
has_relu, act_param,
ctx); ctx);
} }
#endif // arm64 #endif // arm64
...@@ -2283,7 +2284,7 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -2283,7 +2284,7 @@ void sgemm_prepacked_8x12(bool is_transB,
int ldc, int ldc,
const float *bias, const float *bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext *ctx) { ARMContext *ctx) {
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto workspace = ctx->workspace_data<float>(); auto workspace = ctx->workspace_data<float>();
...@@ -2837,33 +2838,6 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -2837,33 +2838,6 @@ void sgemm_prepacked_8x12(bool is_transB,
"fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/ "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/
"fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/ "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/
"11: \n" /* check if relu */ "11: \n" /* check if relu */
"cbz %w[relu], 12f\n" /* skip relu */
"movi v2.4s, #0\n" /* for relu*/
"fmax v8.4s, v8.4s, v2.4s\n" /* relu*/
"fmax v9.4s, v9.4s, v2.4s\n" /* relu*/
"fmax v10.4s, v10.4s, v2.4s\n" /* relu*/
"fmax v11.4s, v11.4s, v2.4s\n" /* relu*/
"fmax v12.4s, v12.4s, v2.4s\n" /* relu*/
"fmax v13.4s, v13.4s, v2.4s\n" /* relu*/
"fmax v14.4s, v14.4s, v2.4s\n" /* relu*/
"fmax v15.4s, v15.4s, v2.4s\n" /* relu*/
"fmax v16.4s,v16.4s,v2.4s\n" /* relu*/
"fmax v17.4s,v17.4s,v2.4s\n" /* relu*/
"fmax v18.4s, v18.4s, v2.4s\n" /* relu*/
"fmax v19.4s, v19.4s, v2.4s\n" /* relu*/
"fmax v20.4s, v20.4s, v2.4s\n" /* relu*/
"fmax v21.4s, v21.4s, v2.4s\n" /* relu*/
"fmax v22.4s, v22.4s, v2.4s\n" /* relu*/
"fmax v23.4s, v23.4s, v2.4s\n" /* relu*/
"fmax v24.4s,v24.4s,v2.4s\n" /* relu*/
"fmax v25.4s,v25.4s,v2.4s\n" /* relu*/
"fmax v26.4s, v26.4s, v2.4s\n" /* relu*/
"fmax v27.4s, v27.4s, v2.4s\n" /* relu*/
"fmax v28.4s, v28.4s, v2.4s\n" /* relu*/
"fmax v29.4s, v29.4s, v2.4s\n" /* relu*/
"fmax v30.4s, v30.4s, v2.4s\n" /* relu*/
"fmax v31.4s, v31.4s, v2.4s\n" /* relu*/
"12: \n"
"st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */
"st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */
"st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */
...@@ -2886,7 +2860,6 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -2886,7 +2860,6 @@ void sgemm_prepacked_8x12(bool is_transB,
[c_ptr6] "+r"(c_ptr6), [c_ptr6] "+r"(c_ptr6),
[c_ptr7] "+r"(c_ptr7) [c_ptr7] "+r"(c_ptr7)
: [bias_ptr] "r"(bias_local), : [bias_ptr] "r"(bias_local),
[relu] "r"(has_relu),
[has_beta] "r"(has_beta), [has_beta] "r"(has_beta),
[beta] "r"(beta) [beta] "r"(beta)
: "cc","memory", : "cc","memory",
...@@ -2911,6 +2884,13 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -2911,6 +2884,13 @@ void sgemm_prepacked_8x12(bool is_transB,
} }
} }
} }
if (act_param.has_active) {
#pragma omp parallel for num_threads(threads)
for (unsigned int x = 0; x < M; x++) {
float *dst = C + x * ldc;
act_switch_process(dst, dst, N, &act_param);
}
}
} }
void sgemm_prepacked_4x4(bool is_transB, void sgemm_prepacked_4x4(bool is_transB,
...@@ -2925,7 +2905,7 @@ void sgemm_prepacked_4x4(bool is_transB, ...@@ -2925,7 +2905,7 @@ void sgemm_prepacked_4x4(bool is_transB,
int ldc, int ldc,
const float *bias, const float *bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext *ctx) { ARMContext *ctx) {
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto workspace = ctx->workspace_data<float>(); auto workspace = ctx->workspace_data<float>();
...@@ -3158,13 +3138,6 @@ void sgemm_prepacked_4x4(bool is_transB, ...@@ -3158,13 +3138,6 @@ void sgemm_prepacked_4x4(bool is_transB,
"fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/ "fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/
"11: \n" /* check if relu */ "11: \n" /* check if relu */
"cbz %w[relu], 12f\n" /* skip relu */
"movi v2.4s, #0\n" /* for relu*/
"fmax v8.4s, v8.4s, v2.4s\n" /* relu*/
"fmax v9.4s, v9.4s, v2.4s\n" /* relu*/
"fmax v10.4s, v10.4s, v2.4s\n" /* relu*/
"fmax v11.4s, v11.4s, v2.4s\n" /* relu*/
"12: \n"
"st1 {v8.4s}, [%[c_ptr0]], #16\n" /* store r0 */ "st1 {v8.4s}, [%[c_ptr0]], #16\n" /* store r0 */
"st1 {v9.4s}, [%[c_ptr1]], #16\n" /* store r1 */ "st1 {v9.4s}, [%[c_ptr1]], #16\n" /* store r1 */
"st1 {v10.4s}, [%[c_ptr2]], #16\n" /* store r2 */ "st1 {v10.4s}, [%[c_ptr2]], #16\n" /* store r2 */
...@@ -3179,7 +3152,6 @@ void sgemm_prepacked_4x4(bool is_transB, ...@@ -3179,7 +3152,6 @@ void sgemm_prepacked_4x4(bool is_transB,
[c_ptr2] "+r"(c_ptr2), [c_ptr2] "+r"(c_ptr2),
[c_ptr3] "+r"(c_ptr3) [c_ptr3] "+r"(c_ptr3)
: [bias_ptr] "r"(bias_local), : [bias_ptr] "r"(bias_local),
[relu] "r"(has_relu),
[has_beta] "r"(has_beta), [has_beta] "r"(has_beta),
[beta] "r"(beta) [beta] "r"(beta)
: "cc","memory", : "cc","memory",
...@@ -3197,6 +3169,13 @@ void sgemm_prepacked_4x4(bool is_transB, ...@@ -3197,6 +3169,13 @@ void sgemm_prepacked_4x4(bool is_transB,
} }
} }
} }
if (act_param.has_active) {
#pragma omp parallel for num_threads(threads)
for (unsigned int x = 0; x < M; x++) {
float *dst = C + x * ldc;
act_switch_process(dst, dst, N, &act_param);
}
}
} }
#else // __aarch64__ #else // __aarch64__
/** /**
...@@ -3222,7 +3201,7 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -3222,7 +3201,7 @@ void sgemm_prepacked_6x8(bool is_transB,
int ldc, int ldc,
const float* bias, const float* bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto* workspace = ctx->workspace_data<float>(); auto* workspace = ctx->workspace_data<float>();
...@@ -3601,22 +3580,6 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -3601,22 +3580,6 @@ void sgemm_prepacked_6x8(bool is_transB,
"vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n"
"2: @ check relu\n" "2: @ check relu\n"
"cmp %[relu], #0 @ check if has relu\n"
"ble 6f @ skip relu if relu <= 0\n"
"vmov.u32 q0, #0 @ for relu\n"
"vmax.f32 q4, q4, q0 @ for relu\n"
"vmax.f32 q5, q5, q0 @ for relu\n"
"vmax.f32 q6, q6, q0 @ for relu\n"
"vmax.f32 q7, q7, q0 @ for relu\n"
"vmax.f32 q8, q8, q0 @ for relu\n"
"vmax.f32 q9, q9, q0 @ for relu\n"
"vmax.f32 q10, q10, q0 @ for relu\n"
"vmax.f32 q11, q11, q0 @ for relu\n"
"vmax.f32 q12, q12, q0 @ for relu\n"
"vmax.f32 q13, q13, q0 @ for relu\n"
"vmax.f32 q14, q14, q0 @ for relu\n"
"vmax.f32 q15, q15, q0 @ for relu\n"
"6: @ store result\n"
"vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n" "vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n"
"vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n" "vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n"
"vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n" "vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n"
...@@ -3634,7 +3597,6 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -3634,7 +3597,6 @@ void sgemm_prepacked_6x8(bool is_transB,
[k] "+r"(k), [k] "+r"(k),
[tails] "+r"(tails) [tails] "+r"(tails)
: [bias_ptr] "r"(bias_local), : [bias_ptr] "r"(bias_local),
[relu] "r"(has_relu),
[beta] "r"(beta) [beta] "r"(beta)
: "q0","q1","q2","q3","q4", : "q0","q1","q2","q3","q4",
"q5","q6","q7","q8","q9","q10","q11", "q5","q6","q7","q8","q9","q10","q11",
...@@ -3654,6 +3616,13 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -3654,6 +3616,13 @@ void sgemm_prepacked_6x8(bool is_transB,
} }
} }
} }
if (act_param.has_active) {
#pragma omp parallel for num_threads(threads)
for (unsigned int x = 0; x < M; x++) {
float* dst = C + x * ldc;
act_switch_process(dst, dst, N, &act_param);
}
}
} }
void sgemm_prepacked_4x8(bool is_transB, void sgemm_prepacked_4x8(bool is_transB,
...@@ -3668,7 +3637,7 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -3668,7 +3637,7 @@ void sgemm_prepacked_4x8(bool is_transB,
int ldc, int ldc,
const float* bias, const float* bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto* workspace = ctx->workspace_data<float>(); auto* workspace = ctx->workspace_data<float>();
...@@ -3953,18 +3922,6 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -3953,18 +3922,6 @@ void sgemm_prepacked_4x8(bool is_transB,
/*aptr - 16*/ /*aptr - 16*/
"sub %[a_ptr], %[a_ptr], #16 @ tail--\n" "sub %[a_ptr], %[a_ptr], #16 @ tail--\n"
"2: @ check relu\n" "2: @ check relu\n"
"cmp %[relu], #0 @ check if has relu\n"
"ble 6f @ skip relu if relu <= 0\n"
"vmov.u32 q0, #0 @ for relu\n"
"vmax.f32 q8, q8, q0 @ for relu\n"
"vmax.f32 q9, q9, q0 @ for relu\n"
"vmax.f32 q10, q10, q0 @ for relu\n"
"vmax.f32 q11, q11, q0 @ for relu\n"
"vmax.f32 q12, q12, q0 @ for relu\n"
"vmax.f32 q13, q13, q0 @ for relu\n"
"vmax.f32 q14, q14, q0 @ for relu\n"
"vmax.f32 q15, q15, q0 @ for relu\n"
"6: @ store result\n"
"vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n" "vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n"
"vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n" "vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n"
"vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n" "vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n"
...@@ -3978,7 +3935,6 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -3978,7 +3935,6 @@ void sgemm_prepacked_4x8(bool is_transB,
[k] "+r"(k), [k] "+r"(k),
[tails] "+r"(tails) [tails] "+r"(tails)
: [bias_ptr] "r"(bias_local), : [bias_ptr] "r"(bias_local),
[relu] "r"(has_relu),
[beta] "r"(beta) [beta] "r"(beta)
: "q0","q1","q2","q3", : "q0","q1","q2","q3",
"q4","q5","q6","q7","q8","q9","q10", "q4","q5","q6","q7","q8","q9","q10",
...@@ -3995,6 +3951,13 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -3995,6 +3951,13 @@ void sgemm_prepacked_4x8(bool is_transB,
} }
} }
} }
if (act_param.has_active) {
#pragma omp parallel for num_threads(threads)
for (unsigned int x = 0; x < M; x++) {
float* dst = C + x * ldc;
act_switch_process(dst, dst, N, &act_param);
}
}
} }
#endif // __aarch64__ #endif // __aarch64__
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cmath> #include <cmath>
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -74,7 +75,7 @@ void sgemm_prepack(bool is_transB, ...@@ -74,7 +75,7 @@ void sgemm_prepack(bool is_transB,
int ldc, int ldc,
const float* bias, const float* bias,
bool has_bias, bool has_bias,
bool has_relu, const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
} // namespace math } // namespace math
......
...@@ -34,7 +34,7 @@ void sgemm(bool is_transA, ...@@ -34,7 +34,7 @@ void sgemm(bool is_transA,
int ldc, int ldc,
const float* bias, const float* bias,
bool is_bias, bool is_bias,
bool is_relu, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
int hblock = get_hblock(ctx); int hblock = get_hblock(ctx);
int m_roundup = hblock * ((M + hblock - 1) / hblock); int m_roundup = hblock * ((M + hblock - 1) / hblock);
...@@ -56,7 +56,7 @@ void sgemm(bool is_transA, ...@@ -56,7 +56,7 @@ void sgemm(bool is_transA,
ldc, ldc,
bias, bias,
is_bias, is_bias,
is_relu, act_param,
ctx); ctx);
TargetFree(TargetType::kARM, packed_A); TargetFree(TargetType::kARM, packed_A);
} }
......
...@@ -39,7 +39,7 @@ void sgemm(bool is_transA, ...@@ -39,7 +39,7 @@ void sgemm(bool is_transA,
int ldc, int ldc,
const float* bias, const float* bias,
bool is_bias, bool is_bias,
bool is_relu, const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
} // namespace math } // namespace math
......
...@@ -103,6 +103,7 @@ void Conv2DTransposeCompute::Run() { ...@@ -103,6 +103,7 @@ void Conv2DTransposeCompute::Run() {
auto din = param.x->data<float>(); auto din = param.x->data<float>();
auto dout = param.output->mutable_data<float>(); auto dout = param.output->mutable_data<float>();
auto weights = param.filter->data<float>(); auto weights = param.filter->data<float>();
auto act_param = param.activation_param;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
const float* din_batch = din + i * chin * hin * win; const float* din_batch = din + i * chin * hin * win;
float* dout_batch = dout + i * chout * hout * wout; float* dout_batch = dout + i * chout * hout * wout;
...@@ -115,7 +116,9 @@ void Conv2DTransposeCompute::Run() { ...@@ -115,7 +116,9 @@ void Conv2DTransposeCompute::Run() {
const float* din_group = din_batch + g * group_size_in; const float* din_group = din_batch + g * group_size_in;
const float* weights_group = weights + g * group_size_weights; const float* weights_group = weights + g * group_size_weights;
float* coldata_group = col_data + g * group_size_coldata; float* coldata_group = col_data + g * group_size_coldata;
if (flag_bias) {
act_param.has_active = false;
}
lite::arm::math::sgemm_prepack(false, lite::arm::math::sgemm_prepack(false,
m, m,
n, n,
...@@ -128,7 +131,7 @@ void Conv2DTransposeCompute::Run() { ...@@ -128,7 +131,7 @@ void Conv2DTransposeCompute::Run() {
n, n,
nullptr, nullptr,
false, false,
fuse_relu && (!flag_bias), act_param,
&ctx); &ctx);
} }
if (!flag_1x1s1p1) { if (!flag_1x1s1p1) {
......
...@@ -94,6 +94,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -94,6 +94,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
if (flag_gemm_) { if (flag_gemm_) {
operators::ActivationParam act_param;
act_param.has_active = false;
lite::arm::math::sgemm(false, lite::arm::math::sgemm(false,
false, false,
m_, m_,
...@@ -109,7 +111,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -109,7 +111,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
n_, n_,
nullptr, nullptr,
false, false,
false, act_param,
&ctx); &ctx);
if (param.bias) { if (param.bias) {
CHECK_EQ(param.bias->numel(), n_); CHECK_EQ(param.bias->numel(), n_);
......
...@@ -42,6 +42,9 @@ void MatMulCompute::Run() { ...@@ -42,6 +42,9 @@ void MatMulCompute::Run() {
float alpha = param.alpha; float alpha = param.alpha;
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
operators::ActivationParam act_param;
act_param.has_active = false;
if (x_dims.size() > 2 && y_dims.size() >= 2) { if (x_dims.size() > 2 && y_dims.size() >= 2) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N] // x: [B, M, K], y: [K, N], out: [B, M, N]
...@@ -97,7 +100,6 @@ void MatMulCompute::Run() { ...@@ -97,7 +100,6 @@ void MatMulCompute::Run() {
if (x_transpose) { if (x_transpose) {
x_data_trans = static_cast<float*>(malloc(sizeof(float) * x_inner)); x_data_trans = static_cast<float*>(malloc(sizeof(float) * x_inner));
} }
if (y_dims.size() > 2) { if (y_dims.size() > 2) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
lite::arm::math::sgemm(x_transpose, lite::arm::math::sgemm(x_transpose,
...@@ -115,7 +117,7 @@ void MatMulCompute::Run() { ...@@ -115,7 +117,7 @@ void MatMulCompute::Run() {
ldc, ldc,
nullptr, nullptr,
false, false,
false, act_param,
&ctx); &ctx);
} }
} else { } else {
...@@ -135,7 +137,7 @@ void MatMulCompute::Run() { ...@@ -135,7 +137,7 @@ void MatMulCompute::Run() {
ldc, ldc,
nullptr, nullptr,
false, false,
false, act_param,
&ctx); &ctx);
} }
} }
...@@ -200,7 +202,7 @@ void MatMulCompute::Run() { ...@@ -200,7 +202,7 @@ void MatMulCompute::Run() {
ldc, ldc,
nullptr, nullptr,
false, false,
false, act_param,
&ctx); &ctx);
} else if (x_dims.size() > 2 && y_dims.size() == 1) { } else if (x_dims.size() > 2 && y_dims.size() == 1) {
// x: [B, M, K], y: [K], out: [B, M] // x: [B, M, K], y: [K], out: [B, M]
...@@ -254,7 +256,7 @@ void MatMulCompute::Run() { ...@@ -254,7 +256,7 @@ void MatMulCompute::Run() {
ldc, ldc,
nullptr, nullptr,
false, false,
false, act_param,
&ctx); &ctx);
} }
} }
......
...@@ -67,6 +67,8 @@ void MulCompute::Run() { ...@@ -67,6 +67,8 @@ void MulCompute::Run() {
if (is_tranposed_y) { if (is_tranposed_y) {
ldb = k_; ldb = k_;
} }
operators::ActivationParam act_param;
act_param.has_active = false;
lite::arm::math::sgemm_prepack(is_tranposed_y, lite::arm::math::sgemm_prepack(is_tranposed_y,
m_, m_,
n_, n_,
...@@ -79,7 +81,7 @@ void MulCompute::Run() { ...@@ -79,7 +81,7 @@ void MulCompute::Run() {
n_, n_,
nullptr, nullptr,
false, false,
false, act_param,
&ctx); &ctx);
} }
} }
......
...@@ -11,8 +11,6 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ ...@@ -11,8 +11,6 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
lite_cc_test(test_kernel_activation_compute SRCS activation_compute_test.cc DEPS arena_framework ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_activation_compute SRCS activation_compute_test.cc DEPS arena_framework ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv_compute SRCS conv_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
inline bool is_a_ge_zero_and_a_lt_b(int a, int b) {
return static_cast<unsigned>(a) < static_cast<unsigned>(b);
}
template <typename Dtype>
void col2im(const Dtype* data_col,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h0,
const int pad_h1,
const int pad_w0,
const int pad_w1,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
Dtype* data_im) {
memset(data_im, 0, height * width * channels * sizeof(float));
const int output_h =
(height + pad_h0 + pad_h1 - (dilation_h * (kernel_h - 1) + 1)) /
stride_h +
1;
const int output_w =
(width + pad_w0 + pad_w1 - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h0 + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
data_col += output_w;
} else {
int input_col = -pad_w0 + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
data_im[input_row * width + input_col] += *data_col;
}
data_col++;
input_col += stride_w;
}
}
input_row += stride_h;
}
}
}
}
}
template <typename Dtype>
void fill_bias_relu(Dtype* tensor,
const Dtype* bias,
int channel,
int channel_size,
bool flag_bias,
bool flag_relu);
template <>
void fill_bias_relu<float>(float* tensor,
const float* bias,
int channel,
int channel_size,
bool flag_bias,
bool flag_relu) {
float* data = tensor;
if (flag_relu) {
for (int j = 0; j < channel; ++j) {
float bias_data = flag_bias ? bias[j] : 0.f;
for (int i = 0; i < channel_size; i++) {
data[i] += bias_data;
data[i] = data[i] > 0 ? data[i] : 0.f;
}
data += channel_size;
}
} else {
for (int j = 0; j < channel; ++j) {
float bias_data = flag_bias ? bias[j] : 0.f;
for (int i = 0; i < channel_size; i++) {
data[i] += bias_data;
}
data += channel_size;
}
}
}
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const DDim data_dims,
const std::vector<int>& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
template <typename type, typename type2>
static void basic_gemm(int m,
int n,
int k,
const type* a,
const type* b,
const type2* bias,
type2* c,
type2 alpha,
type2 beta,
bool trans_a = false,
bool trans_b = false,
bool flag_bias = false,
bool flag_relu = false) {
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
type2 bias_data = (type2)0;
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
type2 sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * m + i];
} else {
av = a[i * k + l];
}
if (trans_b) {
bv = b[j * k + l];
} else {
bv = b[l * n + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * c[i * n + j] + bias_data;
if (flag_relu) {
c[i * n + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
c[i * n + j] = tmp;
}
}
}
}
//! for float, dtype1 and type2 is float
//! for int8, dytpe1 is char, dtype2 is int
template <typename Dtype1, typename Dtype2>
bool deconv_basic(const Dtype1* din,
Dtype2* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const Dtype1* weights,
const Dtype2* bias,
int group,
int kernel_w,
int kernel_h,
int stride_w,
int stride_h,
int dila_w,
int dila_h,
int pad_w0,
int pad_w1,
int pad_h0,
int pad_h1,
bool flag_bias,
bool flag_relu) {
int m = chout * kernel_w * kernel_h / group;
int n = hin * win;
int k = chin / group;
if (chin != chout || group != chin) {
CHECK_OR_FALSE(chin % group == 0);
CHECK_OR_FALSE(chout % group == 0);
}
lite::Tensor workspace_tensor;
std::vector<int64_t> wt_shape = {1, 1, 1, group * m * n};
workspace_tensor.Resize(wt_shape);
auto* workspace_ptr = workspace_tensor.mutable_data<Dtype2>();
int group_size_in = win * hin * chin / group;
int group_size_coldata = m * n;
int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group);
bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) &&
(stride_w == 1) && (pad_w0 == 0) && (pad_h0 == 0) &&
(pad_w1 == 0) && (pad_h1 == 0) && (dila_w == 1) &&
(dila_h == 1);
for (int i = 0; i < num; ++i) {
const Dtype1* din_batch = din + i * chin * hin * win;
Dtype2* dout_batch = dout + i * chout * hout * wout;
Dtype2* col_data = workspace_ptr;
if (flag_1x1s1p1) {
col_data = dout_batch;
}
memset(col_data, 0, sizeof(Dtype2) * group_size_coldata * group);
for (int g = 0; g < group; ++g) {
const Dtype1* din_group = din_batch + g * group_size_in;
const Dtype1* weights_group = weights + g * group_size_weights;
Dtype2* coldata_group = col_data + g * group_size_coldata;
basic_gemm<Dtype1, Dtype2>(m,
n,
k,
weights_group,
din_group,
nullptr,
coldata_group,
(Dtype2)1,
(Dtype2)0,
true,
false,
false,
(!flag_bias && flag_relu));
}
if (!flag_1x1s1p1) {
col2im(col_data,
chout,
hout,
wout,
kernel_h,
kernel_w,
pad_h0,
pad_h1,
pad_w0,
pad_w1,
stride_h,
stride_w,
dila_h,
dila_w,
dout_batch);
}
if (flag_bias) {
fill_bias_relu(
dout_batch, bias, chout, wout * hout, flag_bias, flag_relu);
}
}
return true;
}
class Conv2DTransposeComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string x_ = "x";
std::string output_ = "out";
std::string filter_ = "filter";
std::string bias_ = "bias";
std::string padding_algorithm_ = "";
std::vector<int> strides_{1, 1};
std::vector<int> paddings_{0, 0, 0, 0};
int groups_{1};
std::vector<int> dilations_{1, 1};
bool flag_relu_{false};
int n_ = 1;
int ic_ = 1;
int oc_ = 1;
int ih_ = 9;
int iw_ = 9;
bool flag_bias_ = false;
int ks_ = 1;
public:
Conv2DTransposeComputeTester(const Place& place,
const std::string& alias,
int n,
int ic,
int oc,
int ih,
int iw,
bool flag_bias,
bool flag_relu,
int dilation,
int stride,
int pad_h0,
int pad_h1,
int pad_w0,
int pad_w1,
int ks,
int groups,
std::string padding_algorithm)
: TestCase(place, alias) {
n_ = n;
ic_ = ic;
oc_ = oc;
ih_ = ih;
iw_ = iw;
ks_ = ks;
flag_bias_ = flag_bias;
padding_algorithm_ = padding_algorithm;
strides_ = std::vector<int>({stride, stride});
paddings_ = std::vector<int>({pad_h0, pad_h1, pad_w0, pad_w1});
dilations_ = std::vector<int>({dilation, dilation});
groups_ = groups;
flag_relu_ = flag_relu;
}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
auto* x = scope->FindTensor(x_);
auto input_dim = x->dims();
std::vector<int> ksize({1, 1, ks_, ks_});
UpdatePaddingAndDilation(&paddings_,
&dilations_,
strides_,
padding_algorithm_,
input_dim,
ksize);
int oh = (ih_ - 1) * strides_[0] - paddings_[0] - paddings_[1] +
dilations_[0] * (ks_ - 1) + 1;
int ow = (iw_ - 1) * strides_[1] - paddings_[2] - paddings_[3] +
dilations_[1] * (ks_ - 1) + 1;
CHECK(oh > 0 || ow > 0);
std::vector<int64_t> output_shape = {n_, oc_, oh, ow};
DDim output_dims(output_shape);
out->Resize(output_dims);
auto* output_data = out->mutable_data<float>();
const auto* x_data = x->data<float>();
auto* filter = scope->FindTensor(filter_);
const auto* filter_data = filter->data<float>();
const float* bias_data = nullptr;
if (flag_bias_) {
auto* bias = scope->FindTensor(bias_);
bias_data = bias->data<float>();
}
deconv_basic<float, float>(x_data,
output_data,
n_,
oc_,
oh,
ow,
ic_,
ih_,
iw_,
filter_data,
bias_data,
groups_,
ks_,
ks_,
strides_[1],
strides_[0],
dilations_[1],
dilations_[0],
paddings_[2],
paddings_[3],
paddings_[0],
paddings_[1],
flag_bias_,
flag_relu_);
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("conv2d_transpose");
op_desc->SetInput("Input", {x_});
op_desc->SetInput("Filter", {filter_});
op_desc->SetOutput("Output", {output_});
op_desc->SetAttr("strides", strides_);
op_desc->SetAttr("paddings", paddings_);
op_desc->SetAttr("groups", groups_);
op_desc->SetAttr("dilations", dilations_);
if (flag_bias_) {
op_desc->SetInput("Bias", {bias_});
}
op_desc->SetAttr("fuse_relu", flag_relu_);
op_desc->SetAttr("padding_algorithm", padding_algorithm_);
}
void PrepareData() override {
std::vector<int64_t> input_shape = {n_, ic_, ih_, iw_};
std::vector<int64_t> filter_shape = {ic_, oc_ / groups_, ks_, ks_};
std::vector<int64_t> bias_shape = {1, oc_, 1, 1};
// x tensor
DDim x_dims(input_shape);
std::vector<float> x_data(x_dims.production());
for (int i = 0; i < x_dims.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
x_data[i] = sign * static_cast<float>(i % 128) * 0.013f + 0.001;
}
SetCommonTensor(x_, x_dims, x_data.data());
// filter tensor
DDim filter_dims(filter_shape);
std::vector<float> filter_data(filter_dims.production());
for (int i = 0; i < filter_dims.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
filter_data[i] = sign * static_cast<float>(i % 128) * 0.01f + 0.001;
}
SetCommonTensor(filter_, filter_dims, filter_data.data());
// bias tensor
if (flag_bias_) {
DDim bias_dims(bias_shape);
std::vector<float> bias_data(bias_dims.production());
for (int i = 0; i < bias_dims.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
bias_data[i] = sign * static_cast<float>(i % 128) * 0.01f + 0.001;
}
SetCommonTensor(bias_, bias_dims, bias_data.data());
}
}
};
TEST(conv2d_transpose, precision) {
LOG(INFO) << "test conv2d_transpose op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {2}) {
for (auto ic : {1, 4 /*, 128*/}) {
for (auto oc : {1, 4 /*, 128*/}) {
LOG(INFO) << "n:" << n << ",ic:" << ic << ",oc:" << oc;
for (auto ih : {8, 8 /*, 56 , 112, 224, 512*/}) {
for (auto iw : {8, 16 /*, 56, 112, 224, 512*/}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto pad_h0 : {0, 1}) {
for (auto pad_h1 : {0, 1}) {
for (auto pad_w0 : {0, 1}) {
for (auto pad_w1 : {0, 1}) {
for (auto ks : {1, 4}) {
for (auto group : {1, 2}) {
for (auto padding_algorithm :
{"", "SAME", "VALID"}) {
// obtain shape
// LOG(INFO) << "n:" << n << ",ic:" << ic <<
// ",oc:" <<
// oc
// << ",ih:" << ih << ",iw:" << iw
// << ",flag_bias:" << flag_bias
// << ",flag_relu:" << flag_relu
// << ",dila:" << dilation
// << ",stride:" << stride
// << ",padding:" << padding <<
// ",ks:" << ks
// << ",group:" << group;
if (ic % group != 0 || oc % group != 0) {
group = 1;
}
std::unique_ptr<arena::TestCase> tester(
new Conv2DTransposeComputeTester(
place,
"def",
n,
ic,
oc,
ih,
iw,
flag_bias,
flag_relu,
dilation,
stride,
pad_h0,
pad_h1,
pad_w0,
pad_w1,
ks,
group,
padding_algorithm));
arena::Arena arena(
std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
...@@ -59,6 +59,7 @@ DEFINE_bool(flag_bias, false, "with bias"); ...@@ -59,6 +59,7 @@ DEFINE_bool(flag_bias, false, "with bias");
typedef paddle::lite::DDim DDim; typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam; typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in, DDim compute_out_dim(const DDim& dim_in,
...@@ -117,6 +118,13 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims, ...@@ -117,6 +118,13 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f); paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.bias, 1.f); // paddle::lite::fill_tensor_const(*param.bias, 1.f);
} }
if (flag_relu) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type =
(paddle::lite_api::ActivationType)1; // 2-relu6 4-leakyrelu
param.activation_param = act_param;
}
Tensor tmp_weights; Tensor tmp_weights;
tmp_weights.Resize(weight_dim); tmp_weights.Resize(weight_dim);
tmp_weights.CopyDataFrom(*param.filter); tmp_weights.CopyDataFrom(*param.filter);
......
...@@ -22,9 +22,11 @@ ...@@ -22,9 +22,11 @@
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/profile/timer.h" #include "lite/core/profile/timer.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/tensor_utils.h" #include "lite/tests/utils/tensor_utils.h"
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
DEFINE_int32(power_mode, DEFINE_int32(power_mode,
...@@ -136,6 +138,12 @@ bool test_sgemm(bool tra, ...@@ -136,6 +138,12 @@ bool test_sgemm(bool tra,
has_relu); has_relu);
} }
Timer t0; Timer t0;
ActivationParam act_param;
if (has_relu) {
act_param.has_active = true;
act_param.active_type =
(paddle::lite_api::ActivationType)1; // 2-relu6 4-leakyrelu
}
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
//! compute //! compute
double ops = 2.0 * m * n * k; double ops = 2.0 * m * n * k;
...@@ -163,7 +171,7 @@ bool test_sgemm(bool tra, ...@@ -163,7 +171,7 @@ bool test_sgemm(bool tra,
ldc, ldc,
dbias, dbias,
has_bias, has_bias,
has_relu, act_param,
&ctx); &ctx);
} }
...@@ -184,7 +192,7 @@ bool test_sgemm(bool tra, ...@@ -184,7 +192,7 @@ bool test_sgemm(bool tra,
ldc, ldc,
dbias, dbias,
has_bias, has_bias,
has_relu, act_param,
&ctx); &ctx);
t0.Stop(); t0.Stop();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册